{-# LANGUAGE OverloadedStrings #-}

{-|
  Copyright     : (C) 2020, QBayLogic B.V.
  License       : BSD2 (see the file LICENSE)
  Maintainer    : Christiaan Baaij <christiaan.baaij@gmail.com>

  Types for the Partial Evaluator
-}
module Clash.Core.Evaluator.Types where

import Control.Concurrent.Supply (Supply)
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap (insert, lookup)
import Data.List (foldl')
import Data.Maybe (isJust)
import Data.Text.Prettyprint.Doc (hsep)

import Clash.Core.DataCon (DataCon)
import Clash.Core.Literal (Literal(CharLiteral))
import Clash.Core.Pretty (fromPpr)
import Clash.Core.Term (Term(..), PrimInfo(..), TickInfo, Alt)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (Type)
import Clash.Core.Var (Id, IdScope(..), TyVar)
import Clash.Core.VarEnv
import Clash.Pretty (ClashPretty(..), fromPretty)

{- [Note: forcing special primitives]
Clash uses the `whnf` function in two places (for now):

  1. The case-of-known-constructor transformation
  2. The reduceConstant transformation

The first transformation is needed to reach the required normal form. The
second transformation is more of cleanup transformation, so non-essential.

Normally, `whnf` would force the evaluation of all primitives, which is needed
in the `case-of-known-constructor` transformation. However, there are some
primitives which we want to leave unevaluated in the `reduceConstant`
transformation. Such primitives are:

  - Primitives such as `Clash.Sized.Vector.transpose`, `Clash.Sized.Vector.map`,
    etc. that do not reduce to an expression in normal form. Where the
    `reduceConstant` transformation is supposed to be normal-form preserving.
  - Primitives such as `GHC.Int.I8#`, `GHC.Word.W32#`, etc. which seem like
    wrappers around a 64-bit literal, but actually perform truncation to the
    desired bit-size.

This is why the Primitive Evaluator gets a flag telling whether it should
evaluate these special primitives.
-}

type PrimStep
  =  TyConMap
  -> Bool
  -> PrimInfo
  -> [Type]
  -> [Value]
  -> Machine
  -> Maybe Machine

type PrimUnwind
  =  TyConMap
  -> PrimInfo
  -> [Type]
  -> [Value]
  -> Value
  -> [Term]
  -> Machine
  -> Maybe Machine

type PrimEvaluator = (PrimStep, PrimUnwind)

data Machine = Machine
  { Machine -> PrimStep
mPrimStep   :: PrimStep
  , Machine -> PrimUnwind
mPrimUnwind :: PrimUnwind
  , Machine -> PrimHeap
mHeapPrim   :: PrimHeap
  , Machine -> PureHeap
mHeapGlobal :: PureHeap
  , Machine -> PureHeap
mHeapLocal  :: PureHeap
  , Machine -> Stack
mStack      :: Stack
  , Machine -> Supply
mSupply     :: Supply
  , Machine -> InScopeSet
mScopeNames :: InScopeSet
  , Machine -> Term
mTerm       :: Term
  }

instance Show Machine where
  show :: Machine -> String
show (Machine _ _ ph :: PrimHeap
ph gh :: PureHeap
gh lh :: PureHeap
lh s :: Stack
s _ _ x :: Term
x) =
    [String] -> String
unlines
      [ "Machine:"
      , ""
      , "Heap (Prim):"
      , PrimHeap -> String
forall a. Show a => a -> String
show PrimHeap
ph
      , ""
      , "Heap (Globals):"
      , PureHeap -> String
forall a. Show a => a -> String
show PureHeap
gh
      , ""
      , "Heap (Locals):"
      , PureHeap -> String
forall a. Show a => a -> String
show PureHeap
lh
      , ""
      , "Stack:"
      , [Doc ()] -> String
forall a. Show a => a -> String
show ((StackFrame -> Doc ()) -> Stack -> [Doc ()]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap StackFrame -> Doc ()
forall a. ClashPretty a => a -> Doc ()
clashPretty Stack
s)
      , ""
      , "Term:"
      , Term -> String
forall a. Show a => a -> String
show Term
x
      ]

type PrimHeap = (IntMap Term, Int)
type PureHeap = VarEnv Term

type Stack = [StackFrame]

data StackFrame
  = Update IdScope Id
  | Apply  Id
  | Instantiate Type
  | PrimApply  PrimInfo [Type] [Value] [Term]
  | Scrutinise Type [Alt]
  | Tickish TickInfo
  deriving Int -> StackFrame -> ShowS
Stack -> ShowS
StackFrame -> String
(Int -> StackFrame -> ShowS)
-> (StackFrame -> String) -> (Stack -> ShowS) -> Show StackFrame
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: Stack -> ShowS
$cshowList :: Stack -> ShowS
show :: StackFrame -> String
$cshow :: StackFrame -> String
showsPrec :: Int -> StackFrame -> ShowS
$cshowsPrec :: Int -> StackFrame -> ShowS
Show

instance ClashPretty StackFrame where
  clashPretty :: StackFrame -> Doc ()
clashPretty (Update GlobalId i :: Id
i) = [Doc ()] -> Doc ()
forall ann. [Doc ann] -> Doc ann
hsep ["Update(Global)", Id -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr Id
i]
  clashPretty (Update LocalId i :: Id
i)  = [Doc ()] -> Doc ()
forall ann. [Doc ann] -> Doc ann
hsep ["Update(Local)", Id -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr Id
i]
  clashPretty (Apply i :: Id
i) = [Doc ()] -> Doc ()
forall ann. [Doc ann] -> Doc ann
hsep ["Apply", Id -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr Id
i]
  clashPretty (Instantiate t :: Type
t) = [Doc ()] -> Doc ()
forall ann. [Doc ann] -> Doc ann
hsep ["Instantiate", Type -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr Type
t]
  clashPretty (PrimApply p :: PrimInfo
p tys :: [Type]
tys vs :: [Value]
vs ts :: [Term]
ts) =
    [Doc ()] -> Doc ()
forall ann. [Doc ann] -> Doc ann
hsep ["PrimApply", Text -> Doc ()
forall a. Pretty a => a -> Doc ()
fromPretty (PrimInfo -> Text
primName PrimInfo
p), "::", Type -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr (PrimInfo -> Type
primType PrimInfo
p),
          "; type args=", [Type] -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr [Type]
tys,
          "; val args=", [Term] -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr ((Value -> Term) -> [Value] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Value -> Term
valToTerm [Value]
vs),
          "term args=", [Term] -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr [Term]
ts]
  clashPretty (Scrutinise a :: Type
a b :: [Alt]
b) =
    [Doc ()] -> Doc ()
forall ann. [Doc ann] -> Doc ann
hsep ["Scrutinise ", Type -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr Type
a,
          Term -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr (Term -> Type -> [Alt] -> Term
Case (Literal -> Term
Literal (Char -> Literal
CharLiteral '_')) Type
a [Alt]
b)]
  clashPretty (Tickish sp :: TickInfo
sp) =
    [Doc ()] -> Doc ()
forall ann. [Doc ann] -> Doc ann
hsep ["Tick", TickInfo -> Doc ()
forall a. PrettyPrec a => a -> Doc ()
fromPpr TickInfo
sp]

-- Values
data Value
  = Lambda Id Term
  -- ^ Functions
  | TyLambda TyVar Term
  -- ^ Type abstractions
  | DC DataCon [Either Term Type]
  -- ^ Data constructors
  | Lit Literal
  -- ^ Literals
  | PrimVal  PrimInfo [Type] [Value]
  -- ^ Clash's number types are represented by their "fromInteger#" primitive
  -- function. So some primitives are values.
  | Suspend Term
  -- ^ Used by lazy primitives
  | TickValue TickInfo Value
  -- ^ Preserve ticks from Terms in Values
  | CastValue Value Type Type
  -- ^ Preserve casts from Terms in Values
  deriving Int -> Value -> ShowS
[Value] -> ShowS
Value -> String
(Int -> Value -> ShowS)
-> (Value -> String) -> ([Value] -> ShowS) -> Show Value
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Value] -> ShowS
$cshowList :: [Value] -> ShowS
show :: Value -> String
$cshow :: Value -> String
showsPrec :: Int -> Value -> ShowS
$cshowsPrec :: Int -> Value -> ShowS
Show

valToTerm :: Value -> Term
valToTerm :: Value -> Term
valToTerm v :: Value
v = case Value
v of
  Lambda x :: Id
x e :: Term
e           -> Id -> Term -> Term
Lam Id
x Term
e
  TyLambda x :: TyVar
x e :: Term
e         -> TyVar -> Term -> Term
TyLam TyVar
x Term
e
  DC dc :: DataCon
dc pxs :: [Either Term Type]
pxs            -> (Term -> Either Term Type -> Term)
-> Term -> [Either Term Type] -> Term
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\e :: Term
e a :: Either Term Type
a -> (Term -> Term) -> (Type -> Term) -> Either Term Type -> Term
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Term -> Term -> Term
App Term
e) (Term -> Type -> Term
TyApp Term
e) Either Term Type
a)
                                 (DataCon -> Term
Data DataCon
dc) [Either Term Type]
pxs
  Lit l :: Literal
l                -> Literal -> Term
Literal Literal
l
  PrimVal ty :: PrimInfo
ty tys :: [Type]
tys vs :: [Value]
vs    -> (Term -> Term -> Term) -> Term -> [Term] -> Term
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Term -> Term -> Term
App ((Term -> Type -> Term) -> Term -> [Type] -> Term
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Term -> Type -> Term
TyApp (PrimInfo -> Term
Prim PrimInfo
ty) [Type]
tys)
                                 ((Value -> Term) -> [Value] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Value -> Term
valToTerm [Value]
vs)
  Suspend e :: Term
e            -> Term
e
  TickValue t :: TickInfo
t x :: Value
x        -> TickInfo -> Term -> Term
Tick TickInfo
t (Value -> Term
valToTerm Value
x)
  CastValue x :: Value
x t1 :: Type
t1 t2 :: Type
t2    -> Term -> Type -> Type -> Term
Cast (Value -> Term
valToTerm Value
x) Type
t1 Type
t2

-- Collect all the ticks from a value, exposing the ticked value.
--
collectValueTicks
  :: Value
  -> (Value, [TickInfo])
collectValueTicks :: Value -> (Value, [TickInfo])
collectValueTicks = [TickInfo] -> Value -> (Value, [TickInfo])
go []
 where
  go :: [TickInfo] -> Value -> (Value, [TickInfo])
go ticks :: [TickInfo]
ticks (TickValue t :: TickInfo
t v :: Value
v) = [TickInfo] -> Value -> (Value, [TickInfo])
go (TickInfo
tTickInfo -> [TickInfo] -> [TickInfo]
forall a. a -> [a] -> [a]
:[TickInfo]
ticks) Value
v
  go ticks :: [TickInfo]
ticks v :: Value
v = (Value
v, [TickInfo]
ticks)

-- | Are we in a context where special primitives must be forced.
--
-- See [Note: forcing special primitives]
forcePrims :: Machine -> Bool
forcePrims :: Machine -> Bool
forcePrims = Stack -> Bool
go (Stack -> Bool) -> (Machine -> Stack) -> Machine -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Machine -> Stack
mStack
 where
  go :: Stack -> Bool
go (Scrutinise{}:_) = Bool
True
  go (PrimApply{}:_)  = Bool
True
  go (Tickish{}:xs :: Stack
xs)   = Stack -> Bool
go Stack
xs
  go _                = Bool
False

primCount :: Machine -> Int
primCount :: Machine -> Int
primCount = PrimHeap -> Int
forall a b. (a, b) -> b
snd (PrimHeap -> Int) -> (Machine -> PrimHeap) -> Machine -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Machine -> PrimHeap
mHeapPrim

primLookup :: Int -> Machine -> Maybe Term
primLookup :: Int -> Machine -> Maybe Term
primLookup i :: Int
i = Int -> IntMap Term -> Maybe Term
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
i (IntMap Term -> Maybe Term)
-> (Machine -> IntMap Term) -> Machine -> Maybe Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimHeap -> IntMap Term
forall a b. (a, b) -> a
fst (PrimHeap -> IntMap Term)
-> (Machine -> PrimHeap) -> Machine -> IntMap Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Machine -> PrimHeap
mHeapPrim

primInsert :: Int -> Term -> Machine -> Machine
primInsert :: Int -> Term -> Machine -> Machine
primInsert i :: Int
i x :: Term
x m :: Machine
m =
  let (gh :: IntMap Term
gh, c :: Int
c) = Machine -> PrimHeap
mHeapPrim Machine
m
   in Machine
m { mHeapPrim :: PrimHeap
mHeapPrim = (Int -> Term -> IntMap Term -> IntMap Term
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
i Term
x IntMap Term
gh, Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) }

primUpdate :: Int -> Term -> Machine -> Machine
primUpdate :: Int -> Term -> Machine -> Machine
primUpdate i :: Int
i x :: Term
x m :: Machine
m =
  let (gh :: IntMap Term
gh, c :: Int
c) = Machine -> PrimHeap
mHeapPrim Machine
m
   in Machine
m { mHeapPrim :: PrimHeap
mHeapPrim = (Int -> Term -> IntMap Term -> IntMap Term
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
i Term
x IntMap Term
gh, Int
c) }

heapLookup :: IdScope -> Id -> Machine -> Maybe Term
heapLookup :: IdScope -> Id -> Machine -> Maybe Term
heapLookup GlobalId i :: Id
i m :: Machine
m =
  Id -> PureHeap -> Maybe Term
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
i (PureHeap -> Maybe Term) -> PureHeap -> Maybe Term
forall a b. (a -> b) -> a -> b
$ Machine -> PureHeap
mHeapGlobal Machine
m
heapLookup LocalId i :: Id
i m :: Machine
m =
  Id -> PureHeap -> Maybe Term
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
i (PureHeap -> Maybe Term) -> PureHeap -> Maybe Term
forall a b. (a -> b) -> a -> b
$ Machine -> PureHeap
mHeapLocal Machine
m

heapContains :: IdScope -> Id -> Machine -> Bool
heapContains :: IdScope -> Id -> Machine -> Bool
heapContains scope :: IdScope
scope i :: Id
i = Maybe Term -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Term -> Bool) -> (Machine -> Maybe Term) -> Machine -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdScope -> Id -> Machine -> Maybe Term
heapLookup IdScope
scope Id
i

heapInsert :: IdScope -> Id -> Term -> Machine -> Machine
heapInsert :: IdScope -> Id -> Term -> Machine -> Machine
heapInsert GlobalId i :: Id
i x :: Term
x m :: Machine
m =
  Machine
m { mHeapGlobal :: PureHeap
mHeapGlobal = Id -> Term -> PureHeap -> PureHeap
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
i Term
x (Machine -> PureHeap
mHeapGlobal Machine
m) }
heapInsert LocalId i :: Id
i x :: Term
x m :: Machine
m =
  Machine
m { mHeapLocal :: PureHeap
mHeapLocal = Id -> Term -> PureHeap -> PureHeap
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
i Term
x (Machine -> PureHeap
mHeapLocal Machine
m) }

heapDelete :: IdScope -> Id -> Machine -> Machine
heapDelete :: IdScope -> Id -> Machine -> Machine
heapDelete GlobalId i :: Id
i m :: Machine
m =
  Machine
m { mHeapGlobal :: PureHeap
mHeapGlobal = PureHeap -> Id -> PureHeap
forall a b. VarEnv a -> Var b -> VarEnv a
delVarEnv (Machine -> PureHeap
mHeapGlobal Machine
m) Id
i }
heapDelete LocalId i :: Id
i m :: Machine
m =
  Machine
m { mHeapLocal :: PureHeap
mHeapLocal = PureHeap -> Id -> PureHeap
forall a b. VarEnv a -> Var b -> VarEnv a
delVarEnv (Machine -> PureHeap
mHeapLocal Machine
m) Id
i }

stackPush :: StackFrame -> Machine -> Machine
stackPush :: StackFrame -> Machine -> Machine
stackPush f :: StackFrame
f m :: Machine
m = Machine
m { mStack :: Stack
mStack = StackFrame
f StackFrame -> Stack -> Stack
forall a. a -> [a] -> [a]
: Machine -> Stack
mStack Machine
m }

stackPop :: Machine -> Maybe (Machine, StackFrame)
stackPop :: Machine -> Maybe (Machine, StackFrame)
stackPop m :: Machine
m = case Machine -> Stack
mStack Machine
m of
  [] -> Maybe (Machine, StackFrame)
forall a. Maybe a
Nothing
  (x :: StackFrame
x:xs :: Stack
xs) -> (Machine, StackFrame) -> Maybe (Machine, StackFrame)
forall a. a -> Maybe a
Just (Machine
m { mStack :: Stack
mStack = Stack
xs }, StackFrame
x)

stackClear :: Machine -> Machine
stackClear :: Machine -> Machine
stackClear m :: Machine
m = Machine
m { mStack :: Stack
mStack = [] }

stackNull :: Machine -> Bool
stackNull :: Machine -> Bool
stackNull = Stack -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null (Stack -> Bool) -> (Machine -> Stack) -> Machine -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Machine -> Stack
mStack

getTerm :: Machine -> Term
getTerm :: Machine -> Term
getTerm = Machine -> Term
mTerm

setTerm :: Term -> Machine -> Machine
setTerm :: Term -> Machine -> Machine
setTerm x :: Term
x m :: Machine
m = Machine
m { mTerm :: Term
mTerm = Term
x }