{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} module Zinza.Check (check) where import Control.Monad ((>=>)) import Data.Functor.Identity (Identity (..)) import Data.Proxy (Proxy (..)) import Data.Traversable (for) import Control.Monad.Trans.State (StateT (..), evalStateT, put, get) import Control.Monad.Trans.Class (lift) import qualified Data.Map.Strict as M import Zinza.Class import Zinza.Errors import Zinza.Expr import Zinza.Indexing import Zinza.Node import Zinza.Pos import Zinza.Type import Zinza.Value import Zinza.Var ------------------------------------------------------------------------------- -- Type ------------------------------------------------------------------------------- type Check v m = StateT (M.Map Var (v Value -> m ShowS)) (Either CompileError) ------------------------------------------------------------------------------- -- Nodes ------------------------------------------------------------------------------- check :: forall a m. (Zinza a, ThrowRuntime m) => Nodes Var -> Either CompileError (a -> m String) check nodes = case toType (Proxy :: Proxy a) of rootTy@(TyRecord env) -> do nodes' <- flip (traverse . traverseWithLoc) nodes $ \loc var -> case M.lookup var env of Nothing -> Left (UnboundTopLevelVar loc var) Just _ -> Right (EField (L loc (EVar (L loc (Identity rootTy)))) (L loc var)) run <- evalStateT (checkNodes (map (>>== id) nodes')) M.empty return $ fmap ($ "") . run . Identity . toValue rootTy -> throwRuntime (NotRecord zeroLoc rootTy) checkNodes :: (Indexing v i, ThrowRuntime m) => Nodes (i Ty) -- ^ nodes with root object -> Check v m (v Value -> m ShowS) checkNodes nodes = do nodes' <- traverse checkNode nodes return $ \val -> do ss <- traverse ($ val) nodes' return (foldr (.) id ss) checkNode :: (Indexing v i, ThrowRuntime m) => Node (i Ty) -> Check v m (v Value -> m ShowS) checkNode NComment = return $ \_val -> return id checkNode (NRaw s) = return $ \_val -> return (showString s) checkNode (NIf expr xs ys) = do b' <- checkBool expr xs' <- resetingState $ checkNodes xs ys' <- resetingState $ checkNodes ys return $ \ctx -> do b'' <- b' ctx if b'' then xs' ctx else ys' ctx checkNode (NExpr e) = do e' <- checkString e return $ \ctx -> do s <- e' ctx return $ showString s checkNode (NFor _v expr nodes) = do (expr', ty) <- checkList expr blocks <- get nodes' <- lift $ evalStateT (checkNodes (fmap (fmap (maybe (Here ty) There)) nodes)) (M.map (\f (_ ::: xs) -> f xs) blocks) return $ \ctx -> do xs <- expr' ctx pieces <- for xs $ \x -> nodes' (x ::: ctx) return $ foldr (.) id pieces checkNode (NDefBlock l n nodes) = do blocks <- get if M.member n blocks then lift (Left (ShadowingBlock l n)) else do nodes' <- checkNodes nodes put $ M.insert n nodes' blocks return $ \_ -> return id checkNode (NUseBlock l n) = do blocks <- get case M.lookup n blocks of Nothing -> lift (Left (UnboundUseBlock l n)) Just block -> return block resetingState :: Monad m => StateT s m a -> StateT s m a resetingState m = do s <- get x <- m put s return x ------------------------------------------------------------------------------- -- Expressions ------------------------------------------------------------------------------- checkList :: (Indexing v i, ThrowRuntime m) => LExpr (i Ty) -> Check v m (v Value -> m [Value], Ty) checkList e@(L l _) = do (e', ty) <- checkType e case ty of TyList _ ty' -> return (e' >=> go, ty') _ -> throwRuntime (NotList l ty) where go (VList xs) = return xs go x = throwRuntime (NotList l (valueType x)) checkBool :: (Indexing v i, ThrowRuntime m) => LExpr (i Ty) -> Check v m (v Value -> m Bool) checkBool e@(L l _) = do (e', ty) <- checkType e case ty of TyBool -> return (e' >=> go) _ -> throwRuntime (NotBool l ty) where go (VBool b) = return b go x = throwRuntime (NotBool l (valueType x)) checkString :: (Indexing v i, ThrowRuntime m) => LExpr (i Ty) -> Check v m (v Value -> m String) checkString e@(L l _) = do (e', ty) <- checkType e case ty of TyString _ -> return (e' >=> go) _ -> throwRuntime (NotString l ty) where go (VString b) = return b go x = throwRuntime (NotString l (valueType x)) checkType :: (Indexing v i, ThrowRuntime m) => LExpr (i Ty) -> Check v m (v Value -> m Value, Ty) checkType (L _ (EVar (L _ i))) = return (\v -> return (fst (index v i)), extract i) checkType (L eLoc (EField e (L nameLoc name))) = do (e', ty) <- checkType e case ty of TyRecord tym -> case M.lookup name tym of Just (_sel, tyf) -> return (e' >=> go, tyf) Nothing -> throwRuntime (FieldNotInRecord nameLoc name ty) _ -> throwRuntime (NotRecord eLoc ty) where go x@(VRecord r) = case M.lookup name r of Just y -> return y Nothing -> throwRuntime (FieldNotInRecord nameLoc name (valueType x)) go x = throwRuntime (NotRecord eLoc (valueType x)) checkType (L eLoc (EApp f@(L fLoc _) x)) = do (f', fTy) <- checkType f (x', xTy) <- checkType x case fTy of TyFun xTy' yTy | xTy == xTy' -> do return (go f' x', yTy) TyFun xTy' _ -> throwRuntime (FunArgDontMatch fLoc xTy xTy') _ -> throwRuntime (NotFunction eLoc fTy) where go f' x' ctx = do f2 <- f' ctx x2 <- x' ctx case f2 of VFun f3 -> either throwRuntime return $ f3 x2 _ -> throwRuntime (NotFunction eLoc (valueType f2))