-- | Facilities for changing the representation of some fragment,
-- within a monadic context.  We call this "rephrasing", for no deep
-- reason.
module Futhark.IR.Rephrase
  ( rephraseProg,
    rephraseFunDef,
    rephraseExp,
    rephraseBody,
    rephraseStm,
    rephraseLambda,
    rephrasePat,
    rephrasePatElem,
    Rephraser (..),
    RephraseOp (..),
  )
where

import Data.Bitraversable
import Futhark.IR.Syntax
import Futhark.IR.Traversals

-- | A collection of functions that together allow us to rephrase some
-- IR fragment, in some monad @m@.  If we let @m@ be the 'Maybe'
-- monad, we can conveniently do rephrasing that might fail.  This is
-- useful if you want to see if some IR in e.g. the @Kernels@ rep
-- actually uses any @Kernels@-specific operations.
data Rephraser m from to = Rephraser
  { forall (m :: * -> *) from to.
Rephraser m from to -> ExpDec from -> m (ExpDec to)
rephraseExpDec :: ExpDec from -> m (ExpDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> LetDec from -> m (LetDec to)
rephraseLetBoundDec :: LetDec from -> m (LetDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamDec :: FParamInfo from -> m (FParamInfo to),
    forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamDec :: LParamInfo from -> m (LParamInfo to),
    forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec :: BodyDec from -> m (BodyDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType :: RetType from -> m (RetType to),
    forall (m :: * -> *) from to.
Rephraser m from to -> BranchType from -> m (BranchType to)
rephraseBranchType :: BranchType from -> m (BranchType to),
    forall (m :: * -> *) from to.
Rephraser m from to -> Op from -> m (Op to)
rephraseOp :: Op from -> m (Op to)
  }

-- | Rephrase an entire program.
rephraseProg :: (Monad m) => Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg Rephraser m from to
rephraser Prog from
prog = do
  Seq (Stm to)
consts <- (Stm from -> m (Stm to)) -> Seq (Stm from) -> m (Seq (Stm to))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM (Rephraser m from to -> Stm from -> m (Stm to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser) (Prog from -> Seq (Stm from)
forall rep. Prog rep -> Stms rep
progConsts Prog from
prog)
  [FunDef to]
funs <- (FunDef from -> m (FunDef to)) -> [FunDef from] -> m [FunDef to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> FunDef from -> m (FunDef to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser m from to
rephraser) (Prog from -> [FunDef from]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog from
prog)
  Prog to -> m (Prog to)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog to -> m (Prog to)) -> Prog to -> m (Prog to)
forall a b. (a -> b) -> a -> b
$ Prog from
prog {progConsts = consts, progFuns = funs}

-- | Rephrase a function definition.
rephraseFunDef :: (Monad m) => Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser m from to
rephraser FunDef from
fundec = do
  Body to
body' <- Rephraser m from to -> Body from -> m (Body to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser (Body from -> m (Body to)) -> Body from -> m (Body to)
forall a b. (a -> b) -> a -> b
$ FunDef from -> Body from
forall rep. FunDef rep -> Body rep
funDefBody FunDef from
fundec
  [Param (FParamInfo to)]
params' <- (Param (FParamInfo from) -> m (Param (FParamInfo to)))
-> [Param (FParamInfo from)] -> m [Param (FParamInfo to)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((FParamInfo from -> m (FParamInfo to))
-> Param (FParamInfo from) -> m (Param (FParamInfo to))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam ((FParamInfo from -> m (FParamInfo to))
 -> Param (FParamInfo from) -> m (Param (FParamInfo to)))
-> (FParamInfo from -> m (FParamInfo to))
-> Param (FParamInfo from)
-> m (Param (FParamInfo to))
forall a b. (a -> b) -> a -> b
$ Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamDec Rephraser m from to
rephraser) ([Param (FParamInfo from)] -> m [Param (FParamInfo to)])
-> [Param (FParamInfo from)] -> m [Param (FParamInfo to)]
forall a b. (a -> b) -> a -> b
$ FunDef from -> [Param (FParamInfo from)]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef from
fundec
  [(RetType to, RetAls)]
rettype' <- ((RetType from, RetAls) -> m (RetType to, RetAls))
-> [(RetType from, RetAls)] -> m [(RetType to, RetAls)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((RetType from -> m (RetType to))
-> (RetAls -> m RetAls)
-> (RetType from, RetAls)
-> m (RetType to, RetAls)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> (a, b) -> f (c, d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (Rephraser m from to -> RetType from -> m (RetType to)
forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType Rephraser m from to
rephraser) RetAls -> m RetAls
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) ([(RetType from, RetAls)] -> m [(RetType to, RetAls)])
-> [(RetType from, RetAls)] -> m [(RetType to, RetAls)]
forall a b. (a -> b) -> a -> b
$ FunDef from -> [(RetType from, RetAls)]
forall rep. FunDef rep -> [(RetType rep, RetAls)]
funDefRetType FunDef from
fundec
  FunDef to -> m (FunDef to)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef from
fundec {funDefBody = body', funDefParams = params', funDefRetType = rettype'}

-- | Rephrase an expression.
rephraseExp :: (Monad m) => Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp = Mapper from to m -> Exp from -> m (Exp to)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Mapper from to m -> Exp from -> m (Exp to))
-> (Rephraser m from to -> Mapper from to m)
-> Rephraser m from to
-> Exp from
-> m (Exp to)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser m from to -> Mapper from to m
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Mapper from to m
mapper

-- | Rephrase a statement.
rephraseStm :: (Monad m) => Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser (Let Pat (LetDec from)
pat (StmAux Certs
cs Attrs
attrs ExpDec from
dec) Exp from
e) =
  Pat (LetDec to) -> StmAux (ExpDec to) -> Exp to -> Stm to
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let
    (Pat (LetDec to) -> StmAux (ExpDec to) -> Exp to -> Stm to)
-> m (Pat (LetDec to))
-> m (StmAux (ExpDec to) -> Exp to -> Stm to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LetDec from -> m (LetDec to))
-> Pat (LetDec from) -> m (Pat (LetDec to))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Pat from -> m (Pat to)
rephrasePat (Rephraser m from to -> LetDec from -> m (LetDec to)
forall (m :: * -> *) from to.
Rephraser m from to -> LetDec from -> m (LetDec to)
rephraseLetBoundDec Rephraser m from to
rephraser) Pat (LetDec from)
pat
    m (StmAux (ExpDec to) -> Exp to -> Stm to)
-> m (StmAux (ExpDec to)) -> m (Exp to -> Stm to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Certs -> Attrs -> ExpDec to -> StmAux (ExpDec to)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs (ExpDec to -> StmAux (ExpDec to))
-> m (ExpDec to) -> m (StmAux (ExpDec to))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> ExpDec from -> m (ExpDec to)
forall (m :: * -> *) from to.
Rephraser m from to -> ExpDec from -> m (ExpDec to)
rephraseExpDec Rephraser m from to
rephraser ExpDec from
dec)
    m (Exp to -> Stm to) -> m (Exp to) -> m (Stm to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> Exp from -> m (Exp to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp Rephraser m from to
rephraser Exp from
e

-- | Rephrase a pattern.
rephrasePat ::
  (Monad m) =>
  (from -> m to) ->
  Pat from ->
  m (Pat to)
rephrasePat :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Pat from -> m (Pat to)
rephrasePat = (from -> m to) -> Pat from -> m (Pat to)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Pat a -> f (Pat b)
traverse

-- | Rephrase a pattern element.
rephrasePatElem :: (Monad m) => (from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem from -> m to
rephraser (PatElem VName
ident from
from) =
  VName -> to -> PatElem to
forall dec. VName -> dec -> PatElem dec
PatElem VName
ident (to -> PatElem to) -> m to -> m (PatElem to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> from -> m to
rephraser from
from

-- | Rephrase a parameter.
rephraseParam :: (Monad m) => (from -> m to) -> Param from -> m (Param to)
rephraseParam :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam from -> m to
rephraser (Param Attrs
attrs VName
name from
from) =
  Attrs -> VName -> to -> Param to
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name (to -> Param to) -> m to -> m (Param to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> from -> m to
rephraser from
from

-- | Rephrase a body.
rephraseBody :: (Monad m) => Rephraser m from to -> Body from -> m (Body to)
rephraseBody :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser (Body BodyDec from
rep Stms from
stms Result
res) =
  BodyDec to -> Stms to -> Result -> Body to
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
    (BodyDec to -> Stms to -> Result -> Body to)
-> m (BodyDec to) -> m (Stms to -> Result -> Body to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> BodyDec from -> m (BodyDec to)
forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec Rephraser m from to
rephraser BodyDec from
rep
    m (Stms to -> Result -> Body to)
-> m (Stms to) -> m (Result -> Body to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([Stm to] -> Stms to
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm to] -> Stms to) -> m [Stm to] -> m (Stms to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm from -> m (Stm to)) -> [Stm from] -> m [Stm to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> Stm from -> m (Stm to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser) (Stms from -> [Stm from]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms from
stms))
    m (Result -> Body to) -> m Result -> m (Body to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> m Result
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

-- | Rephrase a lambda.
rephraseLambda :: (Monad m) => Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
rephraser Lambda from
lam = do
  Body to
body' <- Rephraser m from to -> Body from -> m (Body to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser (Body from -> m (Body to)) -> Body from -> m (Body to)
forall a b. (a -> b) -> a -> b
$ Lambda from -> Body from
forall rep. Lambda rep -> Body rep
lambdaBody Lambda from
lam
  [Param (LParamInfo to)]
params' <- (Param (LParamInfo from) -> m (Param (LParamInfo to)))
-> [Param (LParamInfo from)] -> m [Param (LParamInfo to)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((LParamInfo from -> m (LParamInfo to))
-> Param (LParamInfo from) -> m (Param (LParamInfo to))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam ((LParamInfo from -> m (LParamInfo to))
 -> Param (LParamInfo from) -> m (Param (LParamInfo to)))
-> (LParamInfo from -> m (LParamInfo to))
-> Param (LParamInfo from)
-> m (Param (LParamInfo to))
forall a b. (a -> b) -> a -> b
$ Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamDec Rephraser m from to
rephraser) ([Param (LParamInfo from)] -> m [Param (LParamInfo to)])
-> [Param (LParamInfo from)] -> m [Param (LParamInfo to)]
forall a b. (a -> b) -> a -> b
$ Lambda from -> [Param (LParamInfo from)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda from
lam
  Lambda to -> m (Lambda to)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda from
lam {lambdaBody = body', lambdaParams = params'}

mapper :: (Monad m) => Rephraser m from to -> Mapper from to m
mapper :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Mapper from to m
mapper Rephraser m from to
rephraser =
  Mapper Any Any m
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
    { mapOnBody = const $ rephraseBody rephraser,
      mapOnRetType = rephraseRetType rephraser,
      mapOnBranchType = rephraseBranchType rephraser,
      mapOnFParam = rephraseParam (rephraseFParamDec rephraser),
      mapOnLParam = rephraseParam (rephraseLParamDec rephraser),
      mapOnOp = rephraseOp rephraser
    }

-- | Rephrasing any fragments inside an Op from one representation to
-- another.
class RephraseOp op where
  rephraseInOp :: (Monad m) => Rephraser m from to -> op from -> m (op to)

instance RephraseOp NoOp where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> NoOp from -> m (NoOp to)
rephraseInOp Rephraser m from to
_ NoOp from
NoOp = NoOp to -> m (NoOp to)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoOp to
forall {k} (rep :: k). NoOp rep
NoOp