{-|
  Copyright   :  (C) 2017, Google Inc.
  License     :  BSD2 (see the file LICENSE)
  Maintainer  :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Call-by-need evaluator based on the evaluator described in:

  Maximilian Bolingbroke, Simon Peyton Jones, "Supercompilation by evaluation",
  Haskell '10, Baltimore, Maryland, USA.

-}

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Core.Evaluator where

import           Prelude                                 hiding (lookup)

import           Control.Concurrent.Supply               (Supply, freshId)
import           Data.Either                             (lefts,rights)
import           Data.List                               (foldl',mapAccumL)
import           Data.Maybe                              (fromMaybe)
import qualified Data.Primitive.ByteArray                as BA
import qualified Data.Text as Text
import qualified Data.Vector.Primitive                   as PV
import           GHC.Integer.GMP.Internals
  (Integer (..), BigNat (..))

import           Clash.Core.DataCon
import           Clash.Core.Evaluator.Types
import           Clash.Core.FreeVars
import           Clash.Core.Literal
import           Clash.Core.Name
import           Clash.Core.Pretty
import           Clash.Core.Subst
import           Clash.Core.Term
import           Clash.Core.TermInfo
import           Clash.Core.TyCon
import           Clash.Core.Type
import           Clash.Core.Util
import           Clash.Core.Var
import           Clash.Core.VarEnv
import           Clash.Debug
import           Clash.Driver.Types                      (BindingMap, Binding(..))
import           Clash.Pretty
import           Clash.Unique
import           Clash.Util                              (curLoc)

isUndefinedPrimVal :: Value -> Bool
isUndefinedPrimVal (PrimVal (PrimInfo{primName}) _ _) =
  primName == "Clash.Transformations.undefined"
isUndefinedPrimVal _ = False

whnf'
  :: PrimStep
  -> PrimUnwind
  -> BindingMap
  -> TyConMap
  -> PrimHeap
  -> Supply
  -> InScopeSet
  -> Bool
  -> Term
  -> (PrimHeap, PureHeap, Term)
whnf' eval fu bm tcm ph ids is isSubj e =
  toResult $ whnf tcm isSubj m
 where
  toResult x = (mHeapPrim x, mHeapLocal x, mTerm x)

  m  = Machine eval fu ph gh emptyVarEnv [] ids is e
  gh = mapVarEnv bindingTerm bm

-- | Evaluate to WHNF given an existing Heap and Stack
whnf
  :: TyConMap
  -> Bool
  -> Machine
  -> Machine
whnf tcm isSubj m
  | isSubj =
      -- See [Note: empty case expressions]
      let ty = termType tcm (mTerm m)
       in go (stackPush (Scrutinise ty []) m)
  | otherwise = go m
  where
    go s = case step s tcm of
      Just s' -> go s'
      Nothing -> fromMaybe (error . showDoc . ppr $ mTerm m) (unwindStack s)

-- | Completely unwind the stack to get back the complete term
unwindStack :: Machine -> Maybe Machine
unwindStack m
  | stackNull m = Just m
  | otherwise = do
      (m', kf) <- stackPop m

      case kf of
        PrimApply p tys vs tms ->
          let term = foldl' App
                       (foldl' App
                         (foldl' TyApp (Prim p) tys)
                         (fmap valToTerm vs))
                       (mTerm m' : tms)
           in unwindStack (setTerm term m')

        Instantiate ty ->
          let term = TyApp (getTerm m') ty
           in unwindStack (setTerm term m')

        Apply n ->
          case heapLookup LocalId n m' of
            Just e ->
              let term = App (getTerm m') e
               in unwindStack (setTerm term m')

            Nothing -> error $ unlines $
              [ "Clash.Core.Evaluator.unwindStack:"
              , "Stack:"
              ] <>
              [ "  " <> showDoc (clashPretty frame) | frame <- mStack m] <>
              [ ""
              , "Expression:"
              , showPpr (mTerm m)
              , ""
              , "Heap:"
              , showDoc (clashPretty $ mHeapLocal m)
              ]

        Scrutinise _ [] ->
          unwindStack m'

        Scrutinise ty alts ->
          let term = Case (getTerm m') ty alts
           in unwindStack (setTerm term m')

        Update LocalId x ->
          unwindStack (heapInsert LocalId x (mTerm m') m')

        Update GlobalId _ ->
          unwindStack m'

        Tickish sp ->
          let term = Tick sp (getTerm m')
           in unwindStack (setTerm term m')

-- | A single step in the partial evaluator. The result is the new heap and
-- stack, and the next expression to be reduced.
--
type Step = Machine -> TyConMap -> Maybe Machine

stepVar :: Id -> Step
stepVar i m _
  | Just e <- heapLookup LocalId i m
  = go LocalId e

  | Just e <- heapLookup GlobalId i m
  , isGlobalId i
  = go GlobalId e

  | otherwise
  = Nothing
 where
  go s e =
    let term = deShadowTerm (mScopeNames m) (tickExpr e)
     in Just . setTerm term . stackPush (Update s i) $ heapDelete s i m

  -- Removing the heap-bound value on a force ensures we do not get stuck on
  -- expressions such as: "let x = x in x"
  tickExpr = Tick (NameMod PrefixName (LitTy . SymTy $ toStr i))
  unQualName = snd . Text.breakOnEnd "."
  toStr = Text.unpack . unQualName . flip Text.snoc '_' . nameOcc . varName

stepData :: DataCon -> Step
stepData dc m tcm = unwind tcm m (DC dc [])

stepLiteral :: Literal -> Step
stepLiteral l m tcm = unwind tcm m (Lit l)

stepPrim :: PrimInfo -> Step
stepPrim pInfo m tcm
  | primName pInfo == "GHC.Prim.realWorld#" =
      unwind tcm m (PrimVal pInfo [] [])

  | otherwise =
      case fst $ splitFunForallTy (primType pInfo) of
        []  -> mPrimStep m tcm (forcePrims m) pInfo [] [] m
        tys -> newBinder tys (Prim pInfo) m tcm

stepLam :: Id -> Term -> Step
stepLam x e m tcm = unwind tcm m (Lambda x e)

stepTyLam :: TyVar -> Term -> Step
stepTyLam x e m tcm = unwind tcm m (TyLambda x e)

stepApp :: Term -> Term -> Step
stepApp x y m tcm =
  case term of
    Data dc ->
      let tys = fst $ splitFunForallTy (dcType dc)
       in case compare (length args) (length tys) of
            EQ -> unwind tcm m (DC dc args)
            LT -> newBinder tys' (App x y) m tcm
            GT -> error "Overapplied DC"

    Prim p ->
      let tys = fst $ splitFunForallTy (primType p)
       in case compare (length args) (length tys) of
            EQ -> case lefts args of
              -- We make boolean conjunction and disjunction extra lazy by
              -- deferring the evaluation of the arguments during the evaluation
              -- of the primop rule.
              --
              -- This allows us to implement:
              --
              -- x && True  --> x
              -- x && False --> False
              -- x || True  --> True
              -- x || False --> x
              --
              -- even when that 'x' is _|_. This makes the evaluation
              -- rule lazier than the actual Haskel implementations which
              -- are strict in the first argument and lazy in the second.
              [a0, a1] | primName p `elem` ["GHC.Classes.&&","GHC.Classes.||"] ->
                    let (m0,i) = newLetBinding tcm m  a0
                        (m1,j) = newLetBinding tcm m0 a1
                    in  mPrimStep m tcm (forcePrims m) p [] [Suspend (Var i), Suspend (Var j)] m1

              (e':es) ->
                Just . setTerm e' $ stackPush (PrimApply p (rights args) [] es) m

              _ -> error "internal error"

            LT -> newBinder tys' (App x y) m tcm

            GT -> let (m0, n) = newLetBinding tcm m y
                   in Just . setTerm x $ stackPush (Apply n) m0

    _ -> let (m0, n) = newLetBinding tcm m y
          in Just . setTerm x $ stackPush (Apply n) m0
 where
  (term, args, _) = collectArgsTicks (App x y)
  tys' = fst . splitFunForallTy . termType tcm $ App x y

stepTyApp :: Term -> Type -> Step
stepTyApp x ty m tcm =
  case term of
    Data dc ->
      let tys = fst $ splitFunForallTy (dcType dc)
       in case compare (length args) (length tys) of
            EQ -> unwind tcm m (DC dc args)
            LT -> newBinder tys' (TyApp x ty) m tcm
            GT -> error "Overapplied DC"

    Prim p ->
      let tys = fst $ splitFunForallTy (primType p)
       in case compare (length args) (length tys) of
            EQ -> case lefts args of
                    [] | primName p `elem` [ "Clash.Transformations.removedArg"
                                           , "Clash.Transformations.undefined" ] ->
                            unwind tcm m (PrimVal p (rights args) [])

                       | otherwise ->
                            mPrimStep m tcm (forcePrims m) p (rights args) [] m

                    (e':es) ->
                      Just . setTerm e' $ stackPush (PrimApply p (rights args) [] es) m

            LT -> newBinder tys' (TyApp x ty) m tcm
            GT -> Just . setTerm x $ stackPush (Instantiate ty) m

    _ -> Just . setTerm x $ stackPush (Instantiate ty) m
 where
  (term, args, _) = collectArgsTicks (TyApp x ty)
  tys' = fst . splitFunForallTy . termType tcm $ TyApp x ty

stepLetRec :: [LetBinding] -> Term -> Step
stepLetRec bs x m _ = Just (allocate bs x m)

stepCase :: Term -> Type -> [Alt] -> Step
stepCase scrut ty alts m _ =
  Just . setTerm scrut $ stackPush (Scrutinise ty alts) m

-- TODO Support stepwise evaluation of casts.
--
stepCast :: Term -> Type -> Type -> Step
stepCast _ _ _ _ _ =
  flip trace Nothing $ unlines
    [ "WARNING: " <> $(curLoc) <> "Clash can't symbolically evaluate casts"
    , "Please file an issue at https://github.com/clash-lang/clash-compiler/issues"
    ]

stepTick :: TickInfo -> Term -> Step
stepTick tick x m _ =
  Just . setTerm x $ stackPush (Tickish tick) m

-- | Small-step operational semantics.
--
step :: Step
step m = case mTerm m of
  Var i -> stepVar i m
  Data dc -> stepData dc m
  Literal l -> stepLiteral l m
  Prim p -> stepPrim p m
  Lam v x -> stepLam v x m
  TyLam v x -> stepTyLam v x m
  App x y -> stepApp x y m
  TyApp x ty -> stepTyApp x ty m
  Letrec bs x -> stepLetRec bs x m
  Case s ty as -> stepCase s ty as m
  Cast x a b -> stepCast x a b m
  Tick t x -> stepTick t x m

-- | Take a list of types or type variables and create a lambda / type lambda
-- for each one around the given term.
--
newBinder :: [Either TyVar Type] -> Term -> Step
newBinder tys x m tcm =
  let (s', iss', x') = mkAbstr (mSupply m, mScopeNames m, x) tys
      m' = m { mSupply = s', mScopeNames = iss', mTerm = x' }
   in step m' tcm
 where
  mkAbstr = foldr go
    where
      go (Left tv) (s', iss', e') =
        (s', iss', TyLam tv (TyApp e' (VarTy tv)))

      go (Right ty) (s', iss', e') =
        let ((s'', _), n) = mkUniqSystemId (s', iss') ("x", ty)
        in  (s'', iss' ,Lam n (App e' (Var n)))

newLetBinding
  :: TyConMap
  -> Machine
  -> Term
  -> (Machine, Id)
newLetBinding tcm m e
  | Var v <- e
  , heapContains LocalId v m
  = (m, v)

  | otherwise
  = let m' = heapInsert LocalId id_ e m
     in (m' { mSupply = ids', mScopeNames = is1 }, id_)
 where
  ty = termType tcm e
  ((ids', is1), id_) = mkUniqSystemId (mSupply m, mScopeNames m) ("x", ty)

-- | Unwind the stack by 1
unwind
  :: TyConMap
  -> Machine
  -> Value
  -> Maybe Machine
unwind tcm m v = do
  (m', kf) <- stackPop m
  go kf m'
 where
  go (Update s x)             = return . update s x v
  go (Apply x)                = return . apply tcm v x
  go (Instantiate ty)         = return . instantiate tcm v ty
  go (PrimApply p tys vs tms) = mPrimUnwind m tcm p tys vs v tms
  go (Scrutinise altTy as)    = return . scrutinise v altTy as
  go (Tickish _)              = return . setTerm (valToTerm v)

-- | Update the Heap with the evaluated term
update :: IdScope -> Id -> Value -> Machine -> Machine
update s x (valToTerm -> term) =
  setTerm term . heapInsert s x term

-- | Apply a value to a function
apply :: TyConMap -> Value -> Id -> Machine -> Machine
apply _tcm (Lambda x' e) x m =
  setTerm (substTm "Evaluator.apply" subst e) m
 where
  subst  = extendIdSubst subst0 x' (Var x)
  subst0 = mkSubst $ extendInScopeSet (mScopeNames m) x
apply tcm pVal@(PrimVal (PrimInfo{primType}) tys []) x m
  | isUndefinedPrimVal pVal
  = setTerm (undefinedTm (piResultTys tcm primType (tys++[varType x]))) m

apply _ _ _ _ = error "Evaluator.apply: Not a lambda"

-- | Instantiate a type-abstraction
instantiate :: TyConMap -> Value -> Type -> Machine -> Machine
instantiate _tcm (TyLambda x e) ty m =
  setTerm (substTm "Evaluator.instantiate" subst e) m
 where
  subst  = extendTvSubst subst0 x ty
  subst0 = mkSubst iss0
  iss0   = mkInScopeSet (localFVsOfTerms [e] `unionUniqSet` tyFVsOfTypes [ty])
instantiate tcm pVal@(PrimVal (PrimInfo{primType}) tys []) ty m
  | isUndefinedPrimVal pVal
  = setTerm (undefinedTm (piResultTys tcm primType (tys ++ [ty]))) m

instantiate _ _ _ _ = error "Evaluator.instantiate: Not a tylambda"

-- | Evaluate a case-expression
scrutinise :: Value -> Type -> [Alt] -> Machine -> Machine
scrutinise v _altTy [] m = setTerm (valToTerm v) m
-- [Note: empty case expressions]
--
-- Clash does not have empty case-expressions; instead, empty case-expressions
-- are used to indicate that the `whnf` function was called the context of a
-- case-expression, which means certain special primitives must be forced.
-- See also [Note: forcing special primitives]
scrutinise (Lit l) _altTy alts m = case alts of
  (DefaultPat, altE):alts1 -> setTerm (go altE alts1) m
  _ -> let term = go (error $ "Evaluator.scrutinise: no match "
                    <> showPpr (Case (valToTerm (Lit l)) (ConstTy Arrow) alts)) alts
        in setTerm term m
 where
  go def [] = def
  go _ ((LitPat l1,altE):_) | l1 == l = altE
  go _ ((DataPat dc [] [x],altE):_)
    | IntegerLiteral l1 <- l
    , Just patE <- case dcTag dc of
       1 | l1 >= ((-2)^(63::Int)) &&  l1 < 2^(63::Int) ->
          Just (IntLiteral l1)
       2 | l1 >= (2^(63::Int)) ->
          let !(Jp# !(BN# ba0)) = l1
              ba1 = BA.ByteArray ba0
              bv = PV.Vector 0 (BA.sizeofByteArray ba1) ba1
          in  Just (ByteArrayLiteral bv)
       3 | l1 < ((-2)^(63::Int)) ->
          let !(Jn# !(BN# ba0)) = l1
              ba1 = BA.ByteArray ba0
              bv = PV.Vector 0 (BA.sizeofByteArray ba1) ba1
          in  Just (ByteArrayLiteral bv)
       _ -> Nothing
    = let inScope = localFVsOfTerms [altE]
          subst0  = mkSubst (mkInScopeSet inScope)
          subst1  = extendIdSubst subst0 x (Literal patE)
      in  substTm "Evaluator.scrutinise" subst1 altE
    | NaturalLiteral l1  <- l
    , Just patE <- case dcTag dc of
       1 | l1 >= 0 &&  l1 < 2^(64::Int) ->
          Just (WordLiteral l1)
       2 | l1 >= (2^(64::Int)) ->
          let !(Jp# !(BN# ba0)) = l1
              ba1 = BA.ByteArray ba0
              bv = PV.Vector 0 (BA.sizeofByteArray ba1) ba1
          in  Just (ByteArrayLiteral bv)
       _ -> Nothing
    = let inScope = localFVsOfTerms [altE]
          subst0  = mkSubst (mkInScopeSet inScope)
          subst1  = extendIdSubst subst0 x (Literal patE)
      in  substTm "Evaluator.scrutinise" subst1 altE
  go def (_:alts1) = go def alts1

scrutinise (DC dc xs) _altTy alts m
  | altE:_ <- [substInAlt altDc tvs pxs xs altE
              | (DataPat altDc tvs pxs,altE) <- alts, altDc == dc ] ++
              [altE | (DefaultPat,altE) <- alts ]
  = setTerm altE m

scrutinise v@(PrimVal p _ vs) altTy alts m
  | isUndefinedPrimVal v
  = setTerm (undefinedTm altTy) m

  | any (\case {(LitPat {},_) -> True; _ -> False}) alts
  = case alts of
      ((DefaultPat,altE):alts1) -> setTerm (go altE alts1) m
      _ -> let term = go (error $ "Evaluator.scrutinise: no match "
                        <> showPpr (Case (valToTerm v) (ConstTy Arrow) alts)) alts
            in setTerm term m
 where
  go def [] = def
  go _   ((LitPat l1,altE):_) | l1 == l = altE
  go def (_:alts1) = go def alts1

  l = case primName p of
        "Clash.Sized.Internal.BitVector.fromInteger#"
          | [_,Lit (IntegerLiteral 0),Lit l0] <- vs -> l0
        "Clash.Sized.Internal.Index.fromInteger#"
          | [_,Lit l0] <- vs -> l0
        "Clash.Sized.Internal.Signed.fromInteger#"
          | [_,Lit l0] <- vs -> l0
        "Clash.Sized.Internal.Unsigned.fromInteger#"
          | [_,Lit l0] <- vs -> l0
        _ -> error ("scrutinise: " ++ showPpr (Case (valToTerm v) (ConstTy Arrow) alts))

scrutinise v _ alts _ =
  error ("scrutinise: " ++ showPpr (Case (valToTerm v) (ConstTy Arrow) alts))

substInAlt :: DataCon -> [TyVar] -> [Id] -> [Either Term Type] -> Term -> Term
substInAlt dc tvs xs args e = substTm "Evaluator.substInAlt" subst e
 where
  tys        = rights args
  tms        = lefts args
  substTyMap = zip tvs (drop (length (dcUnivTyVars dc)) tys)
  substTmMap = zip xs tms
  inScope    = tyFVsOfTypes tys `unionVarSet` localFVsOfTerms (e:tms)
  subst      = extendTvSubstList (extendIdSubstList subst0 substTmMap) substTyMap
  subst0     = mkSubst (mkInScopeSet inScope)

-- | Allocate let-bindings on the heap
allocate :: [LetBinding] -> Term -> Machine -> Machine
allocate xes e m =
  m { mHeapLocal = extendVarEnvList (mHeapLocal m) xes'
    , mSupply = ids'
    , mScopeNames = isN
    , mTerm = e'
    }
 where
  xNms      = fmap fst xes
  is1       = extendInScopeSetList (mScopeNames m) xNms
  (ids', s) = mapAccumL (letSubst (mHeapLocal m)) (mSupply m) xNms
  (nms, s') = unzip s
  isN       = extendInScopeSetList is1 nms
  subst     = extendIdSubstList subst0 s'
  subst0    = mkSubst (foldl' extendInScopeSet is1 nms)
  xes'      = zip nms (fmap (substTm "Evaluator.allocate0" subst . snd) xes)
  e'        = substTm "Evaluator.allocate1" subst e

-- | Create a unique name and substitution for a let-binder
letSubst
  :: PureHeap
  -> Supply
  -> Id
  -> (Supply, (Id, (Id, Term)))
letSubst h acc id0 =
  let (acc',id1) = mkUniqueHeapId h acc id0
  in  (acc',(id1,(id0,Var id1)))
 where
  mkUniqueHeapId :: PureHeap -> Supply -> Id -> (Supply, Id)
  mkUniqueHeapId h' ids x =
    maybe (ids', x') (const $ mkUniqueHeapId h' ids' x) (lookupVarEnv x' h')
   where
    (i,ids') = freshId ids
    x'       = modifyVarName (`setUnique` i) x