{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Binder.Class
( Bindable (..),
mkLet,
mkLet',
MonadBinder (..),
insertStms,
insertStm,
letBind,
letBindNames,
collectStms_,
bodyBind,
attributing,
auxing,
module Futhark.MonadFreshNames,
)
where
import qualified Data.Kind
import Futhark.IR
import Futhark.MonadFreshNames
class
( ASTLore lore,
FParamInfo lore ~ DeclType,
LParamInfo lore ~ Type,
RetType lore ~ DeclExtType,
BranchType lore ~ ExtType,
SetType (LetDec lore)
) =>
Bindable lore
where
mkExpPat :: [Ident] -> [Ident] -> Exp lore -> Pattern lore
mkExpDec :: Pattern lore -> Exp lore -> ExpDec lore
mkBody :: Stms lore -> Result -> Body lore
mkLetNames ::
(MonadFreshNames m, HasScope lore m) =>
[VName] ->
Exp lore ->
m (Stm lore)
class
( ASTLore (Lore m),
MonadFreshNames m,
Applicative m,
Monad m,
LocalScope (Lore m) m
) =>
MonadBinder m
where
type Lore m :: Data.Kind.Type
mkExpDecM :: Pattern (Lore m) -> Exp (Lore m) -> m (ExpDec (Lore m))
mkBodyM :: Stms (Lore m) -> Result -> m (Body (Lore m))
mkLetNamesM :: [VName] -> Exp (Lore m) -> m (Stm (Lore m))
addStm :: Stm (Lore m) -> m ()
addStm = Seq (Stm (Lore m)) -> m ()
forall (m :: * -> *). MonadBinder m => Seq (Stm (Lore m)) -> m ()
addStms (Seq (Stm (Lore m)) -> m ())
-> (Stm (Lore m) -> Seq (Stm (Lore m))) -> Stm (Lore m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Lore m) -> Seq (Stm (Lore m))
forall lore. Stm lore -> Stms lore
oneStm
addStms :: Stms (Lore m) -> m ()
collectStms :: m a -> m (a, Stms (Lore m))
certifying :: Certificates -> m a -> m a
certifying = (Seq (Stm (Lore m)) -> Seq (Stm (Lore m))) -> m a -> m a
forall (m :: * -> *) a.
MonadBinder m =>
(Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
censorStms ((Seq (Stm (Lore m)) -> Seq (Stm (Lore m))) -> m a -> m a)
-> (Certificates -> Seq (Stm (Lore m)) -> Seq (Stm (Lore m)))
-> Certificates
-> m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm (Lore m) -> Stm (Lore m))
-> Seq (Stm (Lore m)) -> Seq (Stm (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Stm (Lore m) -> Stm (Lore m))
-> Seq (Stm (Lore m)) -> Seq (Stm (Lore m)))
-> (Certificates -> Stm (Lore m) -> Stm (Lore m))
-> Certificates
-> Seq (Stm (Lore m))
-> Seq (Stm (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certificates -> Stm (Lore m) -> Stm (Lore m)
forall lore. Certificates -> Stm lore -> Stm lore
certify
censorStms ::
MonadBinder m =>
(Stms (Lore m) -> Stms (Lore m)) ->
m a ->
m a
censorStms :: (Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
censorStms Stms (Lore m) -> Stms (Lore m)
f m a
m = do
(a
x, Stms (Lore m)
stms) <- m a -> m (a, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Seq (Stm (Lore m)))
collectStms m a
m
Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Seq (Stm (Lore m)) -> m ()
addStms (Stms (Lore m) -> m ()) -> Stms (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Stms (Lore m) -> Stms (Lore m)
f Stms (Lore m)
stms
a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
attributing :: MonadBinder m => Attrs -> m a -> m a
attributing :: Attrs -> m a -> m a
attributing Attrs
attrs = (Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
forall (m :: * -> *) a.
MonadBinder m =>
(Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
censorStms ((Stms (Lore m) -> Stms (Lore m)) -> m a -> m a)
-> (Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ (Stm (Lore m) -> Stm (Lore m)) -> Stms (Lore m) -> Stms (Lore m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Lore m) -> Stm (Lore m)
onStm
where
onStm :: Stm (Lore m) -> Stm (Lore m)
onStm (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux Exp (Lore m)
e) =
Pattern (Lore m)
-> StmAux (ExpDec (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux {stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec (Lore m)) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec (Lore m))
aux} Exp (Lore m)
e
auxing :: MonadBinder m => StmAux anylore -> m a -> m a
auxing :: StmAux anylore -> m a -> m a
auxing (StmAux Certificates
cs Attrs
attrs anylore
_) = (Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
forall (m :: * -> *) a.
MonadBinder m =>
(Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
censorStms ((Stms (Lore m) -> Stms (Lore m)) -> m a -> m a)
-> (Stms (Lore m) -> Stms (Lore m)) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ (Stm (Lore m) -> Stm (Lore m)) -> Stms (Lore m) -> Stms (Lore m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Lore m) -> Stm (Lore m)
onStm
where
onStm :: Stm (Lore m) -> Stm (Lore m)
onStm (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux Exp (Lore m)
e) =
Pattern (Lore m)
-> StmAux (ExpDec (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux' Exp (Lore m)
e
where
aux' :: StmAux (ExpDec (Lore m))
aux' =
StmAux (ExpDec (Lore m))
aux
{ stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec (Lore m)) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec (Lore m))
aux,
stmAuxCerts :: Certificates
stmAuxCerts = Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec (Lore m)) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpDec (Lore m))
aux
}
letBind ::
MonadBinder m =>
Pattern (Lore m) ->
Exp (Lore m) ->
m ()
letBind :: Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat Exp (Lore m)
e =
Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore m) -> m ()) -> m (Stm (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pattern (Lore m)
-> StmAux (ExpDec (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern (Lore m)
pat (StmAux (ExpDec (Lore m)) -> Exp (Lore m) -> Stm (Lore m))
-> m (StmAux (ExpDec (Lore m))) -> m (Exp (Lore m) -> Stm (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpDec (Lore m) -> StmAux (ExpDec (Lore m))
forall dec. dec -> StmAux dec
defAux (ExpDec (Lore m) -> StmAux (ExpDec (Lore m)))
-> m (ExpDec (Lore m)) -> m (StmAux (ExpDec (Lore m)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern (Lore m) -> Exp (Lore m) -> m (ExpDec (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m (ExpDec (Lore m))
mkExpDecM Pattern (Lore m)
pat Exp (Lore m)
e) m (Exp (Lore m) -> Stm (Lore m))
-> m (Exp (Lore m)) -> m (Stm (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
e
mkLet :: Bindable lore => [Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet :: [Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [Ident]
ctx [Ident]
val Exp lore
e =
let pat :: Pattern lore
pat = [Ident] -> [Ident] -> Exp lore -> Pattern lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Pattern lore
mkExpPat [Ident]
ctx [Ident]
val Exp lore
e
dec :: ExpDec lore
dec = Pattern lore -> Exp lore -> ExpDec lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpDec lore
mkExpDec Pattern lore
pat Exp lore
e
in Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat (ExpDec lore -> StmAux (ExpDec lore)
forall dec. dec -> StmAux dec
defAux ExpDec lore
dec) Exp lore
e
mkLet' :: Bindable lore => [Ident] -> [Ident] -> StmAux a -> Exp lore -> Stm lore
mkLet' :: [Ident] -> [Ident] -> StmAux a -> Exp lore -> Stm lore
mkLet' [Ident]
ctx [Ident]
val (StmAux Certificates
cs Attrs
attrs a
_) Exp lore
e =
let pat :: Pattern lore
pat = [Ident] -> [Ident] -> Exp lore -> Pattern lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Pattern lore
mkExpPat [Ident]
ctx [Ident]
val Exp lore
e
dec :: ExpDec lore
dec = Pattern lore -> Exp lore -> ExpDec lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpDec lore
mkExpDec Pattern lore
pat Exp lore
e
in Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
dec) Exp lore
e
letBindNames :: MonadBinder m => [VName] -> Exp (Lore m) -> m ()
letBindNames :: [VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
names Exp (Lore m)
e = Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore m) -> m ()) -> m (Stm (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesM [VName]
names Exp (Lore m)
e
collectStms_ :: MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ :: m a -> m (Stms (Lore m))
collectStms_ = ((a, Stms (Lore m)) -> Stms (Lore m))
-> m (a, Stms (Lore m)) -> m (Stms (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms (Lore m)) -> Stms (Lore m)
forall a b. (a, b) -> b
snd (m (a, Stms (Lore m)) -> m (Stms (Lore m)))
-> (m a -> m (a, Stms (Lore m))) -> m a -> m (Stms (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> m (a, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Seq (Stm (Lore m)))
collectStms
bodyBind :: MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind :: Body (Lore m) -> m [SubExp]
bodyBind (Body BodyDec (Lore m)
_ Stms (Lore m)
stms [SubExp]
es) = do
Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Seq (Stm (Lore m)) -> m ()
addStms Stms (Lore m)
stms
[SubExp] -> m [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [SubExp]
es
insertStms :: Bindable lore => Stms lore -> Body lore -> Body lore
insertStms :: Stms lore -> Body lore -> Body lore
insertStms Stms lore
stms1 (Body BodyDec lore
_ Stms lore
stms2 [SubExp]
res) = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
stms1 Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
stms2) [SubExp]
res
insertStm :: Bindable lore => Stm lore -> Body lore -> Body lore
insertStm :: Stm lore -> Body lore -> Body lore
insertStm = Stms lore -> Body lore -> Body lore
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms (Stms lore -> Body lore -> Body lore)
-> (Stm lore -> Stms lore) -> Stm lore -> Body lore -> Body lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm