-- | Lambda-lifting of typed, monomorphic Futhark programs without
-- modules.  After this pass, the program will no longer contain any
-- 'LetFun's or 'Lambda's.
module Futhark.Internalise.LiftLambdas (transformProg) where

import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Foldable
import Data.List (partition)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.IR.Pretty ()
import Futhark.MonadFreshNames
import Futhark.Util (nubOrd)
import Language.Futhark
import Language.Futhark.Traversals

data Env = Env
  { Env -> Map VName Exp
envReplace :: M.Map VName Exp,
    Env -> Map VName StructType
envVtable :: M.Map VName StructType
  }

initialEnv :: Env
initialEnv :: Env
initialEnv = Map VName Exp -> Map VName StructType -> Env
Env Map VName Exp
forall a. Monoid a => a
mempty Map VName StructType
forall a. Monoid a => a
mempty

data LiftState = State
  { LiftState -> VNameSource
stateNameSource :: VNameSource,
    LiftState -> [ValBind]
stateValBinds :: [ValBind],
    LiftState -> Set VName
stateGlobal :: S.Set VName
  }

initialState :: VNameSource -> LiftState
initialState :: VNameSource -> LiftState
initialState VNameSource
src = VNameSource -> [ValBind] -> Set VName -> LiftState
State VNameSource
src [ValBind]
forall a. Monoid a => a
mempty (Set VName -> LiftState) -> Set VName -> LiftState
forall a b. (a -> b) -> a -> b
$ [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Map VName Intrinsic -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName Intrinsic
intrinsics

newtype LiftM a = LiftM (ReaderT Env (State LiftState) a)
  deriving ((forall a b. (a -> b) -> LiftM a -> LiftM b)
-> (forall a b. a -> LiftM b -> LiftM a) -> Functor LiftM
forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
fmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
$c<$ :: forall a b. a -> LiftM b -> LiftM a
<$ :: forall a b. a -> LiftM b -> LiftM a
Functor, Functor LiftM
Functor LiftM
-> (forall a. a -> LiftM a)
-> (forall a b. LiftM (a -> b) -> LiftM a -> LiftM b)
-> (forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM a)
-> Applicative LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> LiftM a
pure :: forall a. a -> LiftM a
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
liftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
*> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
<* :: forall a b. LiftM a -> LiftM b -> LiftM a
Applicative, Applicative LiftM
Applicative LiftM
-> (forall a b. LiftM a -> (a -> LiftM b) -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a. a -> LiftM a)
-> Monad LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>> :: forall a b. LiftM a -> LiftM b -> LiftM b
$creturn :: forall a. a -> LiftM a
return :: forall a. a -> LiftM a
Monad, MonadReader Env, MonadState LiftState)

instance MonadFreshNames LiftM where
  putNameSource :: VNameSource -> LiftM ()
putNameSource VNameSource
src = (LiftState -> LiftState) -> LiftM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftState -> LiftState) -> LiftM ())
-> (LiftState -> LiftState) -> LiftM ()
forall a b. (a -> b) -> a -> b
$ \LiftState
s -> LiftState
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
  getNameSource :: LiftM VNameSource
getNameSource = (LiftState -> VNameSource) -> LiftM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftState -> VNameSource
stateNameSource

runLiftM :: VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM :: VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM VNameSource
src (LiftM ReaderT Env (State LiftState) ()
m) =
  let s :: LiftState
s = State LiftState () -> LiftState -> LiftState
forall s a. State s a -> s -> s
execState (ReaderT Env (State LiftState) () -> Env -> State LiftState ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Env (State LiftState) ()
m Env
initialEnv) (VNameSource -> LiftState
initialState VNameSource
src)
   in ([ValBind] -> [ValBind]
forall a. [a] -> [a]
reverse (LiftState -> [ValBind]
stateValBinds LiftState
s), LiftState -> VNameSource
stateNameSource LiftState
s)

addValBind :: ValBind -> LiftM ()
addValBind :: ValBind -> LiftM ()
addValBind ValBind
vb = (LiftState -> LiftState) -> LiftM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftState -> LiftState) -> LiftM ())
-> (LiftState -> LiftState) -> LiftM ()
forall a b. (a -> b) -> a -> b
$ \LiftState
s ->
  LiftState
s
    { stateValBinds :: [ValBind]
stateValBinds = ValBind
vb ValBind -> [ValBind] -> [ValBind]
forall a. a -> [a] -> [a]
: LiftState -> [ValBind]
stateValBinds LiftState
s,
      stateGlobal :: Set VName
stateGlobal = (Set VName -> VName -> Set VName)
-> Set VName -> [VName] -> Set VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> Set VName -> Set VName)
-> Set VName -> VName -> Set VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert) (LiftState -> Set VName
stateGlobal LiftState
s) (ValBind -> [VName]
valBindBound ValBind
vb)
    }

replacing :: VName -> Exp -> LiftM a -> LiftM a
replacing :: forall a. VName -> Exp -> LiftM a -> LiftM a
replacing VName
v Exp
e = (Env -> Env) -> LiftM a -> LiftM a
forall a. (Env -> Env) -> LiftM a -> LiftM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> LiftM a -> LiftM a)
-> (Env -> Env) -> LiftM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
  Env
env {envReplace :: Map VName Exp
envReplace = VName -> Exp -> Map VName Exp -> Map VName Exp
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Exp
e (Map VName Exp -> Map VName Exp) -> Map VName Exp -> Map VName Exp
forall a b. (a -> b) -> a -> b
$ Env -> Map VName Exp
envReplace Env
env}

bindingParams :: [VName] -> [Pat ParamType] -> LiftM a -> LiftM a
bindingParams :: forall a. [VName] -> [Pat ParamType] -> LiftM a -> LiftM a
bindingParams [VName]
sizes [Pat ParamType]
params = (Env -> Env) -> LiftM a -> LiftM a
forall a. (Env -> Env) -> LiftM a -> LiftM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> LiftM a -> LiftM a)
-> (Env -> Env) -> LiftM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
  Env
env
    { envVtable :: Map VName StructType
envVtable =
        [(VName, StructType)] -> Map VName StructType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (((VName, ParamType) -> (VName, StructType))
-> [(VName, ParamType)] -> [(VName, StructType)]
forall a b. (a -> b) -> [a] -> [b]
map ((ParamType -> StructType)
-> (VName, ParamType) -> (VName, StructType)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ParamType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct) ((Pat ParamType -> [(VName, ParamType)])
-> [Pat ParamType] -> [(VName, ParamType)]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> [(VName, ParamType)]
forall t. Pat t -> [(VName, t)]
patternMap [Pat ParamType]
params) [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. Semigroup a => a -> a -> a
<> (VName -> (VName, StructType)) -> [VName] -> [(VName, StructType)]
forall a b. (a -> b) -> [a] -> [b]
map (,StructType
forall {dim} {u}. TypeBase dim u
i64) [VName]
sizes)
          Map VName StructType
-> Map VName StructType -> Map VName StructType
forall a. Semigroup a => a -> a -> a
<> Env -> Map VName StructType
envVtable Env
env
    }
  where
    i64 :: TypeBase dim u
i64 = ScalarTypeBase dim u -> TypeBase dim u
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim u -> TypeBase dim u)
-> ScalarTypeBase dim u -> TypeBase dim u
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase dim u
forall dim u. PrimType -> ScalarTypeBase dim u
Prim (PrimType -> ScalarTypeBase dim u)
-> PrimType -> ScalarTypeBase dim u
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

bindingLetPat :: [VName] -> Pat StructType -> LiftM a -> LiftM a
bindingLetPat :: forall a. [VName] -> Pat StructType -> LiftM a -> LiftM a
bindingLetPat [VName]
sizes Pat StructType
pat = (Env -> Env) -> LiftM a -> LiftM a
forall a. (Env -> Env) -> LiftM a -> LiftM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> LiftM a -> LiftM a)
-> (Env -> Env) -> LiftM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
  Env
env
    { envVtable :: Map VName StructType
envVtable =
        [(VName, StructType)] -> Map VName StructType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (((VName, StructType) -> (VName, StructType))
-> [(VName, StructType)] -> [(VName, StructType)]
forall a b. (a -> b) -> [a] -> [b]
map ((StructType -> StructType)
-> (VName, StructType) -> (VName, StructType)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct) (Pat StructType -> [(VName, StructType)]
forall t. Pat t -> [(VName, t)]
patternMap Pat StructType
pat) [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. Semigroup a => a -> a -> a
<> (VName -> (VName, StructType)) -> [VName] -> [(VName, StructType)]
forall a b. (a -> b) -> [a] -> [b]
map (,StructType
forall {dim} {u}. TypeBase dim u
i64) [VName]
sizes)
          Map VName StructType
-> Map VName StructType -> Map VName StructType
forall a. Semigroup a => a -> a -> a
<> Env -> Map VName StructType
envVtable Env
env
    }
  where
    i64 :: TypeBase dim u
i64 = ScalarTypeBase dim u -> TypeBase dim u
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim u -> TypeBase dim u)
-> ScalarTypeBase dim u -> TypeBase dim u
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase dim u
forall dim u. PrimType -> ScalarTypeBase dim u
Prim (PrimType -> ScalarTypeBase dim u)
-> PrimType -> ScalarTypeBase dim u
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

bindingForm :: LoopFormBase Info VName -> LiftM a -> LiftM a
bindingForm :: forall a. LoopFormBase Info VName -> LiftM a -> LiftM a
bindingForm (For IdentBase Info VName StructType
i Exp
_) = [VName] -> Pat StructType -> LiftM a -> LiftM a
forall a. [VName] -> Pat StructType -> LiftM a -> LiftM a
bindingLetPat [] (VName -> Info StructType -> SrcLoc -> Pat StructType
forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id (IdentBase Info VName StructType -> VName
forall {k} (f :: k -> *) vn (t :: k). IdentBase f vn t -> vn
identName IdentBase Info VName StructType
i) (IdentBase Info VName StructType -> Info StructType
forall {k} (f :: k -> *) vn (t :: k). IdentBase f vn t -> f t
identType IdentBase Info VName StructType
i) SrcLoc
forall a. Monoid a => a
mempty)
bindingForm (ForIn Pat StructType
p Exp
_) = [VName] -> Pat StructType -> LiftM a -> LiftM a
forall a. [VName] -> Pat StructType -> LiftM a -> LiftM a
bindingLetPat [] Pat StructType
p
bindingForm While {} = LiftM a -> LiftM a
forall a. a -> a
id

toRet :: TypeBase Size u -> TypeBase Size Uniqueness
toRet :: forall u. TypeBase Exp u -> TypeBase Exp Uniqueness
toRet = (u -> Uniqueness) -> TypeBase Exp u -> TypeBase Exp Uniqueness
forall b c a. (b -> c) -> TypeBase a b -> TypeBase a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Uniqueness -> u -> Uniqueness
forall a b. a -> b -> a
const Uniqueness
Nonunique)

liftFunction :: VName -> [TypeParam] -> [Pat ParamType] -> ResRetType -> Exp -> LiftM Exp
liftFunction :: VName
-> [TypeParam] -> [Pat ParamType] -> ResRetType -> Exp -> LiftM Exp
liftFunction VName
fname [TypeParam]
tparams [Pat ParamType]
params (RetType [VName]
dims TypeBase Exp Uniqueness
ret) Exp
funbody = do
  -- Find free variables
  Map VName StructType
vtable <- (Env -> Map VName StructType) -> LiftM (Map VName StructType)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Map VName StructType
envVtable
  let isFree :: VName -> Maybe (VName, StructType)
isFree VName
v = (VName
v,) (StructType -> (VName, StructType))
-> Maybe StructType -> Maybe (VName, StructType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName StructType -> Maybe StructType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName StructType
vtable
      withTypes :: FV -> [(VName, StructType)]
withTypes = (VName -> Maybe (VName, StructType))
-> [VName] -> [(VName, StructType)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe VName -> Maybe (VName, StructType)
isFree ([VName] -> [(VName, StructType)])
-> (FV -> [VName]) -> FV -> [(VName, StructType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> (FV -> Set VName) -> FV -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FV -> Set VName
fvVars

  let free :: [(VName, StructType)]
free =
        let immediate_free :: [(VName, StructType)]
immediate_free = FV -> [(VName, StructType)]
withTypes (FV -> [(VName, StructType)]) -> FV -> [(VName, StructType)]
forall a b. (a -> b) -> a -> b
$ Exp -> FV
freeInExp Exp
funbody
            sizes_in_free :: FV
sizes_in_free = ((VName, StructType) -> FV) -> [(VName, StructType)] -> FV
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StructType -> FV
forall u. TypeBase Exp u -> FV
freeInType (StructType -> FV)
-> ((VName, StructType) -> StructType) -> (VName, StructType) -> FV
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, StructType) -> StructType
forall a b. (a, b) -> b
snd) [(VName, StructType)]
immediate_free
            sizes :: [(VName, StructType)]
sizes =
              FV -> [(VName, StructType)]
withTypes (FV -> [(VName, StructType)]) -> FV -> [(VName, StructType)]
forall a b. (a -> b) -> a -> b
$
                FV
sizes_in_free
                  FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> (Pat ParamType -> FV) -> [Pat ParamType] -> FV
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> FV
forall u. Pat (TypeBase Exp u) -> FV
freeInPat [Pat ParamType]
params
                  FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> TypeBase Exp Uniqueness -> FV
forall u. TypeBase Exp u -> FV
freeInType TypeBase Exp Uniqueness
ret
         in [(VName, StructType)] -> [(VName, StructType)]
forall a. Ord a => [a] -> [a]
nubOrd ([(VName, StructType)] -> [(VName, StructType)])
-> [(VName, StructType)] -> [(VName, StructType)]
forall a b. (a -> b) -> a -> b
$ [(VName, StructType)]
immediate_free [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. Semigroup a => a -> a -> a
<> [(VName, StructType)]
sizes

      -- Those parameters that correspond to sizes must come first.
      sizes_in_types :: FV
sizes_in_types =
        (StructType -> FV) -> [StructType] -> FV
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> FV
forall u. TypeBase Exp u -> FV
freeInType (TypeBase Exp Uniqueness -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase Exp Uniqueness
ret StructType -> [StructType] -> [StructType]
forall a. a -> [a] -> [a]
: ((VName, StructType) -> StructType)
-> [(VName, StructType)] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map (VName, StructType) -> StructType
forall a b. (a, b) -> b
snd [(VName, StructType)]
free [StructType] -> [StructType] -> [StructType]
forall a. [a] -> [a] -> [a]
++ (Pat ParamType -> StructType) -> [Pat ParamType] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map Pat ParamType -> StructType
forall u. Pat (TypeBase Exp u) -> StructType
patternStructType [Pat ParamType]
params)
      isSize :: (VName, b) -> Bool
isSize (VName
v, b
_) = VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` FV -> Set VName
fvVars FV
sizes_in_types
      ([(VName, StructType)]
free_dims, [(VName, StructType)]
free_nondims) = ((VName, StructType) -> Bool)
-> [(VName, StructType)]
-> ([(VName, StructType)], [(VName, StructType)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (VName, StructType) -> Bool
forall {b}. (VName, b) -> Bool
isSize [(VName, StructType)]
free

      free_ts :: [(VName, TypeBase Exp Uniqueness)]
free_ts = ((VName, StructType) -> (VName, TypeBase Exp Uniqueness))
-> [(VName, StructType)] -> [(VName, TypeBase Exp Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map ((StructType -> TypeBase Exp Uniqueness)
-> (VName, StructType) -> (VName, TypeBase Exp Uniqueness)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (StructType -> Uniqueness -> TypeBase Exp Uniqueness
forall dim u1 u2. TypeBase dim u1 -> u2 -> TypeBase dim u2
`setUniqueness` Uniqueness
Nonunique)) ([(VName, StructType)] -> [(VName, TypeBase Exp Uniqueness)])
-> [(VName, StructType)] -> [(VName, TypeBase Exp Uniqueness)]
forall a b. (a -> b) -> a -> b
$ [(VName, StructType)]
free_dims [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims

  ValBind -> LiftM ()
addValBind (ValBind -> LiftM ()) -> ValBind -> LiftM ()
forall a b. (a -> b) -> a -> b
$
    ValBind
      { valBindName :: VName
valBindName = VName
fname,
        valBindTypeParams :: [TypeParam]
valBindTypeParams = [TypeParam]
tparams,
        valBindParams :: [Pat ParamType]
valBindParams = ((VName, TypeBase Exp Uniqueness) -> Pat ParamType)
-> [(VName, TypeBase Exp Uniqueness)] -> [Pat ParamType]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TypeBase Exp Uniqueness) -> Pat ParamType
forall {vn} {u}. (vn, TypeBase Exp u) -> PatBase Info vn ParamType
mkParam [(VName, TypeBase Exp Uniqueness)]
free_ts [Pat ParamType] -> [Pat ParamType] -> [Pat ParamType]
forall a. [a] -> [a] -> [a]
++ [Pat ParamType]
params,
        valBindRetDecl :: Maybe (TypeExp Info VName)
valBindRetDecl = Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing,
        valBindRetType :: Info ResRetType
valBindRetType = ResRetType -> Info ResRetType
forall a. a -> Info a
Info ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims TypeBase Exp Uniqueness
ret),
        valBindBody :: Exp
valBindBody = Exp
funbody,
        valBindDoc :: Maybe DocComment
valBindDoc = Maybe DocComment
forall a. Maybe a
Nothing,
        valBindAttrs :: [AttrInfo VName]
valBindAttrs = [AttrInfo VName]
forall a. Monoid a => a
mempty,
        valBindLocation :: SrcLoc
valBindLocation = SrcLoc
forall a. Monoid a => a
mempty,
        valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = Maybe (Info EntryPoint)
forall a. Maybe a
Nothing
      }

  Exp -> LiftM Exp
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (Exp -> LiftM Exp) -> Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [(VName, StructType)] -> Exp
apply
      (QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname) (StructType -> Info StructType
forall a. a -> Info a
Info ([(VName, TypeBase Exp Uniqueness)] -> StructType
forall {u}. [(VName, TypeBase Exp u)] -> StructType
augType [(VName, TypeBase Exp Uniqueness)]
free_ts)) SrcLoc
forall a. Monoid a => a
mempty)
    ([(VName, StructType)] -> Exp) -> [(VName, StructType)] -> Exp
forall a b. (a -> b) -> a -> b
$ [(VName, StructType)]
free_dims [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims
  where
    orig_type :: StructType
orig_type = [Pat ParamType] -> ResRetType -> StructType
funType [Pat ParamType]
params (ResRetType -> StructType) -> ResRetType -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims TypeBase Exp Uniqueness
ret
    mkParam :: (vn, TypeBase Exp u) -> PatBase Info vn ParamType
mkParam (vn
v, TypeBase Exp u
t) = vn -> Info ParamType -> SrcLoc -> PatBase Info vn ParamType
forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id vn
v (ParamType -> Info ParamType
forall a. a -> Info a
Info (Diet -> TypeBase Exp u -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe TypeBase Exp u
t)) SrcLoc
forall a. Monoid a => a
mempty
    freeVar :: (vn, StructType) -> ExpBase Info vn
freeVar (vn
v, StructType
t) = QualName vn -> Info StructType -> SrcLoc -> ExpBase Info vn
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (vn -> QualName vn
forall v. v -> QualName v
qualName vn
v) (StructType -> Info StructType
forall a. a -> Info a
Info StructType
t) SrcLoc
forall a. Monoid a => a
mempty
    augType :: [(VName, TypeBase Exp u)] -> StructType
augType [(VName, TypeBase Exp u)]
rem_free = [Pat ParamType] -> ResRetType -> StructType
funType (((VName, TypeBase Exp u) -> Pat ParamType)
-> [(VName, TypeBase Exp u)] -> [Pat ParamType]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TypeBase Exp u) -> Pat ParamType
forall {vn} {u}. (vn, TypeBase Exp u) -> PatBase Info vn ParamType
mkParam [(VName, TypeBase Exp u)]
rem_free) (ResRetType -> StructType) -> ResRetType -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ StructType -> TypeBase Exp Uniqueness
forall u. TypeBase Exp u -> TypeBase Exp Uniqueness
toRet StructType
orig_type

    apply :: Exp -> [(VName, StructType)] -> Exp
    apply :: Exp -> [(VName, StructType)] -> Exp
apply Exp
f [] = Exp
f
    apply Exp
f ((VName, StructType)
p : [(VName, StructType)]
rem_ps) =
      let inner_ret :: AppRes
inner_ret = StructType -> [VName] -> AppRes
AppRes ([(VName, StructType)] -> StructType
forall {u}. [(VName, TypeBase Exp u)] -> StructType
augType [(VName, StructType)]
rem_ps) [VName]
forall a. Monoid a => a
mempty
          inner :: Exp
inner = Exp -> [(Diet, Maybe VName, Exp)] -> AppRes -> Exp
forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
f [(Diet
Observe, Maybe VName
forall a. Maybe a
Nothing, (VName, StructType) -> Exp
forall {vn}. (vn, StructType) -> ExpBase Info vn
freeVar (VName, StructType)
p)] AppRes
inner_ret
       in Exp -> [(VName, StructType)] -> Exp
apply Exp
inner [(VName, StructType)]
rem_ps

transformSubExps :: ASTMapper LiftM
transformSubExps :: ASTMapper LiftM
transformSubExps = ASTMapper LiftM
forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> LiftM Exp
mapOnExp = Exp -> LiftM Exp
transformExp}

transformExp :: Exp -> LiftM Exp
transformExp :: Exp -> LiftM Exp
transformExp (AppExp (LetFun VName
fname ([TypeParam]
tparams, [Pat ParamType]
params, Maybe (TypeExp Info VName)
_, Info ResRetType
ret, Exp
funbody) Exp
body SrcLoc
_) Info AppRes
_) = do
  Exp
funbody' <- [VName] -> [Pat ParamType] -> LiftM Exp -> LiftM Exp
forall a. [VName] -> [Pat ParamType] -> LiftM a -> LiftM a
bindingParams ((TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams) [Pat ParamType]
params (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> LiftM Exp
transformExp Exp
funbody
  VName
fname' <- [Char] -> LiftM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> LiftM VName) -> [Char] -> LiftM VName
forall a b. (a -> b) -> a -> b
$ [Char]
"lifted_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
baseString VName
fname
  Exp
lifted_call <- VName
-> [TypeParam] -> [Pat ParamType] -> ResRetType -> Exp -> LiftM Exp
liftFunction VName
fname' [TypeParam]
tparams [Pat ParamType]
params ResRetType
ret Exp
funbody'
  VName -> Exp -> LiftM Exp -> LiftM Exp
forall a. VName -> Exp -> LiftM a -> LiftM a
replacing VName
fname Exp
lifted_call (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> LiftM Exp
transformExp Exp
body
transformExp (Lambda [Pat ParamType]
params Exp
body Maybe (TypeExp Info VName)
_ (Info ResRetType
ret) SrcLoc
_) = do
  Exp
body' <- [VName] -> [Pat ParamType] -> LiftM Exp -> LiftM Exp
forall a. [VName] -> [Pat ParamType] -> LiftM a -> LiftM a
bindingParams [] [Pat ParamType]
params (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> LiftM Exp
transformExp Exp
body
  VName
fname <- [Char] -> LiftM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"lifted_lambda"
  VName
-> [TypeParam] -> [Pat ParamType] -> ResRetType -> Exp -> LiftM Exp
liftFunction VName
fname [] [Pat ParamType]
params ResRetType
ret Exp
body'
transformExp (AppExp (LetPat [SizeBinder VName]
sizes Pat StructType
pat Exp
e Exp
body SrcLoc
loc) Info AppRes
appres) = do
  Exp
e' <- Exp -> LiftM Exp
transformExp Exp
e
  Exp
body' <- [VName] -> Pat StructType -> LiftM Exp -> LiftM Exp
forall a. [VName] -> Pat StructType -> LiftM a -> LiftM a
bindingLetPat ((SizeBinder VName -> VName) -> [SizeBinder VName] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SizeBinder VName -> VName
forall vn. SizeBinder vn -> vn
sizeName [SizeBinder VName]
sizes) Pat StructType
pat (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> LiftM Exp
transformExp Exp
body
  Exp -> LiftM Exp
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> LiftM Exp) -> Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp ([SizeBinder VName]
-> Pat StructType -> Exp -> Exp -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat [SizeBinder VName]
sizes Pat StructType
pat Exp
e' Exp
body' SrcLoc
loc) Info AppRes
appres
transformExp (AppExp (Match Exp
e NonEmpty (CaseBase Info VName)
cases SrcLoc
loc) Info AppRes
appres) = do
  Exp
e' <- Exp -> LiftM Exp
transformExp Exp
e
  NonEmpty (CaseBase Info VName)
cases' <- (CaseBase Info VName -> LiftM (CaseBase Info VName))
-> NonEmpty (CaseBase Info VName)
-> LiftM (NonEmpty (CaseBase Info VName))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> NonEmpty a -> m (NonEmpty b)
mapM CaseBase Info VName -> LiftM (CaseBase Info VName)
transformCase NonEmpty (CaseBase Info VName)
cases
  Exp -> LiftM Exp
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> LiftM Exp) -> Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> NonEmpty (CaseBase Info VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn) -> SrcLoc -> AppExpBase f vn
Match Exp
e' NonEmpty (CaseBase Info VName)
cases' SrcLoc
loc) Info AppRes
appres
  where
    transformCase :: CaseBase Info VName -> LiftM (CaseBase Info VName)
transformCase (CasePat Pat StructType
case_pat Exp
case_e SrcLoc
case_loc) =
      Pat StructType -> Exp -> SrcLoc -> CaseBase Info VName
forall (f :: * -> *) vn.
PatBase f vn StructType -> ExpBase f vn -> SrcLoc -> CaseBase f vn
CasePat Pat StructType
case_pat
        (Exp -> SrcLoc -> CaseBase Info VName)
-> LiftM Exp -> LiftM (SrcLoc -> CaseBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Pat StructType -> LiftM Exp -> LiftM Exp
forall a. [VName] -> Pat StructType -> LiftM a -> LiftM a
bindingLetPat [] Pat StructType
case_pat (Exp -> LiftM Exp
transformExp Exp
case_e)
        LiftM (SrcLoc -> CaseBase Info VName)
-> LiftM SrcLoc -> LiftM (CaseBase Info VName)
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> LiftM SrcLoc
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
case_loc
transformExp (AppExp (Loop [VName]
sizes Pat ParamType
pat Exp
args LoopFormBase Info VName
form Exp
body SrcLoc
loc) Info AppRes
appres) = do
  Exp
args' <- Exp -> LiftM Exp
transformExp Exp
args
  [VName] -> [Pat ParamType] -> LiftM Exp -> LiftM Exp
forall a. [VName] -> [Pat ParamType] -> LiftM a -> LiftM a
bindingParams [VName]
sizes [Pat ParamType
pat] (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ do
    LoopFormBase Info VName
form' <- ASTMapper LiftM
-> LoopFormBase Info VName -> LiftM (LoopFormBase Info VName)
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
forall (m :: * -> *).
Monad m =>
ASTMapper m
-> LoopFormBase Info VName -> m (LoopFormBase Info VName)
astMap ASTMapper LiftM
transformSubExps LoopFormBase Info VName
form
    Exp
body' <- LoopFormBase Info VName -> LiftM Exp -> LiftM Exp
forall a. LoopFormBase Info VName -> LiftM a -> LiftM a
bindingForm LoopFormBase Info VName
form' (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> LiftM Exp
transformExp Exp
body
    Exp -> LiftM Exp
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> LiftM Exp) -> Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp ([VName]
-> Pat ParamType
-> Exp
-> LoopFormBase Info VName
-> Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
[VName]
-> PatBase f vn ParamType
-> ExpBase f vn
-> LoopFormBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
Loop [VName]
sizes Pat ParamType
pat Exp
args' LoopFormBase Info VName
form' Exp
body' SrcLoc
loc) Info AppRes
appres
transformExp e :: Exp
e@(Var QualName VName
v Info StructType
_ SrcLoc
_) =
  -- Note that function-typed variables can only occur in expressions,
  -- not in other places where VNames/QualNames can occur.
  (Env -> Exp) -> LiftM Exp
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Exp -> Maybe Exp -> Exp
forall a. a -> Maybe a -> a
fromMaybe Exp
e (Maybe Exp -> Exp) -> (Env -> Maybe Exp) -> Env -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName Exp -> Maybe Exp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) (Map VName Exp -> Maybe Exp)
-> (Env -> Map VName Exp) -> Env -> Maybe Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map VName Exp
envReplace)
transformExp Exp
e = ASTMapper LiftM -> Exp -> LiftM Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
forall (m :: * -> *). Monad m => ASTMapper m -> Exp -> m Exp
astMap ASTMapper LiftM
transformSubExps Exp
e

transformValBind :: ValBind -> LiftM ()
transformValBind :: ValBind -> LiftM ()
transformValBind ValBind
vb = do
  Exp
e <-
    [VName] -> [Pat ParamType] -> LiftM Exp -> LiftM Exp
forall a. [VName] -> [Pat ParamType] -> LiftM a -> LiftM a
bindingParams ((TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName ([TypeParam] -> [VName]) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> a -> b
$ ValBind -> [TypeParam]
forall (f :: * -> *) vn. ValBindBase f vn -> [TypeParamBase vn]
valBindTypeParams ValBind
vb) (ValBind -> [Pat ParamType]
forall (f :: * -> *) vn.
ValBindBase f vn -> [PatBase f vn ParamType]
valBindParams ValBind
vb) (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$
      Exp -> LiftM Exp
transformExp (ValBind -> Exp
forall (f :: * -> *) vn. ValBindBase f vn -> ExpBase f vn
valBindBody ValBind
vb)
  ValBind -> LiftM ()
addValBind (ValBind -> LiftM ()) -> ValBind -> LiftM ()
forall a b. (a -> b) -> a -> b
$ ValBind
vb {valBindBody :: Exp
valBindBody = Exp
e}

{-# NOINLINE transformProg #-}

-- | Perform the transformation.
transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind]
transformProg :: forall (m :: * -> *). MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg [ValBind]
vbinds =
  (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ([ValBind], VNameSource)) -> m [ValBind])
-> (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
    VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM VNameSource
namesrc (LiftM () -> ([ValBind], VNameSource))
-> LiftM () -> ([ValBind], VNameSource)
forall a b. (a -> b) -> a -> b
$ (ValBind -> LiftM ()) -> [ValBind] -> LiftM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ValBind -> LiftM ()
transformValBind [ValBind]
vbinds