module Ivory.Eval
( runEval
, runEvalStartingFrom
, Error
, Eval
, EvalState(EvalState)
, Value(..)
, initState
, openModule
, evalAssert
, evalBlock
, evalCond
, evalDeref
, evalExpr
, evalInit
, evalLit
, evalOp
, evalRequires
, evalStmt
) where
import Prelude ()
import qualified Prelude.Compat as Prelude
import Prelude.Compat hiding (negate, div, mod, not, and, or)
import Control.Monad (unless,foldM,void)
import Data.Int
import qualified Data.Map as Map
import Data.Maybe
import qualified Data.Sequence as Seq
import Data.Word
import MonadLib
(ExceptionM(..),runId,Id,ExceptionT,runExceptionT,sets_
,StateT,runStateT,StateM(..))
import Ivory.Language.Syntax.Concrete.Location
import qualified Ivory.Language.Array as I
import qualified Ivory.Language.Syntax as I
type Error = String
type Eval a = StateT EvalState (ExceptionT Error Id) a
runEval :: Eval a -> Either Error a
runEval doThis = fmap fst (runEvalStartingFrom (initState Map.empty) doThis)
runEvalStartingFrom :: EvalState -> Eval a -> Either Error (a, EvalState)
runEvalStartingFrom st doThis = runId (runExceptionT (runStateT st doThis))
data EvalState = EvalState
{ store :: Map.Map I.Sym Value
, loc :: SrcLoc
, structs :: Map.Map I.Sym I.Struct
} deriving Show
initState :: Map.Map I.Sym Value -> EvalState
initState st = EvalState st NoLoc Map.empty
openModule :: I.Module -> Eval a -> Eval a
openModule (I.Module {..}) doThis = do
oldStrs <- fmap structs get
sets_ (\s -> s { structs = Map.union (structs s) newStructs })
res <- doThis
sets_ (\s -> s { structs = oldStrs })
return res
where
newStructs = Map.fromList
[ (sym, struct)
| struct@(I.Struct sym _) <- I.public modStructs ++ I.private modStructs
]
data Value
= Sint8 Int8
| Sint16 Int16
| Sint32 Int32
| Sint64 Int64
| Uint8 Word8
| Uint16 Word16
| Uint32 Word32
| Uint64 Word64
| Float Float
| Double Double
| Char Char
| String String
| Bool Bool
| Array (Seq.Seq Value)
| Struct (Map.Map I.Sym Value)
| Ref I.Sym
deriving (Show, Eq, Ord)
eq :: Value -> Value -> Value
eq x y = Bool (x == y)
neq :: Value -> Value -> Value
neq x y = Bool (x /= y)
not :: Value -> Value
not (Bool x) = Bool (Prelude.not x)
not x = error $ "invalid operands to `not`: " ++ show x
and :: Value -> Value -> Value
and (Bool x) (Bool y) = Bool (x && y)
and x y = error $ "invalid operands to `and`: " ++ show (x,y)
or :: Value -> Value -> Value
or (Bool x) (Bool y) = Bool (x || y)
or x y = error $ "invalid operors to `or`: " ++ show (x,y)
ordOp :: (forall a. Ord a => a -> a -> Bool) -> Value -> Value -> Value
ordOp op (Sint8 x) (Sint8 y) = Bool (op x y)
ordOp op (Sint16 x) (Sint16 y) = Bool (op x y)
ordOp op (Sint32 x) (Sint32 y) = Bool (op x y)
ordOp op (Sint64 x) (Sint64 y) = Bool (op x y)
ordOp op (Uint8 x) (Uint8 y) = Bool (op x y)
ordOp op (Uint16 x) (Uint16 y) = Bool (op x y)
ordOp op (Uint32 x) (Uint32 y) = Bool (op x y)
ordOp op (Uint64 x) (Uint64 y) = Bool (op x y)
ordOp op (Float x) (Float y) = Bool (op x y)
ordOp op (Double x) (Double y) = Bool (op x y)
ordOp op (Char x) (Char y) = Bool (op x y)
ordOp _ x y = error $ "invalid operands to `ordOp`: " ++ show (x,y)
gt :: Value -> Value -> Value
gt = ordOp (>)
gte :: Value -> Value -> Value
gte = ordOp (>=)
lt :: Value -> Value -> Value
lt = ordOp (<)
lte :: Value -> Value -> Value
lte = ordOp (<=)
numUnOp :: (forall a. Num a => a -> a) -> Value -> Value
numUnOp op (Sint8 x) = Sint8 (op x)
numUnOp op (Sint16 x) = Sint16 (op x)
numUnOp op (Sint32 x) = Sint32 (op x)
numUnOp op (Sint64 x) = Sint64 (op x)
numUnOp op (Uint8 x) = Uint8 (op x)
numUnOp op (Uint16 x) = Uint16 (op x)
numUnOp op (Uint32 x) = Uint32 (op x)
numUnOp op (Uint64 x) = Uint64 (op x)
numUnOp op (Float x) = Float (op x)
numUnOp op (Double x) = Double (op x)
numUnOp _ x = error $ "invalid operands to `numUnOp`: " ++ show x
negate :: Value -> Value
negate = numUnOp Prelude.negate
numBinOp :: (forall a. Num a => a -> a -> a) -> Value -> Value -> Value
numBinOp op (Sint8 x) (Sint8 y) = Sint8 (op x y)
numBinOp op (Sint16 x) (Sint16 y) = Sint16 (op x y)
numBinOp op (Sint32 x) (Sint32 y) = Sint32 (op x y)
numBinOp op (Sint64 x) (Sint64 y) = Sint64 (op x y)
numBinOp op (Uint8 x) (Uint8 y) = Uint8 (op x y)
numBinOp op (Uint16 x) (Uint16 y) = Uint16 (op x y)
numBinOp op (Uint32 x) (Uint32 y) = Uint32 (op x y)
numBinOp op (Uint64 x) (Uint64 y) = Uint64 (op x y)
numBinOp op (Float x) (Float y) = Float (op x y)
numBinOp op (Double x) (Double y) = Double (op x y)
numBinOp _ x y = error $ "invalid operands to `numBinOp`: " ++ show (x,y)
add :: Value -> Value -> Value
add = numBinOp (+)
sub :: Value -> Value -> Value
sub = numBinOp ()
mul :: Value -> Value -> Value
mul = numBinOp (*)
div :: Value -> Value -> Value
(Sint8 x) `div` (Sint8 y) = Sint8 (x `Prelude.div` y)
(Sint16 x) `div` (Sint16 y) = Sint16 (x `Prelude.div` y)
(Sint32 x) `div` (Sint32 y) = Sint32 (x `Prelude.div` y)
(Sint64 x) `div` (Sint64 y) = Sint64 (x `Prelude.div` y)
(Uint8 x) `div` (Uint8 y) = Uint8 (x `Prelude.div` y)
(Uint16 x) `div` (Uint16 y) = Uint16 (x `Prelude.div` y)
(Uint32 x) `div` (Uint32 y) = Uint32 (x `Prelude.div` y)
(Uint64 x) `div` (Uint64 y) = Uint64 (x `Prelude.div` y)
(Float x) `div` (Float y) = Float (x / y)
(Double x) `div` (Double y) = Double (x / y)
x `div` y = error $ "invalid operands to `div`: " ++ show (x,y)
mod :: Value -> Value -> Value
(Sint8 x) `mod` (Sint8 y) = Sint8 (x `Prelude.mod` y)
(Sint16 x) `mod` (Sint16 y) = Sint16 (x `Prelude.mod` y)
(Sint32 x) `mod` (Sint32 y) = Sint32 (x `Prelude.mod` y)
(Sint64 x) `mod` (Sint64 y) = Sint64 (x `Prelude.mod` y)
(Uint8 x) `mod` (Uint8 y) = Uint8 (x `Prelude.mod` y)
(Uint16 x) `mod` (Uint16 y) = Uint16 (x `Prelude.mod` y)
(Uint32 x) `mod` (Uint32 y) = Uint32 (x `Prelude.mod` y)
(Uint64 x) `mod` (Uint64 y) = Uint64 (x `Prelude.mod` y)
x `mod` y = error $ "invalid operands to `mod`: " ++ show (x,y)
readStore :: I.Sym -> Eval Value
readStore sym = do
st <- fmap store get
case Map.lookup sym st of
Nothing -> raise $ "Unbound variable: `" ++ sym ++ "'!"
Just v -> return v
writeStore :: I.Sym -> Value -> Eval ()
writeStore sym val = sets_ (\s -> s { store = Map.insert sym val (store s) })
modifyStore :: I.Sym -> (Value -> Value) -> Eval ()
modifyStore sym f = sets_ (\s -> s { store = Map.update (Just . f) sym (store s) })
updateLoc :: SrcLoc -> Eval ()
updateLoc loc = sets_ (\ s -> s { loc = loc })
lookupStruct :: String -> Eval I.Struct
lookupStruct str = do
structs <- fmap structs get
case Map.lookup str structs of
Nothing -> raise $ "Couldn't find struct: " ++ str
Just str -> return str
evalBlock :: I.Block -> Eval ()
evalBlock = mapM_ evalStmt
evalRequires :: [I.Require] -> Eval Bool
evalRequires = fmap Prelude.and . mapM (evalCond . I.getRequire)
evalCond :: I.Cond -> Eval Bool
evalCond cond = case cond of
I.CondBool expr -> do
val <- evalExpr I.TyBool expr
case val of
Bool True -> return True
Bool False -> return False
_ -> raise $ "evalCond: expected boolean, got: " ++ show val
I.CondDeref ty expr var cond -> do
evalStmt (I.Deref ty var expr)
evalCond cond
evalStmt :: I.Stmt -> Eval ()
evalStmt stmt = case stmt of
I.Comment (I.SourcePos loc)
-> updateLoc loc
I.Assert expr
-> evalAssert expr
I.CompilerAssert expr
-> evalAssert expr
I.IfTE cond true false
-> do b <- evalExpr I.TyBool cond
case b of
Bool True -> evalBlock true
Bool False -> evalBlock false
_ -> raise $ "evalStmt: IfTE: expected true or false, got: " ++ show b
I.Deref _ty var expr
-> do val <- evalDeref _ty expr
case val of
Ref ref -> do
val <- readStore ref
writeStore (varSym var) val
_ -> writeStore (varSym var) val
I.Assign _ty var expr
-> do val <- evalExpr _ty expr
writeStore (varSym var) val
I.Local ty var init
-> do val <- evalInit ty init
writeStore (varSym var) val
I.AllocRef _ty var ref
-> writeStore (varSym var) (Ref $ nameSym ref)
I.Loop _ var expr incr body
-> do val <- evalExpr I.ixRep expr
writeStore (varSym var) val
let (checkDone, stepFn, doneExpr) = case incr of
I.IncrTo expr -> ((>=), (`add` Sint32 1), expr)
I.DecrTo expr -> ((<=), (`sub` Sint32 1), expr)
let step = modifyStore (varSym var) stepFn
let done = do curVal <- readStore (varSym var)
doneVal <- evalExpr I.ixRep doneExpr
return (checkDone curVal doneVal)
untilM done (evalBlock body >> step)
I.Return (I.Typed _ty expr)
-> void $ evalExpr _ty expr
I.Store _ty (I.ExpAddrOfGlobal dst) expr
-> do val <- evalExpr _ty expr
writeStore dst val
I.Store _ty (I.ExpVar dst) expr
-> do val <- evalExpr _ty expr
Ref ref <- readStore (varSym dst)
writeStore ref val
I.RefCopy _ty (I.ExpAddrOfGlobal dst) expr
-> do val <- evalExpr _ty expr
writeStore dst val
_ -> error $ show stmt
evalDeref :: I.Type -> I.Expr -> Eval Value
evalDeref _ty expr = case expr of
I.ExpSym sym -> readRef sym
I.ExpVar var -> readRef (varSym var)
I.ExpAddrOfGlobal sym -> readStore sym
I.ExpIndex tarr arr tidx idx
-> do Array arr <- evalDeref tarr arr
Sint32 idx <- evalExpr tidx idx
return (arr `Seq.index` fromIntegral idx)
I.ExpLabel tstr str lab
-> do Struct str <- evalDeref tstr str
return (fromJust $ Map.lookup lab str)
_ -> error $ show expr
readRef :: I.Sym -> Eval Value
readRef sym = do
val <- readStore sym
case val of
Ref ref -> readStore ref
_ -> raise $ "Expected Ref, got: " ++ show val
evalAssert :: I.Expr -> Eval ()
evalAssert asrt = do
b <- evalExpr I.TyBool asrt
case b of
Bool True -> return ()
_ -> raise $ "Assertion failed: " ++ show (b,asrt)
evalExpr :: I.Type -> I.Expr -> Eval Value
evalExpr ty expr = case expr of
I.ExpSym sym -> readStore sym
I.ExpVar var -> readStore (varSym var)
I.ExpLit lit -> evalLit ty lit
I.ExpOp op exprs
-> do let opTy = case op of
I.ExpEq t -> t
I.ExpNeq t -> t
I.ExpGt _ t -> t
I.ExpLt _ t -> t
I.ExpIsNan t -> t
I.ExpIsInf t -> t
_ -> ty
vals <- mapM (evalExpr opTy) exprs
evalOp op vals
I.ExpSafeCast fromTy expr
-> do val <- evalExpr fromTy expr
return (cast ty fromTy val)
I.ExpToIx expr max
-> fmap (`mod` Sint32 (fromInteger max)) (evalExpr I.ixRep expr)
_ -> error $ show expr
cast :: I.Type -> I.Type -> Value -> Value
cast toTy _fromTy val = mkVal toTy integer
where
integer = case val of
Sint8 i -> toInteger i
Sint16 i -> toInteger i
Sint32 i -> toInteger i
Sint64 i -> toInteger i
Uint8 i -> toInteger i
Uint16 i -> toInteger i
Uint32 i -> toInteger i
Uint64 i -> toInteger i
_ -> error $ "Expected number, got: " ++ show val
evalLit :: I.Type -> I.Literal -> Eval Value
evalLit ty lit = case lit of
I.LitInteger i -> return (mkVal ty i)
I.LitFloat c -> return (Float c)
I.LitDouble c -> return (Double c)
I.LitChar c -> return (Char c)
I.LitBool b -> return (Bool b)
I.LitString s -> return (String s)
_ -> raise $ "evalLit: can't handle: " ++ show lit
mkVal :: I.Type -> Integer -> Value
mkVal ty = case ty of
I.TyInt I.Int8 -> Sint8 . fromInteger
I.TyInt I.Int16 -> Sint16 . fromInteger
I.TyInt I.Int32 -> Sint32 . fromInteger
I.TyInt I.Int64 -> Sint64 . fromInteger
I.TyWord I.Word8 -> Uint8 . fromInteger
I.TyWord I.Word16 -> Uint16 . fromInteger
I.TyWord I.Word32 -> Uint32 . fromInteger
I.TyWord I.Word64 -> Uint64 . fromInteger
I.TyIndex _ -> Sint32 . fromInteger
I.TyFloat -> Float . fromInteger
I.TyDouble -> Double . fromInteger
I.TyBool -> Bool . toEnum . fromInteger
I.TyChar -> Char . toEnum . fromInteger
_ -> error $ "mkVal: " ++ show ty
evalOp :: I.ExpOp -> [Value] -> Eval Value
evalOp (I.ExpEq _) [x, y] = return (x `eq` y)
evalOp (I.ExpNeq _) [x, y] = return (x `neq` y)
evalOp (I.ExpGt True _) [x, y] = return (x `gte` y)
evalOp (I.ExpGt False _) [x, y] = return (x `gt` y)
evalOp (I.ExpLt True _) [x, y] = return (x `lte` y)
evalOp (I.ExpLt False _) [x, y] = return (x `lt` y)
evalOp I.ExpAdd [x, y] = return (x `add` y)
evalOp I.ExpSub [x, y] = return (x `sub` y)
evalOp I.ExpMul [x, y] = return (x `mul` y)
evalOp I.ExpDiv [x, y] = return (x `div` y)
evalOp I.ExpMod [x, y] = return (x `mod` y)
evalOp I.ExpNegate [x] = return (negate x)
evalOp I.ExpNot [x] = return (not x)
evalOp I.ExpAnd [x, y] = return (x `and` y)
evalOp I.ExpOr [x, y] = return (x `or` y)
evalOp I.ExpCond [cond, true, false] =
case cond of
Bool True -> return true
Bool False -> return false
_ -> raise $ "evalOp: ExpCond: expected true or false, got: " ++ show cond
evalOp op args = raise $ "evalOp: can't handle: " ++ show (op, args)
evalInit :: I.Type -> I.Init -> Eval Value
evalInit ty init = case init of
I.InitZero
-> case ty of
I.TyArr _ _ -> evalInit ty (I.InitArray [])
I.TyStruct _ -> evalInit ty (I.InitStruct [])
_ -> return (mkVal ty 0)
I.InitExpr ty expr
-> evalExpr ty expr
I.InitArray inits
-> case ty of
I.TyArr len ty
-> Array . Seq.fromList
<$> mapM (evalInit ty) (take len $ inits ++ repeat I.InitZero)
_ -> raise $ "evalInit: InitArray: unexpected type: " ++ show ty
I.InitStruct inits
-> case ty of
I.TyStruct str -> do
I.Struct _ fields <- lookupStruct str
zstr <- foldM (\ str (I.Typed ty fld) -> do
val <- evalInit ty I.InitZero
return (Map.insert fld val str))
Map.empty fields
str <- foldM (\ str (fld, init) -> do
val <- evalInit (lookupTyped fld fields) init
return (Map.insert fld val str))
zstr inits
return $ Struct str
_ -> raise $ "evalInit: InitStruct: unexpected type: " ++ show ty
lookupTyped :: (Show a, Eq a) => a -> [I.Typed a] -> I.Type
lookupTyped a [] = error $ "lookupTyped: couldn't find: " ++ show a
lookupTyped a (I.Typed t x : xs)
| a == x = t
| otherwise = lookupTyped a xs
varSym :: I.Var -> I.Sym
varSym (I.VarName sym) = sym
varSym (I.VarInternal sym) = sym
varSym (I.VarLitName sym) = sym
nameSym :: I.Name -> I.Sym
nameSym (I.NameSym sym) = sym
nameSym (I.NameVar var) = varSym var
untilM :: Monad m => m Bool -> m () -> m ()
untilM done doThis = do
b <- done
unless b (doThis >> untilM done doThis)