-- | Defunctionalization of typed, monomorphic Futhark programs without modules.
module Futhark.Internalise.Defunctionalise (transformProg) where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable
import Data.List (partition, sortOn)
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.IR.Pretty ()
import Futhark.MonadFreshNames
import Futhark.Util (mapAccumLM, nubOrd)
import Language.Futhark
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Types (Subst (..), applySubst)

-- | A static value stores additional information about the result of
-- defunctionalization of an expression, aside from the residual expression.
data StaticVal
  = Dynamic ParamType
  | -- | The Env is the lexical closure of the lambda.
    LambdaSV (Pat ParamType) ResRetType Exp Env
  | RecordSV [(Name, StaticVal)]
  | -- | The constructor that is actually present, plus
    -- the others that are not.
    SumSV Name [StaticVal] [(Name, [ParamType])]
  | -- | The pair is the StaticVal and residual expression of this
    -- function as a whole, while the second StaticVal is its
    -- body. (Don't trust this too much, my understanding may have
    -- holes.)
    DynamicFun (Exp, StaticVal) StaticVal
  | IntrinsicSV
  | HoleSV StructType SrcLoc
  deriving (Int -> StaticVal -> ShowS
[StaticVal] -> ShowS
StaticVal -> [Char]
(Int -> StaticVal -> ShowS)
-> (StaticVal -> [Char])
-> ([StaticVal] -> ShowS)
-> Show StaticVal
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> StaticVal -> ShowS
showsPrec :: Int -> StaticVal -> ShowS
$cshow :: StaticVal -> [Char]
show :: StaticVal -> [Char]
$cshowList :: [StaticVal] -> ShowS
showList :: [StaticVal] -> ShowS
Show)

data Binding = Binding
  { -- | Just if this is a polymorphic binding that must be
    -- instantiated.
    Binding -> Maybe ([VName], StructType)
bindingType :: Maybe ([VName], StructType),
    Binding -> StaticVal
bindingSV :: StaticVal
  }
  deriving (Int -> Binding -> ShowS
[Binding] -> ShowS
Binding -> [Char]
(Int -> Binding -> ShowS)
-> (Binding -> [Char]) -> ([Binding] -> ShowS) -> Show Binding
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Binding -> ShowS
showsPrec :: Int -> Binding -> ShowS
$cshow :: Binding -> [Char]
show :: Binding -> [Char]
$cshowList :: [Binding] -> ShowS
showList :: [Binding] -> ShowS
Show)

-- | Environment mapping variable names to their associated static
-- value.
type Env = M.Map VName Binding

localEnv :: Env -> DefM a -> DefM a
localEnv :: forall a. Env -> DefM a -> DefM a
localEnv Env
env = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a.
((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ (Env -> Env) -> (Set VName, Env) -> (Set VName, Env)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Env
env <>)

-- Even when using a "new" environment (for evaluating closures) we
-- still ram the global environment of DynamicFuns in there.
localNewEnv :: Env -> DefM a -> DefM a
localNewEnv :: forall a. Env -> DefM a -> DefM a
localNewEnv Env
env = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a.
((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ \(Set VName
globals, Env
old_env) ->
  (Set VName
globals, (VName -> Binding -> Bool) -> Env -> Env
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\VName
k Binding
_ -> VName
k VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals) Env
old_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
env)

askEnv :: DefM Env
askEnv :: DefM Env
askEnv = ((Set VName, Env) -> Env) -> DefM Env
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Env
forall a b. (a, b) -> b
snd

areGlobal :: [VName] -> DefM a -> DefM a
areGlobal :: forall a. [VName] -> DefM a -> DefM a
areGlobal [VName]
vs = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a.
((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ (Set VName -> Set VName) -> (Set VName, Env) -> (Set VName, Env)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
vs <>)

replaceTypeSizes ::
  M.Map VName SizeSubst ->
  TypeBase Size als ->
  TypeBase Size als
replaceTypeSizes :: forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs = (Exp -> Exp) -> TypeBase Exp als -> TypeBase Exp als
forall a b c. (a -> b) -> TypeBase a c -> TypeBase b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Exp -> Exp
onDim
  where
    onDim :: Exp -> Exp
onDim (Var QualName VName
v Info StructType
typ SrcLoc
loc) =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
v' Info StructType
typ SrcLoc
loc
        Just (SubstConst Int64
d) -> Integer -> SrcLoc -> Exp
sizeFromInteger (Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
d) SrcLoc
loc
        Maybe SizeSubst
Nothing -> QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
v Info StructType
typ SrcLoc
loc
    onDim Exp
d = Exp
d

replaceStaticValSizes ::
  S.Set VName ->
  M.Map VName SizeSubst ->
  StaticVal ->
  StaticVal
replaceStaticValSizes :: Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv =
  case StaticVal
sv of
    StaticVal
_ | Map VName SizeSubst -> Bool
forall k a. Map k a -> Bool
M.null Map VName SizeSubst
orig_substs -> StaticVal
sv
    LambdaSV Pat ParamType
param (RetType [VName]
t_dims TypeBase Exp Uniqueness
t) Exp
e Env
closure_env ->
      let substs :: Map VName SizeSubst
substs =
            (Map VName SizeSubst -> VName -> Map VName SizeSubst)
-> Map VName SizeSubst -> Set VName -> Map VName SizeSubst
forall b a. (b -> a -> b) -> b -> Set a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> Map VName SizeSubst -> Map VName SizeSubst)
-> Map VName SizeSubst -> VName -> Map VName SizeSubst
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName SizeSubst -> Map VName SizeSubst
forall k a. Ord k => k -> Map k a -> Map k a
M.delete) Map VName SizeSubst
orig_substs (Set VName -> Map VName SizeSubst)
-> Set VName -> Map VName SizeSubst
forall a b. (a -> b) -> a -> b
$
              [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList (Env -> [VName]
forall k a. Map k a -> [k]
M.keys Env
closure_env)
       in Pat ParamType -> ResRetType -> Exp -> Env -> StaticVal
LambdaSV
            ((ParamType -> ParamType) -> Pat ParamType -> Pat ParamType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Map VName SizeSubst -> ParamType -> ParamType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs) Pat ParamType
param)
            ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
t_dims (Map VName SizeSubst
-> TypeBase Exp Uniqueness -> TypeBase Exp Uniqueness
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs TypeBase Exp Uniqueness
t))
            (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e)
            (Map VName SizeSubst -> Env -> Env
forall {k}.
Ord k =>
Map VName SizeSubst -> Map k Binding -> Map k Binding
onEnv Map VName SizeSubst
orig_substs Env
closure_env) -- intentional
    Dynamic ParamType
t ->
      ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> ParamType -> ParamType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
orig_substs ParamType
t
    RecordSV [(Name, StaticVal)]
fs ->
      [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> (Name, StaticVal))
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> StaticVal) -> (Name, StaticVal) -> (Name, StaticVal)
forall a b. (a -> b) -> (Name, a) -> (Name, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs)) [(Name, StaticVal)]
fs
    SumSV Name
c [StaticVal]
svs [(Name, [ParamType])]
ts ->
      Name -> [StaticVal] -> [(Name, [ParamType])] -> StaticVal
SumSV Name
c ((StaticVal -> StaticVal) -> [StaticVal] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs) [StaticVal]
svs) ([(Name, [ParamType])] -> StaticVal)
-> [(Name, [ParamType])] -> StaticVal
forall a b. (a -> b) -> a -> b
$
        ((Name, [ParamType]) -> (Name, [ParamType]))
-> [(Name, [ParamType])] -> [(Name, [ParamType])]
forall a b. (a -> b) -> [a] -> [b]
map (([ParamType] -> [ParamType])
-> (Name, [ParamType]) -> (Name, [ParamType])
forall a b. (a -> b) -> (Name, a) -> (Name, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([ParamType] -> [ParamType])
 -> (Name, [ParamType]) -> (Name, [ParamType]))
-> ([ParamType] -> [ParamType])
-> (Name, [ParamType])
-> (Name, [ParamType])
forall a b. (a -> b) -> a -> b
$ (ParamType -> ParamType) -> [ParamType] -> [ParamType]
forall a b. (a -> b) -> [a] -> [b]
map ((ParamType -> ParamType) -> [ParamType] -> [ParamType])
-> (ParamType -> ParamType) -> [ParamType] -> [ParamType]
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> ParamType -> ParamType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
orig_substs) [(Name, [ParamType])]
ts
    DynamicFun (Exp
e, StaticVal
sv1) StaticVal
sv2 ->
      (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
orig_substs Exp
e, Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv1) (StaticVal -> StaticVal) -> StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$
        Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv2
    StaticVal
IntrinsicSV ->
      StaticVal
IntrinsicSV
    HoleSV StructType
t SrcLoc
loc ->
      StructType -> SrcLoc -> StaticVal
HoleSV StructType
t SrcLoc
loc
  where
    tv :: Map VName SizeSubst -> ASTMapper Identity
tv Map VName SizeSubst
substs =
      ASTMapper
        { mapOnStructType :: StructType -> Identity StructType
mapOnStructType = StructType -> Identity StructType
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StructType -> Identity StructType)
-> (StructType -> StructType) -> StructType -> Identity StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs,
          mapOnParamType :: ParamType -> Identity ParamType
mapOnParamType = ParamType -> Identity ParamType
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ParamType -> Identity ParamType)
-> (ParamType -> ParamType) -> ParamType -> Identity ParamType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> ParamType -> ParamType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs,
          mapOnResRetType :: ResRetType -> Identity ResRetType
mapOnResRetType = ResRetType -> Identity ResRetType
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnExp :: Exp -> Identity Exp
mapOnExp = Exp -> Identity Exp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> Identity Exp) -> (Exp -> Exp) -> Exp -> Identity Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs,
          mapOnName :: VName -> Identity VName
mapOnName = VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs
        }

    onName :: Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs VName
v =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v'
        Maybe SizeSubst
_ -> VName
v

    onExp :: Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs (Var QualName VName
v Info StructType
t SrcLoc
loc) =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') ->
          QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
v' Info StructType
t SrcLoc
loc
        Just (SubstConst Int64
d) ->
          PrimValue -> SrcLoc -> Exp
forall (f :: * -> *) vn. PrimValue -> SrcLoc -> ExpBase f vn
Literal (IntValue -> PrimValue
SignedValue (Int64 -> IntValue
Int64Value (Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
d))) SrcLoc
loc
        Maybe SizeSubst
Nothing ->
          QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
v (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs (StructType -> StructType) -> Info StructType -> Info StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Info StructType
t) SrcLoc
loc
    onExp Map VName SizeSubst
substs (Coerce Exp
e TypeExp Info VName
te Info StructType
t SrcLoc
loc) =
      Exp -> TypeExp Info VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> TypeExp f vn -> f StructType -> SrcLoc -> ExpBase f vn
Coerce (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e) TypeExp Info VName
te (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs (StructType -> StructType) -> Info StructType -> Info StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Info StructType
t) SrcLoc
loc
    onExp Map VName SizeSubst
substs (Lambda [Pat ParamType]
params Exp
e Maybe (TypeExp Info VName)
ret (Info (RetType [VName]
t_dims TypeBase Exp Uniqueness
t)) SrcLoc
loc) =
      [Pat ParamType]
-> Exp
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda
        ((Pat ParamType -> Pat ParamType)
-> [Pat ParamType] -> [Pat ParamType]
forall a b. (a -> b) -> [a] -> [b]
map ((ParamType -> ParamType) -> Pat ParamType -> Pat ParamType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ParamType -> ParamType) -> Pat ParamType -> Pat ParamType)
-> (ParamType -> ParamType) -> Pat ParamType -> Pat ParamType
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> ParamType -> ParamType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs) [Pat ParamType]
params)
        (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e)
        Maybe (TypeExp Info VName)
ret
        (ResRetType -> Info ResRetType
forall a. a -> Info a
Info ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
t_dims (Map VName SizeSubst
-> TypeBase Exp Uniqueness -> TypeBase Exp Uniqueness
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs TypeBase Exp Uniqueness
t)))
        SrcLoc
loc
    onExp Map VName SizeSubst
substs Exp
e = Identity Exp -> Exp
forall a. Identity a -> a
runIdentity (Identity Exp -> Exp) -> Identity Exp -> Exp
forall a b. (a -> b) -> a -> b
$ ASTMapper Identity -> Exp -> Identity Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
forall (m :: * -> *). Monad m => ASTMapper m -> Exp -> m Exp
astMap (Map VName SizeSubst -> ASTMapper Identity
tv Map VName SizeSubst
substs) Exp
e

    onEnv :: Map VName SizeSubst -> Map k Binding -> Map k Binding
onEnv Map VName SizeSubst
substs =
      [(k, Binding)] -> Map k Binding
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
        ([(k, Binding)] -> Map k Binding)
-> (Map k Binding -> [(k, Binding)])
-> Map k Binding
-> Map k Binding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((k, Binding) -> (k, Binding)) -> [(k, Binding)] -> [(k, Binding)]
forall a b. (a -> b) -> [a] -> [b]
map ((Binding -> Binding) -> (k, Binding) -> (k, Binding)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Map VName SizeSubst -> Binding -> Binding
onBinding Map VName SizeSubst
substs))
        ([(k, Binding)] -> [(k, Binding)])
-> (Map k Binding -> [(k, Binding)])
-> Map k Binding
-> [(k, Binding)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map k Binding -> [(k, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList

    onBinding :: Map VName SizeSubst -> Binding -> Binding
onBinding Map VName SizeSubst
substs (Binding Maybe ([VName], StructType)
t StaticVal
bsv) =
      Maybe ([VName], StructType) -> StaticVal -> Binding
Binding
        ((StructType -> StructType)
-> ([VName], StructType) -> ([VName], StructType)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
substs) (([VName], StructType) -> ([VName], StructType))
-> Maybe ([VName], StructType) -> Maybe ([VName], StructType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ([VName], StructType)
t)
        (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
substs StaticVal
bsv)

-- | Returns the defunctionalization environment restricted
-- to the given set of variable names.
restrictEnvTo :: FV -> DefM Env
restrictEnvTo :: FV -> DefM Env
restrictEnvTo FV
fv = ((Set VName, Env) -> Env) -> DefM Env
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Env
restrict
  where
    restrict :: (Set VName, Env) -> Env
restrict (Set VName
globals, Env
env) = (VName -> Binding -> Maybe Binding) -> Env -> Env
forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
M.mapMaybeWithKey VName -> Binding -> Maybe Binding
keep Env
env
      where
        keep :: VName -> Binding -> Maybe Binding
keep VName
k (Binding Maybe ([VName], StructType)
t StaticVal
sv) = do
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (VName
k VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals) Bool -> Bool -> Bool
&& VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VName
k (FV -> Set VName
fvVars FV
fv)
          Binding -> Maybe Binding
forall a. a -> Maybe a
Just (Binding -> Maybe Binding) -> Binding -> Maybe Binding
forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
t (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ StaticVal -> StaticVal
restrict' StaticVal
sv
    restrict' :: StaticVal -> StaticVal
restrict' (Dynamic ParamType
t) =
      ParamType -> StaticVal
Dynamic ParamType
t
    restrict' (LambdaSV Pat ParamType
pat ResRetType
t Exp
e Env
env) =
      Pat ParamType -> ResRetType -> Exp -> Env -> StaticVal
LambdaSV Pat ParamType
pat ResRetType
t Exp
e (Env -> StaticVal) -> Env -> StaticVal
forall a b. (a -> b) -> a -> b
$ (Binding -> Binding) -> Env -> Env
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Binding -> Binding
restrict'' Env
env
    restrict' (RecordSV [(Name, StaticVal)]
fields) =
      [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> (Name, StaticVal))
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> StaticVal) -> (Name, StaticVal) -> (Name, StaticVal)
forall a b. (a -> b) -> (Name, a) -> (Name, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StaticVal -> StaticVal
restrict') [(Name, StaticVal)]
fields
    restrict' (SumSV Name
c [StaticVal]
svs [(Name, [ParamType])]
fields) =
      Name -> [StaticVal] -> [(Name, [ParamType])] -> StaticVal
SumSV Name
c ((StaticVal -> StaticVal) -> [StaticVal] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map StaticVal -> StaticVal
restrict' [StaticVal]
svs) [(Name, [ParamType])]
fields
    restrict' (DynamicFun (Exp
e, StaticVal
sv1) StaticVal
sv2) =
      (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp
e, StaticVal -> StaticVal
restrict' StaticVal
sv1) (StaticVal -> StaticVal) -> StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$ StaticVal -> StaticVal
restrict' StaticVal
sv2
    restrict' StaticVal
IntrinsicSV = StaticVal
IntrinsicSV
    restrict' (HoleSV StructType
t SrcLoc
loc) = StructType -> SrcLoc -> StaticVal
HoleSV StructType
t SrcLoc
loc
    restrict'' :: Binding -> Binding
restrict'' (Binding Maybe ([VName], StructType)
t StaticVal
sv) = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
t (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ StaticVal -> StaticVal
restrict' StaticVal
sv

-- | Defunctionalization monad.  The Reader environment tracks both
-- the current Env as well as the set of globally defined dynamic
-- functions.  This is used to avoid unnecessarily large closure
-- environments.
newtype DefM a
  = DefM (ReaderT (S.Set VName, Env) (State ([ValBind], VNameSource)) a)
  deriving
    ( (forall a b. (a -> b) -> DefM a -> DefM b)
-> (forall a b. a -> DefM b -> DefM a) -> Functor DefM
forall a b. a -> DefM b -> DefM a
forall a b. (a -> b) -> DefM a -> DefM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> DefM a -> DefM b
fmap :: forall a b. (a -> b) -> DefM a -> DefM b
$c<$ :: forall a b. a -> DefM b -> DefM a
<$ :: forall a b. a -> DefM b -> DefM a
Functor,
      Functor DefM
Functor DefM
-> (forall a. a -> DefM a)
-> (forall a b. DefM (a -> b) -> DefM a -> DefM b)
-> (forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c)
-> (forall a b. DefM a -> DefM b -> DefM b)
-> (forall a b. DefM a -> DefM b -> DefM a)
-> Applicative DefM
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM (a -> b) -> DefM a -> DefM b
forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> DefM a
pure :: forall a. a -> DefM a
$c<*> :: forall a b. DefM (a -> b) -> DefM a -> DefM b
<*> :: forall a b. DefM (a -> b) -> DefM a -> DefM b
$cliftA2 :: forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
liftA2 :: forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
$c*> :: forall a b. DefM a -> DefM b -> DefM b
*> :: forall a b. DefM a -> DefM b -> DefM b
$c<* :: forall a b. DefM a -> DefM b -> DefM a
<* :: forall a b. DefM a -> DefM b -> DefM a
Applicative,
      Applicative DefM
Applicative DefM
-> (forall a b. DefM a -> (a -> DefM b) -> DefM b)
-> (forall a b. DefM a -> DefM b -> DefM b)
-> (forall a. a -> DefM a)
-> Monad DefM
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM a -> (a -> DefM b) -> DefM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. DefM a -> (a -> DefM b) -> DefM b
>>= :: forall a b. DefM a -> (a -> DefM b) -> DefM b
$c>> :: forall a b. DefM a -> DefM b -> DefM b
>> :: forall a b. DefM a -> DefM b -> DefM b
$creturn :: forall a. a -> DefM a
return :: forall a. a -> DefM a
Monad,
      MonadReader (S.Set VName, Env),
      MonadState ([ValBind], VNameSource)
    )

instance MonadFreshNames DefM where
  putNameSource :: VNameSource -> DefM ()
putNameSource VNameSource
src = (([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ())
-> (([ValBind], VNameSource) -> ([ValBind], VNameSource))
-> DefM ()
forall a b. (a -> b) -> a -> b
$ \([ValBind]
x, VNameSource
_) -> ([ValBind]
x, VNameSource
src)
  getNameSource :: DefM VNameSource
getNameSource = (([ValBind], VNameSource) -> VNameSource) -> DefM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ([ValBind], VNameSource) -> VNameSource
forall a b. (a, b) -> b
snd

-- | Run a computation in the defunctionalization monad. Returns the result of
-- the computation, a new name source, and a list of lifted function declations.
runDefM :: VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM :: forall a. VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM VNameSource
src (DefM ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
m) =
  let (a
x, ([ValBind]
vbs, VNameSource
src')) = State ([ValBind], VNameSource) a
-> ([ValBind], VNameSource) -> (a, ([ValBind], VNameSource))
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
-> (Set VName, Env) -> State ([ValBind], VNameSource) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
m (Set VName, Env)
forall a. Monoid a => a
mempty) ([ValBind]
forall a. Monoid a => a
mempty, VNameSource
src)
   in (a
x, VNameSource
src', [ValBind] -> [ValBind]
forall a. [a] -> [a]
reverse [ValBind]
vbs)

addValBind :: ValBind -> DefM ()
addValBind :: ValBind -> DefM ()
addValBind ValBind
vb = (([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ())
-> (([ValBind], VNameSource) -> ([ValBind], VNameSource))
-> DefM ()
forall a b. (a -> b) -> a -> b
$ ([ValBind] -> [ValBind])
-> ([ValBind], VNameSource) -> ([ValBind], VNameSource)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ValBind
vb :)

-- | Create a new top-level value declaration with the given function name,
-- return type, list of parameters, and body expression.
liftValDec :: VName -> ResRetType -> [VName] -> [Pat ParamType] -> Exp -> DefM ()
liftValDec :: VName -> ResRetType -> [VName] -> [Pat ParamType] -> Exp -> DefM ()
liftValDec VName
fname (RetType [VName]
ret_dims TypeBase Exp Uniqueness
ret) [VName]
dims [Pat ParamType]
pats Exp
body = ValBind -> DefM ()
addValBind ValBind
dec
  where
    dims' :: [TypeParamBase VName]
dims' = (VName -> TypeParamBase VName) -> [VName] -> [TypeParamBase VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SrcLoc -> TypeParamBase VName
forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` SrcLoc
forall a. Monoid a => a
mempty) [VName]
dims
    -- FIXME: this pass is still not correctly size-preserving, so
    -- forget those return sizes that we forgot to propagate along
    -- the way.  Hopefully the internaliser is conservative and
    -- will insert reshapes...
    bound_here :: Set VName
bound_here = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ [VName]
dims [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> (Pat ParamType -> [VName]) -> [Pat ParamType] -> [VName]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> [VName]
forall t. Pat t -> [VName]
patNames [Pat ParamType]
pats
    mkExt :: VName -> Maybe VName
mkExt VName
v
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound_here = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
    mkExt VName
_ = Maybe VName
forall a. Maybe a
Nothing
    rettype_st :: ResRetType
rettype_st = [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType ((VName -> Maybe VName) -> [VName] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe VName -> Maybe VName
mkExt (Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ FV -> Set VName
fvVars (FV -> Set VName) -> FV -> Set VName
forall a b. (a -> b) -> a -> b
$ TypeBase Exp Uniqueness -> FV
forall u. TypeBase Exp u -> FV
freeInType TypeBase Exp Uniqueness
ret) [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
ret_dims) TypeBase Exp Uniqueness
ret

    dec :: ValBind
dec =
      ValBind
        { valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = Maybe (Info EntryPoint)
forall a. Maybe a
Nothing,
          valBindName :: VName
valBindName = VName
fname,
          valBindRetDecl :: Maybe (TypeExp Info VName)
valBindRetDecl = Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing,
          valBindRetType :: Info ResRetType
valBindRetType = ResRetType -> Info ResRetType
forall a. a -> Info a
Info ResRetType
rettype_st,
          valBindTypeParams :: [TypeParamBase VName]
valBindTypeParams = [TypeParamBase VName]
dims',
          valBindParams :: [Pat ParamType]
valBindParams = [Pat ParamType]
pats,
          valBindBody :: Exp
valBindBody = Exp
body,
          valBindDoc :: Maybe DocComment
valBindDoc = Maybe DocComment
forall a. Maybe a
Nothing,
          valBindAttrs :: [AttrInfo VName]
valBindAttrs = [AttrInfo VName]
forall a. Monoid a => a
mempty,
          valBindLocation :: SrcLoc
valBindLocation = SrcLoc
forall a. Monoid a => a
mempty
        }

-- | Looks up the associated static value for a given name in the environment.
lookupVar :: StructType -> VName -> DefM StaticVal
lookupVar :: StructType -> VName -> DefM StaticVal
lookupVar StructType
t VName
x = do
  Env
env <- DefM Env
askEnv
  case VName -> Env -> Maybe Binding
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Env
env of
    Just (Binding (Just ([VName]
dims, StructType
sv_t)) StaticVal
sv) -> do
      Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst
      Set VName
-> [VName]
-> StructType
-> StructType
-> StaticVal
-> DefM StaticVal
forall (m :: * -> *).
MonadFreshNames m =>
Set VName
-> [VName] -> StructType -> StructType -> StaticVal -> m StaticVal
instStaticVal Set VName
globals [VName]
dims StructType
t StructType
sv_t StaticVal
sv
    Just (Binding Maybe ([VName], StructType)
Nothing StaticVal
sv) ->
      StaticVal -> DefM StaticVal
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StaticVal
sv
    Maybe Binding
Nothing -- If the variable is unknown, it may refer to the 'intrinsics'
    -- module, which we will have to treat specially.
      | VName -> Int
baseTag VName
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag -> StaticVal -> DefM StaticVal
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StaticVal
IntrinsicSV
      | Bool
otherwise ->
          -- Anything not in scope is going to be an existential size.
          StaticVal -> DefM StaticVal
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StaticVal -> DefM StaticVal) -> StaticVal -> DefM StaticVal
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase Exp Diet -> ParamType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Diet -> ParamType)
-> ScalarTypeBase Exp Diet -> ParamType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase Exp Diet
forall dim u. PrimType -> ScalarTypeBase dim u
Prim (PrimType -> ScalarTypeBase Exp Diet)
-> PrimType -> ScalarTypeBase Exp Diet
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

-- Like freeInPat, but ignores sizes that are only found in
-- funtion types.
arraySizes :: StructType -> S.Set VName
arraySizes :: StructType -> Set VName
arraySizes (Scalar Arrow {}) = Set VName
forall a. Monoid a => a
mempty
arraySizes (Scalar (Record Map Name StructType
fields)) = (StructType -> Set VName) -> Map Name StructType -> Set VName
forall m a. Monoid m => (a -> m) -> Map Name a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes Map Name StructType
fields
arraySizes (Scalar (Sum Map Name [StructType]
cs)) = ([StructType] -> Set VName) -> Map Name [StructType] -> Set VName
forall m a. Monoid m => (a -> m) -> Map Name a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((StructType -> Set VName) -> [StructType] -> Set VName
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes) Map Name [StructType]
cs
arraySizes (Scalar (TypeVar NoUniqueness
_ QualName VName
_ [TypeArg Exp]
targs)) =
  [Set VName] -> Set VName
forall a. Monoid a => [a] -> a
mconcat ([Set VName] -> Set VName) -> [Set VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ (TypeArg Exp -> Set VName) -> [TypeArg Exp] -> [Set VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeArg Exp -> Set VName
f [TypeArg Exp]
targs
  where
    f :: TypeArg Exp -> Set VName
f (TypeArgDim (Var QualName VName
d Info StructType
_ SrcLoc
_)) = VName -> Set VName
forall a. a -> Set a
S.singleton (VName -> Set VName) -> VName -> Set VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
    f TypeArgDim {} = Set VName
forall a. Monoid a => a
mempty
    f (TypeArgType StructType
t) = StructType -> Set VName
arraySizes StructType
t
arraySizes (Scalar Prim {}) = Set VName
forall a. Monoid a => a
mempty
arraySizes (Array NoUniqueness
_ Shape Exp
shape ScalarTypeBase Exp NoUniqueness
t) =
  StructType -> Set VName
arraySizes (ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar ScalarTypeBase Exp NoUniqueness
t) Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (Exp -> Set VName) -> [Exp] -> Set VName
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> Set VName
dimName (Shape Exp -> [Exp]
forall dim. Shape dim -> [dim]
shapeDims Shape Exp
shape)
  where
    dimName :: Size -> S.Set VName
    dimName :: Exp -> Set VName
dimName (Var QualName VName
qn Info StructType
_ SrcLoc
_) = VName -> Set VName
forall a. a -> Set a
S.singleton (VName -> Set VName) -> VName -> Set VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn
    dimName Exp
_ = Set VName
forall a. Monoid a => a
mempty

patternArraySizes :: Pat ParamType -> S.Set VName
patternArraySizes :: Pat ParamType -> Set VName
patternArraySizes = StructType -> Set VName
arraySizes (StructType -> Set VName)
-> (Pat ParamType -> StructType) -> Pat ParamType -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat ParamType -> StructType
forall u. Pat (TypeBase Exp u) -> StructType
patternStructType

data SizeSubst
  = SubstNamed (QualName VName)
  | SubstConst Int64
  deriving (SizeSubst -> SizeSubst -> Bool
(SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool) -> Eq SizeSubst
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SizeSubst -> SizeSubst -> Bool
== :: SizeSubst -> SizeSubst -> Bool
$c/= :: SizeSubst -> SizeSubst -> Bool
/= :: SizeSubst -> SizeSubst -> Bool
Eq, Eq SizeSubst
Eq SizeSubst
-> (SizeSubst -> SizeSubst -> Ordering)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> SizeSubst)
-> (SizeSubst -> SizeSubst -> SizeSubst)
-> Ord SizeSubst
SizeSubst -> SizeSubst -> Bool
SizeSubst -> SizeSubst -> Ordering
SizeSubst -> SizeSubst -> SizeSubst
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SizeSubst -> SizeSubst -> Ordering
compare :: SizeSubst -> SizeSubst -> Ordering
$c< :: SizeSubst -> SizeSubst -> Bool
< :: SizeSubst -> SizeSubst -> Bool
$c<= :: SizeSubst -> SizeSubst -> Bool
<= :: SizeSubst -> SizeSubst -> Bool
$c> :: SizeSubst -> SizeSubst -> Bool
> :: SizeSubst -> SizeSubst -> Bool
$c>= :: SizeSubst -> SizeSubst -> Bool
>= :: SizeSubst -> SizeSubst -> Bool
$cmax :: SizeSubst -> SizeSubst -> SizeSubst
max :: SizeSubst -> SizeSubst -> SizeSubst
$cmin :: SizeSubst -> SizeSubst -> SizeSubst
min :: SizeSubst -> SizeSubst -> SizeSubst
Ord, Int -> SizeSubst -> ShowS
[SizeSubst] -> ShowS
SizeSubst -> [Char]
(Int -> SizeSubst -> ShowS)
-> (SizeSubst -> [Char])
-> ([SizeSubst] -> ShowS)
-> Show SizeSubst
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SizeSubst -> ShowS
showsPrec :: Int -> SizeSubst -> ShowS
$cshow :: SizeSubst -> [Char]
show :: SizeSubst -> [Char]
$cshowList :: [SizeSubst] -> ShowS
showList :: [SizeSubst] -> ShowS
Show)

dimMapping ::
  Monoid a =>
  TypeBase Size a ->
  TypeBase Size a ->
  M.Map VName SizeSubst
dimMapping :: forall a.
Monoid a =>
TypeBase Exp a -> TypeBase Exp a -> Map VName SizeSubst
dimMapping TypeBase Exp a
t1 TypeBase Exp a
t2 = State (Map VName SizeSubst) (TypeBase Exp a)
-> Map VName SizeSubst -> Map VName SizeSubst
forall s a. State s a -> s -> s
execState (([VName]
 -> Exp -> Exp -> StateT (Map VName SizeSubst) Identity Exp)
-> TypeBase Exp a
-> TypeBase Exp a
-> State (Map VName SizeSubst) (TypeBase Exp a)
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
([VName] -> d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims [VName] -> Exp -> Exp -> StateT (Map VName SizeSubst) Identity Exp
forall {t :: * -> *} {f :: * -> *} {k} {f :: * -> *} {f :: * -> *}.
(Foldable t, MonadState (Map k SizeSubst) f, Ord k) =>
t VName -> ExpBase f k -> ExpBase f VName -> f (ExpBase f k)
f TypeBase Exp a
t1 TypeBase Exp a
t2) Map VName SizeSubst
forall a. Monoid a => a
mempty
  where
    f :: t VName -> ExpBase f k -> ExpBase f VName -> f (ExpBase f k)
f t VName
bound ExpBase f k
d1 (Var QualName VName
d2 f StructType
_ SrcLoc
_)
      | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d2 VName -> t VName -> Bool
forall a. Eq a => a -> t a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t VName
bound = ExpBase f k -> f (ExpBase f k)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpBase f k
d1
    f t VName
_ (Var QualName k
d1 f StructType
typ SrcLoc
loc) (Var QualName VName
d2 f StructType
_ SrcLoc
_) = do
      (Map k SizeSubst -> Map k SizeSubst) -> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map k SizeSubst -> Map k SizeSubst) -> f ())
-> (Map k SizeSubst -> Map k SizeSubst) -> f ()
forall a b. (a -> b) -> a -> b
$ k -> SizeSubst -> Map k SizeSubst -> Map k SizeSubst
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName k -> k
forall vn. QualName vn -> vn
qualLeaf QualName k
d1) (SizeSubst -> Map k SizeSubst -> Map k SizeSubst)
-> SizeSubst -> Map k SizeSubst -> Map k SizeSubst
forall a b. (a -> b) -> a -> b
$ QualName VName -> SizeSubst
SubstNamed QualName VName
d2
      ExpBase f k -> f (ExpBase f k)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpBase f k -> f (ExpBase f k)) -> ExpBase f k -> f (ExpBase f k)
forall a b. (a -> b) -> a -> b
$ QualName k -> f StructType -> SrcLoc -> ExpBase f k
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName k
d1 f StructType
typ SrcLoc
loc
    f t VName
_ (Var QualName k
d1 f StructType
typ SrcLoc
loc) (IntLit Integer
d2 f StructType
_ SrcLoc
_) = do
      (Map k SizeSubst -> Map k SizeSubst) -> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map k SizeSubst -> Map k SizeSubst) -> f ())
-> (Map k SizeSubst -> Map k SizeSubst) -> f ()
forall a b. (a -> b) -> a -> b
$ k -> SizeSubst -> Map k SizeSubst -> Map k SizeSubst
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName k -> k
forall vn. QualName vn -> vn
qualLeaf QualName k
d1) (SizeSubst -> Map k SizeSubst -> Map k SizeSubst)
-> SizeSubst -> Map k SizeSubst -> Map k SizeSubst
forall a b. (a -> b) -> a -> b
$ Int64 -> SizeSubst
SubstConst (Int64 -> SizeSubst) -> Int64 -> SizeSubst
forall a b. (a -> b) -> a -> b
$ Integer -> Int64
forall a. Num a => Integer -> a
fromInteger Integer
d2
      ExpBase f k -> f (ExpBase f k)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpBase f k -> f (ExpBase f k)) -> ExpBase f k -> f (ExpBase f k)
forall a b. (a -> b) -> a -> b
$ QualName k -> f StructType -> SrcLoc -> ExpBase f k
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName k
d1 f StructType
typ SrcLoc
loc
    f t VName
_ ExpBase f k
d ExpBase f VName
_ = ExpBase f k -> f (ExpBase f k)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpBase f k
d

dimMapping' ::
  Monoid a =>
  TypeBase Size a ->
  TypeBase Size a ->
  M.Map VName VName
dimMapping' :: forall a.
Monoid a =>
TypeBase Exp a -> TypeBase Exp a -> Map VName VName
dimMapping' TypeBase Exp a
t1 TypeBase Exp a
t2 = (SizeSubst -> Maybe VName)
-> Map VName SizeSubst -> Map VName VName
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe SizeSubst -> Maybe VName
f (Map VName SizeSubst -> Map VName VName)
-> Map VName SizeSubst -> Map VName VName
forall a b. (a -> b) -> a -> b
$ TypeBase Exp a -> TypeBase Exp a -> Map VName SizeSubst
forall a.
Monoid a =>
TypeBase Exp a -> TypeBase Exp a -> Map VName SizeSubst
dimMapping TypeBase Exp a
t1 TypeBase Exp a
t2
  where
    f :: SizeSubst -> Maybe VName
f (SubstNamed QualName VName
d) = VName -> Maybe VName
forall a. a -> Maybe a
Just (VName -> Maybe VName) -> VName -> Maybe VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
    f SizeSubst
_ = Maybe VName
forall a. Maybe a
Nothing

sizesToRename :: StaticVal -> S.Set VName
sizesToRename :: StaticVal -> Set VName
sizesToRename (DynamicFun (Exp
_, StaticVal
sv1) StaticVal
sv2) =
  StaticVal -> Set VName
sizesToRename StaticVal
sv1 Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> StaticVal -> Set VName
sizesToRename StaticVal
sv2
sizesToRename StaticVal
IntrinsicSV =
  Set VName
forall a. Monoid a => a
mempty
sizesToRename HoleSV {} =
  Set VName
forall a. Monoid a => a
mempty
sizesToRename Dynamic {} =
  Set VName
forall a. Monoid a => a
mempty
sizesToRename (RecordSV [(Name, StaticVal)]
fs) =
  ((Name, StaticVal) -> Set VName)
-> [(Name, StaticVal)] -> Set VName
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StaticVal -> Set VName
sizesToRename (StaticVal -> Set VName)
-> ((Name, StaticVal) -> StaticVal)
-> (Name, StaticVal)
-> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd) [(Name, StaticVal)]
fs
sizesToRename (SumSV Name
_ [StaticVal]
svs [(Name, [ParamType])]
_) =
  (StaticVal -> Set VName) -> [StaticVal] -> Set VName
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StaticVal -> Set VName
sizesToRename [StaticVal]
svs
sizesToRename (LambdaSV Pat ParamType
param ResRetType
_ Exp
_ Env
_) =
  -- We used to rename parameters here, but I don't understand why
  -- that was necessary and it caused some problems.
  FV -> Set VName
fvVars (Pat ParamType -> FV
forall u. Pat (TypeBase Exp u) -> FV
freeInPat Pat ParamType
param)

-- | Combine the shape information of types as much as possible. The first
-- argument is the orignal type and the second is the type of the transformed
-- expression. This is necessary since the original type may contain additional
-- information (e.g., shape restrictions) from the user given annotation.
combineTypeShapes ::
  (Monoid as) =>
  TypeBase Size as ->
  TypeBase Size as ->
  TypeBase Size as
combineTypeShapes :: forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes (Scalar (Record Map Name (TypeBase Exp as)
ts1)) (Scalar (Record Map Name (TypeBase Exp as)
ts2))
  | Map Name (TypeBase Exp as) -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name (TypeBase Exp as)
ts1 [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== Map Name (TypeBase Exp as) -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name (TypeBase Exp as)
ts2 =
      ScalarTypeBase Exp as -> TypeBase Exp as
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp as -> TypeBase Exp as)
-> ScalarTypeBase Exp as -> TypeBase Exp as
forall a b. (a -> b) -> a -> b
$
        Map Name (TypeBase Exp as) -> ScalarTypeBase Exp as
forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record (Map Name (TypeBase Exp as) -> ScalarTypeBase Exp as)
-> Map Name (TypeBase Exp as) -> ScalarTypeBase Exp as
forall a b. (a -> b) -> a -> b
$
          ((TypeBase Exp as, TypeBase Exp as) -> TypeBase Exp as)
-> Map Name (TypeBase Exp as, TypeBase Exp as)
-> Map Name (TypeBase Exp as)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map
            ((TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as)
-> (TypeBase Exp as, TypeBase Exp as) -> TypeBase Exp as
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes)
            ((TypeBase Exp as
 -> TypeBase Exp as -> (TypeBase Exp as, TypeBase Exp as))
-> Map Name (TypeBase Exp as)
-> Map Name (TypeBase Exp as)
-> Map Name (TypeBase Exp as, TypeBase Exp as)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name (TypeBase Exp as)
ts1 Map Name (TypeBase Exp as)
ts2)
combineTypeShapes (Scalar (Sum Map Name [TypeBase Exp as]
cs1)) (Scalar (Sum Map Name [TypeBase Exp as]
cs2))
  | Map Name [TypeBase Exp as] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [TypeBase Exp as]
cs1 [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== Map Name [TypeBase Exp as] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [TypeBase Exp as]
cs2 =
      ScalarTypeBase Exp as -> TypeBase Exp as
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp as -> TypeBase Exp as)
-> ScalarTypeBase Exp as -> TypeBase Exp as
forall a b. (a -> b) -> a -> b
$
        Map Name [TypeBase Exp as] -> ScalarTypeBase Exp as
forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum (Map Name [TypeBase Exp as] -> ScalarTypeBase Exp as)
-> Map Name [TypeBase Exp as] -> ScalarTypeBase Exp as
forall a b. (a -> b) -> a -> b
$
          (([TypeBase Exp as], [TypeBase Exp as]) -> [TypeBase Exp as])
-> Map Name ([TypeBase Exp as], [TypeBase Exp as])
-> Map Name [TypeBase Exp as]
forall a b k. (a -> b) -> Map k a -> Map k b
M.map
            (([TypeBase Exp as] -> [TypeBase Exp as] -> [TypeBase Exp as])
-> ([TypeBase Exp as], [TypeBase Exp as]) -> [TypeBase Exp as]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (([TypeBase Exp as] -> [TypeBase Exp as] -> [TypeBase Exp as])
 -> ([TypeBase Exp as], [TypeBase Exp as]) -> [TypeBase Exp as])
-> ([TypeBase Exp as] -> [TypeBase Exp as] -> [TypeBase Exp as])
-> ([TypeBase Exp as], [TypeBase Exp as])
-> [TypeBase Exp as]
forall a b. (a -> b) -> a -> b
$ (TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as)
-> [TypeBase Exp as] -> [TypeBase Exp as] -> [TypeBase Exp as]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes)
            (([TypeBase Exp as]
 -> [TypeBase Exp as] -> ([TypeBase Exp as], [TypeBase Exp as]))
-> Map Name [TypeBase Exp as]
-> Map Name [TypeBase Exp as]
-> Map Name ([TypeBase Exp as], [TypeBase Exp as])
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name [TypeBase Exp as]
cs1 Map Name [TypeBase Exp as]
cs2)
combineTypeShapes (Scalar (Arrow as
als1 PName
p1 Diet
d1 StructType
a1 (RetType [VName]
dims1 TypeBase Exp Uniqueness
b1))) (Scalar (Arrow as
als2 PName
_p2 Diet
_d2 StructType
a2 (RetType [VName]
_ TypeBase Exp Uniqueness
b2))) =
  ScalarTypeBase Exp as -> TypeBase Exp as
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp as -> TypeBase Exp as)
-> ScalarTypeBase Exp as -> TypeBase Exp as
forall a b. (a -> b) -> a -> b
$
    as
-> PName
-> Diet
-> StructType
-> ResRetType
-> ScalarTypeBase Exp as
forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow
      (as
als1 as -> as -> as
forall a. Semigroup a => a -> a -> a
<> as
als2)
      PName
p1
      Diet
d1
      (StructType -> StructType -> StructType
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes StructType
a1 StructType
a2)
      ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims1 (TypeBase Exp Uniqueness
-> TypeBase Exp Uniqueness -> TypeBase Exp Uniqueness
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes TypeBase Exp Uniqueness
b1 TypeBase Exp Uniqueness
b2))
combineTypeShapes (Scalar (TypeVar as
u QualName VName
v [TypeArg Exp]
targs1)) (Scalar (TypeVar as
_ QualName VName
_ [TypeArg Exp]
targs2)) =
  ScalarTypeBase Exp as -> TypeBase Exp as
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp as -> TypeBase Exp as)
-> ScalarTypeBase Exp as -> TypeBase Exp as
forall a b. (a -> b) -> a -> b
$ as -> QualName VName -> [TypeArg Exp] -> ScalarTypeBase Exp as
forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar as
u QualName VName
v ([TypeArg Exp] -> ScalarTypeBase Exp as)
-> [TypeArg Exp] -> ScalarTypeBase Exp as
forall a b. (a -> b) -> a -> b
$ (TypeArg Exp -> TypeArg Exp -> TypeArg Exp)
-> [TypeArg Exp] -> [TypeArg Exp] -> [TypeArg Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeArg Exp -> TypeArg Exp -> TypeArg Exp
f [TypeArg Exp]
targs1 [TypeArg Exp]
targs2
  where
    f :: TypeArg Exp -> TypeArg Exp -> TypeArg Exp
f (TypeArgType StructType
t1) (TypeArgType StructType
t2) = StructType -> TypeArg Exp
forall dim. TypeBase dim NoUniqueness -> TypeArg dim
TypeArgType (StructType -> StructType -> StructType
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes StructType
t1 StructType
t2)
    f TypeArg Exp
targ TypeArg Exp
_ = TypeArg Exp
targ
combineTypeShapes (Array as
u Shape Exp
shape1 ScalarTypeBase Exp NoUniqueness
et1) (Array as
_ Shape Exp
_shape2 ScalarTypeBase Exp NoUniqueness
et2) =
  as -> Shape Exp -> TypeBase Exp as -> TypeBase Exp as
forall u dim. u -> Shape dim -> TypeBase dim u -> TypeBase dim u
arrayOfWithAliases
    as
u
    Shape Exp
shape1
    (TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes (StructType -> as -> TypeBase Exp as
forall dim u1 u2. TypeBase dim u1 -> u2 -> TypeBase dim u2
setUniqueness (ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar ScalarTypeBase Exp NoUniqueness
et1) as
u) (StructType -> as -> TypeBase Exp as
forall dim u1 u2. TypeBase dim u1 -> u2 -> TypeBase dim u2
setUniqueness (ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar ScalarTypeBase Exp NoUniqueness
et2) as
u))
combineTypeShapes TypeBase Exp as
_ TypeBase Exp as
t = TypeBase Exp as
t

-- When we instantiate a polymorphic StaticVal, we rename all the
-- sizes to avoid name conflicts later on.  This is a bit of a hack...
instStaticVal ::
  MonadFreshNames m =>
  S.Set VName ->
  [VName] ->
  StructType ->
  StructType ->
  StaticVal ->
  m StaticVal
instStaticVal :: forall (m :: * -> *).
MonadFreshNames m =>
Set VName
-> [VName] -> StructType -> StructType -> StaticVal -> m StaticVal
instStaticVal Set VName
globals [VName]
dims StructType
t StructType
sv_t StaticVal
sv = do
  Map VName SizeSubst
fresh_substs <-
    [VName] -> m (Map VName SizeSubst)
forall {f :: * -> *}.
MonadFreshNames f =>
[VName] -> f (Map VName SizeSubst)
mkSubsts ([VName] -> m (Map VName SizeSubst))
-> (Set VName -> [VName]) -> Set VName -> m (Map VName SizeSubst)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
globals) ([VName] -> [VName])
-> (Set VName -> [VName]) -> Set VName -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> m (Map VName SizeSubst))
-> Set VName -> m (Map VName SizeSubst)
forall a b. (a -> b) -> a -> b
$
      [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> StaticVal -> Set VName
sizesToRename StaticVal
sv
  let dims' :: [VName]
dims' = (VName -> VName) -> [VName] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
fresh_substs) [VName]
dims
      isDim :: VName -> p -> Bool
isDim VName
k p
_ = VName
k VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
dims'
      dim_substs :: Map VName SizeSubst
dim_substs =
        (VName -> SizeSubst -> Bool)
-> Map VName SizeSubst -> Map VName SizeSubst
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey VName -> SizeSubst -> Bool
forall {p}. VName -> p -> Bool
isDim (Map VName SizeSubst -> Map VName SizeSubst)
-> Map VName SizeSubst -> Map VName SizeSubst
forall a b. (a -> b) -> a -> b
$ StructType -> StructType -> Map VName SizeSubst
forall a.
Monoid a =>
TypeBase Exp a -> TypeBase Exp a -> Map VName SizeSubst
dimMapping (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst -> TypeBase Exp als -> TypeBase Exp als
replaceTypeSizes Map VName SizeSubst
fresh_substs StructType
sv_t) StructType
t
      replace :: SizeSubst -> SizeSubst
replace (SubstNamed QualName VName
k) = SizeSubst -> Maybe SizeSubst -> SizeSubst
forall a. a -> Maybe a -> a
fromMaybe (QualName VName -> SizeSubst
SubstNamed QualName VName
k) (Maybe SizeSubst -> SizeSubst) -> Maybe SizeSubst -> SizeSubst
forall a b. (a -> b) -> a -> b
$ VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
k) Map VName SizeSubst
dim_substs
      replace SizeSubst
k = SizeSubst
k
      substs :: Map VName SizeSubst
substs = (SizeSubst -> SizeSubst)
-> Map VName SizeSubst -> Map VName SizeSubst
forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeSubst -> SizeSubst
replace Map VName SizeSubst
fresh_substs Map VName SizeSubst -> Map VName SizeSubst -> Map VName SizeSubst
forall a. Semigroup a => a -> a -> a
<> Map VName SizeSubst
dim_substs

  StaticVal -> m StaticVal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StaticVal -> m StaticVal) -> StaticVal -> m StaticVal
forall a b. (a -> b) -> a -> b
$ Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
substs StaticVal
sv
  where
    mkSubsts :: [VName] -> f (Map VName SizeSubst)
mkSubsts [VName]
names =
      [(VName, SizeSubst)] -> Map VName SizeSubst
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SizeSubst)] -> Map VName SizeSubst)
-> ([VName] -> [(VName, SizeSubst)])
-> [VName]
-> Map VName SizeSubst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [SizeSubst] -> [(VName, SizeSubst)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names ([SizeSubst] -> [(VName, SizeSubst)])
-> ([VName] -> [SizeSubst]) -> [VName] -> [(VName, SizeSubst)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SizeSubst) -> [VName] -> [SizeSubst]
forall a b. (a -> b) -> [a] -> [b]
map (QualName VName -> SizeSubst
SubstNamed (QualName VName -> SizeSubst)
-> (VName -> QualName VName) -> VName -> SizeSubst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName)
        ([VName] -> Map VName SizeSubst)
-> f [VName] -> f (Map VName SizeSubst)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> f VName) -> [VName] -> f [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> f VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName [VName]
names

    onName :: Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs VName
v =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v'
        Maybe SizeSubst
_ -> VName
v

defuncFun ::
  [VName] ->
  [Pat ParamType] ->
  Exp ->
  ResRetType ->
  SrcLoc ->
  DefM (Exp, StaticVal)
defuncFun :: [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [VName]
tparams [Pat ParamType]
pats Exp
e0 ResRetType
ret SrcLoc
loc = do
  -- Extract the first parameter of the lambda and "push" the
  -- remaining ones (if there are any) into the body of the lambda.
  let (Pat ParamType
pat, ResRetType
ret', Exp
e0') = case [Pat ParamType]
pats of
        [] -> [Char] -> (Pat ParamType, ResRetType, Exp)
forall a. HasCallStack => [Char] -> a
error [Char]
"Received a lambda with no parameters."
        [Pat ParamType
pat'] -> (Pat ParamType
pat', ResRetType
ret, Exp
e0)
        (Pat ParamType
pat' : [Pat ParamType]
pats') ->
          ( Pat ParamType
pat',
            [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ (NoUniqueness -> Uniqueness)
-> StructType -> TypeBase Exp Uniqueness
forall b c a. (b -> c) -> TypeBase a b -> TypeBase a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Uniqueness -> NoUniqueness -> Uniqueness
forall a b. a -> b -> a
const Uniqueness
Nonunique) (StructType -> TypeBase Exp Uniqueness)
-> StructType -> TypeBase Exp Uniqueness
forall a b. (a -> b) -> a -> b
$ [Pat ParamType] -> ResRetType -> StructType
funType [Pat ParamType]
pats' ResRetType
ret,
            [Pat ParamType]
-> Exp
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda [Pat ParamType]
pats' Exp
e0 Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing (ResRetType -> Info ResRetType
forall a. a -> Info a
Info ResRetType
ret) SrcLoc
loc
          )

  -- Construct a record literal that closes over the environment of
  -- the lambda.  Closed-over 'DynamicFun's are converted to their
  -- closure representation.
  let used :: FV
used =
        Exp -> FV
freeInExp ([Pat ParamType]
-> Exp
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda [Pat ParamType]
pats Exp
e0 Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing (ResRetType -> Info ResRetType
forall a. a -> Info a
Info ResRetType
ret) SrcLoc
loc)
          FV -> Set VName -> FV
`freeWithout` [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
tparams
  Env
used_env <- FV -> DefM Env
restrictEnvTo FV
used

  -- The closure parts that are sizes are proactively turned into size
  -- parameters.
  let sizes_of_arrays :: Set VName
sizes_of_arrays =
        (Binding -> Set VName) -> Env -> Set VName
forall m a. Monoid m => (a -> m) -> Map VName a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StructType -> Set VName
arraySizes (StructType -> Set VName)
-> (Binding -> StructType) -> Binding -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StaticVal -> StructType
structTypeFromSV (StaticVal -> StructType)
-> (Binding -> StaticVal) -> Binding -> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> StaticVal
bindingSV) Env
used_env
          Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Pat ParamType -> Set VName
patternArraySizes Pat ParamType
pat
      notSize :: VName -> Bool
notSize = Bool -> Bool
not (Bool -> Bool) -> (VName -> Bool) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes_of_arrays)
      ([FieldBase Info VName]
fields, Env
env) =
        ([(VName, Binding)] -> Env)
-> ([FieldBase Info VName], [(VName, Binding)])
-> ([FieldBase Info VName], Env)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [(VName, Binding)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
          (([FieldBase Info VName], [(VName, Binding)])
 -> ([FieldBase Info VName], Env))
-> ([(VName, Binding)]
    -> ([FieldBase Info VName], [(VName, Binding)]))
-> [(VName, Binding)]
-> ([FieldBase Info VName], Env)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FieldBase Info VName, (VName, Binding))]
-> ([FieldBase Info VName], [(VName, Binding)])
forall a b. [(a, b)] -> ([a], [b])
unzip
          ([(FieldBase Info VName, (VName, Binding))]
 -> ([FieldBase Info VName], [(VName, Binding)]))
-> ([(VName, Binding)]
    -> [(FieldBase Info VName, (VName, Binding))])
-> [(VName, Binding)]
-> ([FieldBase Info VName], [(VName, Binding)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Binding) -> (FieldBase Info VName, (VName, Binding)))
-> [(VName, Binding)] -> [(FieldBase Info VName, (VName, Binding))]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Binding) -> (FieldBase Info VName, (VName, Binding))
closureFromDynamicFun
          ([(VName, Binding)] -> [(FieldBase Info VName, (VName, Binding))])
-> ([(VName, Binding)] -> [(VName, Binding)])
-> [(VName, Binding)]
-> [(FieldBase Info VName, (VName, Binding))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Binding) -> Bool)
-> [(VName, Binding)] -> [(VName, Binding)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Bool
notSize (VName -> Bool)
-> ((VName, Binding) -> VName) -> (VName, Binding) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Binding) -> VName
forall a b. (a, b) -> a
fst)
          ([(VName, Binding)] -> ([FieldBase Info VName], Env))
-> [(VName, Binding)] -> ([FieldBase Info VName], Env)
forall a b. (a -> b) -> a -> b
$ Env -> [(VName, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList Env
used_env

  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [FieldBase Info VName] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fields SrcLoc
loc,
      Pat ParamType -> ResRetType -> Exp -> Env -> StaticVal
LambdaSV Pat ParamType
pat ResRetType
ret' Exp
e0' Env
env
    )
  where
    closureFromDynamicFun :: (VName, Binding) -> (FieldBase Info VName, (VName, Binding))
closureFromDynamicFun (VName
vn, Binding Maybe ([VName], StructType)
_ (DynamicFun (Exp
clsr_env, StaticVal
sv) StaticVal
_)) =
      let name :: Name
name = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
vn
       in ( Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
name Exp
clsr_env SrcLoc
forall a. Monoid a => a
mempty,
            (VName
vn, Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing StaticVal
sv)
          )
    closureFromDynamicFun (VName
vn, Binding Maybe ([VName], StructType)
_ StaticVal
sv) =
      let name :: Name
name = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
vn
          tp' :: StructType
tp' = StaticVal -> StructType
structTypeFromSV StaticVal
sv
       in ( Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit
              Name
name
              (QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
vn) (StructType -> Info StructType
forall a. a -> Info a
Info StructType
tp') SrcLoc
forall a. Monoid a => a
mempty)
              SrcLoc
forall a. Monoid a => a
mempty,
            (VName
vn, Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing StaticVal
sv)
          )

-- | Defunctionalization of an expression. Returns the residual expression and
-- the associated static value in the defunctionalization monad.
defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp e :: Exp
e@Literal {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e)
defuncExp e :: Exp
e@IntLit {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e)
defuncExp e :: Exp
e@FloatLit {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e)
defuncExp e :: Exp
e@StringLit {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e)
defuncExp (Parens Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens Exp
e' SrcLoc
loc, StaticVal
sv)
defuncExp (QualParens (QualName VName, SrcLoc)
qn Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((QualName VName, SrcLoc) -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
(QualName vn, SrcLoc) -> ExpBase f vn -> SrcLoc -> ExpBase f vn
QualParens (QualName VName, SrcLoc)
qn Exp
e' SrcLoc
loc, StaticVal
sv)
defuncExp (TupLit [Exp]
es SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- (Exp -> DefM (Exp, StaticVal))
-> [Exp] -> DefM ([Exp], [StaticVal])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Exp] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit [Exp]
es' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ [Name] -> [StaticVal] -> [(Name, StaticVal)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
tupleFieldNames [StaticVal]
svs)
defuncExp (RecordLit [FieldBase Info VName]
fs SrcLoc
loc) = do
  ([FieldBase Info VName]
fs', [(Name, StaticVal)]
names_svs) <- (FieldBase Info VName
 -> DefM (FieldBase Info VName, (Name, StaticVal)))
-> [FieldBase Info VName]
-> DefM ([FieldBase Info VName], [(Name, StaticVal)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField [FieldBase Info VName]
fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([FieldBase Info VName] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fs' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV [(Name, StaticVal)]
names_svs)
  where
    defuncField :: FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField (RecordFieldExplicit Name
vn Exp
e SrcLoc
loc') = do
      (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
      (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn Exp
e' SrcLoc
loc', (Name
vn, StaticVal
sv))
    defuncField (RecordFieldImplicit VName
vn (Info StructType
t) SrcLoc
loc') = do
      StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
t) VName
vn
      case StaticVal
sv of
        -- If the implicit field refers to a dynamic function, we
        -- convert it to an explicit field with a record closing over
        -- the environment and bind the corresponding static value.
        DynamicFun (Exp
e, StaticVal
sv') StaticVal
_ ->
          let vn' :: Name
vn' = VName -> Name
baseName VName
vn
           in (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn' Exp
e SrcLoc
loc',
                  (Name
vn', StaticVal
sv')
                )
        -- The field may refer to a functional expression, so we get the
        -- type from the static value and not the one from the AST.
        StaticVal
_ ->
          let tp :: Info StructType
tp = StructType -> Info StructType
forall a. a -> Info a
Info (StructType -> Info StructType) -> StructType -> Info StructType
forall a b. (a -> b) -> a -> b
$ StaticVal -> StructType
structTypeFromSV StaticVal
sv
           in (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Info StructType -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
vn -> f StructType -> SrcLoc -> FieldBase f vn
RecordFieldImplicit VName
vn Info StructType
tp SrcLoc
loc', (VName -> Name
baseName VName
vn, StaticVal
sv))
defuncExp (ArrayLit [Exp]
es t :: Info StructType
t@(Info StructType
t') SrcLoc
loc) = do
  [Exp]
es' <- (Exp -> DefM Exp) -> [Exp] -> DefM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Exp -> DefM Exp
defuncExp' [Exp]
es
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Exp] -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
[ExpBase f vn] -> f StructType -> SrcLoc -> ExpBase f vn
ArrayLit [Exp]
es' Info StructType
t SrcLoc
loc, ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
t')
defuncExp (AppExp (Range Exp
e1 Maybe Exp
me Inclusiveness Exp
incl SrcLoc
loc) Info AppRes
res) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  Maybe Exp
me' <- (Exp -> DefM Exp) -> Maybe Exp -> DefM (Maybe Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Maybe a -> m (Maybe b)
mapM Exp -> DefM Exp
defuncExp' Maybe Exp
me
  Inclusiveness Exp
incl' <- (Exp -> DefM Exp) -> Inclusiveness Exp -> DefM (Inclusiveness Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Inclusiveness a -> m (Inclusiveness b)
mapM Exp -> DefM Exp
defuncExp' Inclusiveness Exp
incl
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> Maybe Exp
-> Inclusiveness Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> Maybe (ExpBase f vn)
-> Inclusiveness (ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Range Exp
e1' Maybe Exp
me' Inclusiveness Exp
incl' SrcLoc
loc) Info AppRes
res,
      ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ AppRes -> StructType
appResType (AppRes -> StructType) -> AppRes -> StructType
forall a b. (a -> b) -> a -> b
$ Info AppRes -> AppRes
forall a. Info a -> a
unInfo Info AppRes
res
    )
defuncExp e :: Exp
e@(Var QualName VName
qn (Info StructType
t) SrcLoc
loc) = do
  StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
t) (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
  case StaticVal
sv of
    -- If the variable refers to a dynamic function, we eta-expand it
    -- so that we do not have to duplicate its definition.
    DynamicFun {} -> do
      ([Pat ParamType]
params, Exp
body, ResRetType
ret) <- ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StructType -> TypeBase Exp Uniqueness
forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique StructType
t) Exp
e
      [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [] [Pat ParamType]
params Exp
body ResRetType
ret SrcLoc
forall a. Monoid a => a
mempty
    -- Intrinsic functions used as variables are eta-expanded, so we
    -- can get rid of them.
    StaticVal
IntrinsicSV -> do
      ([Pat ParamType]
pats, Exp
body, ResRetType
tp) <- ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StructType -> TypeBase Exp Uniqueness
forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique StructType
t) Exp
e
      Exp -> DefM (Exp, StaticVal)
defuncExp (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Pat ParamType]
-> Exp
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda [Pat ParamType]
pats Exp
body Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing (ResRetType -> Info ResRetType
forall a. a -> Info a
Info ResRetType
tp) SrcLoc
forall a. Monoid a => a
mempty
    HoleSV StructType
_ SrcLoc
hole_loc ->
      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn. f StructType -> SrcLoc -> ExpBase f vn
Hole (StructType -> Info StructType
forall a. a -> Info a
Info StructType
t) SrcLoc
hole_loc, StaticVal
sv)
    StaticVal
_ ->
      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (StructType -> Info StructType
forall a. a -> Info a
Info (StaticVal -> StructType
structTypeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)
defuncExp (Hole (Info StructType
t) SrcLoc
loc) =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn. f StructType -> SrcLoc -> ExpBase f vn
Hole (StructType -> Info StructType
forall a. a -> Info a
Info StructType
t) SrcLoc
loc, StructType -> SrcLoc -> StaticVal
HoleSV StructType
t SrcLoc
loc)
defuncExp (Ascript Exp
e0 TypeExp Info VName
tydecl SrcLoc
loc)
  | StructType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> StructType
typeOf Exp
e0) = do
      (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> TypeExp Info VName -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> TypeExp f vn -> SrcLoc -> ExpBase f vn
Ascript Exp
e0' TypeExp Info VName
tydecl SrcLoc
loc, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
defuncExp (Coerce Exp
e0 TypeExp Info VName
tydecl Info StructType
t SrcLoc
loc)
  | StructType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> StructType
typeOf Exp
e0) = do
      (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> TypeExp Info VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> TypeExp f vn -> f StructType -> SrcLoc -> ExpBase f vn
Coerce Exp
e0' TypeExp Info VName
tydecl Info StructType
t SrcLoc
loc, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
defuncExp (AppExp (LetPat [SizeBinder VName]
sizes PatBase Info VName StructType
pat Exp
e1 Exp
e2 SrcLoc
loc) (Info (AppRes StructType
t [VName]
retext))) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env :: Env
env = Pat ParamType -> StaticVal -> Env
alwaysMatchPatSV ((StructType -> ParamType)
-> PatBase Info VName StructType -> Pat ParamType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe) PatBase Info VName StructType
pat) StaticVal
sv1
      pat' :: Pat ParamType
pat' = Pat ParamType -> StaticVal -> Pat ParamType
updatePat ((StructType -> ParamType)
-> PatBase Info VName StructType -> Pat ParamType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe) PatBase Info VName StructType
pat) StaticVal
sv1
  (Exp
e2', StaticVal
sv2) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  -- To maintain any sizes going out of scope, we need to compute the
  -- old size substitution induced by retext and also apply it to the
  -- newly computed body type.
  let mapping :: Map VName VName
mapping = StructType -> StructType -> Map VName VName
forall a.
Monoid a =>
TypeBase Exp a -> TypeBase Exp a -> Map VName VName
dimMapping' (Exp -> StructType
typeOf Exp
e2) StructType
t
      subst :: VName -> Maybe (Subst t)
subst VName
v = Exp -> Subst t
forall t. Exp -> Subst t
ExpSubst (Exp -> Subst t) -> (VName -> Exp) -> VName -> Subst t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (QualName VName -> SrcLoc -> Exp)
-> SrcLoc -> QualName VName -> Exp
forall a b c. (a -> b -> c) -> b -> a -> c
flip QualName VName -> SrcLoc -> Exp
sizeFromName SrcLoc
forall a. Monoid a => a
mempty (QualName VName -> Exp)
-> (VName -> QualName VName) -> VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName (VName -> Subst t) -> Maybe VName -> Maybe (Subst t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
mapping
      t' :: StructType
t' = TypeSubs -> StructType -> StructType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst TypeSubs
forall {t}. VName -> Maybe (Subst t)
subst (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e2'
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp ([SizeBinder VName]
-> PatBase Info VName StructType
-> Exp
-> Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat [SizeBinder VName]
sizes ((ParamType -> StructType)
-> Pat ParamType -> PatBase Info VName StructType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ParamType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct Pat ParamType
pat') Exp
e1' Exp
e2' SrcLoc
loc) (AppRes -> Info AppRes
forall a. a -> Info a
Info (StructType -> [VName] -> AppRes
AppRes StructType
t' [VName]
retext)), StaticVal
sv2)
defuncExp (AppExp (LetFun VName
vn ([TypeParamBase VName], [Pat ParamType],
 Maybe (TypeExp Info VName), Info ResRetType, Exp)
_ Exp
_ SrcLoc
_) Info AppRes
_) =
  [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error ([Char] -> DefM (Exp, StaticVal))
-> [Char] -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Char]
"defuncExp: Unexpected LetFun: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Show a => a -> [Char]
show VName
vn
defuncExp (AppExp (If Exp
e1 Exp
e2 Exp
e3 SrcLoc
loc) Info AppRes
res) = do
  (Exp
e1', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp
e3', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp -> Exp -> Exp -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn -> ExpBase f vn -> SrcLoc -> AppExpBase f vn
If Exp
e1' Exp
e2' Exp
e3' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
defuncExp (AppExp (Apply Exp
f NonEmpty (Info (Diet, Maybe VName), Exp)
args SrcLoc
loc) (Info AppRes
appres)) =
  Exp
-> NonEmpty ((Diet, Maybe VName), Exp)
-> AppRes
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncApply Exp
f (((Info (Diet, Maybe VName), Exp) -> ((Diet, Maybe VName), Exp))
-> NonEmpty (Info (Diet, Maybe VName), Exp)
-> NonEmpty ((Diet, Maybe VName), Exp)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Info (Diet, Maybe VName) -> (Diet, Maybe VName))
-> (Info (Diet, Maybe VName), Exp) -> ((Diet, Maybe VName), Exp)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Info (Diet, Maybe VName) -> (Diet, Maybe VName)
forall a. Info a -> a
unInfo) NonEmpty (Info (Diet, Maybe VName), Exp)
args) AppRes
appres SrcLoc
loc
defuncExp (Negate Exp
e0 SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Negate Exp
e0' SrcLoc
loc, StaticVal
sv)
defuncExp (Not Exp
e0 SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Not Exp
e0' SrcLoc
loc, StaticVal
sv)
defuncExp (Lambda [Pat ParamType]
pats Exp
e0 Maybe (TypeExp Info VName)
_ (Info ResRetType
ret) SrcLoc
loc) =
  [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [] [Pat ParamType]
pats Exp
e0 ResRetType
ret SrcLoc
loc
-- Operator sections are expected to be converted to lambda-expressions
-- by the monomorphizer, so they should no longer occur at this point.
defuncExp OpSection {} = [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected operator section."
defuncExp OpSectionLeft {} = [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected operator section."
defuncExp OpSectionRight {} = [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected operator section."
defuncExp ProjectSection {} = [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected projection section."
defuncExp IndexSection {} = [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected projection section."
defuncExp (AppExp (DoLoop [VName]
sparams Pat ParamType
pat Exp
e1 LoopFormBase Info VName
form Exp
e3 SrcLoc
loc) Info AppRes
res) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env1 :: Env
env1 = Pat ParamType -> StaticVal -> Env
alwaysMatchPatSV Pat ParamType
pat StaticVal
sv1
  (LoopFormBase Info VName
form', Env
env2) <- case LoopFormBase Info VName
form of
    For IdentBase Info VName StructType
v Exp
e2 -> do
      Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
      (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IdentBase Info VName StructType -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
IdentBase f vn StructType -> ExpBase f vn -> LoopFormBase f vn
For IdentBase Info VName StructType
v Exp
e2', IdentBase Info VName StructType -> Env
forall {k} {u}. IdentBase Info k (TypeBase Exp u) -> Map k Binding
envFromIdent IdentBase Info VName StructType
v)
    ForIn PatBase Info VName StructType
pat2 Exp
e2 -> do
      Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
      (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatBase Info VName StructType -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
PatBase f vn StructType -> ExpBase f vn -> LoopFormBase f vn
ForIn PatBase Info VName StructType
pat2 Exp
e2', Pat ParamType -> Env
envFromPat (Pat ParamType -> Env) -> Pat ParamType -> Env
forall a b. (a -> b) -> a -> b
$ (StructType -> ParamType)
-> PatBase Info VName StructType -> Pat ParamType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe) PatBase Info VName StructType
pat2)
    While Exp
e2 -> do
      Exp
e2' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env1 (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
e2
      (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While Exp
e2', Env
forall a. Monoid a => a
mempty)
  (Exp
e3', StaticVal
sv) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv (Env
env1 Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
env2) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp ([VName]
-> Pat ParamType
-> Exp
-> LoopFormBase Info VName
-> Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
[VName]
-> PatBase f vn ParamType
-> ExpBase f vn
-> LoopFormBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
DoLoop [VName]
sparams Pat ParamType
pat Exp
e1' LoopFormBase Info VName
form' Exp
e3' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
  where
    envFromIdent :: IdentBase Info k (TypeBase Exp u) -> Map k Binding
envFromIdent (Ident k
vn (Info TypeBase Exp u
tp) SrcLoc
_) =
      k -> Binding -> Map k Binding
forall k a. k -> a -> Map k a
M.singleton k
vn (Binding -> Map k Binding) -> Binding -> Map k Binding
forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> TypeBase Exp u -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe TypeBase Exp u
tp
defuncExp e :: Exp
e@(AppExp BinOp {} Info AppRes
_) =
  [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error ([Char] -> DefM (Exp, StaticVal))
-> [Char] -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Char]
"defuncExp: unexpected binary operator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp
e
defuncExp (Project Name
vn Exp
e0 tp :: Info StructType
tp@(Info StructType
tp') SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv0) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  case StaticVal
sv0 of
    RecordSV [(Name, StaticVal)]
svs -> case Name -> [(Name, StaticVal)] -> Maybe StaticVal
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
vn [(Name, StaticVal)]
svs of
      Just StaticVal
sv -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Exp -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f StructType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' (StructType -> Info StructType
forall a. a -> Info a
Info (StructType -> Info StructType) -> StructType -> Info StructType
forall a b. (a -> b) -> a -> b
$ StaticVal -> StructType
structTypeFromSV StaticVal
sv) SrcLoc
loc, StaticVal
sv)
      Maybe StaticVal
Nothing -> [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid record projection."
    Dynamic ParamType
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Exp -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f StructType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' Info StructType
tp SrcLoc
loc, ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
tp')
    HoleSV StructType
_ SrcLoc
hloc -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Exp -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f StructType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' Info StructType
tp SrcLoc
loc, StructType -> SrcLoc -> StaticVal
HoleSV StructType
tp' SrcLoc
hloc)
    StaticVal
_ -> [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error ([Char] -> DefM (Exp, StaticVal))
-> [Char] -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Char]
"Projection of an expression with static value " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> [Char]
forall a. Show a => a -> [Char]
show StaticVal
sv0
defuncExp (AppExp (LetWith IdentBase Info VName StructType
id1 IdentBase Info VName StructType
id2 SliceBase Info VName
idxs Exp
e1 Exp
body SrcLoc
loc) Info AppRes
res) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  SliceBase Info VName
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> SliceBase Info VName -> DefM (SliceBase Info VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  let id1_binding :: Binding
id1_binding =
        Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ IdentBase Info VName StructType -> Info StructType
forall {k} (f :: k -> *) vn (t :: k). IdentBase f vn t -> f t
identType IdentBase Info VName StructType
id1
  (Exp
body', StaticVal
sv) <-
    Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv (VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton (IdentBase Info VName StructType -> VName
forall {k} (f :: k -> *) vn (t :: k). IdentBase f vn t -> vn
identName IdentBase Info VName StructType
id1) Binding
id1_binding) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
      Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (IdentBase Info VName StructType
-> IdentBase Info VName StructType
-> SliceBase Info VName
-> Exp
-> Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
IdentBase f vn StructType
-> IdentBase f vn StructType
-> SliceBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetWith IdentBase Info VName StructType
id1 IdentBase Info VName StructType
id2 SliceBase Info VName
idxs' Exp
e1' Exp
body' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
defuncExp expr :: Exp
expr@(AppExp (Index Exp
e0 SliceBase Info VName
idxs SrcLoc
loc) Info AppRes
res) = do
  Exp
e0' <- Exp -> DefM Exp
defuncExp' Exp
e0
  SliceBase Info VName
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> SliceBase Info VName -> DefM (SliceBase Info VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp -> SliceBase Info VName -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn -> SliceBase f vn -> SrcLoc -> AppExpBase f vn
Index Exp
e0' SliceBase Info VName
idxs' SrcLoc
loc) Info AppRes
res,
      ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
expr
    )
defuncExp (Update Exp
e1 SliceBase Info VName
idxs Exp
e2 SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  SliceBase Info VName
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> SliceBase Info VName -> DefM (SliceBase Info VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> SliceBase Info VName -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> SliceBase f vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Update Exp
e1' SliceBase Info VName
idxs' Exp
e2' SrcLoc
loc, StaticVal
sv)

-- Note that we might change the type of the record field here.  This
-- is not permitted in the type checker due to problems with type
-- inference, but it actually works fine.
defuncExp (RecordUpdate Exp
e1 [Name]
fs Exp
e2 Info StructType
_ SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv2) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  let sv :: StaticVal
sv = StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv1 StaticVal
sv2 [Name]
fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Exp -> [Name] -> Exp -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [Name] -> ExpBase f vn -> f StructType -> SrcLoc -> ExpBase f vn
RecordUpdate Exp
e1' [Name]
fs Exp
e2' (StructType -> Info StructType
forall a. a -> Info a
Info (StructType -> Info StructType) -> StructType -> Info StructType
forall a b. (a -> b) -> a -> b
$ StaticVal -> StructType
structTypeFromSV StaticVal
sv1) SrcLoc
loc,
      StaticVal
sv
    )
  where
    staticField :: StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (RecordSV [(Name, StaticVal)]
svs) StaticVal
sv2 (Name
f : [Name]
fs') =
      case Name -> [(Name, StaticVal)] -> Maybe StaticVal
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
f [(Name, StaticVal)]
svs of
        Just StaticVal
sv ->
          [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$
            (Name
f, StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv StaticVal
sv2 [Name]
fs') (Name, StaticVal) -> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a. a -> [a] -> [a]
: ((Name, StaticVal) -> Bool)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= Name
f) (Name -> Bool)
-> ((Name, StaticVal) -> Name) -> (Name, StaticVal) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst) [(Name, StaticVal)]
svs
        Maybe StaticVal
Nothing -> [Char] -> StaticVal
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid record projection."
    staticField (Dynamic t :: ParamType
t@(Scalar Record {})) StaticVal
sv2 fs' :: [Name]
fs'@(Name
_ : [Name]
_) =
      StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (ParamType -> StaticVal
svFromType ParamType
t) StaticVal
sv2 [Name]
fs'
    staticField StaticVal
_ StaticVal
sv2 [Name]
_ = StaticVal
sv2
defuncExp (Assert Exp
e1 Exp
e2 Info Text
desc SrcLoc
loc) = do
  (Exp
e1', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> Exp -> Info Text -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> ExpBase f vn -> f Text -> SrcLoc -> ExpBase f vn
Assert Exp
e1' Exp
e2' Info Text
desc SrcLoc
loc, StaticVal
sv)
defuncExp (Constr Name
name [Exp]
es (Info sum_t :: StructType
sum_t@(Scalar (Sum Map Name [StructType]
all_fs))) SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- (Exp -> DefM (Exp, StaticVal))
-> [Exp] -> DefM ([Exp], [StaticVal])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  let sv :: StaticVal
sv =
        Name -> [StaticVal] -> [(Name, [ParamType])] -> StaticVal
SumSV Name
name [StaticVal]
svs ([(Name, [ParamType])] -> StaticVal)
-> [(Name, [ParamType])] -> StaticVal
forall a b. (a -> b) -> a -> b
$
          Map Name [ParamType] -> [(Name, [ParamType])]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name [ParamType] -> [(Name, [ParamType])])
-> Map Name [ParamType] -> [(Name, [ParamType])]
forall a b. (a -> b) -> a -> b
$
            Name
name Name -> Map Name [ParamType] -> Map Name [ParamType]
forall k a. Ord k => k -> Map k a -> Map k a
`M.delete` ([StructType] -> [ParamType])
-> Map Name [StructType] -> Map Name [ParamType]
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((StructType -> ParamType) -> [StructType] -> [ParamType]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType)
-> (StructType -> StructType) -> StructType -> ParamType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StructType -> StructType
forall als. Monoid als => TypeBase Exp als -> TypeBase Exp als
defuncType)) Map Name [StructType]
all_fs
      sum_t' :: StructType
sum_t' = StructType -> StructType -> StructType
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes StructType
sum_t (StaticVal -> StructType
structTypeFromSV StaticVal
sv)
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> [Exp] -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> [ExpBase f vn] -> f StructType -> SrcLoc -> ExpBase f vn
Constr Name
name [Exp]
es' (StructType -> Info StructType
forall a. a -> Info a
Info StructType
sum_t') SrcLoc
loc, StaticVal
sv)
  where
    defuncType ::
      Monoid als =>
      TypeBase Size als ->
      TypeBase Size als
    defuncType :: forall als. Monoid als => TypeBase Exp als -> TypeBase Exp als
defuncType (Array als
u Shape Exp
shape ScalarTypeBase Exp NoUniqueness
t) = als
-> Shape Exp -> ScalarTypeBase Exp NoUniqueness -> TypeBase Exp als
forall dim u.
u -> Shape dim -> ScalarTypeBase dim NoUniqueness -> TypeBase dim u
Array als
u Shape Exp
shape (ScalarTypeBase Exp NoUniqueness -> ScalarTypeBase Exp NoUniqueness
forall als.
Monoid als =>
ScalarTypeBase Exp als -> ScalarTypeBase Exp als
defuncScalar ScalarTypeBase Exp NoUniqueness
t)
    defuncType (Scalar ScalarTypeBase Exp als
t) = ScalarTypeBase Exp als -> TypeBase Exp als
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp als -> TypeBase Exp als)
-> ScalarTypeBase Exp als -> TypeBase Exp als
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase Exp als -> ScalarTypeBase Exp als
forall als.
Monoid als =>
ScalarTypeBase Exp als -> ScalarTypeBase Exp als
defuncScalar ScalarTypeBase Exp als
t

    defuncScalar ::
      Monoid als =>
      ScalarTypeBase Size als ->
      ScalarTypeBase Size als
    defuncScalar :: forall als.
Monoid als =>
ScalarTypeBase Exp als -> ScalarTypeBase Exp als
defuncScalar (Record Map Name (TypeBase Exp als)
fs) = Map Name (TypeBase Exp als) -> ScalarTypeBase Exp als
forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record (Map Name (TypeBase Exp als) -> ScalarTypeBase Exp als)
-> Map Name (TypeBase Exp als) -> ScalarTypeBase Exp als
forall a b. (a -> b) -> a -> b
$ (TypeBase Exp als -> TypeBase Exp als)
-> Map Name (TypeBase Exp als) -> Map Name (TypeBase Exp als)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map TypeBase Exp als -> TypeBase Exp als
forall als. Monoid als => TypeBase Exp als -> TypeBase Exp als
defuncType Map Name (TypeBase Exp als)
fs
    defuncScalar Arrow {} = Map Name (TypeBase Exp als) -> ScalarTypeBase Exp als
forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record Map Name (TypeBase Exp als)
forall a. Monoid a => a
mempty
    defuncScalar (Sum Map Name [TypeBase Exp als]
fs) = Map Name [TypeBase Exp als] -> ScalarTypeBase Exp als
forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum (Map Name [TypeBase Exp als] -> ScalarTypeBase Exp als)
-> Map Name [TypeBase Exp als] -> ScalarTypeBase Exp als
forall a b. (a -> b) -> a -> b
$ ([TypeBase Exp als] -> [TypeBase Exp als])
-> Map Name [TypeBase Exp als] -> Map Name [TypeBase Exp als]
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((TypeBase Exp als -> TypeBase Exp als)
-> [TypeBase Exp als] -> [TypeBase Exp als]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Exp als -> TypeBase Exp als
forall als. Monoid als => TypeBase Exp als -> TypeBase Exp als
defuncType) Map Name [TypeBase Exp als]
fs
    defuncScalar (Prim PrimType
t) = PrimType -> ScalarTypeBase Exp als
forall dim u. PrimType -> ScalarTypeBase dim u
Prim PrimType
t
    defuncScalar (TypeVar als
u QualName VName
tn [TypeArg Exp]
targs) = als -> QualName VName -> [TypeArg Exp] -> ScalarTypeBase Exp als
forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar als
u QualName VName
tn [TypeArg Exp]
targs
defuncExp (Constr Name
name [Exp]
_ (Info StructType
t) SrcLoc
loc) =
  [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error ([Char] -> DefM (Exp, StaticVal))
-> [Char] -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
    [Char]
"Constructor "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Name
name
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" given type "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ StructType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString StructType
t
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" at "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> [Char]
forall a. Located a => a -> [Char]
locStr SrcLoc
loc
defuncExp (AppExp (Match Exp
e NonEmpty (CaseBase Info VName)
cs SrcLoc
loc) Info AppRes
res) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  let bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"No case matches StaticVal\n" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> StaticVal -> [Char]
forall a. Show a => a -> [Char]
show StaticVal
sv
  NonEmpty (CaseBase Info VName, StaticVal)
csPairs <-
    NonEmpty (CaseBase Info VName, StaticVal)
-> Maybe (NonEmpty (CaseBase Info VName, StaticVal))
-> NonEmpty (CaseBase Info VName, StaticVal)
forall a. a -> Maybe a -> a
fromMaybe NonEmpty (CaseBase Info VName, StaticVal)
forall {a}. a
bad (Maybe (NonEmpty (CaseBase Info VName, StaticVal))
 -> NonEmpty (CaseBase Info VName, StaticVal))
-> ([Maybe (CaseBase Info VName, StaticVal)]
    -> Maybe (NonEmpty (CaseBase Info VName, StaticVal)))
-> [Maybe (CaseBase Info VName, StaticVal)]
-> NonEmpty (CaseBase Info VName, StaticVal)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(CaseBase Info VName, StaticVal)]
-> Maybe (NonEmpty (CaseBase Info VName, StaticVal))
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([(CaseBase Info VName, StaticVal)]
 -> Maybe (NonEmpty (CaseBase Info VName, StaticVal)))
-> ([Maybe (CaseBase Info VName, StaticVal)]
    -> [(CaseBase Info VName, StaticVal)])
-> [Maybe (CaseBase Info VName, StaticVal)]
-> Maybe (NonEmpty (CaseBase Info VName, StaticVal))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (CaseBase Info VName, StaticVal)]
-> [(CaseBase Info VName, StaticVal)]
forall a. [Maybe a] -> [a]
catMaybes
      ([Maybe (CaseBase Info VName, StaticVal)]
 -> NonEmpty (CaseBase Info VName, StaticVal))
-> DefM [Maybe (CaseBase Info VName, StaticVal)]
-> DefM (NonEmpty (CaseBase Info VName, StaticVal))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (CaseBase Info VName
 -> DefM (Maybe (CaseBase Info VName, StaticVal)))
-> [CaseBase Info VName]
-> DefM [Maybe (CaseBase Info VName, StaticVal)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (StaticVal
-> CaseBase Info VName
-> DefM (Maybe (CaseBase Info VName, StaticVal))
defuncCase StaticVal
sv) (NonEmpty (CaseBase Info VName) -> [CaseBase Info VName]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty (CaseBase Info VName)
cs)
  let cs' :: NonEmpty (CaseBase Info VName)
cs' = ((CaseBase Info VName, StaticVal) -> CaseBase Info VName)
-> NonEmpty (CaseBase Info VName, StaticVal)
-> NonEmpty (CaseBase Info VName)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CaseBase Info VName, StaticVal) -> CaseBase Info VName
forall a b. (a, b) -> a
fst NonEmpty (CaseBase Info VName, StaticVal)
csPairs
      sv' :: StaticVal
sv' = (CaseBase Info VName, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd ((CaseBase Info VName, StaticVal) -> StaticVal)
-> (CaseBase Info VName, StaticVal) -> StaticVal
forall a b. (a -> b) -> a -> b
$ NonEmpty (CaseBase Info VName, StaticVal)
-> (CaseBase Info VName, StaticVal)
forall a. NonEmpty a -> a
NE.head NonEmpty (CaseBase Info VName, StaticVal)
csPairs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> NonEmpty (CaseBase Info VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn) -> SrcLoc -> AppExpBase f vn
Match Exp
e' NonEmpty (CaseBase Info VName)
cs' SrcLoc
loc) Info AppRes
res, StaticVal
sv')
defuncExp (Attr AttrInfo VName
info Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AttrInfo VName -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
AttrInfo vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Attr AttrInfo VName
info Exp
e' SrcLoc
loc, StaticVal
sv)

-- | Same as 'defuncExp', except it ignores the static value.
defuncExp' :: Exp -> DefM Exp
defuncExp' :: Exp -> DefM Exp
defuncExp' = ((Exp, StaticVal) -> Exp) -> DefM (Exp, StaticVal) -> DefM Exp
forall a b. (a -> b) -> DefM a -> DefM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Exp, StaticVal) -> Exp
forall a b. (a, b) -> a
fst (DefM (Exp, StaticVal) -> DefM Exp)
-> (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> DefM (Exp, StaticVal)
defuncExp

defuncCase :: StaticVal -> Case -> DefM (Maybe (Case, StaticVal))
defuncCase :: StaticVal
-> CaseBase Info VName
-> DefM (Maybe (CaseBase Info VName, StaticVal))
defuncCase StaticVal
sv (CasePat PatBase Info VName StructType
p Exp
e SrcLoc
loc) = do
  let p' :: Pat ParamType
p' = Pat ParamType -> StaticVal -> Pat ParamType
updatePat ((StructType -> ParamType)
-> PatBase Info VName StructType -> Pat ParamType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe) PatBase Info VName StructType
p) StaticVal
sv
  case Pat ParamType -> StaticVal -> Maybe Env
matchPatSV ((StructType -> ParamType)
-> PatBase Info VName StructType -> Pat ParamType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe) PatBase Info VName StructType
p) StaticVal
sv of
    Just Env
env -> do
      (Exp
e', StaticVal
sv') <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
      Maybe (CaseBase Info VName, StaticVal)
-> DefM (Maybe (CaseBase Info VName, StaticVal))
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (CaseBase Info VName, StaticVal)
 -> DefM (Maybe (CaseBase Info VName, StaticVal)))
-> Maybe (CaseBase Info VName, StaticVal)
-> DefM (Maybe (CaseBase Info VName, StaticVal))
forall a b. (a -> b) -> a -> b
$ (CaseBase Info VName, StaticVal)
-> Maybe (CaseBase Info VName, StaticVal)
forall a. a -> Maybe a
Just (PatBase Info VName StructType
-> Exp -> SrcLoc -> CaseBase Info VName
forall (f :: * -> *) vn.
PatBase f vn StructType -> ExpBase f vn -> SrcLoc -> CaseBase f vn
CasePat ((ParamType -> StructType)
-> Pat ParamType -> PatBase Info VName StructType
forall a b.
(a -> b) -> PatBase Info VName a -> PatBase Info VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ParamType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct Pat ParamType
p') Exp
e' SrcLoc
loc, StaticVal
sv')
    Maybe Env
Nothing ->
      Maybe (CaseBase Info VName, StaticVal)
-> DefM (Maybe (CaseBase Info VName, StaticVal))
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (CaseBase Info VName, StaticVal)
forall a. Maybe a
Nothing

-- | Defunctionalize the function argument to a SOAC by eta-expanding if
-- necessary and then defunctionalizing the body of the introduced lambda.
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp e :: Exp
e@OpSection {} = Exp -> DefM Exp
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp e :: Exp
e@OpSectionLeft {} = Exp -> DefM Exp
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp e :: Exp
e@OpSectionRight {} = Exp -> DefM Exp
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp e :: Exp
e@ProjectSection {} = Exp -> DefM Exp
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp (Parens Exp
e SrcLoc
loc) =
  Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens (Exp -> SrcLoc -> Exp) -> DefM Exp -> DefM (SrcLoc -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM Exp
defuncSoacExp Exp
e DefM (SrcLoc -> Exp) -> DefM SrcLoc -> DefM Exp
forall a b. DefM (a -> b) -> DefM a -> DefM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> DefM SrcLoc
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
defuncSoacExp (Lambda [Pat ParamType]
params Exp
e0 Maybe (TypeExp Info VName)
decl Info ResRetType
tp SrcLoc
loc) = do
  let env :: Env
env = (Pat ParamType -> Env) -> [Pat ParamType] -> Env
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> Env
envFromPat [Pat ParamType]
params
  Exp
e0' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncSoacExp Exp
e0
  Exp -> DefM Exp
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> DefM Exp) -> Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ [Pat ParamType]
-> Exp
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda [Pat ParamType]
params Exp
e0' Maybe (TypeExp Info VName)
decl Info ResRetType
tp SrcLoc
loc
defuncSoacExp Exp
e
  | Scalar Arrow {} <- Exp -> StructType
typeOf Exp
e = do
      ([Pat ParamType]
pats, Exp
body, ResRetType
tp) <- ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StructType -> TypeBase Exp Uniqueness
forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique (StructType -> TypeBase Exp Uniqueness)
-> StructType -> TypeBase Exp Uniqueness
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e) Exp
e
      let env :: Env
env = (Pat ParamType -> Env) -> [Pat ParamType] -> Env
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> Env
envFromPat [Pat ParamType]
pats
      Exp
body' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
body
      Exp -> DefM Exp
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> DefM Exp) -> Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ [Pat ParamType]
-> Exp
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda [Pat ParamType]
pats Exp
body' Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing (ResRetType -> Info ResRetType
forall a. a -> Info a
Info ResRetType
tp) SrcLoc
forall a. Monoid a => a
mempty
  | Bool
otherwise = Exp -> DefM Exp
defuncExp' Exp
e

etaExpand :: ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand :: ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand ResRetType
e_t Exp
e = do
  let ([(PName, (Diet, StructType))]
ps, ResRetType
ret) = ResRetType -> ([(PName, (Diet, StructType))], ResRetType)
forall {dim}.
RetTypeBase dim Uniqueness
-> ([(PName, (Diet, TypeBase dim NoUniqueness))],
    RetTypeBase dim Uniqueness)
getType ResRetType
e_t
  -- Some careful hackery to avoid duplicate names.
  ([VName]
_, ([Pat ParamType]
params, [Exp]
vars)) <- ([(Pat ParamType, Exp)] -> ([Pat ParamType], [Exp]))
-> ([VName], [(Pat ParamType, Exp)])
-> ([VName], ([Pat ParamType], [Exp]))
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [(Pat ParamType, Exp)] -> ([Pat ParamType], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip (([VName], [(Pat ParamType, Exp)])
 -> ([VName], ([Pat ParamType], [Exp])))
-> DefM ([VName], [(Pat ParamType, Exp)])
-> DefM ([VName], ([Pat ParamType], [Exp]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([VName]
 -> (PName, (Diet, StructType))
 -> DefM ([VName], (Pat ParamType, Exp)))
-> [VName]
-> [(PName, (Diet, StructType))]
-> DefM ([VName], [(Pat ParamType, Exp)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [VName]
-> (PName, (Diet, StructType))
-> DefM ([VName], (Pat ParamType, Exp))
forall {m :: * -> *} {u} {b}.
MonadFreshNames m =>
[VName]
-> (PName, (u, TypeBase Exp b))
-> m ([VName], (PatBase Info VName (TypeBase Exp u), Exp))
f [] [(PName, (Diet, StructType))]
ps
  -- Important that we synthesize new existential names and substitute
  -- them into the (body) return type.
  [VName]
ext' <- (VName -> DefM VName) -> [VName] -> DefM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> DefM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName ([VName] -> DefM [VName]) -> [VName] -> DefM [VName]
forall a b. (a -> b) -> a -> b
$ ResRetType -> [VName]
forall dim as. RetTypeBase dim as -> [VName]
retDims ResRetType
ret
  let extsubst :: Map VName (Subst t)
extsubst =
        [(VName, Subst t)] -> Map VName (Subst t)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Subst t)] -> Map VName (Subst t))
-> ([Subst t] -> [(VName, Subst t)])
-> [Subst t]
-> Map VName (Subst t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [Subst t] -> [(VName, Subst t)]
forall a b. [a] -> [b] -> [(a, b)]
zip (ResRetType -> [VName]
forall dim as. RetTypeBase dim as -> [VName]
retDims ResRetType
ret) ([Subst t] -> Map VName (Subst t))
-> [Subst t] -> Map VName (Subst t)
forall a b. (a -> b) -> a -> b
$
          (VName -> Subst t) -> [VName] -> [Subst t]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Subst t
forall t. Exp -> Subst t
ExpSubst (Exp -> Subst t) -> (VName -> Exp) -> VName -> Subst t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (QualName VName -> SrcLoc -> Exp)
-> SrcLoc -> QualName VName -> Exp
forall a b c. (a -> b -> c) -> b -> a -> c
flip QualName VName -> SrcLoc -> Exp
sizeFromName SrcLoc
forall a. Monoid a => a
mempty (QualName VName -> Exp)
-> (VName -> QualName VName) -> VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName) [VName]
ext'
      ret' :: ResRetType
ret' = TypeSubs -> ResRetType -> ResRetType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName
-> Map VName (Subst StructRetType) -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
forall {t}. Map VName (Subst t)
extsubst) ResRetType
ret
      e' :: Exp
e' =
        Exp -> [(Diet, Maybe VName, Exp)] -> AppRes -> Exp
forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply
          Exp
e
          ([Diet] -> [Maybe VName] -> [Exp] -> [(Diet, Maybe VName, Exp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (((PName, (Diet, StructType)) -> Diet)
-> [(PName, (Diet, StructType))] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map ((Diet, StructType) -> Diet
forall a b. (a, b) -> a
fst ((Diet, StructType) -> Diet)
-> ((PName, (Diet, StructType)) -> (Diet, StructType))
-> (PName, (Diet, StructType))
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PName, (Diet, StructType)) -> (Diet, StructType)
forall a b. (a, b) -> b
snd) [(PName, (Diet, StructType))]
ps) (Maybe VName -> [Maybe VName]
forall a. a -> [a]
repeat Maybe VName
forall a. Maybe a
Nothing) [Exp]
vars)
          (StructType -> [VName] -> AppRes
AppRes (TypeBase Exp Uniqueness -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (TypeBase Exp Uniqueness -> StructType)
-> TypeBase Exp Uniqueness -> StructType
forall a b. (a -> b) -> a -> b
$ ResRetType -> TypeBase Exp Uniqueness
forall dim as. RetTypeBase dim as -> TypeBase dim as
retType ResRetType
ret') [VName]
ext')
  ([Pat ParamType], Exp, ResRetType)
-> DefM ([Pat ParamType], Exp, ResRetType)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Pat ParamType]
params, Exp
e', ResRetType
ret)
  where
    getType :: RetTypeBase dim Uniqueness
-> ([(PName, (Diet, TypeBase dim NoUniqueness))],
    RetTypeBase dim Uniqueness)
getType (RetType [VName]
_ (Scalar (Arrow Uniqueness
_ PName
p Diet
d TypeBase dim NoUniqueness
t1 RetTypeBase dim Uniqueness
t2))) =
      let ([(PName, (Diet, TypeBase dim NoUniqueness))]
ps, RetTypeBase dim Uniqueness
r) = RetTypeBase dim Uniqueness
-> ([(PName, (Diet, TypeBase dim NoUniqueness))],
    RetTypeBase dim Uniqueness)
getType RetTypeBase dim Uniqueness
t2
       in ((PName
p, (Diet
d, TypeBase dim NoUniqueness
t1)) (PName, (Diet, TypeBase dim NoUniqueness))
-> [(PName, (Diet, TypeBase dim NoUniqueness))]
-> [(PName, (Diet, TypeBase dim NoUniqueness))]
forall a. a -> [a] -> [a]
: [(PName, (Diet, TypeBase dim NoUniqueness))]
ps, RetTypeBase dim Uniqueness
r)
    getType RetTypeBase dim Uniqueness
t = ([], RetTypeBase dim Uniqueness
t)

    f :: [VName]
-> (PName, (u, TypeBase Exp b))
-> m ([VName], (PatBase Info VName (TypeBase Exp u), Exp))
f [VName]
prev (PName
p, (u
d, TypeBase Exp b
t)) = do
      let t' :: TypeBase Exp u
t' = (b -> u) -> TypeBase Exp b -> TypeBase Exp u
forall b c a. (b -> c) -> TypeBase a b -> TypeBase a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (u -> b -> u
forall a b. a -> b -> a
const u
d) TypeBase Exp b
t
      VName
x <- case PName
p of
        Named VName
x | VName
x VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
prev -> VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
x
        PName
_ -> [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString [Char]
"eta_p"
      ([VName], (PatBase Info VName (TypeBase Exp u), Exp))
-> m ([VName], (PatBase Info VName (TypeBase Exp u), Exp))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
prev,
          ( VName
-> Info (TypeBase Exp u)
-> SrcLoc
-> PatBase Info VName (TypeBase Exp u)
forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
x (TypeBase Exp u -> Info (TypeBase Exp u)
forall a. a -> Info a
Info TypeBase Exp u
t') SrcLoc
forall a. Monoid a => a
mempty,
            QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
x) (StructType -> Info StructType
forall a. a -> Info a
Info (StructType -> Info StructType) -> StructType -> Info StructType
forall a b. (a -> b) -> a -> b
$ TypeBase Exp u -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase Exp u
t') SrcLoc
forall a. Monoid a => a
mempty
          )
        )

-- | Defunctionalize an indexing of a single array dimension.
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex (DimFix Exp
e1) = Exp -> DimIndexBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> DimIndexBase f vn
DimFix (Exp -> DimIndexBase Info VName)
-> ((Exp, StaticVal) -> Exp)
-> (Exp, StaticVal)
-> DimIndexBase Info VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp, StaticVal) -> Exp
forall a b. (a, b) -> a
fst ((Exp, StaticVal) -> DimIndexBase Info VName)
-> DefM (Exp, StaticVal) -> DefM (DimIndexBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
defuncDimIndex (DimSlice Maybe Exp
me1 Maybe Exp
me2 Maybe Exp
me3) =
  Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndexBase Info VName
forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice (Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp)
-> DefM (Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me1 DefM (Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp) -> DefM (Maybe Exp -> DimIndexBase Info VName)
forall a b. DefM (a -> b) -> DefM a -> DefM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me2 DefM (Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp) -> DefM (DimIndexBase Info VName)
forall a b. DefM (a -> b) -> DefM a -> DefM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me3
  where
    defunc' :: Maybe Exp -> DefM (Maybe Exp)
defunc' = (Exp -> DefM Exp) -> Maybe Exp -> DefM (Maybe Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Maybe a -> m (Maybe b)
mapM Exp -> DefM Exp
defuncExp'

-- | Defunctionalize a let-bound function, while preserving parameters
-- that have order 0 types (i.e., non-functional).
defuncLet ::
  [VName] ->
  [Pat ParamType] ->
  Exp ->
  ResRetType ->
  DefM ([VName], [Pat ParamType], Exp, StaticVal, ResType)
defuncLet :: [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
defuncLet [VName]
dims ps :: [Pat ParamType]
ps@(Pat ParamType
pat : [Pat ParamType]
pats) Exp
body (RetType [VName]
ret_dims TypeBase Exp Uniqueness
rettype)
  | Pat ParamType -> Bool
forall d u. Pat (TypeBase d u) -> Bool
patternOrderZero Pat ParamType
pat = do
      let bound_by_pat :: VName -> Bool
bound_by_pat = (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` FV -> Set VName
fvVars (Pat ParamType -> FV
forall u. Pat (TypeBase Exp u) -> FV
freeInPat Pat ParamType
pat))
          -- Take care to not include more size parameters than necessary.
          ([VName]
pat_dims, [VName]
rest_dims) = (VName -> Bool) -> [VName] -> ([VName], [VName])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition VName -> Bool
bound_by_pat [VName]
dims
          env :: Env
env = Pat ParamType -> Env
envFromPat Pat ParamType
pat Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> [VName] -> Env
envFromDimNames [VName]
pat_dims
      ([VName]
rest_dims', [Pat ParamType]
pats', Exp
body', StaticVal
sv, TypeBase Exp Uniqueness
sv_t) <-
        Env
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM
   ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
 -> DefM
      ([VName], [Pat ParamType], Exp, StaticVal,
       TypeBase Exp Uniqueness))
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
forall a b. (a -> b) -> a -> b
$ [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
defuncLet [VName]
rest_dims [Pat ParamType]
pats Exp
body (ResRetType
 -> DefM
      ([VName], [Pat ParamType], Exp, StaticVal,
       TypeBase Exp Uniqueness))
-> ResRetType
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
forall a b. (a -> b) -> a -> b
$ [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims TypeBase Exp Uniqueness
rettype
      (Exp, StaticVal)
closure <- [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [VName]
dims [Pat ParamType]
ps Exp
body ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims TypeBase Exp Uniqueness
rettype) SrcLoc
forall a. Monoid a => a
mempty
      ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( [VName]
pat_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
rest_dims',
          Pat ParamType
pat Pat ParamType -> [Pat ParamType] -> [Pat ParamType]
forall a. a -> [a] -> [a]
: [Pat ParamType]
pats',
          Exp
body',
          (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
closure StaticVal
sv,
          TypeBase Exp Uniqueness
sv_t
        )
  | Bool
otherwise = do
      (Exp
e, StaticVal
sv) <- [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [VName]
dims [Pat ParamType]
ps Exp
body ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims TypeBase Exp Uniqueness
rettype) SrcLoc
forall a. Monoid a => a
mempty
      ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [], Exp
e, StaticVal
sv, StaticVal -> TypeBase Exp Uniqueness
resTypeFromSV StaticVal
sv)
defuncLet [VName]
_ [] Exp
body (RetType [VName]
_ TypeBase Exp Uniqueness
rettype) = do
  (Exp
body', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [],
      [],
      Exp
body',
      StaticVal -> ParamType -> StaticVal
imposeType StaticVal
sv (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ TypeBase Exp Uniqueness -> ParamType
resToParam TypeBase Exp Uniqueness
rettype,
      StaticVal -> TypeBase Exp Uniqueness
resTypeFromSV StaticVal
sv
    )
  where
    imposeType :: StaticVal -> ParamType -> StaticVal
imposeType Dynamic {} ParamType
t =
      ParamType -> StaticVal
Dynamic ParamType
t
    imposeType (RecordSV [(Name, StaticVal)]
fs1) (Scalar (Record Map Name ParamType
fs2)) =
      [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ Map Name StaticVal -> [(Name, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name StaticVal -> [(Name, StaticVal)])
-> Map Name StaticVal -> [(Name, StaticVal)]
forall a b. (a -> b) -> a -> b
$ (StaticVal -> ParamType -> StaticVal)
-> Map Name StaticVal -> Map Name ParamType -> Map Name StaticVal
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith StaticVal -> ParamType -> StaticVal
imposeType ([(Name, StaticVal)] -> Map Name StaticVal
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, StaticVal)]
fs1) Map Name ParamType
fs2
    imposeType StaticVal
sv ParamType
_ = StaticVal
sv

instAnySizes :: MonadFreshNames m => [Pat ParamType] -> m [Pat ParamType]
instAnySizes :: forall (m :: * -> *).
MonadFreshNames m =>
[Pat ParamType] -> m [Pat ParamType]
instAnySizes = (Pat ParamType -> m (Pat ParamType))
-> [Pat ParamType] -> m [Pat ParamType]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Pat ParamType -> m (Pat ParamType))
 -> [Pat ParamType] -> m [Pat ParamType])
-> (Pat ParamType -> m (Pat ParamType))
-> [Pat ParamType]
-> m [Pat ParamType]
forall a b. (a -> b) -> a -> b
$ (ParamType -> m ParamType) -> Pat ParamType -> m (Pat ParamType)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> PatBase Info VName a -> f (PatBase Info VName b)
traverse ((ParamType -> m ParamType) -> Pat ParamType -> m (Pat ParamType))
-> (ParamType -> m ParamType) -> Pat ParamType -> m (Pat ParamType)
forall a b. (a -> b) -> a -> b
$ (Exp -> m Exp) -> (Diet -> m Diet) -> ParamType -> m ParamType
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> TypeBase a b -> f (TypeBase c d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Exp -> m Exp
forall {m :: * -> *}. MonadFreshNames m => Exp -> m Exp
onDim Diet -> m Diet
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  where
    onDim :: Exp -> m Exp
onDim Exp
d
      | Exp
d Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
anySize = do
          VName
v <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"size"
          Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (VName -> QualName VName
forall v. v -> QualName v
qualName VName
v) SrcLoc
forall a. Monoid a => a
mempty
    onDim Exp
d = Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
d

unboundSizes :: S.Set VName -> [Pat ParamType] -> [VName]
unboundSizes :: Set VName -> [Pat ParamType] -> [VName]
unboundSizes Set VName
bound_sizes [Pat ParamType]
params = [VName] -> [VName]
forall a. Ord a => [a] -> [a]
nubOrd ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ State [VName] [Pat ParamType] -> [VName] -> [VName]
forall s a. State s a -> s -> s
execState ([Pat ParamType] -> State [VName] [Pat ParamType]
forall {f :: * -> *} {d}.
[PatBase Info VName (TypeBase (ExpBase f VName) d)]
-> StateT
     [VName]
     Identity
     [PatBase Info VName (TypeBase (ExpBase f VName) d)]
f [Pat ParamType]
params) []
  where
    f :: [PatBase Info VName (TypeBase (ExpBase f VName) d)]
-> StateT
     [VName]
     Identity
     [PatBase Info VName (TypeBase (ExpBase f VName) d)]
f = (PatBase Info VName (TypeBase (ExpBase f VName) d)
 -> StateT
      [VName]
      Identity
      (PatBase Info VName (TypeBase (ExpBase f VName) d)))
-> [PatBase Info VName (TypeBase (ExpBase f VName) d)]
-> StateT
     [VName]
     Identity
     [PatBase Info VName (TypeBase (ExpBase f VName) d)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((PatBase Info VName (TypeBase (ExpBase f VName) d)
  -> StateT
       [VName]
       Identity
       (PatBase Info VName (TypeBase (ExpBase f VName) d)))
 -> [PatBase Info VName (TypeBase (ExpBase f VName) d)]
 -> StateT
      [VName]
      Identity
      [PatBase Info VName (TypeBase (ExpBase f VName) d)])
-> (PatBase Info VName (TypeBase (ExpBase f VName) d)
    -> StateT
         [VName]
         Identity
         (PatBase Info VName (TypeBase (ExpBase f VName) d)))
-> [PatBase Info VName (TypeBase (ExpBase f VName) d)]
-> StateT
     [VName]
     Identity
     [PatBase Info VName (TypeBase (ExpBase f VName) d)]
forall a b. (a -> b) -> a -> b
$ (TypeBase (ExpBase f VName) d
 -> StateT [VName] Identity (TypeBase (ExpBase f VName) d))
-> PatBase Info VName (TypeBase (ExpBase f VName) d)
-> StateT
     [VName]
     Identity
     (PatBase Info VName (TypeBase (ExpBase f VName) d))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> PatBase Info VName a -> f (PatBase Info VName b)
traverse ((TypeBase (ExpBase f VName) d
  -> StateT [VName] Identity (TypeBase (ExpBase f VName) d))
 -> PatBase Info VName (TypeBase (ExpBase f VName) d)
 -> StateT
      [VName]
      Identity
      (PatBase Info VName (TypeBase (ExpBase f VName) d)))
-> (TypeBase (ExpBase f VName) d
    -> StateT [VName] Identity (TypeBase (ExpBase f VName) d))
-> PatBase Info VName (TypeBase (ExpBase f VName) d)
-> StateT
     [VName]
     Identity
     (PatBase Info VName (TypeBase (ExpBase f VName) d))
forall a b. (a -> b) -> a -> b
$ (ExpBase f VName -> StateT [VName] Identity (ExpBase f VName))
-> (d -> StateT [VName] Identity d)
-> TypeBase (ExpBase f VName) d
-> StateT [VName] Identity (TypeBase (ExpBase f VName) d)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> TypeBase a b -> f (TypeBase c d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse ExpBase f VName -> StateT [VName] Identity (ExpBase f VName)
forall {m :: * -> *} {f :: * -> *}.
MonadState [VName] m =>
ExpBase f VName -> m (ExpBase f VName)
onDim d -> StateT [VName] Identity d
forall a. a -> StateT [VName] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    bound :: Set VName
bound = Set VName
bound_sizes Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ((Pat ParamType -> [VName]) -> [Pat ParamType] -> [VName]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> [VName]
forall t. Pat t -> [VName]
patNames [Pat ParamType]
params)
    onDim :: ExpBase f VName -> m (ExpBase f VName)
onDim (Var QualName VName
d f StructType
typ SrcLoc
loc) = do
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ ([VName] -> [VName]) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d :)
      ExpBase f VName -> m (ExpBase f VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpBase f VName -> m (ExpBase f VName))
-> ExpBase f VName -> m (ExpBase f VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> f StructType -> SrcLoc -> ExpBase f VName
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
d f StructType
typ SrcLoc
loc
    onDim ExpBase f VName
d = ExpBase f VName -> m (ExpBase f VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpBase f VName
d

unRetType :: ResRetType -> DefM AppRes
unRetType :: ResRetType -> DefM AppRes
unRetType (RetType [] TypeBase Exp Uniqueness
t) = AppRes -> DefM AppRes
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AppRes -> DefM AppRes) -> AppRes -> DefM AppRes
forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes (TypeBase Exp Uniqueness -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase Exp Uniqueness
t) []
unRetType (RetType [VName]
ext TypeBase Exp Uniqueness
t) = do
  [VName]
ext' <- (VName -> DefM VName) -> [VName] -> DefM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> DefM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName [VName]
ext
  let extsubst :: Map VName (Subst t)
extsubst =
        [(VName, Subst t)] -> Map VName (Subst t)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Subst t)] -> Map VName (Subst t))
-> ([Subst t] -> [(VName, Subst t)])
-> [Subst t]
-> Map VName (Subst t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [Subst t] -> [(VName, Subst t)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ext ([Subst t] -> Map VName (Subst t))
-> [Subst t] -> Map VName (Subst t)
forall a b. (a -> b) -> a -> b
$
          (VName -> Subst t) -> [VName] -> [Subst t]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Subst t
forall t. Exp -> Subst t
ExpSubst (Exp -> Subst t) -> (VName -> Exp) -> VName -> Subst t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (QualName VName -> SrcLoc -> Exp)
-> SrcLoc -> QualName VName -> Exp
forall a b c. (a -> b -> c) -> b -> a -> c
flip QualName VName -> SrcLoc -> Exp
sizeFromName SrcLoc
forall a. Monoid a => a
mempty (QualName VName -> Exp)
-> (VName -> QualName VName) -> VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName) [VName]
ext'
  AppRes -> DefM AppRes
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AppRes -> DefM AppRes) -> AppRes -> DefM AppRes
forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes (TypeSubs -> StructType -> StructType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName
-> Map VName (Subst StructRetType) -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
forall {t}. Map VName (Subst t)
extsubst) (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeBase Exp Uniqueness -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase Exp Uniqueness
t) [VName]
ext'

defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction e :: Exp
e@(Var QualName VName
qn (Info StructType
t) SrcLoc
loc) Int
num_args = do
  let ([ParamType]
argtypes, StructType
rettype) = StructType -> ([ParamType], StructType)
forall dim as.
TypeBase dim as -> ([TypeBase dim Diet], TypeBase dim NoUniqueness)
unfoldFunType StructType
t
  StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
t) (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)

  case StaticVal
sv of
    DynamicFun (Exp, StaticVal)
_ StaticVal
_
      | StaticVal -> Int -> Bool
fullyApplied StaticVal
sv Int
num_args -> do
          -- We still need to update the types in case the dynamic
          -- function returns a higher-order term.
          let ([ParamType]
argtypes', TypeBase Exp Uniqueness
rettype') = StaticVal -> [ParamType] -> ([ParamType], TypeBase Exp Uniqueness)
dynamicFunType StaticVal
sv [ParamType]
argtypes
          (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (StructType -> Info StructType
forall a. a -> Info a
Info ([ParamType] -> ResRetType -> StructType
foldFunType [ParamType]
argtypes' (ResRetType -> StructType) -> ResRetType -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] TypeBase Exp Uniqueness
rettype')) SrcLoc
loc, StaticVal
sv)
      | (ParamType -> Bool) -> [ParamType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ParamType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero [ParamType]
argtypes,
        StructType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero StructType
rettype -> do
          ([Pat ParamType]
params, Exp
body, ResRetType
ret) <- ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StructType -> TypeBase Exp Uniqueness
forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique StructType
t) Exp
e
          [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [] [Pat ParamType]
params Exp
body ResRetType
ret SrcLoc
forall a. Monoid a => a
mempty
      | Bool
otherwise -> do
          VName
fname <- [Char] -> DefM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> DefM VName) -> [Char] -> DefM VName
forall a b. (a -> b) -> a -> b
$ [Char]
"dyn_" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> [Char]
baseString (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
          let ([Pat ParamType]
pats, Exp
e0, StaticVal
sv') = [Char] -> StaticVal -> Int -> ([Pat ParamType], Exp, StaticVal)
liftDynFun (QualName VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString QualName VName
qn) StaticVal
sv Int
num_args
              ([ParamType]
argtypes', TypeBase Exp Uniqueness
rettype') = StaticVal -> [ParamType] -> ([ParamType], TypeBase Exp Uniqueness)
dynamicFunType StaticVal
sv' [ParamType]
argtypes
              dims' :: [VName]
dims' = [VName]
forall a. Monoid a => a
mempty

          -- Ensure that no parameter sizes are AnySize.  The internaliser
          -- expects this.  This is easy, because they are all
          -- first-order.
          Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst
          let bound_sizes :: Set VName
bound_sizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims' Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
globals
          [Pat ParamType]
pats' <- [Pat ParamType] -> DefM [Pat ParamType]
forall (m :: * -> *).
MonadFreshNames m =>
[Pat ParamType] -> m [Pat ParamType]
instAnySizes [Pat ParamType]
pats

          VName -> ResRetType -> [VName] -> [Pat ParamType] -> Exp -> DefM ()
liftValDec VName
fname ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] TypeBase Exp Uniqueness
rettype') ([VName]
dims' [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Set VName -> [Pat ParamType] -> [VName]
unboundSizes Set VName
bound_sizes [Pat ParamType]
pats') [Pat ParamType]
pats' Exp
e0
          (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var
                (VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname)
                (StructType -> Info StructType
forall a. a -> Info a
Info ([ParamType] -> ResRetType -> StructType
foldFunType [ParamType]
argtypes' (ResRetType -> StructType) -> ResRetType -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] TypeBase Exp Uniqueness
rettype'))
                SrcLoc
loc,
              StaticVal
sv'
            )
    StaticVal
IntrinsicSV -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, StaticVal
IntrinsicSV)
    StaticVal
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (StructType -> Info StructType
forall a. a -> Info a
Info (StaticVal -> StructType
structTypeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)
defuncApplyFunction Exp
e Int
_ = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e

-- Embed some information about the original function
-- into the name of the lifted function, to make the
-- result slightly more human-readable.
liftedName :: Int -> Exp -> String
liftedName :: Int -> Exp -> [Char]
liftedName Int
i (Var QualName VName
f Info StructType
_ SrcLoc
_) =
  [Char]
"defunc_" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
baseString (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
f)
liftedName Int
i (AppExp (Apply Exp
f NonEmpty (Info (Diet, Maybe VName), Exp)
_ SrcLoc
_) Info AppRes
_) =
  Int -> Exp -> [Char]
liftedName (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Exp
f
liftedName Int
_ Exp
_ = [Char]
"defunc"

defuncApplyArg ::
  String ->
  (Exp, StaticVal) ->
  (((Diet, Maybe VName), Exp), [ParamType]) ->
  DefM (Exp, StaticVal)
defuncApplyArg :: [Char]
-> (Exp, StaticVal)
-> (((Diet, Maybe VName), Exp), [ParamType])
-> DefM (Exp, StaticVal)
defuncApplyArg [Char]
fname_s (Exp
f', LambdaSV Pat ParamType
pat ResRetType
lam_e_t Exp
lam_e Env
closure_env) (((Diet
d, Maybe VName
argext), Exp
arg), [ParamType]
_) = do
  (Exp
arg', StaticVal
arg_sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
arg
  let env' :: Env
env' = Pat ParamType -> StaticVal -> Env
alwaysMatchPatSV Pat ParamType
pat StaticVal
arg_sv
      dims :: [VName]
dims = [VName]
forall a. Monoid a => a
mempty
  (Exp
lam_e', StaticVal
sv) <-
    Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localNewEnv (Env
env' Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
closure_env) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
      Exp -> DefM (Exp, StaticVal)
defuncExp Exp
lam_e

  let closure_pat :: Pat ParamType
closure_pat = [VName] -> Env -> Pat ParamType
buildEnvPat [VName]
dims Env
closure_env
      pat' :: Pat ParamType
pat' = Pat ParamType -> StaticVal -> Pat ParamType
updatePat Pat ParamType
pat StaticVal
arg_sv

  Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst

  -- Lift lambda to top-level function definition.  We put in
  -- a lot of effort to try to infer the uniqueness attributes
  -- of the lifted function, but this is ultimately all a sham
  -- and a hack.  There is some piece we're missing.
  let params :: [Pat ParamType]
params = [Pat ParamType
closure_pat, Pat ParamType
pat']
      lifted_rettype :: ResRetType
lifted_rettype =
        [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType (ResRetType -> [VName]
forall dim as. RetTypeBase dim as -> [VName]
retDims ResRetType
lam_e_t) (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$
          TypeBase Exp Uniqueness
-> TypeBase Exp Uniqueness -> TypeBase Exp Uniqueness
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes (ResRetType -> TypeBase Exp Uniqueness
forall dim as. RetTypeBase dim as -> TypeBase dim as
retType ResRetType
lam_e_t) (StaticVal -> TypeBase Exp Uniqueness
resTypeFromSV StaticVal
sv)

      already_bound :: Set VName
already_bound =
        Set VName
globals Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName]
dims [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> (Pat ParamType -> [VName]) -> [Pat ParamType] -> [VName]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> [VName]
forall t. Pat t -> [VName]
patNames [Pat ParamType]
params)

      more_dims :: [VName]
more_dims =
        Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$
          (VName -> Bool) -> Set VName -> Set VName
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
already_bound) (Set VName -> Set VName) -> Set VName -> Set VName
forall a b. (a -> b) -> a -> b
$
            (Pat ParamType -> Set VName) -> [Pat ParamType] -> Set VName
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> Set VName
patternArraySizes [Pat ParamType]
params

  -- Ensure that no parameter sizes are AnySize.  The internaliser
  -- expects this.  This is easy, because they are all
  -- first-order.
  let bound_sizes :: Set VName
bound_sizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName]
dims [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_dims) Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
globals
  [Pat ParamType]
params' <- [Pat ParamType] -> DefM [Pat ParamType]
forall (m :: * -> *).
MonadFreshNames m =>
[Pat ParamType] -> m [Pat ParamType]
instAnySizes [Pat ParamType]
params

  VName
fname <- [Char] -> DefM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString [Char]
fname_s
  VName -> ResRetType -> [VName] -> [Pat ParamType] -> Exp -> DefM ()
liftValDec
    VName
fname
    ResRetType
lifted_rettype
    ([VName]
dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
more_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Set VName -> [Pat ParamType] -> [VName]
unboundSizes Set VName
bound_sizes [Pat ParamType]
params')
    [Pat ParamType]
params'
    Exp
lam_e'

  let f_t :: StructType
f_t = StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
f'
      arg_t :: StructType
arg_t = StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
arg'
      fname_t :: StructType
fname_t = [ParamType] -> ResRetType -> StructType
foldFunType [Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
f_t, Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
d StructType
arg_t] ResRetType
lifted_rettype
      fname' :: Exp
fname' = QualName VName -> Info StructType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname) (StructType -> Info StructType
forall a. a -> Info a
Info StructType
fname_t) (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
arg)
  AppRes
callret <- ResRetType -> DefM AppRes
unRetType ResRetType
lifted_rettype

  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Exp -> [(Diet, Maybe VName, Exp)] -> AppRes -> Exp
forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
fname' [(Diet
Observe, Maybe VName
forall a. Maybe a
Nothing, Exp
f'), (Diet
Observe, Maybe VName
argext, Exp
arg')] AppRes
callret,
      StaticVal
sv
    )
-- If 'f' is a dynamic function, we just leave the application in
-- place, but we update the types since it may be partially
-- applied or return a higher-order value.
defuncApplyArg [Char]
_ (Exp
f', DynamicFun (Exp, StaticVal)
_ StaticVal
sv) (((Diet
d, Maybe VName
argext), Exp
arg), [ParamType]
argtypes) = do
  (Exp
arg', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
arg
  let ([ParamType]
argtypes', TypeBase Exp Uniqueness
rettype) = StaticVal -> [ParamType] -> ([ParamType], TypeBase Exp Uniqueness)
dynamicFunType StaticVal
sv [ParamType]
argtypes
      restype :: StructType
restype = [ParamType] -> ResRetType -> StructType
foldFunType [ParamType]
argtypes' ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] TypeBase Exp Uniqueness
rettype)
      callret :: AppRes
callret = StructType -> [VName] -> AppRes
AppRes StructType
restype []
      apply_e :: Exp
apply_e = Exp -> [(Diet, Maybe VName, Exp)] -> AppRes -> Exp
forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
f' [(Diet
d, Maybe VName
argext, Exp
arg')] AppRes
callret
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
apply_e, StaticVal
sv)
--
defuncApplyArg [Char]
fname_s (Exp
_, StaticVal
sv) (((Diet, Maybe VName), Exp), [ParamType])
_ =
  [Char] -> DefM (Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error ([Char] -> DefM (Exp, StaticVal))
-> [Char] -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
    [Char]
"defuncApplyArg: cannot apply StaticVal\n"
      [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> StaticVal -> [Char]
forall a. Show a => a -> [Char]
show StaticVal
sv
      [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\nFunction name: "
      [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> ShowS
forall a. Pretty a => a -> [Char]
prettyString [Char]
fname_s

updateReturn :: AppRes -> Exp -> Exp
updateReturn :: AppRes -> Exp -> Exp
updateReturn (AppRes StructType
ret1 [VName]
ext1) (AppExp AppExpBase Info VName
apply (Info (AppRes StructType
ret2 [VName]
ext2))) =
  AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp AppExpBase Info VName
apply (Info AppRes -> Exp) -> Info AppRes -> Exp
forall a b. (a -> b) -> a -> b
$ AppRes -> Info AppRes
forall a. a -> Info a
Info (AppRes -> Info AppRes) -> AppRes -> Info AppRes
forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes (StructType -> StructType -> StructType
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes StructType
ret1 StructType
ret2) ([VName]
ext1 [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
ext2)
updateReturn AppRes
_ Exp
e = Exp
e

defuncApply :: Exp -> NE.NonEmpty ((Diet, Maybe VName), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal)
defuncApply :: Exp
-> NonEmpty ((Diet, Maybe VName), Exp)
-> AppRes
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncApply Exp
f NonEmpty ((Diet, Maybe VName), Exp)
args AppRes
appres SrcLoc
loc = do
  (Exp
f', StaticVal
f_sv) <- Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction Exp
f (NonEmpty ((Diet, Maybe VName), Exp) -> Int
forall a. NonEmpty a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty ((Diet, Maybe VName), Exp)
args)
  case StaticVal
f_sv of
    StaticVal
IntrinsicSV -> do
      NonEmpty (Info (Diet, Maybe VName), Exp)
args' <- (((Diet, Maybe VName), Exp) -> (Info (Diet, Maybe VName), Exp))
-> NonEmpty ((Diet, Maybe VName), Exp)
-> NonEmpty (Info (Diet, Maybe VName), Exp)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Diet, Maybe VName) -> Info (Diet, Maybe VName))
-> ((Diet, Maybe VName), Exp) -> (Info (Diet, Maybe VName), Exp)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info) (NonEmpty ((Diet, Maybe VName), Exp)
 -> NonEmpty (Info (Diet, Maybe VName), Exp))
-> DefM (NonEmpty ((Diet, Maybe VName), Exp))
-> DefM (NonEmpty (Info (Diet, Maybe VName), Exp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (((Diet, Maybe VName), Exp) -> DefM ((Diet, Maybe VName), Exp))
-> NonEmpty ((Diet, Maybe VName), Exp)
-> DefM (NonEmpty ((Diet, Maybe VName), Exp))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> NonEmpty a -> f (NonEmpty b)
traverse ((Exp -> DefM Exp)
-> ((Diet, Maybe VName), Exp) -> DefM ((Diet, Maybe VName), Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b)
-> ((Diet, Maybe VName), a) -> f ((Diet, Maybe VName), b)
traverse Exp -> DefM Exp
defuncSoacExp) NonEmpty ((Diet, Maybe VName), Exp)
args
      let e' :: Exp
e' = AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> NonEmpty (Info (Diet, Maybe VName), Exp)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (f (Diet, Maybe VName), ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
f' NonEmpty (Info (Diet, Maybe VName), Exp)
args' SrcLoc
loc) (AppRes -> Info AppRes
forall a. a -> Info a
Info AppRes
appres)
      Exp -> DefM (Exp, StaticVal)
intrinsicOrHole Exp
e'
    HoleSV {} -> do
      NonEmpty (Info (Diet, Maybe VName), Exp)
args' <- (((Diet, Maybe VName), Exp) -> (Info (Diet, Maybe VName), Exp))
-> NonEmpty ((Diet, Maybe VName), Exp)
-> NonEmpty (Info (Diet, Maybe VName), Exp)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Diet, Maybe VName) -> Info (Diet, Maybe VName))
-> ((Diet, Maybe VName), Exp) -> (Info (Diet, Maybe VName), Exp)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info) (NonEmpty ((Diet, Maybe VName), Exp)
 -> NonEmpty (Info (Diet, Maybe VName), Exp))
-> DefM (NonEmpty ((Diet, Maybe VName), Exp))
-> DefM (NonEmpty (Info (Diet, Maybe VName), Exp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (((Diet, Maybe VName), Exp) -> DefM ((Diet, Maybe VName), Exp))
-> NonEmpty ((Diet, Maybe VName), Exp)
-> DefM (NonEmpty ((Diet, Maybe VName), Exp))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> NonEmpty a -> f (NonEmpty b)
traverse ((Exp -> DefM Exp)
-> ((Diet, Maybe VName), Exp) -> DefM ((Diet, Maybe VName), Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b)
-> ((Diet, Maybe VName), a) -> f ((Diet, Maybe VName), b)
traverse ((Exp -> DefM Exp)
 -> ((Diet, Maybe VName), Exp) -> DefM ((Diet, Maybe VName), Exp))
-> (Exp -> DefM Exp)
-> ((Diet, Maybe VName), Exp)
-> DefM ((Diet, Maybe VName), Exp)
forall a b. (a -> b) -> a -> b
$ ((Exp, StaticVal) -> Exp) -> DefM (Exp, StaticVal) -> DefM Exp
forall a b. (a -> b) -> DefM a -> DefM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Exp, StaticVal) -> Exp
forall a b. (a, b) -> a
fst (DefM (Exp, StaticVal) -> DefM Exp)
-> (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> DefM (Exp, StaticVal)
defuncExp) NonEmpty ((Diet, Maybe VName), Exp)
args
      let e' :: Exp
e' = AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> NonEmpty (Info (Diet, Maybe VName), Exp)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (f (Diet, Maybe VName), ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
f' NonEmpty (Info (Diet, Maybe VName), Exp)
args' SrcLoc
loc) (AppRes -> Info AppRes
forall a. a -> Info a
Info AppRes
appres)
      Exp -> DefM (Exp, StaticVal)
intrinsicOrHole Exp
e'
    StaticVal
_ -> do
      let fname :: [Char]
fname = Int -> Exp -> [Char]
liftedName Int
0 Exp
f
          ([ParamType]
argtypes, StructType
_) = StructType -> ([ParamType], StructType)
forall dim as.
TypeBase dim as -> ([TypeBase dim Diet], TypeBase dim NoUniqueness)
unfoldFunType (StructType -> ([ParamType], StructType))
-> StructType -> ([ParamType], StructType)
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
f
      ((Exp, StaticVal) -> (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> DefM a -> DefM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Exp -> Exp) -> (Exp, StaticVal) -> (Exp, StaticVal)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Exp -> Exp) -> (Exp, StaticVal) -> (Exp, StaticVal))
-> (Exp -> Exp) -> (Exp, StaticVal) -> (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ AppRes -> Exp -> Exp
updateReturn AppRes
appres) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
        ((Exp, StaticVal)
 -> (((Diet, Maybe VName), Exp), [ParamType])
 -> DefM (Exp, StaticVal))
-> (Exp, StaticVal)
-> NonEmpty (((Diet, Maybe VName), Exp), [ParamType])
-> DefM (Exp, StaticVal)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Char]
-> (Exp, StaticVal)
-> (((Diet, Maybe VName), Exp), [ParamType])
-> DefM (Exp, StaticVal)
defuncApplyArg [Char]
fname) (Exp
f', StaticVal
f_sv) (NonEmpty (((Diet, Maybe VName), Exp), [ParamType])
 -> DefM (Exp, StaticVal))
-> NonEmpty (((Diet, Maybe VName), Exp), [ParamType])
-> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
          NonEmpty ((Diet, Maybe VName), Exp)
-> NonEmpty [ParamType]
-> NonEmpty (((Diet, Maybe VName), Exp), [ParamType])
forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip NonEmpty ((Diet, Maybe VName), Exp)
args (NonEmpty [ParamType]
 -> NonEmpty (((Diet, Maybe VName), Exp), [ParamType]))
-> NonEmpty [ParamType]
-> NonEmpty (((Diet, Maybe VName), Exp), [ParamType])
forall a b. (a -> b) -> a -> b
$
            [ParamType] -> NonEmpty [ParamType]
forall (f :: * -> *) a. Foldable f => f a -> NonEmpty [a]
NE.tails [ParamType]
argtypes
  where
    intrinsicOrHole :: Exp -> DefM (Exp, StaticVal)
intrinsicOrHole Exp
e' = do
      -- If the intrinsic is fully applied, then we are done.
      -- Otherwise we need to eta-expand it and recursively
      -- defunctionalise. XXX: might it be better to simply eta-expand
      -- immediately any time we encounter a non-fully-applied
      -- intrinsic?
      if [ParamType] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([ParamType] -> Bool) -> [ParamType] -> Bool
forall a b. (a -> b) -> a -> b
$ ([ParamType], StructType) -> [ParamType]
forall a b. (a, b) -> a
fst (([ParamType], StructType) -> [ParamType])
-> ([ParamType], StructType) -> [ParamType]
forall a b. (a -> b) -> a -> b
$ StructType -> ([ParamType], StructType)
forall dim as.
TypeBase dim as -> ([TypeBase dim Diet], TypeBase dim NoUniqueness)
unfoldFunType (StructType -> ([ParamType], StructType))
-> StructType -> ([ParamType], StructType)
forall a b. (a -> b) -> a -> b
$ AppRes -> StructType
appResType AppRes
appres
        then (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e', ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe (StructType -> ParamType) -> StructType -> ParamType
forall a b. (a -> b) -> a -> b
$ AppRes -> StructType
appResType AppRes
appres)
        else do
          ([Pat ParamType]
pats, Exp
body, ResRetType
tp) <- ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand ([VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StructType -> TypeBase Exp Uniqueness
forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique (StructType -> TypeBase Exp Uniqueness)
-> StructType -> TypeBase Exp Uniqueness
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e') Exp
e'
          Exp -> DefM (Exp, StaticVal)
defuncExp (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Pat ParamType]
-> Exp
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda [Pat ParamType]
pats Exp
body Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing (ResRetType -> Info ResRetType
forall a. a -> Info a
Info ResRetType
tp) SrcLoc
forall a. Monoid a => a
mempty

-- | Check if a 'StaticVal' and a given application depth corresponds
-- to a fully applied dynamic function.
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) Int
depth
  | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bool
False
  | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = StaticVal -> Int -> Bool
fullyApplied StaticVal
sv (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
fullyApplied StaticVal
_ Int
_ = Bool
True

-- | Converts a dynamic function 'StaticVal' into a list of
-- dimensions, a list of parameters, a function body, and the
-- appropriate static value for applying the function at the given
-- depth of partial application.
liftDynFun :: String -> StaticVal -> Int -> ([Pat ParamType], Exp, StaticVal)
liftDynFun :: [Char] -> StaticVal -> Int -> ([Pat ParamType], Exp, StaticVal)
liftDynFun [Char]
_ (DynamicFun (Exp
e, StaticVal
sv) StaticVal
_) Int
0 = ([], Exp
e, StaticVal
sv)
liftDynFun [Char]
s (DynamicFun clsr :: (Exp, StaticVal)
clsr@(Exp
_, LambdaSV Pat ParamType
pat ResRetType
_ Exp
_ Env
_) StaticVal
sv) Int
d
  | Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
      let ([Pat ParamType]
pats, Exp
e', StaticVal
sv') = [Char] -> StaticVal -> Int -> ([Pat ParamType], Exp, StaticVal)
liftDynFun [Char]
s StaticVal
sv (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
       in (Pat ParamType
pat Pat ParamType -> [Pat ParamType] -> [Pat ParamType]
forall a. a -> [a] -> [a]
: [Pat ParamType]
pats, Exp
e', (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
clsr StaticVal
sv')
liftDynFun [Char]
s StaticVal
sv Int
d =
  [Char] -> ([Pat ParamType], Exp, StaticVal)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ([Pat ParamType], Exp, StaticVal))
-> [Char] -> ([Pat ParamType], Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
    [Char]
s
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" Tried to lift a StaticVal "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> ShowS
forall a. Int -> [a] -> [a]
take Int
100 (StaticVal -> [Char]
forall a. Show a => a -> [Char]
show StaticVal
sv)
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
", but expected a dynamic function.\n"
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Int
d

-- | Converts a pattern to an environment that binds the individual names of the
-- pattern to their corresponding types wrapped in a 'Dynamic' static value.
envFromPat :: Pat ParamType -> Env
envFromPat :: Pat ParamType -> Env
envFromPat Pat ParamType
pat = case Pat ParamType
pat of
  TuplePat [Pat ParamType]
ps SrcLoc
_ -> (Pat ParamType -> Env) -> [Pat ParamType] -> Env
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> Env
envFromPat [Pat ParamType]
ps
  RecordPat [(Name, Pat ParamType)]
fs SrcLoc
_ -> ((Name, Pat ParamType) -> Env) -> [(Name, Pat ParamType)] -> Env
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat ParamType -> Env
envFromPat (Pat ParamType -> Env)
-> ((Name, Pat ParamType) -> Pat ParamType)
-> (Name, Pat ParamType)
-> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Pat ParamType) -> Pat ParamType
forall a b. (a, b) -> b
snd) [(Name, Pat ParamType)]
fs
  PatParens Pat ParamType
p SrcLoc
_ -> Pat ParamType -> Env
envFromPat Pat ParamType
p
  PatAttr AttrInfo VName
_ Pat ParamType
p SrcLoc
_ -> Pat ParamType -> Env
envFromPat Pat ParamType
p
  Id VName
vn (Info ParamType
t) SrcLoc
_ -> VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (Binding -> Env) -> Binding -> Env
forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
Dynamic ParamType
t
  Wildcard Info ParamType
_ SrcLoc
_ -> Env
forall a. Monoid a => a
mempty
  PatAscription Pat ParamType
p TypeExp Info VName
_ SrcLoc
_ -> Pat ParamType -> Env
envFromPat Pat ParamType
p
  PatLit {} -> Env
forall a. Monoid a => a
mempty
  PatConstr Name
_ Info ParamType
_ [Pat ParamType]
ps SrcLoc
_ -> (Pat ParamType -> Env) -> [Pat ParamType] -> Env
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> Env
envFromPat [Pat ParamType]
ps

envFromDimNames :: [VName] -> Env
envFromDimNames :: [VName] -> Env
envFromDimNames = [(VName, Binding)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Binding)] -> Env)
-> ([VName] -> [(VName, Binding)]) -> [VName] -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> [Binding] -> [(VName, Binding)])
-> [Binding] -> [VName] -> [(VName, Binding)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [VName] -> [Binding] -> [(VName, Binding)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Binding -> [Binding]
forall a. a -> [a]
repeat Binding
d)
  where
    d :: Binding
d = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase Exp Diet -> ParamType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Diet -> ParamType)
-> ScalarTypeBase Exp Diet -> ParamType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase Exp Diet
forall dim u. PrimType -> ScalarTypeBase dim u
Prim (PrimType -> ScalarTypeBase Exp Diet)
-> PrimType -> ScalarTypeBase Exp Diet
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

-- | Given a closure environment, construct a record pattern that
-- binds the closed over variables.  Insert wildcard for any patterns
-- that would otherwise clash with size parameters.
buildEnvPat :: [VName] -> Env -> Pat ParamType
buildEnvPat :: [VName] -> Env -> Pat ParamType
buildEnvPat [VName]
sizes Env
env = [(Name, Pat ParamType)] -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t.
[(Name, PatBase f vn t)] -> SrcLoc -> PatBase f vn t
RecordPat (((VName, Binding) -> (Name, Pat ParamType))
-> [(VName, Binding)] -> [(Name, Pat ParamType)]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Binding) -> (Name, Pat ParamType)
buildField ([(VName, Binding)] -> [(Name, Pat ParamType)])
-> [(VName, Binding)] -> [(Name, Pat ParamType)]
forall a b. (a -> b) -> a -> b
$ Env -> [(VName, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList Env
env) SrcLoc
forall a. Monoid a => a
mempty
  where
    buildField :: (VName, Binding) -> (Name, Pat ParamType)
buildField (VName
vn, Binding Maybe ([VName], StructType)
_ StaticVal
sv) =
      ( [Char] -> Name
nameFromString (VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
vn),
        if VName
vn VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
sizes
          then Info ParamType -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t. f t -> SrcLoc -> PatBase f vn t
Wildcard (ParamType -> Info ParamType
forall a. a -> Info a
Info (ParamType -> Info ParamType) -> ParamType -> Info ParamType
forall a b. (a -> b) -> a -> b
$ StaticVal -> ParamType
paramTypeFromSV StaticVal
sv) SrcLoc
forall a. Monoid a => a
mempty
          else VName -> Info ParamType -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
vn (ParamType -> Info ParamType
forall a. a -> Info a
Info (ParamType -> Info ParamType) -> ParamType -> Info ParamType
forall a b. (a -> b) -> a -> b
$ StaticVal -> ParamType
paramTypeFromSV StaticVal
sv) SrcLoc
forall a. Monoid a => a
mempty
      )

-- | Compute the corresponding type for the *representation* of a
-- given static value (not the original possibly higher-order value).
typeFromSV :: StaticVal -> ParamType
typeFromSV :: StaticVal -> ParamType
typeFromSV (Dynamic ParamType
tp) =
  ParamType
tp
typeFromSV (LambdaSV Pat ParamType
_ ResRetType
_ Exp
_ Env
env) =
  ScalarTypeBase Exp Diet -> ParamType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Diet -> ParamType)
-> ([(Name, ParamType)] -> ScalarTypeBase Exp Diet)
-> [(Name, ParamType)]
-> ParamType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name ParamType -> ScalarTypeBase Exp Diet
forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record (Map Name ParamType -> ScalarTypeBase Exp Diet)
-> ([(Name, ParamType)] -> Map Name ParamType)
-> [(Name, ParamType)]
-> ScalarTypeBase Exp Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Name, ParamType)] -> Map Name ParamType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, ParamType)] -> ParamType)
-> [(Name, ParamType)] -> ParamType
forall a b. (a -> b) -> a -> b
$
    ((VName, Binding) -> (Name, ParamType))
-> [(VName, Binding)] -> [(Name, ParamType)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Name)
-> (Binding -> ParamType) -> (VName, Binding) -> (Name, ParamType)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ([Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString) (StaticVal -> ParamType
typeFromSV (StaticVal -> ParamType)
-> (Binding -> StaticVal) -> Binding -> ParamType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> StaticVal
bindingSV)) ([(VName, Binding)] -> [(Name, ParamType)])
-> [(VName, Binding)] -> [(Name, ParamType)]
forall a b. (a -> b) -> a -> b
$
      Env -> [(VName, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList Env
env
typeFromSV (RecordSV [(Name, StaticVal)]
ls) =
  let ts :: [(Name, ParamType)]
ts = ((Name, StaticVal) -> (Name, ParamType))
-> [(Name, StaticVal)] -> [(Name, ParamType)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> ParamType) -> (Name, StaticVal) -> (Name, ParamType)
forall a b. (a -> b) -> (Name, a) -> (Name, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StaticVal -> ParamType
typeFromSV) [(Name, StaticVal)]
ls
   in ScalarTypeBase Exp Diet -> ParamType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Diet -> ParamType)
-> ScalarTypeBase Exp Diet -> ParamType
forall a b. (a -> b) -> a -> b
$ Map Name ParamType -> ScalarTypeBase Exp Diet
forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record (Map Name ParamType -> ScalarTypeBase Exp Diet)
-> Map Name ParamType -> ScalarTypeBase Exp Diet
forall a b. (a -> b) -> a -> b
$ [(Name, ParamType)] -> Map Name ParamType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, ParamType)]
ts
typeFromSV (DynamicFun (Exp
_, StaticVal
sv) StaticVal
_) =
  StaticVal -> ParamType
typeFromSV StaticVal
sv
typeFromSV (SumSV Name
name [StaticVal]
svs [(Name, [ParamType])]
fields) =
  let svs' :: [ParamType]
svs' = (StaticVal -> ParamType) -> [StaticVal] -> [ParamType]
forall a b. (a -> b) -> [a] -> [b]
map StaticVal -> ParamType
typeFromSV [StaticVal]
svs
   in ScalarTypeBase Exp Diet -> ParamType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Diet -> ParamType)
-> ScalarTypeBase Exp Diet -> ParamType
forall a b. (a -> b) -> a -> b
$ Map Name [ParamType] -> ScalarTypeBase Exp Diet
forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum (Map Name [ParamType] -> ScalarTypeBase Exp Diet)
-> Map Name [ParamType] -> ScalarTypeBase Exp Diet
forall a b. (a -> b) -> a -> b
$ Name -> [ParamType] -> Map Name [ParamType] -> Map Name [ParamType]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name [ParamType]
svs' (Map Name [ParamType] -> Map Name [ParamType])
-> Map Name [ParamType] -> Map Name [ParamType]
forall a b. (a -> b) -> a -> b
$ [(Name, [ParamType])] -> Map Name [ParamType]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, [ParamType])]
fields
typeFromSV (HoleSV StructType
t SrcLoc
_) =
  Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
t
typeFromSV StaticVal
IntrinsicSV =
  [Char] -> ParamType
forall a. HasCallStack => [Char] -> a
error [Char]
"Tried to get the type from the static value of an intrinsic."

resTypeFromSV :: StaticVal -> ResType
resTypeFromSV :: StaticVal -> TypeBase Exp Uniqueness
resTypeFromSV = ParamType -> TypeBase Exp Uniqueness
paramToRes (ParamType -> TypeBase Exp Uniqueness)
-> (StaticVal -> ParamType) -> StaticVal -> TypeBase Exp Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StaticVal -> ParamType
typeFromSV

structTypeFromSV :: StaticVal -> StructType
structTypeFromSV :: StaticVal -> StructType
structTypeFromSV = ParamType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (ParamType -> StructType)
-> (StaticVal -> ParamType) -> StaticVal -> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StaticVal -> ParamType
typeFromSV

paramTypeFromSV :: StaticVal -> ParamType
paramTypeFromSV :: StaticVal -> ParamType
paramTypeFromSV = StaticVal -> ParamType
typeFromSV

-- | Construct the type for a fully-applied dynamic function from its
-- static value and the original types of its arguments.
dynamicFunType :: StaticVal -> [ParamType] -> ([ParamType], ResType)
dynamicFunType :: StaticVal -> [ParamType] -> ([ParamType], TypeBase Exp Uniqueness)
dynamicFunType (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) (ParamType
p : [ParamType]
ps) =
  let ([ParamType]
ps', TypeBase Exp Uniqueness
ret) = StaticVal -> [ParamType] -> ([ParamType], TypeBase Exp Uniqueness)
dynamicFunType StaticVal
sv [ParamType]
ps
   in (ParamType
p ParamType -> [ParamType] -> [ParamType]
forall a. a -> [a] -> [a]
: [ParamType]
ps', TypeBase Exp Uniqueness
ret)
dynamicFunType StaticVal
sv [ParamType]
_ = ([], StaticVal -> TypeBase Exp Uniqueness
resTypeFromSV StaticVal
sv)

-- | Match a pattern with its static value. Returns an environment
-- with the identifier components of the pattern mapped to the
-- corresponding subcomponents of the static value.  If this function
-- returns 'Nothing', then it corresponds to an unmatchable case.
-- These should only occur for 'Match' expressions.
matchPatSV :: Pat ParamType -> StaticVal -> Maybe Env
matchPatSV :: Pat ParamType -> StaticVal -> Maybe Env
matchPatSV (TuplePat [Pat ParamType]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls) =
  [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> Maybe [Env] -> Maybe Env
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat ParamType -> (Name, StaticVal) -> Maybe Env)
-> [Pat ParamType] -> [(Name, StaticVal)] -> Maybe [Env]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Pat ParamType
p (Name
_, StaticVal
sv) -> Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
p StaticVal
sv) [Pat ParamType]
ps [(Name, StaticVal)]
ls
matchPatSV (RecordPat [(Name, Pat ParamType)]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls)
  | [(Name, Pat ParamType)]
ps' <- ((Name, Pat ParamType) -> Name)
-> [(Name, Pat ParamType)] -> [(Name, Pat ParamType)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, Pat ParamType) -> Name
forall a b. (a, b) -> a
fst [(Name, Pat ParamType)]
ps,
    [(Name, StaticVal)]
ls' <- ((Name, StaticVal) -> Name)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls,
    ((Name, Pat ParamType) -> Name)
-> [(Name, Pat ParamType)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Pat ParamType) -> Name
forall a b. (a, b) -> a
fst [(Name, Pat ParamType)]
ps' [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== ((Name, StaticVal) -> Name) -> [(Name, StaticVal)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls' =
      [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> Maybe [Env] -> Maybe Env
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Pat ParamType) -> (Name, StaticVal) -> Maybe Env)
-> [(Name, Pat ParamType)] -> [(Name, StaticVal)] -> Maybe [Env]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\(Name
_, Pat ParamType
p) (Name
_, StaticVal
sv) -> Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
p StaticVal
sv) [(Name, Pat ParamType)]
ps' [(Name, StaticVal)]
ls'
matchPatSV (PatParens Pat ParamType
pat SrcLoc
_) StaticVal
sv = Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
pat StaticVal
sv
matchPatSV (PatAttr AttrInfo VName
_ Pat ParamType
pat SrcLoc
_) StaticVal
sv = Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
pat StaticVal
sv
matchPatSV (Id VName
vn (Info ParamType
t) SrcLoc
_) StaticVal
sv =
  -- When matching a zero-order pattern with a StaticVal, the type of
  -- the pattern wins out.  This is important for propagating sizes
  -- (but probably reveals a flaw in our bookkeeping).
  Env -> Maybe Env
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env -> Maybe Env) -> Env -> Maybe Env
forall a b. (a -> b) -> a -> b
$
    if ParamType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero ParamType
t
      then Env
dim_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
Dynamic ParamType
t)
      else Env
dim_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing StaticVal
sv)
  where
    dim_env :: Env
dim_env =
      [(VName, Binding)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Binding)] -> Env) -> [(VName, Binding)] -> Env
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, Binding)) -> [VName] -> [(VName, Binding)]
forall a b. (a -> b) -> [a] -> [b]
map (,Binding
i64) ([VName] -> [(VName, Binding)]) -> [VName] -> [(VName, Binding)]
forall a b. (a -> b) -> a -> b
$ Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ FV -> Set VName
fvVars (FV -> Set VName) -> FV -> Set VName
forall a b. (a -> b) -> a -> b
$ ParamType -> FV
forall u. TypeBase Exp u -> FV
freeInType ParamType
t
    i64 :: Binding
i64 = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
Dynamic (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase Exp Diet -> ParamType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Diet -> ParamType)
-> ScalarTypeBase Exp Diet -> ParamType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase Exp Diet
forall dim u. PrimType -> ScalarTypeBase dim u
Prim (PrimType -> ScalarTypeBase Exp Diet)
-> PrimType -> ScalarTypeBase Exp Diet
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64
matchPatSV (Wildcard Info ParamType
_ SrcLoc
_) StaticVal
_ = Env -> Maybe Env
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Env
forall a. Monoid a => a
mempty
matchPatSV (PatAscription Pat ParamType
pat TypeExp Info VName
_ SrcLoc
_) StaticVal
sv = Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
pat StaticVal
sv
matchPatSV PatLit {} StaticVal
_ = Env -> Maybe Env
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Env
forall a. Monoid a => a
mempty
matchPatSV (PatConstr Name
c1 Info ParamType
_ [Pat ParamType]
ps SrcLoc
_) (SumSV Name
c2 [StaticVal]
ls [(Name, [ParamType])]
fs)
  | Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2 =
      [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> Maybe [Env] -> Maybe Env
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat ParamType -> StaticVal -> Maybe Env)
-> [Pat ParamType] -> [StaticVal] -> Maybe [Env]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Pat ParamType -> StaticVal -> Maybe Env
matchPatSV [Pat ParamType]
ps [StaticVal]
ls
  | Just [ParamType]
_ <- Name -> [(Name, [ParamType])] -> Maybe [ParamType]
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
c1 [(Name, [ParamType])]
fs =
      Maybe Env
forall a. Maybe a
Nothing
  | Bool
otherwise =
      [Char] -> Maybe Env
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe Env) -> [Char] -> Maybe Env
forall a b. (a -> b) -> a -> b
$ [Char]
"matchPatSV: missing constructor in type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Name
c1
matchPatSV (PatConstr Name
c1 Info ParamType
_ [Pat ParamType]
ps SrcLoc
_) (Dynamic (Scalar (Sum Map Name [ParamType]
fs)))
  | Just [ParamType]
ts <- Name -> Map Name [ParamType] -> Maybe [ParamType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c1 Map Name [ParamType]
fs =
      -- A higher-order pattern can only match an appropriate SumSV.
      if (ParamType -> Bool) -> [ParamType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ParamType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero [ParamType]
ts
        then [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> Maybe [Env] -> Maybe Env
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat ParamType -> StaticVal -> Maybe Env)
-> [Pat ParamType] -> [StaticVal] -> Maybe [Env]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Pat ParamType -> StaticVal -> Maybe Env
matchPatSV [Pat ParamType]
ps ((ParamType -> StaticVal) -> [ParamType] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map ParamType -> StaticVal
svFromType [ParamType]
ts)
        else Maybe Env
forall a. Maybe a
Nothing
  | Bool
otherwise =
      [Char] -> Maybe Env
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe Env) -> [Char] -> Maybe Env
forall a b. (a -> b) -> a -> b
$ [Char]
"matchPatSV: missing constructor in type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Name
c1
matchPatSV Pat ParamType
pat (Dynamic ParamType
t) = Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
pat (StaticVal -> Maybe Env) -> StaticVal -> Maybe Env
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
svFromType ParamType
t
matchPatSV Pat ParamType
pat (HoleSV StructType
t SrcLoc
_) = Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
pat (StaticVal -> Maybe Env) -> StaticVal -> Maybe Env
forall a b. (a -> b) -> a -> b
$ ParamType -> StaticVal
svFromType (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
t
matchPatSV Pat ParamType
pat StaticVal
sv =
  [Char] -> Maybe Env
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe Env) -> [Char] -> Maybe Env
forall a b. (a -> b) -> a -> b
$
    [Char]
"Tried to match pattern\n"
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat ParamType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat ParamType
pat
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\n with static value\n"
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> [Char]
forall a. Show a => a -> [Char]
show StaticVal
sv

alwaysMatchPatSV :: Pat ParamType -> StaticVal -> Env
alwaysMatchPatSV :: Pat ParamType -> StaticVal -> Env
alwaysMatchPatSV Pat ParamType
pat StaticVal
sv = Env -> Maybe Env -> Env
forall a. a -> Maybe a -> a
fromMaybe Env
forall {a}. a
bad (Maybe Env -> Env) -> Maybe Env -> Env
forall a b. (a -> b) -> a -> b
$ Pat ParamType -> StaticVal -> Maybe Env
matchPatSV Pat ParamType
pat StaticVal
sv
  where
    bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines [Pat ParamType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat ParamType
pat, [Char]
"cannot match StaticVal", StaticVal -> [Char]
forall a. Show a => a -> [Char]
show StaticVal
sv]

-- | Given a pattern and the static value for the defunctionalized argument,
-- update the pattern to reflect the changes in the types.
updatePat :: Pat ParamType -> StaticVal -> Pat ParamType
updatePat :: Pat ParamType -> StaticVal -> Pat ParamType
updatePat (TuplePat [Pat ParamType]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs) =
  [Pat ParamType] -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t.
[PatBase f vn t] -> SrcLoc -> PatBase f vn t
TuplePat ((Pat ParamType -> StaticVal -> Pat ParamType)
-> [Pat ParamType] -> [StaticVal] -> [Pat ParamType]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat ParamType -> StaticVal -> Pat ParamType
updatePat [Pat ParamType]
ps ([StaticVal] -> [Pat ParamType]) -> [StaticVal] -> [Pat ParamType]
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> StaticVal)
-> [(Name, StaticVal)] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map (Name, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd [(Name, StaticVal)]
svs) SrcLoc
loc
updatePat (RecordPat [(Name, Pat ParamType)]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs)
  | [(Name, Pat ParamType)]
ps' <- ((Name, Pat ParamType) -> Name)
-> [(Name, Pat ParamType)] -> [(Name, Pat ParamType)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, Pat ParamType) -> Name
forall a b. (a, b) -> a
fst [(Name, Pat ParamType)]
ps,
    [(Name, StaticVal)]
svs' <- ((Name, StaticVal) -> Name)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
svs =
      [(Name, Pat ParamType)] -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t.
[(Name, PatBase f vn t)] -> SrcLoc -> PatBase f vn t
RecordPat
        (((Name, Pat ParamType)
 -> (Name, StaticVal) -> (Name, Pat ParamType))
-> [(Name, Pat ParamType)]
-> [(Name, StaticVal)]
-> [(Name, Pat ParamType)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(Name
n, Pat ParamType
p) (Name
_, StaticVal
sv) -> (Name
n, Pat ParamType -> StaticVal -> Pat ParamType
updatePat Pat ParamType
p StaticVal
sv)) [(Name, Pat ParamType)]
ps' [(Name, StaticVal)]
svs')
        SrcLoc
loc
updatePat (PatParens Pat ParamType
pat SrcLoc
loc) StaticVal
sv =
  Pat ParamType -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t.
PatBase f vn t -> SrcLoc -> PatBase f vn t
PatParens (Pat ParamType -> StaticVal -> Pat ParamType
updatePat Pat ParamType
pat StaticVal
sv) SrcLoc
loc
updatePat (PatAttr AttrInfo VName
attr Pat ParamType
pat SrcLoc
loc) StaticVal
sv =
  AttrInfo VName -> Pat ParamType -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t.
AttrInfo vn -> PatBase f vn t -> SrcLoc -> PatBase f vn t
PatAttr AttrInfo VName
attr (Pat ParamType -> StaticVal -> Pat ParamType
updatePat Pat ParamType
pat StaticVal
sv) SrcLoc
loc
updatePat (Id VName
vn (Info ParamType
tp) SrcLoc
loc) StaticVal
sv =
  VName -> Info ParamType -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
vn (ParamType -> Info ParamType
forall a. a -> Info a
Info (ParamType -> Info ParamType) -> ParamType -> Info ParamType
forall a b. (a -> b) -> a -> b
$ ParamType -> ParamType -> ParamType
forall {dim} {u}.
TypeBase dim u -> TypeBase dim u -> TypeBase dim u
comb ParamType
tp (ParamType -> ParamType) -> ParamType -> ParamType
forall a b. (a -> b) -> a -> b
$ StaticVal -> ParamType
paramTypeFromSV StaticVal
sv) SrcLoc
loc
  where
    -- Preserve any original zeroth-order types.
    comb :: TypeBase dim u -> TypeBase dim u -> TypeBase dim u
comb (Scalar Arrow {}) TypeBase dim u
t2 = TypeBase dim u
t2
    comb (Scalar (Record Map Name (TypeBase dim u)
m1)) (Scalar (Record Map Name (TypeBase dim u)
m2)) =
      ScalarTypeBase dim u -> TypeBase dim u
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim u -> TypeBase dim u)
-> ScalarTypeBase dim u -> TypeBase dim u
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase dim u) -> ScalarTypeBase dim u
forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record (Map Name (TypeBase dim u) -> ScalarTypeBase dim u)
-> Map Name (TypeBase dim u) -> ScalarTypeBase dim u
forall a b. (a -> b) -> a -> b
$ (TypeBase dim u -> TypeBase dim u -> TypeBase dim u)
-> Map Name (TypeBase dim u)
-> Map Name (TypeBase dim u)
-> Map Name (TypeBase dim u)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase dim u -> TypeBase dim u -> TypeBase dim u
comb Map Name (TypeBase dim u)
m1 Map Name (TypeBase dim u)
m2
    comb (Scalar (Sum Map Name [TypeBase dim u]
m1)) (Scalar (Sum Map Name [TypeBase dim u]
m2)) =
      ScalarTypeBase dim u -> TypeBase dim u
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim u -> TypeBase dim u)
-> ScalarTypeBase dim u -> TypeBase dim u
forall a b. (a -> b) -> a -> b
$ Map Name [TypeBase dim u] -> ScalarTypeBase dim u
forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum (Map Name [TypeBase dim u] -> ScalarTypeBase dim u)
-> Map Name [TypeBase dim u] -> ScalarTypeBase dim u
forall a b. (a -> b) -> a -> b
$ ([TypeBase dim u] -> [TypeBase dim u] -> [TypeBase dim u])
-> Map Name [TypeBase dim u]
-> Map Name [TypeBase dim u]
-> Map Name [TypeBase dim u]
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith ((TypeBase dim u -> TypeBase dim u -> TypeBase dim u)
-> [TypeBase dim u] -> [TypeBase dim u] -> [TypeBase dim u]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase dim u -> TypeBase dim u -> TypeBase dim u
comb) Map Name [TypeBase dim u]
m1 Map Name [TypeBase dim u]
m2
    comb TypeBase dim u
t1 TypeBase dim u
_ = TypeBase dim u
t1 -- t1 must be array or prim.
updatePat pat :: Pat ParamType
pat@(Wildcard (Info ParamType
tp) SrcLoc
loc) StaticVal
sv
  | ParamType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero ParamType
tp = Pat ParamType
pat
  | Bool
otherwise = Info ParamType -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t. f t -> SrcLoc -> PatBase f vn t
Wildcard (ParamType -> Info ParamType
forall a. a -> Info a
Info (ParamType -> Info ParamType) -> ParamType -> Info ParamType
forall a b. (a -> b) -> a -> b
$ StaticVal -> ParamType
paramTypeFromSV StaticVal
sv) SrcLoc
loc
updatePat (PatAscription Pat ParamType
pat TypeExp Info VName
_ SrcLoc
_) StaticVal
sv =
  Pat ParamType -> StaticVal -> Pat ParamType
updatePat Pat ParamType
pat StaticVal
sv
updatePat p :: Pat ParamType
p@PatLit {} StaticVal
_ = Pat ParamType
p
updatePat pat :: Pat ParamType
pat@(PatConstr Name
c1 (Info ParamType
t) [Pat ParamType]
ps SrcLoc
loc) sv :: StaticVal
sv@(SumSV Name
_ [StaticVal]
svs [(Name, [ParamType])]
_)
  | ParamType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero ParamType
t = Pat ParamType
pat
  | Bool
otherwise = Name
-> Info ParamType -> [Pat ParamType] -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t.
Name -> f t -> [PatBase f vn t] -> SrcLoc -> PatBase f vn t
PatConstr Name
c1 (ParamType -> Info ParamType
forall a. a -> Info a
Info (ParamType -> Info ParamType) -> ParamType -> Info ParamType
forall a b. (a -> b) -> a -> b
$ Diet -> TypeBase Exp Uniqueness -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe TypeBase Exp Uniqueness
t') [Pat ParamType]
ps' SrcLoc
loc
  where
    t' :: TypeBase Exp Uniqueness
t' = StaticVal -> TypeBase Exp Uniqueness
resTypeFromSV StaticVal
sv
    ps' :: [Pat ParamType]
ps' = (Pat ParamType -> StaticVal -> Pat ParamType)
-> [Pat ParamType] -> [StaticVal] -> [Pat ParamType]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat ParamType -> StaticVal -> Pat ParamType
updatePat [Pat ParamType]
ps [StaticVal]
svs
updatePat (PatConstr Name
c1 Info ParamType
_ [Pat ParamType]
ps SrcLoc
loc) (Dynamic ParamType
t) =
  Name
-> Info ParamType -> [Pat ParamType] -> SrcLoc -> Pat ParamType
forall (f :: * -> *) vn t.
Name -> f t -> [PatBase f vn t] -> SrcLoc -> PatBase f vn t
PatConstr Name
c1 (ParamType -> Info ParamType
forall a. a -> Info a
Info (ParamType -> Info ParamType) -> ParamType -> Info ParamType
forall a b. (a -> b) -> a -> b
$ Diet -> ParamType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe ParamType
t) [Pat ParamType]
ps SrcLoc
loc
updatePat Pat ParamType
pat (Dynamic ParamType
t) = Pat ParamType -> StaticVal -> Pat ParamType
updatePat Pat ParamType
pat (ParamType -> StaticVal
svFromType ParamType
t)
updatePat Pat ParamType
pat (HoleSV StructType
t SrcLoc
_) = Pat ParamType -> StaticVal -> Pat ParamType
updatePat Pat ParamType
pat (ParamType -> StaticVal
svFromType (ParamType -> StaticVal) -> ParamType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
t)
updatePat Pat ParamType
pat StaticVal
sv =
  [Char] -> Pat ParamType
forall a. HasCallStack => [Char] -> a
error ([Char] -> Pat ParamType) -> [Char] -> Pat ParamType
forall a b. (a -> b) -> a -> b
$
    [Char]
"Tried to update pattern\n"
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat ParamType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat ParamType
pat
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\nto reflect the static value\n"
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> [Char]
forall a. Show a => a -> [Char]
show StaticVal
sv

-- | Convert a record (or tuple) type to a record static value. This
-- is used for "unwrapping" tuples and records that are nested in
-- 'Dynamic' static values.
svFromType :: ParamType -> StaticVal
svFromType :: ParamType -> StaticVal
svFromType (Scalar (Record Map Name ParamType
fs)) = [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> (Map Name StaticVal -> [(Name, StaticVal)])
-> Map Name StaticVal
-> StaticVal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name StaticVal -> [(Name, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name StaticVal -> StaticVal)
-> Map Name StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$ (ParamType -> StaticVal)
-> Map Name ParamType -> Map Name StaticVal
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ParamType -> StaticVal
svFromType Map Name ParamType
fs
svFromType ParamType
t = ParamType -> StaticVal
Dynamic ParamType
t

-- | Defunctionalize a top-level value binding. Returns the
-- transformed result as well as an environment that binds the name of
-- the value binding to the static value of the transformed body.  The
-- boolean is true if the function is a 'DynamicFun'.
defuncValBind :: ValBind -> DefM (ValBind, Env)
-- Eta-expand entry points with a functional return type.
defuncValBind :: ValBind -> DefM (ValBind, Env)
defuncValBind (ValBind Maybe (Info EntryPoint)
entry VName
name Maybe (TypeExp Info VName)
_ (Info ResRetType
rettype) [TypeParamBase VName]
tparams [Pat ParamType]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
attrs SrcLoc
loc)
  | Scalar Arrow {} <- ResRetType -> TypeBase Exp Uniqueness
forall dim as. RetTypeBase dim as -> TypeBase dim as
retType ResRetType
rettype = do
      ([Pat ParamType]
body_pats, Exp
body', ResRetType
rettype') <- ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType)
etaExpand ((Uniqueness -> Uniqueness) -> ResRetType -> ResRetType
forall b c a. (b -> c) -> RetTypeBase a b -> RetTypeBase a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Uniqueness -> Uniqueness -> Uniqueness
forall a b. a -> b -> a
const Uniqueness
forall a. Monoid a => a
mempty) ResRetType
rettype) Exp
body
      ValBind -> DefM (ValBind, Env)
defuncValBind (ValBind -> DefM (ValBind, Env)) -> ValBind -> DefM (ValBind, Env)
forall a b. (a -> b) -> a -> b
$
        Maybe (Info EntryPoint)
-> VName
-> Maybe (TypeExp Info VName)
-> Info ResRetType
-> [TypeParamBase VName]
-> [Pat ParamType]
-> Exp
-> Maybe DocComment
-> [AttrInfo VName]
-> SrcLoc
-> ValBind
forall (f :: * -> *) vn.
Maybe (f EntryPoint)
-> vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> [TypeParamBase vn]
-> [PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe DocComment
-> [AttrInfo vn]
-> SrcLoc
-> ValBindBase f vn
ValBind
          Maybe (Info EntryPoint)
entry
          VName
name
          Maybe (TypeExp Info VName)
forall a. Maybe a
Nothing
          (ResRetType -> Info ResRetType
forall a. a -> Info a
Info ResRetType
rettype')
          [TypeParamBase VName]
tparams
          ([Pat ParamType]
params [Pat ParamType] -> [Pat ParamType] -> [Pat ParamType]
forall a. Semigroup a => a -> a -> a
<> [Pat ParamType]
body_pats)
          Exp
body'
          Maybe DocComment
forall a. Maybe a
Nothing
          [AttrInfo VName]
attrs
          SrcLoc
loc
defuncValBind valbind :: ValBind
valbind@(ValBind Maybe (Info EntryPoint)
_ VName
name Maybe (TypeExp Info VName)
retdecl (Info (RetType [VName]
ret_dims TypeBase Exp Uniqueness
rettype)) [TypeParamBase VName]
tparams [Pat ParamType]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
_ SrcLoc
_) = do
  Bool -> DefM () -> DefM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((TypeParamBase VName -> Bool) -> [TypeParamBase VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any TypeParamBase VName -> Bool
forall vn. TypeParamBase vn -> Bool
isTypeParam [TypeParamBase VName]
tparams) (DefM () -> DefM ()) -> DefM () -> DefM ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> DefM ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> DefM ()) -> [Char] -> DefM ()
forall a b. (a -> b) -> a -> b
$
      VName -> [Char]
forall a. Show a => a -> [Char]
show VName
name
        [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" has type parameters, "
        [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"but the defunctionaliser expects a monomorphic input program."
  ([VName]
tparams', [Pat ParamType]
params', Exp
body', StaticVal
sv, TypeBase Exp Uniqueness
sv_t) <-
    [VName]
-> [Pat ParamType]
-> Exp
-> ResRetType
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
defuncLet ((TypeParamBase VName -> VName) -> [TypeParamBase VName] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParamBase VName -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParamBase VName]
tparams) [Pat ParamType]
params Exp
body (ResRetType
 -> DefM
      ([VName], [Pat ParamType], Exp, StaticVal,
       TypeBase Exp Uniqueness))
-> ResRetType
-> DefM
     ([VName], [Pat ParamType], Exp, StaticVal, TypeBase Exp Uniqueness)
forall a b. (a -> b) -> a -> b
$ [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims TypeBase Exp Uniqueness
rettype
  Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst
  let bound_sizes :: Set VName
bound_sizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ((Pat ParamType -> [VName]) -> [Pat ParamType] -> [VName]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat ParamType -> [VName]
forall t. Pat t -> [VName]
patNames [Pat ParamType]
params') Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
tparams' Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
globals
  [Pat ParamType]
params'' <- [Pat ParamType] -> DefM [Pat ParamType]
forall (m :: * -> *).
MonadFreshNames m =>
[Pat ParamType] -> m [Pat ParamType]
instAnySizes [Pat ParamType]
params'
  let rettype' :: TypeBase Exp Uniqueness
rettype' = TypeBase Exp Uniqueness
-> TypeBase Exp Uniqueness -> TypeBase Exp Uniqueness
forall as.
Monoid as =>
TypeBase Exp as -> TypeBase Exp as -> TypeBase Exp as
combineTypeShapes TypeBase Exp Uniqueness
rettype TypeBase Exp Uniqueness
sv_t
      tparams'' :: [VName]
tparams'' = [VName]
tparams' [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Set VName -> [Pat ParamType] -> [VName]
unboundSizes Set VName
bound_sizes [Pat ParamType]
params''
      ret_dims' :: [VName]
ret_dims' = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Set VName -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Set VName
bound_sizes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ FV -> Set VName
fvVars (FV -> Set VName) -> FV -> Set VName
forall a b. (a -> b) -> a -> b
$ TypeBase Exp Uniqueness -> FV
forall u. TypeBase Exp u -> FV
freeInType TypeBase Exp Uniqueness
rettype'

  (ValBind, Env) -> DefM (ValBind, Env)
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( ValBind
valbind
        { valBindRetDecl :: Maybe (TypeExp Info VName)
valBindRetDecl = Maybe (TypeExp Info VName)
retdecl,
          valBindRetType :: Info ResRetType
valBindRetType =
            ResRetType -> Info ResRetType
forall a. a -> Info a
Info (ResRetType -> Info ResRetType) -> ResRetType -> Info ResRetType
forall a b. (a -> b) -> a -> b
$
              if [Pat ParamType] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Pat ParamType]
params'
                then [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims' (TypeBase Exp Uniqueness -> ResRetType)
-> TypeBase Exp Uniqueness -> ResRetType
forall a b. (a -> b) -> a -> b
$ TypeBase Exp Uniqueness
rettype' TypeBase Exp Uniqueness -> Uniqueness -> TypeBase Exp Uniqueness
forall dim u1 u2. TypeBase dim u1 -> u2 -> TypeBase dim u2
`setUniqueness` Uniqueness
Nonunique
                else [VName] -> TypeBase Exp Uniqueness -> ResRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims' TypeBase Exp Uniqueness
rettype',
          valBindTypeParams :: [TypeParamBase VName]
valBindTypeParams = (VName -> TypeParamBase VName) -> [VName] -> [TypeParamBase VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SrcLoc -> TypeParamBase VName
forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` SrcLoc
forall a. Monoid a => a
mempty) [VName]
tparams'',
          valBindParams :: [Pat ParamType]
valBindParams = [Pat ParamType]
params'',
          valBindBody :: Exp
valBindBody = Exp
body'
        },
      VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
name (Binding -> Env) -> Binding -> Env
forall a b. (a -> b) -> a -> b
$
        Maybe ([VName], StructType) -> StaticVal -> Binding
Binding
          (([VName], StructType) -> Maybe ([VName], StructType)
forall a. a -> Maybe a
Just (([TypeParamBase VName] -> [VName])
-> ([TypeParamBase VName], StructType) -> ([VName], StructType)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((TypeParamBase VName -> VName) -> [TypeParamBase VName] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParamBase VName -> VName
forall vn. TypeParamBase vn -> vn
typeParamName) (ValBind -> ([TypeParamBase VName], StructType)
valBindTypeScheme ValBind
valbind)))
          StaticVal
sv
    )

-- | Defunctionalize a list of top-level declarations.
defuncVals :: [ValBind] -> DefM ()
defuncVals :: [ValBind] -> DefM ()
defuncVals [] = () -> DefM ()
forall a. a -> DefM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defuncVals (ValBind
valbind : [ValBind]
ds) = do
  (ValBind
valbind', Env
env) <- ValBind -> DefM (ValBind, Env)
defuncValBind ValBind
valbind
  ValBind -> DefM ()
addValBind ValBind
valbind'
  let globals :: [VName]
globals = ValBind -> [VName]
valBindBound ValBind
valbind'
  Env -> DefM () -> DefM ()
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM () -> DefM ()) -> DefM () -> DefM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> DefM () -> DefM ()
forall a. [VName] -> DefM a -> DefM a
areGlobal [VName]
globals (DefM () -> DefM ()) -> DefM () -> DefM ()
forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM ()
defuncVals [ValBind]
ds

{-# NOINLINE transformProg #-}

-- | Transform a list of top-level value bindings. May produce new
-- lifted function definitions, which are placed in front of the
-- resulting list of declarations.
transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg :: forall (m :: * -> *). MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg [ValBind]
decs = (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ([ValBind], VNameSource)) -> m [ValBind])
-> (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
  let ((), VNameSource
namesrc', [ValBind]
decs') = VNameSource -> DefM () -> ((), VNameSource, [ValBind])
forall a. VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM VNameSource
namesrc (DefM () -> ((), VNameSource, [ValBind]))
-> DefM () -> ((), VNameSource, [ValBind])
forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM ()
defuncVals [ValBind]
decs
   in ([ValBind]
decs', VNameSource
namesrc')