module Hydra.Rewriting where import Hydra.Core import Hydra.Monads import Hydra.Module import Hydra.Lexical import Hydra.Compute --import Hydra.CoreDecoding import Hydra.Mantle import Hydra.Sorting import qualified Control.Monad as CM import qualified Data.List as L import qualified Data.Map as M import qualified Data.Set as S import qualified Data.Maybe as Y -- | Turn arbitrary terms like 'compareTo 42' into terms like '\x.compareTo 42 x', -- whose arity (in the absences of application terms) is equal to the depth of nested lambdas. -- This function leaves application terms intact, simply rewriting their left and right subterms. expandLambdas :: Ord m => Term m -> GraphFlow m (Term m) expandLambdas = rewriteTermM (expand []) (pure . id) where expand args recurse term = case term of TermAnnotated (Annotated term' ann) -> TermAnnotated <$> (Annotated <$> expand args recurse term' <*> pure ann) TermApplication (Application lhs rhs) -> do rhs' <- expandLambdas rhs expand (rhs':args) recurse lhs TermFunction f -> case f of FunctionCompareTo _ -> pad args 1 <$> recurse term FunctionElimination _ -> pad args 1 <$> recurse term FunctionLambda _ -> passThrough FunctionPrimitive name -> do prim <- requirePrimitiveFunction name return $ pad args (primitiveFunctionArity prim) term _ -> passThrough where passThrough = pad args 0 <$> recurse term pad args arity term = L.foldl lam (L.foldl app term args') $ L.reverse variables where variables = L.take (max 0 (arity - L.length args)) ((\i -> Variable $ "v" ++ show i) <$> [1..]) args' = args ++ (TermVariable <$> variables) app lhs rhs = TermApplication $ Application lhs rhs lam body v = TermFunction $ FunctionLambda $ Lambda v body foldOverTerm :: TraversalOrder -> (a -> Term m -> a) -> a -> Term m -> a foldOverTerm order fld b0 term = case order of TraversalOrderPre -> L.foldl (foldOverTerm order fld) (fld b0 term) children TraversalOrderPost -> fld (L.foldl (foldOverTerm order fld) b0 children) term where children = subterms term foldOverType :: TraversalOrder -> (a -> Type m -> a) -> a -> Type m -> a foldOverType order fld b0 typ = case order of TraversalOrderPre -> L.foldl (foldOverType order fld) (fld b0 typ) children TraversalOrderPost -> fld (L.foldl (foldOverType order fld) b0 children) typ where children = subtypes typ freeVariablesInScheme :: Show m => TypeScheme m -> S.Set VariableType freeVariablesInScheme (TypeScheme vars t) = S.difference (freeVariablesInType t) (S.fromList vars) freeVariablesInTerm :: Term m -> S.Set Variable freeVariablesInTerm term = case term of TermAnnotated (Annotated term1 _) -> freeVariablesInTerm term1 TermFunction (FunctionLambda (Lambda var body)) -> S.delete var $ freeVariablesInTerm body TermVariable v -> S.fromList [v] _ -> L.foldl (\s t -> S.union s $ freeVariablesInTerm t) S.empty $ subterms term freeVariablesInType :: Type m -> S.Set VariableType freeVariablesInType = foldOverType TraversalOrderPost fld S.empty where fld vars typ = case typ of TypeVariable v -> S.insert v vars _ -> vars moduleDependencyNamespaces :: Bool -> Bool -> Bool -> Module m -> S.Set Namespace moduleDependencyNamespaces withEls withPrims withNoms mod = S.delete (moduleNamespace mod) names where names = S.fromList (namespaceOfEager <$> S.toList elNames) elNames = L.foldl (\s t -> S.union s $ termDependencyNames withEls withPrims withNoms t) S.empty $ (elementData <$> moduleElements mod) ++ (elementSchema <$> moduleElements mod) isFreeIn :: Variable -> Term m -> Bool isFreeIn v term = not $ S.member v $ freeVariablesInTerm term -- | Recursively remove term annotations, including within subterms removeTermAnnotations :: Ord m => Term m -> Term m removeTermAnnotations = rewriteTerm remove id where remove recurse term = case term of TermAnnotated (Annotated term' _) -> remove recurse term' _ -> recurse term -- | Recursively remove type annotations, including within subtypes removeTypeAnnotations :: Ord m => Type m -> Type m removeTypeAnnotations = rewriteType remove id where remove recurse typ = case recurse typ of TypeAnnotated (Annotated typ' _) -> remove recurse typ' _ -> recurse typ replaceFreeVariableType :: Ord m => VariableType -> Type m -> Type m -> Type m replaceFreeVariableType v rep = rewriteType mapExpr id where mapExpr recurse t = case t of TypeLambda (LambdaType v' body) -> if v == v' then t else TypeLambda $ LambdaType v' $ recurse body TypeVariable v' -> if v == v' then rep else t _ -> recurse t rewrite :: ((a -> b) -> a -> b) -> ((a -> b) -> a -> b) -> a -> b rewrite fsub f = recurse where recurse = f (fsub recurse) rewriteTerm :: Ord b => ((Term a -> Term b) -> Term a -> Term b) -> (a -> b) -> Term a -> Term b rewriteTerm f mf = rewrite fsub f where fsub recurse term = case term of TermAnnotated (Annotated ex ann) -> TermAnnotated $ Annotated (recurse ex) (mf ann) TermApplication (Application lhs rhs) -> TermApplication $ Application (recurse lhs) (recurse rhs) TermElement name -> TermElement name TermFunction fun -> TermFunction $ case fun of FunctionCompareTo other -> FunctionCompareTo $ recurse other FunctionElimination e -> FunctionElimination $ case e of EliminationElement -> EliminationElement EliminationNominal name -> EliminationNominal name EliminationOptional (OptionalCases nothing just) -> EliminationOptional (OptionalCases (recurse nothing) (recurse just)) EliminationRecord p -> EliminationRecord p EliminationUnion (CaseStatement n cases) -> EliminationUnion $ CaseStatement n (forField <$> cases) FunctionLambda (Lambda v body) -> FunctionLambda $ Lambda v $ recurse body FunctionPrimitive name -> FunctionPrimitive name TermLet (Let v t1 t2) -> TermLet $ Let v (recurse t1) (recurse t2) TermList els -> TermList $ recurse <$> els TermLiteral v -> TermLiteral v TermMap m -> TermMap $ M.fromList $ (\(k, v) -> (recurse k, recurse v)) <$> M.toList m TermNominal (Named name t) -> TermNominal (Named name $ recurse t) TermOptional m -> TermOptional $ recurse <$> m TermProduct tuple -> TermProduct (recurse <$> tuple) TermRecord (Record n fields) -> TermRecord $ Record n $ forField <$> fields TermSet s -> TermSet $ S.fromList $ recurse <$> S.toList s TermSum (Sum i s trm) -> TermSum $ Sum i s $ recurse trm TermUnion (Union n field) -> TermUnion $ Union n $ forField field TermVariable v -> TermVariable v where forField f = f {fieldTerm = recurse (fieldTerm f)} rewriteTermM :: Ord b => ((Term a -> Flow s (Term b)) -> Term a -> (Flow s (Term b))) -> (a -> Flow s b) -> Term a -> Flow s (Term b) rewriteTermM f mf = rewrite fsub f where fsub recurse term = case term of TermAnnotated (Annotated ex ma) -> TermAnnotated <$> (Annotated <$> recurse ex <*> mf ma) TermApplication (Application lhs rhs) -> TermApplication <$> (Application <$> recurse lhs <*> recurse rhs) TermElement name -> pure $ TermElement name TermFunction fun -> TermFunction <$> case fun of FunctionCompareTo other -> FunctionCompareTo <$> recurse other FunctionElimination e -> FunctionElimination <$> case e of EliminationElement -> pure EliminationElement EliminationNominal name -> pure $ EliminationNominal name EliminationOptional (OptionalCases nothing just) -> EliminationOptional <$> (OptionalCases <$> recurse nothing <*> recurse just) EliminationRecord p -> pure $ EliminationRecord p EliminationUnion (CaseStatement n cases) -> EliminationUnion <$> (CaseStatement n <$> (CM.mapM forField cases)) FunctionLambda (Lambda v body) -> FunctionLambda <$> (Lambda v <$> recurse body) FunctionPrimitive name -> pure $ FunctionPrimitive name TermLet (Let v t1 t2) -> TermLet <$> (Let v <$> recurse t1 <*> recurse t2) TermList els -> TermList <$> (CM.mapM recurse els) TermLiteral v -> pure $ TermLiteral v TermMap m -> TermMap <$> (M.fromList <$> CM.mapM forPair (M.toList m)) where forPair (k, v) = do km <- recurse k vm <- recurse v return (km, vm) TermNominal (Named name t) -> TermNominal <$> (Named name <$> recurse t) TermOptional m -> TermOptional <$> (CM.mapM recurse m) TermProduct tuple -> TermProduct <$> (CM.mapM recurse tuple) TermRecord (Record n fields) -> TermRecord <$> (Record n <$> (CM.mapM forField fields)) TermSet s -> TermSet <$> (S.fromList <$> (CM.mapM recurse $ S.toList s)) TermSum (Sum i s trm) -> TermSum <$> (Sum i s <$> recurse trm) TermUnion (Union n field) -> TermUnion <$> (Union n <$> forField field) TermVariable v -> pure $ TermVariable v where forField f = do t <- recurse (fieldTerm f) return f {fieldTerm = t} rewriteTermMeta :: Ord b => (a -> b) -> Term a -> Term b rewriteTermMeta = rewriteTerm mapExpr where mapExpr recurse term = recurse term rewriteType :: ((Type a -> Type b) -> Type a -> Type b) -> (a -> b) -> Type a -> Type b rewriteType f mf = rewrite fsub f where fsub recurse typ = case typ of TypeAnnotated (Annotated t ann) -> TypeAnnotated $ Annotated (recurse t) (mf ann) TypeApplication (ApplicationType lhs rhs) -> TypeApplication $ ApplicationType (recurse lhs) (recurse rhs) TypeElement t -> TypeElement $ recurse t TypeFunction (FunctionType dom cod) -> TypeFunction (FunctionType (recurse dom) (recurse cod)) TypeLambda (LambdaType v b) -> TypeLambda (LambdaType v $ recurse b) TypeList t -> TypeList $ recurse t TypeLiteral lt -> TypeLiteral lt TypeMap (MapType kt vt) -> TypeMap (MapType (recurse kt) (recurse vt)) TypeNominal name -> TypeNominal name TypeOptional t -> TypeOptional $ recurse t TypeProduct types -> TypeProduct (recurse <$> types) TypeRecord (RowType name extends fields) -> TypeRecord $ RowType name extends (forfield <$> fields) TypeSet t -> TypeSet $ recurse t TypeSum types -> TypeSum (recurse <$> types) TypeUnion (RowType name extends fields) -> TypeUnion $ RowType name extends (forfield <$> fields) TypeVariable v -> TypeVariable v where forfield f = f {fieldTypeType = recurse (fieldTypeType f)} rewriteTypeMeta :: (a -> b) -> Type a -> Type b rewriteTypeMeta = rewriteType mapExpr where mapExpr recurse term = recurse term simplifyTerm :: Ord m => Term m -> Term m simplifyTerm = rewriteTerm simplify id where simplify recurse term = recurse $ case stripTerm term of TermApplication (Application lhs rhs) -> case stripTerm lhs of TermFunction (FunctionLambda (Lambda var body)) -> if S.member var (freeVariablesInTerm body) then case stripTerm rhs of TermVariable v -> simplifyTerm $ substituteVariable var v body _ -> term else simplifyTerm body _ -> term _ -> term substituteVariable :: Ord m => Variable -> Variable -> Term m -> Term m substituteVariable from to = rewriteTerm replace id where replace recurse term = case term of TermVariable x -> recurse $ (TermVariable $ if x == from then to else x) TermFunction (FunctionLambda (Lambda var _)) -> if var == from then term else recurse term _ -> recurse term subterms :: Term m -> [Term m] subterms term = case term of TermAnnotated (Annotated t _) -> [t] TermApplication (Application lhs rhs) -> [lhs, rhs] TermFunction f -> case f of FunctionCompareTo other -> [other] FunctionElimination e -> case e of EliminationOptional (OptionalCases nothing just) -> [nothing, just] EliminationUnion (CaseStatement _ cases) -> fieldTerm <$> cases _ -> [] FunctionLambda (Lambda _ body) -> [body] _ -> [] TermLet (Let _ t1 t2) -> [t1, t2] TermList els -> els TermMap m -> L.concat ((\(k, v) -> [k, v]) <$> M.toList m) TermNominal (Named _ t) -> [t] TermOptional m -> Y.maybeToList m TermProduct tuple -> tuple TermRecord (Record n fields) -> fieldTerm <$> fields TermSet s -> S.toList s TermSum (Sum _ _ trm) -> [trm] TermUnion (Union _ field) -> [fieldTerm field] _ -> [] subtypes :: Type m -> [Type m] subtypes typ = case typ of TypeAnnotated (Annotated t _) -> [t] TypeApplication (ApplicationType lhs rhs) -> [lhs, rhs] TypeElement et -> [et] TypeFunction (FunctionType dom cod) -> [dom, cod] TypeLambda (LambdaType v body) -> [body] TypeList lt -> [lt] TypeLiteral _ -> [] TypeMap (MapType kt vt) -> [kt, vt] TypeNominal _ -> [] TypeOptional ot -> [ot] TypeProduct types -> types TypeRecord rt -> fieldTypeType <$> rowTypeFields rt TypeSet st -> [st] TypeSum types -> types TypeUnion rt -> fieldTypeType <$> rowTypeFields rt TypeVariable _ -> [] termDependencyNames :: Bool -> Bool -> Bool -> Term m -> S.Set Name termDependencyNames withEls withPrims withNoms = foldOverTerm TraversalOrderPre addNames S.empty where addNames names term = case term of TermElement name -> if withEls then S.insert name names else names TermFunction (FunctionPrimitive name) -> if withPrims then S.insert name names else names TermNominal (Named name _) -> if withNoms then S.insert name names else names _ -> names topologicalSortElements :: [Element m] -> Maybe [Name] topologicalSortElements els = topologicalSort $ adjlist <$> els where adjlist e = (elementName e, S.toList $ termDependencyNames True True True $ elementData e)