-- | -- AST traversal helpers -- module Language.PureScript.AST.Traversals where import Prelude import Protolude (swap) import Control.Monad ((<=<), (>=>)) import Control.Monad.Trans.State (StateT(..)) import Data.Foldable (fold) import Data.Functor.Identity (runIdentity) import Data.List (mapAccumL) import Data.Maybe (mapMaybe) import Data.List.NonEmpty qualified as NEL import Data.Map qualified as M import Data.Set qualified as S import Language.PureScript.AST.Binders (Binder(..), binderNames) import Language.PureScript.AST.Declarations (CaseAlternative(..), DataConstructorDeclaration(..), Declaration(..), DoNotationElement(..), Expr(..), Guard(..), GuardedExpr(..), TypeDeclarationData(..), TypeInstanceBody(..), pattern ValueDecl, ValueDeclarationData(..), mapTypeInstanceBody, traverseTypeInstanceBody) import Language.PureScript.AST.Literals (Literal(..)) import Language.PureScript.Names (pattern ByNullSourcePos, Ident) import Language.PureScript.Traversals (sndM, sndM', thirdM) import Language.PureScript.TypeClassDictionaries (TypeClassDictionaryInScope(..)) import Language.PureScript.Types (Constraint(..), SourceType, mapConstraintArgs) guardedExprM :: Applicative m => (Guard -> m Guard) -> (Expr -> m Expr) -> GuardedExpr -> m GuardedExpr guardedExprM f g (GuardedExpr guards rhs) = GuardedExpr <$> traverse f guards <*> g rhs mapGuardedExpr :: (Guard -> Guard) -> (Expr -> Expr) -> GuardedExpr -> GuardedExpr mapGuardedExpr f g (GuardedExpr guards rhs) = GuardedExpr (fmap f guards) (g rhs) litM :: Monad m => (a -> m a) -> Literal a -> m (Literal a) litM go (ObjectLiteral as) = ObjectLiteral <$> traverse (sndM go) as litM go (ArrayLiteral as) = ArrayLiteral <$> traverse go as litM _ other = pure other everywhereOnValues :: (Declaration -> Declaration) -> (Expr -> Expr) -> (Binder -> Binder) -> ( Declaration -> Declaration , Expr -> Expr , Binder -> Binder ) everywhereOnValues f g h = (f', g', h') where f' :: Declaration -> Declaration f' (DataBindingGroupDeclaration ds) = f (DataBindingGroupDeclaration (fmap f' ds)) f' (ValueDecl sa name nameKind bs val) = f (ValueDecl sa name nameKind (fmap h' bs) (fmap (mapGuardedExpr handleGuard g') val)) f' (BoundValueDeclaration sa b expr) = f (BoundValueDeclaration sa (h' b) (g' expr)) f' (BindingGroupDeclaration ds) = f (BindingGroupDeclaration (fmap (\(name, nameKind, val) -> (name, nameKind, g' val)) ds)) f' (TypeClassDeclaration sa name args implies deps ds) = f (TypeClassDeclaration sa name args implies deps (fmap f' ds)) f' (TypeInstanceDeclaration sa na ch idx name cs className args ds) = f (TypeInstanceDeclaration sa na ch idx name cs className args (mapTypeInstanceBody (fmap f') ds)) f' other = f other g' :: Expr -> Expr g' (Literal ss l) = g (Literal ss (lit g' l)) g' (UnaryMinus ss v) = g (UnaryMinus ss (g' v)) g' (BinaryNoParens op v1 v2) = g (BinaryNoParens (g' op) (g' v1) (g' v2)) g' (Parens v) = g (Parens (g' v)) g' (Accessor prop v) = g (Accessor prop (g' v)) g' (ObjectUpdate obj vs) = g (ObjectUpdate (g' obj) (fmap (fmap g') vs)) g' (ObjectUpdateNested obj vs) = g (ObjectUpdateNested (g' obj) (fmap g' vs)) g' (Abs binder v) = g (Abs (h' binder) (g' v)) g' (App v1 v2) = g (App (g' v1) (g' v2)) g' (VisibleTypeApp v ty) = g (VisibleTypeApp (g' v) ty) g' (Unused v) = g (Unused (g' v)) g' (IfThenElse v1 v2 v3) = g (IfThenElse (g' v1) (g' v2) (g' v3)) g' (Case vs alts) = g (Case (fmap g' vs) (fmap handleCaseAlternative alts)) g' (TypedValue check v ty) = g (TypedValue check (g' v) ty) g' (Let w ds v) = g (Let w (fmap f' ds) (g' v)) g' (Do m es) = g (Do m (fmap handleDoNotationElement es)) g' (Ado m es v) = g (Ado m (fmap handleDoNotationElement es) (g' v)) g' (PositionedValue pos com v) = g (PositionedValue pos com (g' v)) g' other = g other h' :: Binder -> Binder h' (ConstructorBinder ss ctor bs) = h (ConstructorBinder ss ctor (fmap h' bs)) h' (BinaryNoParensBinder b1 b2 b3) = h (BinaryNoParensBinder (h' b1) (h' b2) (h' b3)) h' (ParensInBinder b) = h (ParensInBinder (h' b)) h' (LiteralBinder ss l) = h (LiteralBinder ss (lit h' l)) h' (NamedBinder ss name b) = h (NamedBinder ss name (h' b)) h' (PositionedBinder pos com b) = h (PositionedBinder pos com (h' b)) h' (TypedBinder t b) = h (TypedBinder t (h' b)) h' other = h other lit :: (a -> a) -> Literal a -> Literal a lit go (ArrayLiteral as) = ArrayLiteral (fmap go as) lit go (ObjectLiteral as) = ObjectLiteral (fmap (fmap go) as) lit _ other = other handleCaseAlternative :: CaseAlternative -> CaseAlternative handleCaseAlternative ca = ca { caseAlternativeBinders = fmap h' (caseAlternativeBinders ca) , caseAlternativeResult = fmap (mapGuardedExpr handleGuard g') (caseAlternativeResult ca) } handleDoNotationElement :: DoNotationElement -> DoNotationElement handleDoNotationElement (DoNotationValue v) = DoNotationValue (g' v) handleDoNotationElement (DoNotationBind b v) = DoNotationBind (h' b) (g' v) handleDoNotationElement (DoNotationLet ds) = DoNotationLet (fmap f' ds) handleDoNotationElement (PositionedDoNotationElement pos com e) = PositionedDoNotationElement pos com (handleDoNotationElement e) handleGuard :: Guard -> Guard handleGuard (ConditionGuard e) = ConditionGuard (g' e) handleGuard (PatternGuard b e) = PatternGuard (h' b) (g' e) everywhereOnValuesTopDownM :: forall m . (Monad m) => (Declaration -> m Declaration) -> (Expr -> m Expr) -> (Binder -> m Binder) -> ( Declaration -> m Declaration , Expr -> m Expr , Binder -> m Binder ) everywhereOnValuesTopDownM f g h = (f' <=< f, g' <=< g, h' <=< h) where f' :: Declaration -> m Declaration f' (DataBindingGroupDeclaration ds) = DataBindingGroupDeclaration <$> traverse (f' <=< f) ds f' (ValueDecl sa name nameKind bs val) = ValueDecl sa name nameKind <$> traverse (h' <=< h) bs <*> traverse (guardedExprM handleGuard (g' <=< g)) val f' (BindingGroupDeclaration ds) = BindingGroupDeclaration <$> traverse (\(name, nameKind, val) -> (name, nameKind, ) <$> (g val >>= g')) ds f' (TypeClassDeclaration sa name args implies deps ds) = TypeClassDeclaration sa name args implies deps <$> traverse (f' <=< f) ds f' (TypeInstanceDeclaration sa na ch idx name cs className args ds) = TypeInstanceDeclaration sa na ch idx name cs className args <$> traverseTypeInstanceBody (traverse (f' <=< f)) ds f' (BoundValueDeclaration sa b expr) = BoundValueDeclaration sa <$> (h' <=< h) b <*> (g' <=< g) expr f' other = f other g' :: Expr -> m Expr g' (Literal ss l) = Literal ss <$> litM (g >=> g') l g' (UnaryMinus ss v) = UnaryMinus ss <$> (g v >>= g') g' (BinaryNoParens op v1 v2) = BinaryNoParens <$> (g op >>= g') <*> (g v1 >>= g') <*> (g v2 >>= g') g' (Parens v) = Parens <$> (g v >>= g') g' (Accessor prop v) = Accessor prop <$> (g v >>= g') g' (ObjectUpdate obj vs) = ObjectUpdate <$> (g obj >>= g') <*> traverse (sndM (g' <=< g)) vs g' (ObjectUpdateNested obj vs) = ObjectUpdateNested <$> (g obj >>= g') <*> traverse (g' <=< g) vs g' (Abs binder v) = Abs <$> (h binder >>= h') <*> (g v >>= g') g' (App v1 v2) = App <$> (g v1 >>= g') <*> (g v2 >>= g') g' (VisibleTypeApp v ty) = VisibleTypeApp <$> (g v >>= g') <*> pure ty g' (Unused v) = Unused <$> (g v >>= g') g' (IfThenElse v1 v2 v3) = IfThenElse <$> (g v1 >>= g') <*> (g v2 >>= g') <*> (g v3 >>= g') g' (Case vs alts) = Case <$> traverse (g' <=< g) vs <*> traverse handleCaseAlternative alts g' (TypedValue check v ty) = TypedValue check <$> (g v >>= g') <*> pure ty g' (Let w ds v) = Let w <$> traverse (f' <=< f) ds <*> (g v >>= g') g' (Do m es) = Do m <$> traverse handleDoNotationElement es g' (Ado m es v) = Ado m <$> traverse handleDoNotationElement es <*> (g v >>= g') g' (PositionedValue pos com v) = PositionedValue pos com <$> (g v >>= g') g' other = g other h' :: Binder -> m Binder h' (LiteralBinder ss l) = LiteralBinder ss <$> litM (h >=> h') l h' (ConstructorBinder ss ctor bs) = ConstructorBinder ss ctor <$> traverse (h' <=< h) bs h' (BinaryNoParensBinder b1 b2 b3) = BinaryNoParensBinder <$> (h b1 >>= h') <*> (h b2 >>= h') <*> (h b3 >>= h') h' (ParensInBinder b) = ParensInBinder <$> (h b >>= h') h' (NamedBinder ss name b) = NamedBinder ss name <$> (h b >>= h') h' (PositionedBinder pos com b) = PositionedBinder pos com <$> (h b >>= h') h' (TypedBinder t b) = TypedBinder t <$> (h b >>= h') h' other = h other handleCaseAlternative :: CaseAlternative -> m CaseAlternative handleCaseAlternative (CaseAlternative bs val) = CaseAlternative <$> traverse (h' <=< h) bs <*> traverse (guardedExprM handleGuard (g' <=< g)) val handleDoNotationElement :: DoNotationElement -> m DoNotationElement handleDoNotationElement (DoNotationValue v) = DoNotationValue <$> (g' <=< g) v handleDoNotationElement (DoNotationBind b v) = DoNotationBind <$> (h' <=< h) b <*> (g' <=< g) v handleDoNotationElement (DoNotationLet ds) = DoNotationLet <$> traverse (f' <=< f) ds handleDoNotationElement (PositionedDoNotationElement pos com e) = PositionedDoNotationElement pos com <$> handleDoNotationElement e handleGuard :: Guard -> m Guard handleGuard (ConditionGuard e) = ConditionGuard <$> (g' <=< g) e handleGuard (PatternGuard b e) = PatternGuard <$> (h' <=< h) b <*> (g' <=< g) e everywhereOnValuesM :: forall m . (Monad m) => (Declaration -> m Declaration) -> (Expr -> m Expr) -> (Binder -> m Binder) -> ( Declaration -> m Declaration , Expr -> m Expr , Binder -> m Binder ) everywhereOnValuesM f g h = (f', g', h') where f' :: Declaration -> m Declaration f' (DataBindingGroupDeclaration ds) = (DataBindingGroupDeclaration <$> traverse f' ds) >>= f f' (ValueDecl sa name nameKind bs val) = ValueDecl sa name nameKind <$> traverse h' bs <*> traverse (guardedExprM handleGuard g') val >>= f f' (BindingGroupDeclaration ds) = (BindingGroupDeclaration <$> traverse (\(name, nameKind, val) -> (name, nameKind, ) <$> g' val) ds) >>= f f' (BoundValueDeclaration sa b expr) = (BoundValueDeclaration sa <$> h' b <*> g' expr) >>= f f' (TypeClassDeclaration sa name args implies deps ds) = (TypeClassDeclaration sa name args implies deps <$> traverse f' ds) >>= f f' (TypeInstanceDeclaration sa na ch idx name cs className args ds) = (TypeInstanceDeclaration sa na ch idx name cs className args <$> traverseTypeInstanceBody (traverse f') ds) >>= f f' other = f other g' :: Expr -> m Expr g' (Literal ss l) = (Literal ss <$> litM g' l) >>= g g' (UnaryMinus ss v) = (UnaryMinus ss <$> g' v) >>= g g' (BinaryNoParens op v1 v2) = (BinaryNoParens <$> g' op <*> g' v1 <*> g' v2) >>= g g' (Parens v) = (Parens <$> g' v) >>= g g' (Accessor prop v) = (Accessor prop <$> g' v) >>= g g' (ObjectUpdate obj vs) = (ObjectUpdate <$> g' obj <*> traverse (sndM g') vs) >>= g g' (ObjectUpdateNested obj vs) = (ObjectUpdateNested <$> g' obj <*> traverse g' vs) >>= g g' (Abs binder v) = (Abs <$> h' binder <*> g' v) >>= g g' (App v1 v2) = (App <$> g' v1 <*> g' v2) >>= g g' (VisibleTypeApp v ty) = (VisibleTypeApp <$> g' v <*> pure ty) >>= g g' (Unused v) = (Unused <$> g' v) >>= g g' (IfThenElse v1 v2 v3) = (IfThenElse <$> g' v1 <*> g' v2 <*> g' v3) >>= g g' (Case vs alts) = (Case <$> traverse g' vs <*> traverse handleCaseAlternative alts) >>= g g' (TypedValue check v ty) = (TypedValue check <$> g' v <*> pure ty) >>= g g' (Let w ds v) = (Let w <$> traverse f' ds <*> g' v) >>= g g' (Do m es) = (Do m <$> traverse handleDoNotationElement es) >>= g g' (Ado m es v) = (Ado m <$> traverse handleDoNotationElement es <*> g' v) >>= g g' (PositionedValue pos com v) = (PositionedValue pos com <$> g' v) >>= g g' other = g other h' :: Binder -> m Binder h' (LiteralBinder ss l) = (LiteralBinder ss <$> litM h' l) >>= h h' (ConstructorBinder ss ctor bs) = (ConstructorBinder ss ctor <$> traverse h' bs) >>= h h' (BinaryNoParensBinder b1 b2 b3) = (BinaryNoParensBinder <$> h' b1 <*> h' b2 <*> h' b3) >>= h h' (ParensInBinder b) = (ParensInBinder <$> h' b) >>= h h' (NamedBinder ss name b) = (NamedBinder ss name <$> h' b) >>= h h' (PositionedBinder pos com b) = (PositionedBinder pos com <$> h' b) >>= h h' (TypedBinder t b) = (TypedBinder t <$> h' b) >>= h h' other = h other handleCaseAlternative :: CaseAlternative -> m CaseAlternative handleCaseAlternative (CaseAlternative bs val) = CaseAlternative <$> traverse h' bs <*> traverse (guardedExprM handleGuard g') val handleDoNotationElement :: DoNotationElement -> m DoNotationElement handleDoNotationElement (DoNotationValue v) = DoNotationValue <$> g' v handleDoNotationElement (DoNotationBind b v) = DoNotationBind <$> h' b <*> g' v handleDoNotationElement (DoNotationLet ds) = DoNotationLet <$> traverse f' ds handleDoNotationElement (PositionedDoNotationElement pos com e) = PositionedDoNotationElement pos com <$> handleDoNotationElement e handleGuard :: Guard -> m Guard handleGuard (ConditionGuard e) = ConditionGuard <$> g' e handleGuard (PatternGuard b e) = PatternGuard <$> h' b <*> g' e everythingOnValues :: forall r . (r -> r -> r) -> (Declaration -> r) -> (Expr -> r) -> (Binder -> r) -> (CaseAlternative -> r) -> (DoNotationElement -> r) -> ( Declaration -> r , Expr -> r , Binder -> r , CaseAlternative -> r , DoNotationElement -> r ) everythingOnValues (<>.) f g h i j = (f', g', h', i', j') where f' :: Declaration -> r f' d@(DataBindingGroupDeclaration ds) = foldl (<>.) (f d) (fmap f' ds) f' d@(ValueDeclaration vd) = foldl (<>.) (f d) (fmap h' (valdeclBinders vd) ++ concatMap (\(GuardedExpr grd v) -> fmap k' grd ++ [g' v]) (valdeclExpression vd)) f' d@(BindingGroupDeclaration ds) = foldl (<>.) (f d) (fmap (\(_, _, val) -> g' val) ds) f' d@(TypeClassDeclaration _ _ _ _ _ ds) = foldl (<>.) (f d) (fmap f' ds) f' d@(TypeInstanceDeclaration _ _ _ _ _ _ _ _ (ExplicitInstance ds)) = foldl (<>.) (f d) (fmap f' ds) f' d@(BoundValueDeclaration _ b expr) = f d <>. h' b <>. g' expr f' d = f d g' :: Expr -> r g' v@(Literal _ l) = lit (g v) g' l g' v@(UnaryMinus _ v1) = g v <>. g' v1 g' v@(BinaryNoParens op v1 v2) = g v <>. g' op <>. g' v1 <>. g' v2 g' v@(Parens v1) = g v <>. g' v1 g' v@(Accessor _ v1) = g v <>. g' v1 g' v@(ObjectUpdate obj vs) = foldl (<>.) (g v <>. g' obj) (fmap (g' . snd) vs) g' v@(ObjectUpdateNested obj vs) = foldl (<>.) (g v <>. g' obj) (fmap g' vs) g' v@(Abs b v1) = g v <>. h' b <>. g' v1 g' v@(App v1 v2) = g v <>. g' v1 <>. g' v2 g' v@(VisibleTypeApp v' _) = g v <>. g' v' g' v@(Unused v1) = g v <>. g' v1 g' v@(IfThenElse v1 v2 v3) = g v <>. g' v1 <>. g' v2 <>. g' v3 g' v@(Case vs alts) = foldl (<>.) (foldl (<>.) (g v) (fmap g' vs)) (fmap i' alts) g' v@(TypedValue _ v1 _) = g v <>. g' v1 g' v@(Let _ ds v1) = foldl (<>.) (g v) (fmap f' ds) <>. g' v1 g' v@(Do _ es) = foldl (<>.) (g v) (fmap j' es) g' v@(Ado _ es v1) = foldl (<>.) (g v) (fmap j' es) <>. g' v1 g' v@(PositionedValue _ _ v1) = g v <>. g' v1 g' v = g v h' :: Binder -> r h' b@(LiteralBinder _ l) = lit (h b) h' l h' b@(ConstructorBinder _ _ bs) = foldl (<>.) (h b) (fmap h' bs) h' b@(BinaryNoParensBinder b1 b2 b3) = h b <>. h' b1 <>. h' b2 <>. h' b3 h' b@(ParensInBinder b1) = h b <>. h' b1 h' b@(NamedBinder _ _ b1) = h b <>. h' b1 h' b@(PositionedBinder _ _ b1) = h b <>. h' b1 h' b@(TypedBinder _ b1) = h b <>. h' b1 h' b = h b lit :: r -> (a -> r) -> Literal a -> r lit r go (ArrayLiteral as) = foldl (<>.) r (fmap go as) lit r go (ObjectLiteral as) = foldl (<>.) r (fmap (go . snd) as) lit r _ _ = r i' :: CaseAlternative -> r i' ca@(CaseAlternative bs gs) = foldl (<>.) (i ca) (fmap h' bs ++ concatMap (\(GuardedExpr grd val) -> fmap k' grd ++ [g' val]) gs) j' :: DoNotationElement -> r j' e@(DoNotationValue v) = j e <>. g' v j' e@(DoNotationBind b v) = j e <>. h' b <>. g' v j' e@(DoNotationLet ds) = foldl (<>.) (j e) (fmap f' ds) j' e@(PositionedDoNotationElement _ _ e1) = j e <>. j' e1 k' :: Guard -> r k' (ConditionGuard e) = g' e k' (PatternGuard b e) = h' b <>. g' e everythingWithContextOnValues :: forall s r . s -> r -> (r -> r -> r) -> (s -> Declaration -> (s, r)) -> (s -> Expr -> (s, r)) -> (s -> Binder -> (s, r)) -> (s -> CaseAlternative -> (s, r)) -> (s -> DoNotationElement -> (s, r)) -> ( Declaration -> r , Expr -> r , Binder -> r , CaseAlternative -> r , DoNotationElement -> r) everythingWithContextOnValues s0 r0 (<>.) f g h i j = (f'' s0, g'' s0, h'' s0, i'' s0, j'' s0) where f'' :: s -> Declaration -> r f'' s d = let (s', r) = f s d in r <>. f' s' d f' :: s -> Declaration -> r f' s (DataBindingGroupDeclaration ds) = foldl (<>.) r0 (fmap (f'' s) ds) f' s (ValueDeclaration vd) = foldl (<>.) r0 (fmap (h'' s) (valdeclBinders vd) ++ concatMap (\(GuardedExpr grd v) -> fmap (k' s) grd ++ [g'' s v]) (valdeclExpression vd)) f' s (BindingGroupDeclaration ds) = foldl (<>.) r0 (fmap (\(_, _, val) -> g'' s val) ds) f' s (TypeClassDeclaration _ _ _ _ _ ds) = foldl (<>.) r0 (fmap (f'' s) ds) f' s (TypeInstanceDeclaration _ _ _ _ _ _ _ _ (ExplicitInstance ds)) = foldl (<>.) r0 (fmap (f'' s) ds) f' _ _ = r0 g'' :: s -> Expr -> r g'' s v = let (s', r) = g s v in r <>. g' s' v g' :: s -> Expr -> r g' s (Literal _ l) = lit g'' s l g' s (UnaryMinus _ v1) = g'' s v1 g' s (BinaryNoParens op v1 v2) = g'' s op <>. g'' s v1 <>. g'' s v2 g' s (Parens v1) = g'' s v1 g' s (Accessor _ v1) = g'' s v1 g' s (ObjectUpdate obj vs) = foldl (<>.) (g'' s obj) (fmap (g'' s . snd) vs) g' s (ObjectUpdateNested obj vs) = foldl (<>.) (g'' s obj) (fmap (g'' s) vs) g' s (Abs binder v1) = h'' s binder <>. g'' s v1 g' s (App v1 v2) = g'' s v1 <>. g'' s v2 g' s (VisibleTypeApp v _) = g'' s v g' s (Unused v) = g'' s v g' s (IfThenElse v1 v2 v3) = g'' s v1 <>. g'' s v2 <>. g'' s v3 g' s (Case vs alts) = foldl (<>.) (foldl (<>.) r0 (fmap (g'' s) vs)) (fmap (i'' s) alts) g' s (TypedValue _ v1 _) = g'' s v1 g' s (Let _ ds v1) = foldl (<>.) r0 (fmap (f'' s) ds) <>. g'' s v1 g' s (Do _ es) = foldl (<>.) r0 (fmap (j'' s) es) g' s (Ado _ es v1) = foldl (<>.) r0 (fmap (j'' s) es) <>. g'' s v1 g' s (PositionedValue _ _ v1) = g'' s v1 g' _ _ = r0 h'' :: s -> Binder -> r h'' s b = let (s', r) = h s b in r <>. h' s' b h' :: s -> Binder -> r h' s (LiteralBinder _ l) = lit h'' s l h' s (ConstructorBinder _ _ bs) = foldl (<>.) r0 (fmap (h'' s) bs) h' s (BinaryNoParensBinder b1 b2 b3) = h'' s b1 <>. h'' s b2 <>. h'' s b3 h' s (ParensInBinder b) = h'' s b h' s (NamedBinder _ _ b1) = h'' s b1 h' s (PositionedBinder _ _ b1) = h'' s b1 h' s (TypedBinder _ b1) = h'' s b1 h' _ _ = r0 lit :: (s -> a -> r) -> s -> Literal a -> r lit go s (ArrayLiteral as) = foldl (<>.) r0 (fmap (go s) as) lit go s (ObjectLiteral as) = foldl (<>.) r0 (fmap (go s . snd) as) lit _ _ _ = r0 i'' :: s -> CaseAlternative -> r i'' s ca = let (s', r) = i s ca in r <>. i' s' ca i' :: s -> CaseAlternative -> r i' s (CaseAlternative bs gs) = foldl (<>.) r0 (fmap (h'' s) bs ++ concatMap (\(GuardedExpr grd val) -> fmap (k' s) grd ++ [g'' s val]) gs) j'' :: s -> DoNotationElement -> r j'' s e = let (s', r) = j s e in r <>. j' s' e j' :: s -> DoNotationElement -> r j' s (DoNotationValue v) = g'' s v j' s (DoNotationBind b v) = h'' s b <>. g'' s v j' s (DoNotationLet ds) = foldl (<>.) r0 (fmap (f'' s) ds) j' s (PositionedDoNotationElement _ _ e1) = j'' s e1 k' :: s -> Guard -> r k' s (ConditionGuard e) = g'' s e k' s (PatternGuard b e) = h'' s b <>. g'' s e everywhereWithContextOnValues :: forall s . s -> (s -> Declaration -> (s, Declaration)) -> (s -> Expr -> (s, Expr)) -> (s -> Binder -> (s, Binder)) -> (s -> CaseAlternative -> (s, CaseAlternative)) -> (s -> DoNotationElement -> (s, DoNotationElement)) -> (s -> Guard -> (s, Guard)) -> ( Declaration -> Declaration , Expr -> Expr , Binder -> Binder , CaseAlternative -> CaseAlternative , DoNotationElement -> DoNotationElement , Guard -> Guard ) everywhereWithContextOnValues s f g h i j k = (runIdentity . f', runIdentity . g', runIdentity . h', runIdentity . i', runIdentity . j', runIdentity . k') where (f', g', h', i', j', k') = everywhereWithContextOnValuesM s (wrap f) (wrap g) (wrap h) (wrap i) (wrap j) (wrap k) wrap = ((pure .) .) everywhereWithContextOnValuesM :: forall m s . (Monad m) => s -> (s -> Declaration -> m (s, Declaration)) -> (s -> Expr -> m (s, Expr)) -> (s -> Binder -> m (s, Binder)) -> (s -> CaseAlternative -> m (s, CaseAlternative)) -> (s -> DoNotationElement -> m (s, DoNotationElement)) -> (s -> Guard -> m (s, Guard)) -> ( Declaration -> m Declaration , Expr -> m Expr , Binder -> m Binder , CaseAlternative -> m CaseAlternative , DoNotationElement -> m DoNotationElement , Guard -> m Guard ) everywhereWithContextOnValuesM s0 f g h i j k = (f'' s0, g'' s0, h'' s0, i'' s0, j'' s0, k'' s0) where f'' s = uncurry f' <=< f s f' s (DataBindingGroupDeclaration ds) = DataBindingGroupDeclaration <$> traverse (f'' s) ds f' s (ValueDecl sa name nameKind bs val) = ValueDecl sa name nameKind <$> traverse (h'' s) bs <*> traverse (guardedExprM (k' s) (g'' s)) val f' s (BindingGroupDeclaration ds) = BindingGroupDeclaration <$> traverse (thirdM (g'' s)) ds f' s (TypeClassDeclaration sa name args implies deps ds) = TypeClassDeclaration sa name args implies deps <$> traverse (f'' s) ds f' s (TypeInstanceDeclaration sa na ch idx name cs className args ds) = TypeInstanceDeclaration sa na ch idx name cs className args <$> traverseTypeInstanceBody (traverse (f'' s)) ds f' _ other = return other g'' s = uncurry g' <=< g s g' s (Literal ss l) = Literal ss <$> lit g'' s l g' s (UnaryMinus ss v) = UnaryMinus ss <$> g'' s v g' s (BinaryNoParens op v1 v2) = BinaryNoParens <$> g'' s op <*> g'' s v1 <*> g'' s v2 g' s (Parens v) = Parens <$> g'' s v g' s (Accessor prop v) = Accessor prop <$> g'' s v g' s (ObjectUpdate obj vs) = ObjectUpdate <$> g'' s obj <*> traverse (sndM (g'' s)) vs g' s (ObjectUpdateNested obj vs) = ObjectUpdateNested <$> g'' s obj <*> traverse (g'' s) vs g' s (Abs binder v) = Abs <$> h' s binder <*> g'' s v g' s (App v1 v2) = App <$> g'' s v1 <*> g'' s v2 g' s (VisibleTypeApp v ty) = VisibleTypeApp <$> g'' s v <*> pure ty g' s (Unused v) = Unused <$> g'' s v g' s (IfThenElse v1 v2 v3) = IfThenElse <$> g'' s v1 <*> g'' s v2 <*> g'' s v3 g' s (Case vs alts) = Case <$> traverse (g'' s) vs <*> traverse (i'' s) alts g' s (TypedValue check v ty) = TypedValue check <$> g'' s v <*> pure ty g' s (Let w ds v) = Let w <$> traverse (f'' s) ds <*> g'' s v g' s (Do m es) = Do m <$> traverse (j'' s) es g' s (Ado m es v) = Ado m <$> traverse (j'' s) es <*> g'' s v g' s (PositionedValue pos com v) = PositionedValue pos com <$> g'' s v g' _ other = return other h'' s = uncurry h' <=< h s h' s (LiteralBinder ss l) = LiteralBinder ss <$> lit h'' s l h' s (ConstructorBinder ss ctor bs) = ConstructorBinder ss ctor <$> traverse (h'' s) bs h' s (BinaryNoParensBinder b1 b2 b3) = BinaryNoParensBinder <$> h'' s b1 <*> h'' s b2 <*> h'' s b3 h' s (ParensInBinder b) = ParensInBinder <$> h'' s b h' s (NamedBinder ss name b) = NamedBinder ss name <$> h'' s b h' s (PositionedBinder pos com b) = PositionedBinder pos com <$> h'' s b h' s (TypedBinder t b) = TypedBinder t <$> h'' s b h' _ other = return other lit :: (s -> a -> m a) -> s -> Literal a -> m (Literal a) lit go s (ArrayLiteral as) = ArrayLiteral <$> traverse (go s) as lit go s (ObjectLiteral as) = ObjectLiteral <$> traverse (sndM (go s)) as lit _ _ other = return other i'' s = uncurry i' <=< i s i' s (CaseAlternative bs val) = CaseAlternative <$> traverse (h'' s) bs <*> traverse (guardedExprM' s) val -- A specialized `guardedExprM` that keeps track of the context `s` -- after traversing `guards`, such that it's also exposed to `expr`. guardedExprM' :: s -> GuardedExpr -> m GuardedExpr guardedExprM' s (GuardedExpr guards expr) = do (guards', s') <- runStateT (traverse (StateT . goGuard) guards) s GuardedExpr guards' <$> g'' s' expr -- Like k'', but `s` is tracked. goGuard :: Guard -> s -> m (Guard, s) goGuard x s = k s x >>= fmap swap . sndM' k' j'' s = uncurry j' <=< j s j' s (DoNotationValue v) = DoNotationValue <$> g'' s v j' s (DoNotationBind b v) = DoNotationBind <$> h'' s b <*> g'' s v j' s (DoNotationLet ds) = DoNotationLet <$> traverse (f'' s) ds j' s (PositionedDoNotationElement pos com e1) = PositionedDoNotationElement pos com <$> j'' s e1 k'' s = uncurry k' <=< k s k' s (ConditionGuard e) = ConditionGuard <$> g'' s e k' s (PatternGuard b e) = PatternGuard <$> h'' s b <*> g'' s e data ScopedIdent = LocalIdent Ident | ToplevelIdent Ident deriving (Show, Eq, Ord) inScope :: Ident -> S.Set ScopedIdent -> Bool inScope i s = (LocalIdent i `S.member` s) || (ToplevelIdent i `S.member` s) everythingWithScope :: forall r . (Monoid r) => (S.Set ScopedIdent -> Declaration -> r) -> (S.Set ScopedIdent -> Expr -> r) -> (S.Set ScopedIdent -> Binder -> r) -> (S.Set ScopedIdent -> CaseAlternative -> r) -> (S.Set ScopedIdent -> DoNotationElement -> r) -> ( S.Set ScopedIdent -> Declaration -> r , S.Set ScopedIdent -> Expr -> r , S.Set ScopedIdent -> Binder -> r , S.Set ScopedIdent -> CaseAlternative -> r , S.Set ScopedIdent -> DoNotationElement -> r ) everythingWithScope f g h i j = (f'', g'', h'', i'', \s -> snd . j'' s) where f'' :: S.Set ScopedIdent -> Declaration -> r f'' s a = f s a <> f' s a f' :: S.Set ScopedIdent -> Declaration -> r f' s (DataBindingGroupDeclaration ds) = let s' = S.union s (S.fromList (map ToplevelIdent (mapMaybe getDeclIdent (NEL.toList ds)))) in foldMap (f'' s') ds f' s (ValueDecl _ name _ bs val) = let s' = S.insert (ToplevelIdent name) s s'' = S.union s' (S.fromList (concatMap localBinderNames bs)) in foldMap (h'' s') bs <> foldMap (l' s'') val f' s (BindingGroupDeclaration ds) = let s' = S.union s (S.fromList (NEL.toList (fmap (\((_, name), _, _) -> ToplevelIdent name) ds))) in foldMap (\(_, _, val) -> g'' s' val) ds f' s (TypeClassDeclaration _ _ _ _ _ ds) = foldMap (f'' s) ds f' s (TypeInstanceDeclaration _ _ _ _ _ _ _ _ (ExplicitInstance ds)) = foldMap (f'' s) ds f' _ _ = mempty g'' :: S.Set ScopedIdent -> Expr -> r g'' s a = g s a <> g' s a g' :: S.Set ScopedIdent -> Expr -> r g' s (Literal _ l) = lit g'' s l g' s (UnaryMinus _ v1) = g'' s v1 g' s (BinaryNoParens op v1 v2) = g'' s op <> g'' s v1 <> g'' s v2 g' s (Parens v1) = g'' s v1 g' s (Accessor _ v1) = g'' s v1 g' s (ObjectUpdate obj vs) = g'' s obj <> foldMap (g'' s . snd) vs g' s (ObjectUpdateNested obj vs) = g'' s obj <> foldMap (g'' s) vs g' s (Abs b v1) = let s' = S.union (S.fromList (localBinderNames b)) s in h'' s b <> g'' s' v1 g' s (App v1 v2) = g'' s v1 <> g'' s v2 g' s (VisibleTypeApp v _) = g'' s v g' s (Unused v) = g'' s v g' s (IfThenElse v1 v2 v3) = g'' s v1 <> g'' s v2 <> g'' s v3 g' s (Case vs alts) = foldMap (g'' s) vs <> foldMap (i'' s) alts g' s (TypedValue _ v1 _) = g'' s v1 g' s (Let _ ds v1) = let s' = S.union s (S.fromList (map LocalIdent (mapMaybe getDeclIdent ds))) in foldMap (f'' s') ds <> g'' s' v1 g' s (Do _ es) = fold . snd . mapAccumL j'' s $ es g' s (Ado _ es v1) = let s' = S.union s (foldMap (fst . j'' s) es) in g'' s' v1 g' s (PositionedValue _ _ v1) = g'' s v1 g' _ _ = mempty h'' :: S.Set ScopedIdent -> Binder -> r h'' s a = h s a <> h' s a h' :: S.Set ScopedIdent -> Binder -> r h' s (LiteralBinder _ l) = lit h'' s l h' s (ConstructorBinder _ _ bs) = foldMap (h'' s) bs h' s (BinaryNoParensBinder b1 b2 b3) = foldMap (h'' s) [b1, b2, b3] h' s (ParensInBinder b) = h'' s b h' s (NamedBinder _ name b1) = h'' (S.insert (LocalIdent name) s) b1 h' s (PositionedBinder _ _ b1) = h'' s b1 h' s (TypedBinder _ b1) = h'' s b1 h' _ _ = mempty lit :: (S.Set ScopedIdent -> a -> r) -> S.Set ScopedIdent -> Literal a -> r lit go s (ArrayLiteral as) = foldMap (go s) as lit go s (ObjectLiteral as) = foldMap (go s . snd) as lit _ _ _ = mempty i'' :: S.Set ScopedIdent -> CaseAlternative -> r i'' s a = i s a <> i' s a i' :: S.Set ScopedIdent -> CaseAlternative -> r i' s (CaseAlternative bs gs) = let s' = S.union s (S.fromList (concatMap localBinderNames bs)) in foldMap (h'' s) bs <> foldMap (l' s') gs j'' :: S.Set ScopedIdent -> DoNotationElement -> (S.Set ScopedIdent, r) j'' s a = let (s', r) = j' s a in (s', j s a <> r) j' :: S.Set ScopedIdent -> DoNotationElement -> (S.Set ScopedIdent, r) j' s (DoNotationValue v) = (s, g'' s v) j' s (DoNotationBind b v) = let s' = S.union (S.fromList (localBinderNames b)) s in (s', h'' s b <> g'' s v) j' s (DoNotationLet ds) = let s' = S.union s (S.fromList (map LocalIdent (mapMaybe getDeclIdent ds))) in (s', foldMap (f'' s') ds) j' s (PositionedDoNotationElement _ _ e1) = j'' s e1 k' :: S.Set ScopedIdent -> Guard -> (S.Set ScopedIdent, r) k' s (ConditionGuard e) = (s, g'' s e) k' s (PatternGuard b e) = let s' = S.union (S.fromList (localBinderNames b)) s in (s', h'' s b <> g'' s' e) l' s (GuardedExpr [] e) = g'' s e l' s (GuardedExpr (grd:gs) e) = let (s', r) = k' s grd in r <> l' s' (GuardedExpr gs e) getDeclIdent :: Declaration -> Maybe Ident getDeclIdent (ValueDeclaration vd) = Just (valdeclIdent vd) getDeclIdent (TypeDeclaration td) = Just (tydeclIdent td) getDeclIdent _ = Nothing localBinderNames = map LocalIdent . binderNames accumTypes :: (Monoid r) => (SourceType -> r) -> ( Declaration -> r , Expr -> r , Binder -> r , CaseAlternative -> r , DoNotationElement -> r ) accumTypes f = everythingOnValues mappend forDecls forValues forBinders (const mempty) (const mempty) where forDecls (DataDeclaration _ _ _ args dctors) = foldMap (foldMap f . snd) args <> foldMap (foldMap (f . snd) . dataCtorFields) dctors forDecls (ExternDataDeclaration _ _ ty) = f ty forDecls (ExternDeclaration _ _ ty) = f ty forDecls (TypeClassDeclaration _ _ args implies _ _) = foldMap (foldMap (foldMap f)) args <> foldMap (foldMap f . constraintArgs) implies forDecls (TypeInstanceDeclaration _ _ _ _ _ cs _ tys _) = foldMap (foldMap f . constraintArgs) cs <> foldMap f tys forDecls (TypeSynonymDeclaration _ _ args ty) = foldMap (foldMap f . snd) args <> f ty forDecls (KindDeclaration _ _ _ ty) = f ty forDecls (TypeDeclaration td) = f (tydeclType td) forDecls _ = mempty forValues (TypeClassDictionary c _ _) = foldMap f (constraintArgs c) forValues (DeferredDictionary _ tys) = foldMap f tys forValues (TypedValue _ _ ty) = f ty forValues (VisibleTypeApp _ ty) = f ty forValues _ = mempty forBinders (TypedBinder ty _) = f ty forBinders _ = mempty -- | -- Map a function over type annotations appearing inside a value -- overTypes :: (SourceType -> SourceType) -> Expr -> Expr overTypes f = let (_, f', _) = everywhereOnValues id g id in f' where g :: Expr -> Expr g (TypedValue checkTy val t) = TypedValue checkTy val (f t) g (TypeClassDictionary c sco hints) = TypeClassDictionary (mapConstraintArgs (fmap f) c) (updateCtx sco) hints g other = other updateDict fn dict = dict { tcdInstanceTypes = fn (tcdInstanceTypes dict) } updateScope = fmap . fmap . fmap . fmap $ updateDict $ fmap f updateCtx = M.alter updateScope ByNullSourcePos