{-# LANGUAGE ConstraintKinds #-}

-- | Facilities for changing the lore of some fragment, with no
-- context.  We call this "rephrasing", for no deep reason.
module Futhark.Analysis.Rephrase
  ( rephraseProg,
    rephraseFunDef,
    rephraseExp,
    rephraseBody,
    rephraseStm,
    rephraseLambda,
    rephrasePattern,
    rephrasePatElem,
    Rephraser (..),
  )
where

import Futhark.IR

-- | A collection of functions that together allow us to rephrase some
-- IR fragment, in some monad @m@.  If we let @m@ be the 'Maybe'
-- monad, we can conveniently do rephrasing that might fail.  This is
-- useful if you want to see if some IR in e.g. the @Kernels@ lore
-- actually uses any @Kernels@-specific operations.
data Rephraser m from to = Rephraser
  { forall (m :: * -> *) from to.
Rephraser m from to -> ExpDec from -> m (ExpDec to)
rephraseExpLore :: ExpDec from -> m (ExpDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> LetDec from -> m (LetDec to)
rephraseLetBoundLore :: LetDec from -> m (LetDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamLore :: FParamInfo from -> m (FParamInfo to),
    forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamLore :: LParamInfo from -> m (LParamInfo to),
    forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyLore :: BodyDec from -> m (BodyDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType :: RetType from -> m (RetType to),
    forall (m :: * -> *) from to.
Rephraser m from to -> BranchType from -> m (BranchType to)
rephraseBranchType :: BranchType from -> m (BranchType to),
    forall (m :: * -> *) from to.
Rephraser m from to -> Op from -> m (Op to)
rephraseOp :: Op from -> m (Op to)
  }

-- | Rephrase an entire program.
rephraseProg :: Monad m => Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg Rephraser m from to
rephraser (Prog Stms from
consts [FunDef from]
funs) =
  Stms to -> [FunDef to] -> Prog to
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog
    (Stms to -> [FunDef to] -> Prog to)
-> m (Stms to) -> m ([FunDef to] -> Prog to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm from -> m (Stm to)) -> Stms from -> m (Stms to)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Rephraser m from to -> Stm from -> m (Stm to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser) Stms from
consts
    m ([FunDef to] -> Prog to) -> m [FunDef to] -> m (Prog to)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (FunDef from -> m (FunDef to)) -> [FunDef from] -> m [FunDef to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Rephraser m from to -> FunDef from -> m (FunDef to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser m from to
rephraser) [FunDef from]
funs

-- | Rephrase a function definition.
rephraseFunDef :: Monad m => Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser m from to
rephraser FunDef from
fundec = do
  Body to
body' <- Rephraser m from to -> Body from -> m (Body to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser (Body from -> m (Body to)) -> Body from -> m (Body to)
forall a b. (a -> b) -> a -> b
$ FunDef from -> Body from
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef from
fundec
  [Param (FParamInfo to)]
params' <- (Param (FParamInfo from) -> m (Param (FParamInfo to)))
-> [Param (FParamInfo from)] -> m [Param (FParamInfo to)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo from -> m (FParamInfo to))
-> Param (FParamInfo from) -> m (Param (FParamInfo to))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam ((FParamInfo from -> m (FParamInfo to))
 -> Param (FParamInfo from) -> m (Param (FParamInfo to)))
-> (FParamInfo from -> m (FParamInfo to))
-> Param (FParamInfo from)
-> m (Param (FParamInfo to))
forall a b. (a -> b) -> a -> b
$ Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamLore Rephraser m from to
rephraser) ([Param (FParamInfo from)] -> m [Param (FParamInfo to)])
-> [Param (FParamInfo from)] -> m [Param (FParamInfo to)]
forall a b. (a -> b) -> a -> b
$ FunDef from -> [Param (FParamInfo from)]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef from
fundec
  [RetType to]
rettype' <- (RetType from -> m (RetType to))
-> [RetType from] -> m [RetType to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Rephraser m from to -> RetType from -> m (RetType to)
forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType Rephraser m from to
rephraser) ([RetType from] -> m [RetType to])
-> [RetType from] -> m [RetType to]
forall a b. (a -> b) -> a -> b
$ FunDef from -> [RetType from]
forall lore. FunDef lore -> [RetType lore]
funDefRetType FunDef from
fundec
  FunDef to -> m (FunDef to)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef from
fundec {funDefBody :: Body to
funDefBody = Body to
body', funDefParams :: [Param (FParamInfo to)]
funDefParams = [Param (FParamInfo to)]
params', funDefRetType :: [RetType to]
funDefRetType = [RetType to]
rettype'}

-- | Rephrase an expression.
rephraseExp :: Monad m => Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp = Mapper from to m -> Exp from -> m (Exp to)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM (Mapper from to m -> Exp from -> m (Exp to))
-> (Rephraser m from to -> Mapper from to m)
-> Rephraser m from to
-> Exp from
-> m (Exp to)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser m from to -> Mapper from to m
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Mapper from to m
mapper

-- | Rephrase a statement.
rephraseStm :: Monad m => Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser (Let Pattern from
pat (StmAux Certificates
cs Attrs
attrs ExpDec from
dec) Exp from
e) =
  Pattern to -> StmAux (ExpDec to) -> Exp to -> Stm to
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let
    (Pattern to -> StmAux (ExpDec to) -> Exp to -> Stm to)
-> m (Pattern to) -> m (StmAux (ExpDec to) -> Exp to -> Stm to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LetDec from -> m (LetDec to)) -> Pattern from -> m (Pattern to)
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatternT from -> m (PatternT to)
rephrasePattern (Rephraser m from to -> LetDec from -> m (LetDec to)
forall (m :: * -> *) from to.
Rephraser m from to -> LetDec from -> m (LetDec to)
rephraseLetBoundLore Rephraser m from to
rephraser) Pattern from
pat
    m (StmAux (ExpDec to) -> Exp to -> Stm to)
-> m (StmAux (ExpDec to)) -> m (Exp to -> Stm to)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Certificates -> Attrs -> ExpDec to -> StmAux (ExpDec to)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs (ExpDec to -> StmAux (ExpDec to))
-> m (ExpDec to) -> m (StmAux (ExpDec to))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> ExpDec from -> m (ExpDec to)
forall (m :: * -> *) from to.
Rephraser m from to -> ExpDec from -> m (ExpDec to)
rephraseExpLore Rephraser m from to
rephraser ExpDec from
dec)
    m (Exp to -> Stm to) -> m (Exp to) -> m (Stm to)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> Exp from -> m (Exp to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp Rephraser m from to
rephraser Exp from
e

-- | Rephrase a pattern.
rephrasePattern ::
  Monad m =>
  (from -> m to) ->
  PatternT from ->
  m (PatternT to)
rephrasePattern :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatternT from -> m (PatternT to)
rephrasePattern = (from -> m to) -> PatternT from -> m (PatternT to)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse

-- | Rephrase a pattern element.
rephrasePatElem :: Monad m => (from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem from -> m to
rephraser (PatElem VName
ident from
from) =
  VName -> to -> PatElemT to
forall dec. VName -> dec -> PatElemT dec
PatElem VName
ident (to -> PatElemT to) -> m to -> m (PatElemT to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> from -> m to
rephraser from
from

-- | Rephrase a parameter.
rephraseParam :: Monad m => (from -> m to) -> Param from -> m (Param to)
rephraseParam :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam from -> m to
rephraser (Param VName
name from
from) =
  VName -> to -> Param to
forall dec. VName -> dec -> Param dec
Param VName
name (to -> Param to) -> m to -> m (Param to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> from -> m to
rephraser from
from

-- | Rephrase a body.
rephraseBody :: Monad m => Rephraser m from to -> Body from -> m (Body to)
rephraseBody :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser (Body BodyDec from
lore Stms from
bnds Result
res) =
  BodyDec to -> Stms to -> Result -> BodyT to
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body
    (BodyDec to -> Stms to -> Result -> BodyT to)
-> m (BodyDec to) -> m (Stms to -> Result -> BodyT to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> BodyDec from -> m (BodyDec to)
forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyLore Rephraser m from to
rephraser BodyDec from
lore
    m (Stms to -> Result -> BodyT to)
-> m (Stms to) -> m (Result -> BodyT to)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([Stm to] -> Stms to
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm to] -> Stms to) -> m [Stm to] -> m (Stms to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm from -> m (Stm to)) -> [Stm from] -> m [Stm to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Rephraser m from to -> Stm from -> m (Stm to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser) (Stms from -> [Stm from]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms from
bnds))
    m (Result -> BodyT to) -> m Result -> m (BodyT to)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

-- | Rephrase a lambda.
rephraseLambda :: Monad m => Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
rephraser Lambda from
lam = do
  Body to
body' <- Rephraser m from to -> Body from -> m (Body to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser (Body from -> m (Body to)) -> Body from -> m (Body to)
forall a b. (a -> b) -> a -> b
$ Lambda from -> Body from
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda from
lam
  [Param (LParamInfo to)]
params' <- (Param (LParamInfo from) -> m (Param (LParamInfo to)))
-> [Param (LParamInfo from)] -> m [Param (LParamInfo to)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamInfo from -> m (LParamInfo to))
-> Param (LParamInfo from) -> m (Param (LParamInfo to))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam ((LParamInfo from -> m (LParamInfo to))
 -> Param (LParamInfo from) -> m (Param (LParamInfo to)))
-> (LParamInfo from -> m (LParamInfo to))
-> Param (LParamInfo from)
-> m (Param (LParamInfo to))
forall a b. (a -> b) -> a -> b
$ Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamLore Rephraser m from to
rephraser) ([Param (LParamInfo from)] -> m [Param (LParamInfo to)])
-> [Param (LParamInfo from)] -> m [Param (LParamInfo to)]
forall a b. (a -> b) -> a -> b
$ Lambda from -> [Param (LParamInfo from)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda from
lam
  Lambda to -> m (Lambda to)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda from
lam {lambdaBody :: Body to
lambdaBody = Body to
body', lambdaParams :: [Param (LParamInfo to)]
lambdaParams = [Param (LParamInfo to)]
params'}

mapper :: Monad m => Rephraser m from to -> Mapper from to m
mapper :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Mapper from to m
mapper Rephraser m from to
rephraser =
  Mapper Any Any m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
    { mapOnBody :: Scope to -> Body from -> m (Body to)
mapOnBody = (Body from -> m (Body to)) -> Scope to -> Body from -> m (Body to)
forall a b. a -> b -> a
const ((Body from -> m (Body to))
 -> Scope to -> Body from -> m (Body to))
-> (Body from -> m (Body to))
-> Scope to
-> Body from
-> m (Body to)
forall a b. (a -> b) -> a -> b
$ Rephraser m from to -> Body from -> m (Body to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser,
      mapOnRetType :: RetType from -> m (RetType to)
mapOnRetType = Rephraser m from to -> RetType from -> m (RetType to)
forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType Rephraser m from to
rephraser,
      mapOnBranchType :: BranchType from -> m (BranchType to)
mapOnBranchType = Rephraser m from to -> BranchType from -> m (BranchType to)
forall (m :: * -> *) from to.
Rephraser m from to -> BranchType from -> m (BranchType to)
rephraseBranchType Rephraser m from to
rephraser,
      mapOnFParam :: FParam from -> m (FParam to)
mapOnFParam = (FParamInfo from -> m (FParamInfo to))
-> FParam from -> m (FParam to)
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam (Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamLore Rephraser m from to
rephraser),
      mapOnLParam :: LParam from -> m (LParam to)
mapOnLParam = (LParamInfo from -> m (LParamInfo to))
-> LParam from -> m (LParam to)
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam (Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamLore Rephraser m from to
rephraser),
      mapOnOp :: Op from -> m (Op to)
mapOnOp = Rephraser m from to -> Op from -> m (Op to)
forall (m :: * -> *) from to.
Rephraser m from to -> Op from -> m (Op to)
rephraseOp Rephraser m from to
rephraser
    }