{-# LANGUAGE FlexibleContexts, RecordWildCards, ScopedTypeVariables, ViewPatterns, PatternGuards #-}
module Tip.Simplify where

import Tip.Core
import Tip.Scope
import Tip.Fresh
import Data.Generics.Geniplate
import Data.List
import Data.Maybe
import Data.Monoid
import Control.Applicative
import Control.Monad
import qualified Data.Map as Map
import Tip.Writer

-- | Options for the simplifier
data SimplifyOpts a =
  SimplifyOpts {
    touch_lets    :: Bool,
    -- ^ Allow simplifications on lets
    remove_variable_scrutinee_in_branches :: Bool,
    -- ^ transform
    -- @(match x (case (K y) (e x y)))@
    -- to
    -- @(match x (case (K y) (e (K y) y))@
    -- This is useful for triggering other known-case simplifications,
    -- and is therefore on by default.
    should_inline :: Occurrences -> Maybe (Scope a) -> Expr a -> Bool,
    -- ^ Inlining predicate
    inline_match  :: Bool
    -- ^ Allow function inlining to introduce match
  }

newtype Occurrences = Occurrences Int

-- | Gentle, but without inlining
gentlyNoInline :: SimplifyOpts a
gentlyNoInline = gently { should_inline = \ _ _ _ -> False }

-- | Gentle options: if there is risk for code duplication, only inline atomic expressions
gently :: SimplifyOpts a
gently       = SimplifyOpts True True (\ (Occurrences occ) _ e -> occ <= 1 || atomic e) True

-- | Aggressive options: inline everything that might plausibly lead to simplification
aggressively :: Name a => SimplifyOpts a
aggressively = SimplifyOpts True True (\ (Occurrences occ) mscp e -> occ <= 1 || useful mscp e) True
  where
    useful _ Lam{} = True
    useful mscp (f :@: _) = isConstructor mscp f
    useful _ _ = False

-- | Simplify an entire theory
simplifyTheory :: Name a => SimplifyOpts a -> Theory a -> Fresh (Theory a)
simplifyTheory opts thy@Theory{..} = do
  thy_funcs   <- mapM (simplifyExprIn (Just thy) opts) thy_funcs
  thy_asserts <- mapM (simplifyExprIn (Just thy) opts{inline_match = False}) thy_asserts
  return Theory{..}

-- | Simplify an expression, without knowing its theory
simplifyExpr :: forall f a. (TransformBiM (WriterT Any Fresh) (Expr a) (f a), Name a) => SimplifyOpts a -> f a -> Fresh (f a)
simplifyExpr opts = simplifyExprIn Nothing opts

-- | Simplify an expression within a theory
simplifyExprIn :: forall f a. (TransformBiM (WriterT Any Fresh) (Expr a) (f a), Name a) => Maybe (Theory a) -> SimplifyOpts a -> f a -> Fresh (f a)
simplifyExprIn mthy opts@SimplifyOpts{..} = fmap fst . runWriterT . aux
  where
    {-# SPECIALISE aux :: Expr a -> WriterT Any Fresh (Expr a) #-}
    aux :: forall f. TransformBiM (WriterT Any Fresh) (Expr a) (f a) => f a -> WriterT Any Fresh (f a)
    aux = transformBiM $ \e0 ->
      let
        share e1 | e1 /= e0  = return e1
                 | otherwise = return e0 in
      case e0 of
        Builtin At :@: (Lam vars body:args) ->
          hooray $
          aux (foldr (uncurry Let) body (zip vars args))

        Let x e body | touch_lets && (atomic e || occurrences x body <= 1) ->
          lift ((e // x) body) >>= aux

        Let x e body | touch_lets && inlineable body x e ->
          do e1 <- lift ((e // x) body)
             (e2, Any simplified) <- lift (runWriterT (aux e1))
             if simplified then hooray $ return e2 else return e0

        Match _ [Case Default body] -> hooray $ return body

        Match e (Case Default (Match e' cases'):cases) | e == e' ->
          hooray $ aux $
          Match e (filter (not . dead . case_pat) cases' ++ cases)
          where
            dead (LitPat l)   = LitPat l `elem` map case_pat cases
            dead (ConPat{..}) =
              gbl_name pat_con `elem`
                [ gbl_name pat_con | ConPat{..} <- map case_pat cases ]
            dead Default = False

        Match e (Case Default def:cases)
          | TyCon ty args <- exprType e,
            Just (d, c@Constructor{..}) <- missingCase mscp ty cases -> do
              let gbl = constructor d c args
              pat <- lift (fmap (ConPat gbl) (freshArgs gbl))
              aux (Match e (Case pat def:cases))

        Match e [Case _ e1,Case (LitPat (Bool b)) e2]
          | e1 == bool (not b) && e2 == bool b -> hooray $ return e
          | e1 == bool b && e2 == bool (not b) -> hooray $ return (neg e)

        Match (Let x e body) alts | touch_lets ->
          aux (Let x e (Match body alts))

        Match e alts
          | Just e' <- tryMatch mscp e alts -> hooray $ aux e'

        Match (Lcl x) alts | remove_variable_scrutinee_in_branches ->
          Match (Lcl x) <$> sequence
          [ Case pat <$> case pat of
              ConPat g bs -> subst ((Gbl g :@: map Lcl bs) /// x) rhs
              LitPat l    -> subst (literal l /// x) rhs
              _           -> return rhs
          | Case pat rhs <- alts
          ]
          where
            subst f e = do
              (e', Any successful) <- lift (runWriterT (f e))
              if successful then aux e' else return e

        Builtin Equal :@: [Builtin (Lit (Bool x)) :@: [], t]
          | x -> hooray $ return t
          | otherwise -> hooray $ return $ neg t

        Builtin Equal :@: [t, Builtin (Lit (Bool x)) :@: []]
          | x -> hooray $ return t
          | otherwise -> hooray $ return $ neg t

        Builtin Equal :@: [litView -> Just s,litView -> Just t] -> hooray $ return (bool (s == t))

        -- cons(x,y) == nil ~> false
        -- cons(x,y) /= nil ~> true
        --
        -- cons(x1,y1) == cons(x2,y2) ~> x1==x2 & y1==y2
        -- cons(x1,y1) /= cons(x2,y2) ~> x1/=x2 | y1/=y2
        Builtin eq_op :@: [Gbl k :@: kargs,Gbl j :@: jargs]
          | Just scp <- mscp
          , Just (_,Constructor kk _ _) <- lookupConstructor (gbl_name k) scp
          , Just (_,Constructor jj _ _) <- lookupConstructor (gbl_name j) scp
          , Just res <- case (eq_op, kk == jj) of
                          (Equal   ,False) -> Just falseExpr
                          (Distinct,False) -> Just trueExpr
                          (Equal,   True)  -> Just (ands (zipWith (===) kargs jargs))
                          (Distinct,True)  -> Just (ors  (zipWith (=/=) kargs jargs))
                          _                -> Nothing
          -> hooray $ aux res

        Builtin Distinct :@: [litView -> Just s,litView -> Just t] -> hooray $ return (bool (s /= t))

        Builtin Not     :@: [e]      -> share (neg e)
        Builtin And     :@: [e1, e2] | e1 == e2  -> return e1
                                     | otherwise -> share (e1 /\ e2)
        Builtin Or      :@: [e1, e2] | e1 == e2  -> return e1
                                     | otherwise -> share (e1 \/ e2)
        Builtin Implies :@: [e1, e2] -> share (e1 ==> e2)

        Builtin Equal :@: [e1, e2] ->
          case exprType e1 of
            t@(_ :=>: _) -> hooray $ go t e1 e2 []
              where
              go (args :=>: rest) u v lcls =
                do more <- lift (mapM freshLocal args)
                   go rest (apply u (map Lcl more))
                           (apply v (map Lcl more))
                           (lcls ++ more)
              go _ u v lcls = return (mkQuant Forall lcls (u === v))
            _ -> return e0

        Gbl gbl@Global{..} :@: ts ->
          case Map.lookup gbl_name inlinings of
            Just func@Function{..}
              | and [ inlineable func_body x t | (x, t) <- zip func_args ts ] -> do
                  func_body <- boo $ aux func_body
                  e1 <-
                    transformTypeInExpr (applyType func_tvs gbl_args) <$>
                      lift (substMany (zip func_args ts) func_body)
                  (e2, Any simplified) <- lift (runWriterT (aux e1))
                  if (simplified && (inline_match || not (containsMatch e2))) || atomic func_body
                  then hooray $ return e2
                  else return (Gbl gbl :@: ts)
            _ -> return (Gbl gbl :@: ts)

        _ -> return e0

    inlineable body var val = should_inline (Occurrences (occurrences var body)) mscp val
    mscp = fmap scope mthy

    isRecursiveGroup [fun] = defines fun `elem` uses fun
    isRecursiveGroup _     = True

    inlinings =
      case mthy of
        Nothing -> Map.empty
        Just Theory{..} ->
          Map.fromList . map (\fun -> (func_name fun, fun)) .
          concat . filter (not . isRecursiveGroup) . topsort $ thy_funcs

    containsMatch e = not (null [ e' | e'@Match{} <- universe e ])

    new /// old = transformExprM $ \e ->
      if e == Lcl old then hooray $ lift (freshen new) else return e

    hooray x = do
      tell (Any True)
      x

    boo x = censor (const (Any False)) x

isConstructor :: Name a => Maybe (Scope a) -> Head a -> Bool
isConstructor _ (Builtin Lit{}) = True
isConstructor mscp (Gbl gbl) = isJust $ do
  scp <- mscp
  lookupConstructor (gbl_name gbl) scp
isConstructor _ _ = False

missingCase :: Name a => Maybe (Scope a) -> a -> [Case a] -> Maybe (Datatype a, Constructor a)
missingCase mscp tc cases = do
  scp <- mscp
  dt@Datatype{..} <- lookupDatatype tc scp
  let matched Constructor{..} =
        con_name `elem` [ gbl_name pat_con | ConPat{..} <- map case_pat cases ]
  case filter (not . matched) data_cons of
    [con] -> return (dt, con)
    _     -> Nothing

tryMatch :: Name a => Maybe (Scope a) -> Expr a -> [Case a] -> Maybe (Expr a)
tryMatch mscp (hd :@: args) alts | isConstructor mscp hd =
  -- We use reverse because the default case comes first and we want it last
  case filter (matches hd . case_pat) (reverse alts) of
    [] -> Nothing
    Case (ConPat _ lcls) body:_ ->
      Just $ foldr (uncurry Let) body (zip lcls args)
    Case _ body:_ -> Just body
  where
    matches (Gbl gbl) (ConPat gbl' _) = gbl == gbl'
    matches (Builtin (Lit lit)) (LitPat lit') = lit == lit'
    matches _ Default = True
    matches _ _ = False
tryMatch _ _ _ = Nothing