{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE FlexibleInstances #-} module YesodDsl.ParserState (ParserMonad, initParserState, getParserState, getPath, getParsed, setParserState, runParser, pushScope, popScope, declare, declareGlobal, SymType(..), ParserState, mkLoc, parseErrorCount, withSymbol, withGlobalSymbol, withSymbolNow, pError, hasReserved, getEntitySymbol, symbolMatches, requireClass, requireEnumName, requireEntity, requireEntityResult, requireEntityOrClass, requireEntityId, requireEntityField, requireEntityFieldSelectedOrResult, requireField, requireEnum, requireEnumValue, requireParam, requireFunction, setCurrentHandlerType, getCurrentHandlerType, requireHandlerType, validateExtractField, validateInsert, beginHandler, statement, lastStatement, postValidation) where import qualified Data.Map as Map import Control.Monad.State.Lazy import YesodDsl.AST import YesodDsl.Lexer import qualified Data.List as L import Data.Maybe import System.IO data SymType = SEnum EnumType | SEnumName EnumName | SClass Class | SEntity EntityName | SEntityResult EntityName | SEntityId EntityName | SField Field | SFieldType FieldType | SUnique Unique | SRoute Route | SHandler Handler | SParam | SForParam FieldRef | SReserved | SFunction instance Show SymType where show st = case st of SEnum _ -> "enum" SEnumName _ -> "enum name" SClass _ -> "class" SEntity _ -> "entity" SEntityResult _ -> "get-entity" SEntityId _ -> "entity id" SField _ -> "field" SFieldType _ -> "field type" SUnique _ -> "unique" SRoute _ -> "route" SHandler _ -> "handler" SParam -> "param" SForParam _ -> "for-param" SReserved -> "reserved" SFunction -> "function" data Sym = Sym Int Location SymType deriving (Show) type ParserMonad = StateT ParserState IO type Syms = Map.Map String [Sym] type EntityValidation = (EntityName, Entity -> ParserMonad ()) type PendingValidation = Syms -> ParserMonad () data ParserState = ParserState { psSyms :: Syms, psScopeId :: Int, psPath :: FilePath, psParsed :: [FilePath], psErrors :: Int, psHandlerType :: Maybe HandlerType, psPendingValidations :: [[PendingValidation]], psEntityValidations :: [EntityValidation], psLastStatement :: Maybe (Location, String), psChecks :: [(Location,String,FieldType)] } deriving (Show) instance Show (Syms -> ParserMonad ()) where show _ = "" instance Show (Entity -> ParserMonad ()) where show _ = "" initParserState :: ParserState initParserState = ParserState { psSyms = Map.fromList [ ("ClassInstance", [ Sym 0 (Loc "" 0 0) (SEntity "ClassInstance") ]) ], psScopeId = 0, psPath = "", psParsed = [], psErrors = 0, psHandlerType = Nothing, psPendingValidations = [], psEntityValidations = [], psLastStatement = Nothing, psChecks = [] } getParserState :: ParserMonad ParserState getParserState = get getPath :: ParserMonad FilePath getPath = gets psPath getParsed :: ParserMonad [FilePath] getParsed = gets psParsed setCurrentHandlerType :: HandlerType -> ParserMonad () setCurrentHandlerType ht = modify $ \ps -> ps { psHandlerType = Just ht } requireHandlerType :: Location -> String -> (HandlerType -> Bool) -> ParserMonad () requireHandlerType l n f = do mht <- gets psHandlerType when (fmap f mht /= Just True) $ pError l $ n ++ " not allowed " ++ (fromMaybe "outside handler" $ mht >>= \ht' -> return $ "in " ++ show ht' ++ " handler") getCurrentHandlerType :: ParserMonad (Maybe HandlerType) getCurrentHandlerType = gets psHandlerType setParserState :: ParserState -> ParserMonad () setParserState = put runEntityValidation :: Module -> EntityValidation -> ParserMonad () runEntityValidation m (en,f) = do case L.find (\e -> entityName e == en) $ modEntities m of Just e -> f e Nothing -> return () validateExtractField :: Location -> String -> ParserMonad () validateExtractField l s = if s `elem` validFields then return () else pError l $ "Unknown subfield '" ++ s ++ "' to extract" where validFields = [ "century", "day", "decade", "dow", "doy", "epoch", "hour", "isodow", "microseconds", "millennium", "milliseconds", "minute", "month", "quarter", "second", "timezone", "timezone_hour", "timezone_minute", "week", "year" ] validateInsert :: Location -> Entity -> Maybe (Maybe VariableName, [FieldRefMapping]) -> ParserMonad () validateInsert l e (Just (Nothing, ifs)) = do case [ fieldName f | f <- entityFields e, (not . fieldOptional) f, isNothing (fieldDefault f) ] L.\\ [ fn | (fn,_,_) <- ifs ] of fs@(_:_) -> pError l $ "Missing required fields without default value: " ++ (show fs) _ -> return () validateInsert _ _ _ = return () beginHandler :: ParserMonad () beginHandler = modify $ \ps -> ps { psLastStatement = Nothing } statement :: Location -> String -> ParserMonad () statement l s= do mlast <- gets psLastStatement fromMaybe (return ()) (mlast >>= \(l',s') -> do return $ pError l $ "'" ++ s ++ "' not allowed after '" ++ s' ++ "' in " ++ show l') lastStatement :: Location -> String -> ParserMonad () lastStatement l s = do statement l s modify $ \ps -> ps { psLastStatement = listToMaybe $ catMaybes [ psLastStatement ps, Just (l,s) ] } postValidation :: Module -> ParserState -> IO Int postValidation m ps = do ps' <- execStateT f ps return $ psErrors ps' where f = do forM_ (psEntityValidations ps) (runEntityValidation m) when (isNothing $ modName m) $ globalError "Missing top-level module name" runParser :: FilePath -> ParserState -> ParserMonad a -> IO (a,ParserState) runParser path ps m = do (r,ps') <- runStateT m $ ps { psPath = path, psParsed = path:psParsed ps } return (r, ps' { psPath = psPath ps} ) pushScope :: ParserMonad () pushScope = do modify $ \ps -> ps { psScopeId = psScopeId ps + 1, psPendingValidations = []:psPendingValidations ps } popScope :: ParserMonad () popScope = do (vs:_) <- gets psPendingValidations syms <- gets psSyms forM_ vs $ \v -> v syms modify $ \ps -> ps { psSyms = Map.filter (not . null) $ Map.map (filter (\(Sym s _ _) -> s < psScopeId ps)) $ psSyms ps, psScopeId = psScopeId ps - 1, psPendingValidations = tail $ psPendingValidations ps } pError :: Location -> String -> ParserMonad () pError l e = do lift $ hPutStrLn stderr $ show l ++ ": " ++ e modify $ \ps -> ps { psErrors = psErrors ps + 1 } globalError :: String -> ParserMonad () globalError e = do lift $ hPutStrLn stderr e modify $ \ps -> ps { psErrors = psErrors ps + 1 } declare :: Location -> String -> SymType -> ParserMonad () declare l n t = declare' False l n t declareGlobal :: Location -> String -> SymType -> ParserMonad () declareGlobal l n t = declare' True l n t declare' :: Bool -> Location -> String -> SymType -> ParserMonad () declare' global l n t = do ps <- get let scopeId = if global then 0 else psScopeId ps sym = [Sym scopeId l t ] case Map.lookup n (psSyms ps) of Just ((Sym s l' _):_) -> if s == scopeId then do pError l $ "'" ++ n ++ "' already declared in " ++ show l' return () else put $ ps { psSyms = Map.adjust (sym++) n (psSyms ps) } _ -> put $ ps { psSyms = Map.insert n sym $ psSyms ps } mkLoc :: Token -> ParserMonad Location mkLoc t = do path <- gets psPath return $ Loc path (tokenLineNum t) (tokenColNum t) parseErrorCount :: ParserState -> Int parseErrorCount = psErrors addEntityValidation :: EntityValidation -> ParserMonad () addEntityValidation ev = modify $ \ps -> ps { psEntityValidations = psEntityValidations ps ++ [ev] } addPendingValidation :: (Syms -> ParserMonad ()) -> ParserMonad () addPendingValidation v = modify $ \ps -> let vs = psPendingValidations ps in ps { psPendingValidations = (v:(head vs)):tail vs } addGlobalPendingValidation :: (Syms -> ParserMonad ()) -> ParserMonad () addGlobalPendingValidation v = modify $ \ps -> let vs = psPendingValidations ps in ps { psPendingValidations = init vs ++ [v : last vs] } hasReserved :: String -> ParserMonad Bool hasReserved n = do syms <- gets psSyms return $ fromMaybe False $  Map.lookup n syms >>= return . (not . null . (filter match)) where match (Sym _ _ SReserved) = True match _ = False withSymbol' :: Syms -> a -> Location -> String -> (Location -> Location -> SymType -> ParserMonad a) -> ParserMonad a withSymbol' syms d l n f = do case Map.lookup n syms >>= return . (filter notReserved) of Just ((Sym _ l' st):_) -> f l l' st _ -> pError l ("Reference to an undeclared identifier '" ++ n ++ "'") >> return d notReserved :: Sym -> Bool notReserved (Sym _ _ SReserved) = False notReserved _ = True withSymbolNow :: a -> Location -> String -> (Location -> Location -> SymType -> ParserMonad a) -> ParserMonad a withSymbolNow d l n f = do syms <- gets psSyms withSymbol' syms d l n f withSymbol :: Location -> String -> (Location -> Location -> SymType -> ParserMonad ()) -> ParserMonad () withSymbol l n f = addPendingValidation $ \syms -> void $ withSymbol' syms () l n f withGlobalSymbol :: Location -> String -> (Location -> Location -> SymType -> ParserMonad ()) -> ParserMonad () withGlobalSymbol l n f = addGlobalPendingValidation $ \syms -> void $ withSymbol' syms () l n f getEntitySymbol :: Location -> Location -> SymType -> ParserMonad (Maybe EntityName) getEntitySymbol _ _ (SEntity en) = return $ Just en getEntitySymbol _ _ _ = return Nothing symbolMatches :: String -> (SymType -> Bool) -> ParserMonad Bool symbolMatches n f = do syms <- gets psSyms case Map.lookup n syms >>= return . (filter notReserved) of Just ((Sym _ _ st):_) -> return $ f st _ -> return False requireClass :: (Class -> ParserMonad ()) -> (Location -> Location -> SymType -> ParserMonad ()) requireClass f = f' where f' _ _ (SClass c) = f c f' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected class)" requireEntity :: (Entity -> ParserMonad ()) -> (Location -> Location -> SymType -> ParserMonad ()) requireEntity f = f' where f' _ _ (SEntity en) = addEntityValidation (en, f) f' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected entity)" requireEntityResult :: (Entity -> ParserMonad ()) -> (Location -> Location -> SymType -> ParserMonad ()) requireEntityResult f = f' where f' _ _ (SEntityResult en) = addEntityValidation (en, f) f' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected get-entity)" requireEntityOrClass :: (Location -> Location -> SymType -> ParserMonad ()) requireEntityOrClass = f' where f' _ _ (SEntity _) = return () f' _ _ (SClass _) = return () f' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected entity or class)" requireEntityId :: (EntityName -> ParserMonad (Maybe a)) -> (Location -> Location -> SymType -> ParserMonad (Maybe a)) requireEntityId f = f' where f' _ _ (SEntityId en) = f en f' l1 l2 st = pError l1 ("Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected entity id)") >> return Nothing requireEnumName :: (EntityName -> ParserMonad (Maybe a)) -> (Location -> Location -> SymType -> ParserMonad (Maybe a)) requireEnumName f = f' where f' _ _ (SEnumName en) = f en f' l1 l2 st = pError l1 ("Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected enum name)") >> return Nothing requireEntityField :: Location -> FieldName -> ((Entity,Field) -> ParserMonad ()) -> (Location -> Location -> SymType -> ParserMonad ()) requireEntityField l fn fun = fun' where fun' _ _ (SEntity en) = addEntityValidation (en, \e -> case L.find (\f -> fieldName f == fn) $ entityFields e ++ entityClassFields e of Just f -> fun (e,f) Nothing -> pError l ("Reference to undeclared field '" ++ fn ++ "' of entity '" ++ entityName e ++ "'")) fun' l1 l2 st = pError l1 ( "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected entity)") requireEntityFieldSelectedOrResult :: Location -> FieldName -> (Location -> Location -> SymType -> ParserMonad ()) requireEntityFieldSelectedOrResult l fn = fun' where fun' _ _ (SEntity en) = validate en fun' _ _ (SEntityResult en) = validate en fun' l1 l2 st = pError l1 ( "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected entity or get-entity)") validate en = addEntityValidation (en, \e -> case L.find (\f -> fieldName f == fn) $ entityFields e ++ entityClassFields e of Just _ -> return () Nothing -> pError l ("Reference to undeclared field '" ++ fn ++ "' of entity '" ++ entityName e ++ "'")) requireEnum :: Location -> Location -> SymType -> ParserMonad () requireEnum = fun' where fun' _ _ (SEnum e) = return () fun' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected enum)" requireEnumValue :: Location -> EnumValue -> (Location -> Location -> SymType -> ParserMonad ()) requireEnumValue l ev = fun' where fun' _ _ (SEnum e) = case L.find (\ev' -> ev' == ev) $ enumValues e of Just _ -> return () Nothing -> pError l $ "Reference to undeclared enum value '" ++ ev ++ "' of enum '" ++ enumName e ++ "'" fun' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected enum)" requireParam :: (Location -> Location -> SymType -> ParserMonad ()) requireParam = fun' where fun' _ _ SParam = return () fun' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected param)" requireField:: (Field -> ParserMonad ()) -> (Location -> Location -> SymType -> ParserMonad ()) requireField f = f' where f' _ _ (SField sf) = f sf f' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected field)" requireFunction :: (Location -> Location -> SymType -> ParserMonad ()) requireFunction = fun' where fun' _ _ SFunction = return () fun' l1 l2 st = pError l1 $ "Reference to " ++ show st ++ " declared in " ++ show l2 ++ " (expected function)"