-- | This monomorphization module converts a well-typed, polymorphic,
-- module-free Futhark program into an equivalent monomorphic program.
--
-- This pass also does a few other simplifications to make the job of
-- subsequent passes easier.  Specifically, it does the following:
--
-- * Turn operator sections into explicit lambdas.
--
-- * Converts applications of intrinsic SOACs into SOAC AST nodes
--   (Map, Reduce, etc).
--
-- * Elide functions that are not reachable from an entry point (this
--   is a side effect of the monomorphisation algorithm, which uses
--   the entry points as roots).
--
-- * Turns implicit record fields into explicit record fields.
--
-- * Rewrite BinOp nodes to Apply nodes.
--
-- * Replace all size expressions by constants or variables,
--   complex expressions replaced by variables are calculated in
--   let binding or replaced by size parameters if in argument.
--
-- Note that these changes are unfortunately not visible in the AST
-- representation.
module Futhark.Internalise.Monomorphise (transformProg) where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.RWS (MonadReader (..), MonadWriter (..), RWST, asks, runRWST)
import Control.Monad.State
import Control.Monad.Writer (Writer, runWriter, runWriterT)
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable
import Data.List (partition)
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence qualified as Seq
import Data.Set qualified as S
import Futhark.MonadFreshNames
import Futhark.Util (nubOrd, topologicalSort)
import Futhark.Util.Pretty
import Language.Futhark
import Language.Futhark.Semantic (TypeBinding (..))
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Types

i64 :: TypeBase dim als
i64 :: forall dim als. TypeBase dim als
i64 = forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall a b. (a -> b) -> a -> b
$ forall dim u. PrimType -> ScalarTypeBase dim u
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

-- The monomorphization monad reads 'PolyBinding's and writes
-- 'ValBind's.  The 'TypeParam's in the 'ValBind's can only be size
-- parameters.
newtype PolyBinding
  = PolyBinding
      ( VName,
        [TypeParam],
        [Pat ParamType],
        ResRetType,
        Exp,
        [AttrInfo VName],
        SrcLoc
      )

-- | To deduplicate size expressions, we want a looser notation of
-- equality than the strict syntactical equality provided by the Eq
-- instance on Exp.  This newtype wrapper provides such a looser
-- notion of equality.
newtype ReplacedExp = ReplacedExp {ReplacedExp -> Exp
unReplaced :: Exp}
  deriving (Int -> ReplacedExp -> ShowS
[ReplacedExp] -> ShowS
ReplacedExp -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ReplacedExp] -> ShowS
$cshowList :: [ReplacedExp] -> ShowS
show :: ReplacedExp -> [Char]
$cshow :: ReplacedExp -> [Char]
showsPrec :: Int -> ReplacedExp -> ShowS
$cshowsPrec :: Int -> ReplacedExp -> ShowS
Show)

instance Pretty ReplacedExp where
  pretty :: forall ann. ReplacedExp -> Doc ann
pretty (ReplacedExp Exp
e) = forall a ann. Pretty a => a -> Doc ann
pretty Exp
e

instance Eq ReplacedExp where
  ReplacedExp Exp
e1 == :: ReplacedExp -> ReplacedExp -> Bool
== ReplacedExp Exp
e2
    | Just [(Exp, Exp)]
es <- Exp -> Exp -> Maybe [(Exp, Exp)]
similarExps Exp
e1 Exp
e2 =
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Exp -> ReplacedExp
ReplacedExp Exp -> ReplacedExp
ReplacedExp) [(Exp, Exp)]
es
  ReplacedExp
_ == ReplacedExp
_ = Bool
False

type ExpReplacements = [(ReplacedExp, VName)]

canCalculate :: S.Set VName -> ExpReplacements -> ExpReplacements
canCalculate :: Set VName -> ExpReplacements -> ExpReplacements
canCalculate Set VName
scope ExpReplacements
mapping = do
  forall a. (a -> Bool) -> [a] -> [a]
filter
    ( (forall a. Ord a => Set a -> Set a -> Bool
`S.isSubsetOf` Set VName
scope)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> Set a -> Set a
S.filter VName -> Bool
notIntrisic
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. FV -> Set VName
fvVars
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> FV
freeInExp
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReplacedExp -> Exp
unReplaced
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
    )
    ExpReplacements
mapping
  where
    notIntrisic :: VName -> Bool
notIntrisic VName
vn = VName -> Int
baseTag VName
vn forall a. Ord a => a -> a -> Bool
> Int
maxIntrinsicTag

-- Replace some expressions by a parameter.
expReplace :: ExpReplacements -> Exp -> Exp
expReplace :: ExpReplacements -> Exp -> Exp
expReplace ExpReplacements
mapping Exp
e
  | Just VName
vn <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (Exp -> ReplacedExp
ReplacedExp Exp
e) ExpReplacements
mapping =
      forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
vn) (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e) (forall a. Located a => a -> SrcLoc
srclocOf Exp
e)
expReplace ExpReplacements
mapping Exp
e = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper Identity
mapper Exp
e
  where
    mapper :: ASTMapper Identity
mapper = forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> Identity Exp
mapOnExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpReplacements -> Exp -> Exp
expReplace ExpReplacements
mapping}

-- Construct an Assert expression that checks that the names (values)
-- in the mapping have the same value as the expression they
-- represent.  This is injected into entry points, where we cannot
-- otherwise trust the input.  XXX: the error message generated from
-- this is not great; we should rework it eventually.
entryAssert :: ExpReplacements -> Exp -> Exp
entryAssert :: ExpReplacements -> Exp -> Exp
entryAssert [] Exp
body = Exp
body
entryAssert ((ReplacedExp, VName)
x : ExpReplacements
xs) Exp
body =
  forall (f :: * -> *) vn.
ExpBase f vn -> ExpBase f vn -> f Text -> SrcLoc -> ExpBase f vn
Assert (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
logAnd ((ReplacedExp, VName) -> Exp
cmpExp (ReplacedExp, VName)
x) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (ReplacedExp, VName) -> Exp
cmpExp ExpReplacements
xs) Exp
body Info Text
errmsg (forall a. Located a => a -> SrcLoc
srclocOf Exp
body)
  where
    errmsg :: Info Text
errmsg = forall a. a -> Info a
Info Text
"entry point arguments have invalid sizes."
    bool :: TypeBase dim u
bool = forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall a b. (a -> b) -> a -> b
$ forall dim u. PrimType -> ScalarTypeBase dim u
Prim PrimType
Bool
    opt :: StructType
opt = [ParamType] -> ResRetType -> StructType
foldFunType [forall dim als. TypeBase dim als
bool, forall dim als. TypeBase dim als
bool] forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] forall dim als. TypeBase dim als
bool
    andop :: Exp
andop = forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName (Name -> VName
intrinsicVar Name
"&&")) (forall a. a -> Info a
Info StructType
opt) forall a. Monoid a => a
mempty
    eqop :: Exp
eqop = forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName (Name -> VName
intrinsicVar Name
"==")) (forall a. a -> Info a
Info StructType
opt) forall a. Monoid a => a
mempty
    logAnd :: Exp -> Exp -> Exp
logAnd Exp
x' Exp
y =
      forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
andop [(Diet
Observe, forall a. Maybe a
Nothing, Exp
x'), (Diet
Observe, forall a. Maybe a
Nothing, Exp
y)] forall a b. (a -> b) -> a -> b
$
        StructType -> [VName] -> AppRes
AppRes forall dim als. TypeBase dim als
bool []
    cmpExp :: (ReplacedExp, VName) -> Exp
cmpExp (ReplacedExp Exp
x', VName
y) =
      forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
eqop [(Diet
Observe, forall a. Maybe a
Nothing, Exp
x'), (Diet
Observe, forall a. Maybe a
Nothing, Exp
y')] forall a b. (a -> b) -> a -> b
$
        StructType -> [VName] -> AppRes
AppRes forall dim als. TypeBase dim als
bool []
      where
        y' :: Exp
y' = forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
y) (forall a. a -> Info a
Info forall dim als. TypeBase dim als
i64) forall a. Monoid a => a
mempty

-- Monomorphization environment mapping names of polymorphic functions
-- to a representation of their corresponding function bindings.
data Env = Env
  { Env -> Map VName PolyBinding
envPolyBindings :: M.Map VName PolyBinding,
    Env -> Map VName TypeBinding
envTypeBindings :: M.Map VName TypeBinding,
    Env -> Set VName
envScope :: S.Set VName,
    Env -> Set VName
envGlobalScope :: S.Set VName,
    Env -> ExpReplacements
envParametrized :: ExpReplacements
  }

instance Semigroup Env where
  Env Map VName PolyBinding
tb1 Map VName TypeBinding
pb1 Set VName
sc1 Set VName
gs1 ExpReplacements
pr1 <> :: Env -> Env -> Env
<> Env Map VName PolyBinding
tb2 Map VName TypeBinding
pb2 Set VName
sc2 Set VName
gs2 ExpReplacements
pr2 = Map VName PolyBinding
-> Map VName TypeBinding
-> Set VName
-> Set VName
-> ExpReplacements
-> Env
Env (Map VName PolyBinding
tb1 forall a. Semigroup a => a -> a -> a
<> Map VName PolyBinding
tb2) (Map VName TypeBinding
pb1 forall a. Semigroup a => a -> a -> a
<> Map VName TypeBinding
pb2) (Set VName
sc1 forall a. Semigroup a => a -> a -> a
<> Set VName
sc2) (Set VName
gs1 forall a. Semigroup a => a -> a -> a
<> Set VName
gs2) (ExpReplacements
pr1 forall a. Semigroup a => a -> a -> a
<> ExpReplacements
pr2)

instance Monoid Env where
  mempty :: Env
mempty = Map VName PolyBinding
-> Map VName TypeBinding
-> Set VName
-> Set VName
-> ExpReplacements
-> Env
Env forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

localEnv :: Env -> MonoM a -> MonoM a
localEnv :: forall a. Env -> MonoM a -> MonoM a
localEnv Env
env = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Env
env <>)

extendEnv :: VName -> PolyBinding -> MonoM a -> MonoM a
extendEnv :: forall a. VName -> PolyBinding -> MonoM a -> MonoM a
extendEnv VName
vn PolyBinding
binding =
  forall a. Env -> MonoM a -> MonoM a
localEnv
    forall a. Monoid a => a
mempty {envPolyBindings :: Map VName PolyBinding
envPolyBindings = forall k a. k -> a -> Map k a
M.singleton VName
vn PolyBinding
binding}

isolateNormalisation :: MonoM a -> MonoM a
isolateNormalisation :: forall a. MonoM a -> MonoM a
isolateNormalisation MonoM a
m = do
  ExpReplacements
prevRepl <- forall s (m :: * -> *). MonadState s m => m s
get
  forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a. Monoid a => a
mempty
  a
ret <- forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env
env -> Env
env {envScope :: Set VName
envScope = forall a. Monoid a => a
mempty, envParametrized :: ExpReplacements
envParametrized = forall a. Monoid a => a
mempty}) MonoM a
m
  forall s (m :: * -> *). MonadState s m => s -> m ()
put ExpReplacements
prevRepl
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
ret

withArgs :: S.Set VName -> MonoM a -> MonoM a
withArgs :: forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
args = forall a. Env -> MonoM a -> MonoM a
localEnv forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => a
mempty {envScope :: Set VName
envScope = Set VName
args}

withParams :: ExpReplacements -> MonoM a -> MonoM a
withParams :: forall a. ExpReplacements -> MonoM a -> MonoM a
withParams ExpReplacements
params = forall a. Env -> MonoM a -> MonoM a
localEnv forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => a
mempty {envParametrized :: ExpReplacements
envParametrized = ExpReplacements
params}

-- The monomorphization monad.
newtype MonoM a
  = MonoM
      ( RWST
          Env
          (Seq.Seq (VName, ValBind))
          (ExpReplacements, VNameSource)
          (State Lifts)
          a
      )
  deriving
    ( forall a b. a -> MonoM b -> MonoM a
forall a b. (a -> b) -> MonoM a -> MonoM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> MonoM b -> MonoM a
$c<$ :: forall a b. a -> MonoM b -> MonoM a
fmap :: forall a b. (a -> b) -> MonoM a -> MonoM b
$cfmap :: forall a b. (a -> b) -> MonoM a -> MonoM b
Functor,
      Functor MonoM
forall a. a -> MonoM a
forall a b. MonoM a -> MonoM b -> MonoM a
forall a b. MonoM a -> MonoM b -> MonoM b
forall a b. MonoM (a -> b) -> MonoM a -> MonoM b
forall a b c. (a -> b -> c) -> MonoM a -> MonoM b -> MonoM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. MonoM a -> MonoM b -> MonoM a
$c<* :: forall a b. MonoM a -> MonoM b -> MonoM a
*> :: forall a b. MonoM a -> MonoM b -> MonoM b
$c*> :: forall a b. MonoM a -> MonoM b -> MonoM b
liftA2 :: forall a b c. (a -> b -> c) -> MonoM a -> MonoM b -> MonoM c
$cliftA2 :: forall a b c. (a -> b -> c) -> MonoM a -> MonoM b -> MonoM c
<*> :: forall a b. MonoM (a -> b) -> MonoM a -> MonoM b
$c<*> :: forall a b. MonoM (a -> b) -> MonoM a -> MonoM b
pure :: forall a. a -> MonoM a
$cpure :: forall a. a -> MonoM a
Applicative,
      Applicative MonoM
forall a. a -> MonoM a
forall a b. MonoM a -> MonoM b -> MonoM b
forall a b. MonoM a -> (a -> MonoM b) -> MonoM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> MonoM a
$creturn :: forall a. a -> MonoM a
>> :: forall a b. MonoM a -> MonoM b -> MonoM b
$c>> :: forall a b. MonoM a -> MonoM b -> MonoM b
>>= :: forall a b. MonoM a -> (a -> MonoM b) -> MonoM b
$c>>= :: forall a b. MonoM a -> (a -> MonoM b) -> MonoM b
Monad,
      MonadReader Env,
      MonadWriter (Seq.Seq (VName, ValBind))
    )

instance MonadFreshNames MonoM where
  getNameSource :: MonoM VNameSource
getNameSource = forall a.
RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
-> MonoM a
MonoM forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a, b) -> b
snd
  putNameSource :: VNameSource -> MonoM ()
putNameSource = forall a.
RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
-> MonoM a
MonoM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const

instance MonadState ExpReplacements MonoM where
  get :: MonoM ExpReplacements
get = forall a.
RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
-> MonoM a
MonoM forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a, b) -> a
fst
  put :: ExpReplacements -> MonoM ()
put = forall a.
RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
-> MonoM a
MonoM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const

runMonoM :: VNameSource -> MonoM a -> ((a, Seq.Seq (VName, ValBind)), VNameSource)
runMonoM :: forall a.
VNameSource -> MonoM a -> ((a, Seq (VName, ValBind)), VNameSource)
runMonoM VNameSource
src (MonoM RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
m) = ((a
a, Seq (VName, ValBind)
defs), VNameSource
src')
  where
    (a
a, (ExpReplacements
_, VNameSource
src'), Seq (VName, ValBind)
defs) = forall s a. State s a -> s -> a
evalState (forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
runRWST RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
m forall a. Monoid a => a
mempty (forall a. Monoid a => a
mempty, VNameSource
src)) forall a. Monoid a => a
mempty

lookupFun :: VName -> MonoM (Maybe PolyBinding)
lookupFun :: VName -> MonoM (Maybe PolyBinding)
lookupFun VName
vn = do
  Map VName PolyBinding
env <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Map VName PolyBinding
envPolyBindings
  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Map VName PolyBinding
env of
    Just PolyBinding
valbind -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just PolyBinding
valbind
    Maybe PolyBinding
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

askScope :: MonoM (S.Set VName)
askScope :: MonoM (Set VName)
askScope = do
  Set VName
scope <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Set VName
envScope
  Set VName
scope' <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall a. Ord a => Set a -> Set a -> Set a
S.union Set VName
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Set VName
envGlobalScope
  Set VName
scope'' <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall a. Ord a => Set a -> Set a -> Set a
S.union Set VName
scope' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> Set k
M.keysSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map VName PolyBinding
envPolyBindings
  forall a. Ord a => Set a -> Set a -> Set a
S.union Set VName
scope'' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM Lifts
getLifts

-- | Asks the introduced variables in a set of argument,
-- that is arguments not currently in scope.
askIntros :: S.Set VName -> MonoM (S.Set VName)
askIntros :: Set VName -> MonoM (Set VName)
askIntros Set VName
argset =
  (forall a. (a -> Bool) -> Set a -> Set a
S.filter VName -> Bool
notIntrisic Set VName
argset `S.difference`) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM (Set VName)
askScope
  where
    notIntrisic :: VName -> Bool
notIntrisic VName
vn = VName -> Int
baseTag VName
vn forall a. Ord a => a -> a -> Bool
> Int
maxIntrinsicTag

-- | Gets and removes expressions that could not be calculated when
-- the arguments set will be unscoped.
-- This should be called without argset in scope, for good detection of intros.
parametrizing :: S.Set VName -> MonoM ExpReplacements
parametrizing :: Set VName -> MonoM ExpReplacements
parametrizing Set VName
argset = do
  Set VName
intros <- Set VName -> MonoM (Set VName)
askIntros Set VName
argset
  (ExpReplacements
params, ExpReplacements
nxtBind) <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => Set a -> Set a -> Bool
S.disjoint Set VName
intros forall b c a. (b -> c) -> (a -> b) -> a -> c
. FV -> Set VName
fvVars forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> FV
freeInExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReplacedExp -> Exp
unReplaced forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
  forall s (m :: * -> *). MonadState s m => s -> m ()
put ExpReplacements
nxtBind
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReplacements
params

calculateDims :: Exp -> ExpReplacements -> MonoM Exp
calculateDims :: Exp -> ExpReplacements -> MonoM Exp
calculateDims Exp
body ExpReplacements
repl =
  forall {f :: * -> *}.
MonadFreshNames f =>
ExpReplacements -> Exp -> f Exp
foldCalc ExpReplacements
top_repl forall a b. (a -> b) -> a -> b
$ ExpReplacements -> Exp -> Exp
expReplace ExpReplacements
top_repl Exp
body
  where
    -- list of strict sub-expressions of e
    subExps :: Exp -> [ReplacedExp]
subExps Exp
e
      | Just Exp
e' <- Exp -> Maybe Exp
stripExp Exp
e = Exp -> [ReplacedExp]
subExps Exp
e'
      | Bool
otherwise = forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper (StateT [ReplacedExp] Identity)
mapper Exp
e forall s a. State s a -> s -> s
`execState` forall a. Monoid a => a
mempty
      where
        mapOnExp :: Exp -> StateT [ReplacedExp] Identity Exp
mapOnExp Exp
e'
          | Just Exp
e'' <- Exp -> Maybe Exp
stripExp Exp
e' = Exp -> StateT [ReplacedExp] Identity Exp
mapOnExp Exp
e''
          | Bool
otherwise = do
              forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Exp -> ReplacedExp
ReplacedExp Exp
e' :)
              forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper (StateT [ReplacedExp] Identity)
mapper Exp
e'
        mapper :: ASTMapper (StateT [ReplacedExp] Identity)
mapper = forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {Exp -> StateT [ReplacedExp] Identity Exp
mapOnExp :: Exp -> StateT [ReplacedExp] Identity Exp
mapOnExp :: Exp -> StateT [ReplacedExp] Identity Exp
mapOnExp}
    depends :: (ReplacedExp, b) -> (ReplacedExp, b) -> Bool
depends (ReplacedExp
a, b
_) (ReplacedExp
b, b
_) = ReplacedExp
b forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Exp -> [ReplacedExp]
subExps (ReplacedExp -> Exp
unReplaced ReplacedExp
a)
    top_repl :: ExpReplacements
top_repl =
      forall a. (a -> a -> Bool) -> [a] -> [a]
topologicalSort forall {b} {b}. (ReplacedExp, b) -> (ReplacedExp, b) -> Bool
depends ExpReplacements
repl

    ---- Calculus insertion
    foldCalc :: ExpReplacements -> Exp -> f Exp
foldCalc [] Exp
body' = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
body'
    foldCalc ((ReplacedExp
dim, VName
vn) : ExpReplacements
repls) Exp
body' = do
      VName
reName <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
vn
      let expr :: Exp
expr = ExpReplacements -> Exp -> Exp
expReplace ExpReplacements
repls forall a b. (a -> b) -> a -> b
$ ReplacedExp -> Exp
unReplaced ReplacedExp
dim
          subst :: VName -> Maybe (Subst t)
subst VName
vn' =
            if VName
vn' forall a. Eq a => a -> a -> Bool
== VName
vn
              then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall t. Exp -> Subst t
ExpSubst forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
reName) forall a. Monoid a => a
mempty
              else forall a. Maybe a
Nothing
          appRes :: Info AppRes
appRes = case Exp
body' of
            (AppExp AppExpBase Info VName
_ (Info (AppRes StructType
ty [VName]
ext))) -> forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes (forall a. Substitutable a => TypeSubs -> a -> a
applySubst forall {t}. VName -> Maybe (Subst t)
subst StructType
ty) (VName
reName forall a. a -> [a] -> [a]
: [VName]
ext)
            Exp
e -> forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes (forall a. Substitutable a => TypeSubs -> a -> a
applySubst forall {t}. VName -> Maybe (Subst t)
subst forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e) [VName
reName]
      ExpReplacements -> Exp -> f Exp
foldCalc ExpReplacements
repls forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
          ( forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat
              []
              (forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
vn (forall a. a -> Info a
Info forall dim als. TypeBase dim als
i64) (forall a. Located a => a -> SrcLoc
srclocOf Exp
expr))
              Exp
expr
              Exp
body'
              forall a. Monoid a => a
mempty
          )
          Info AppRes
appRes

unscoping :: S.Set VName -> Exp -> MonoM Exp
unscoping :: Set VName -> Exp -> MonoM Exp
unscoping Set VName
argset Exp
body = do
  ExpReplacements
localDims <- Set VName -> MonoM ExpReplacements
parametrizing Set VName
argset
  Set VName
scope <- forall a. Ord a => Set a -> Set a -> Set a
S.union Set VName
argset forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM (Set VName)
askScope
  Exp -> ExpReplacements -> MonoM Exp
calculateDims Exp
body forall a b. (a -> b) -> a -> b
$ Set VName -> ExpReplacements -> ExpReplacements
canCalculate Set VName
scope ExpReplacements
localDims

scoping :: S.Set VName -> MonoM Exp -> MonoM Exp
scoping :: Set VName -> MonoM Exp -> MonoM Exp
scoping Set VName
argset MonoM Exp
m =
  forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
argset MonoM Exp
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Set VName -> Exp -> MonoM Exp
unscoping Set VName
argset

-- Given instantiated type of function, produce size arguments.
type InferSizeArgs = StructType -> MonoM [Exp]

data MonoSize
  = -- | The integer encodes an equivalence class, so we can keep
    -- track of sizes that are statically identical.
    MonoKnown Int
  | MonoAnon
  deriving (Int -> MonoSize -> ShowS
[MonoSize] -> ShowS
MonoSize -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MonoSize] -> ShowS
$cshowList :: [MonoSize] -> ShowS
show :: MonoSize -> [Char]
$cshow :: MonoSize -> [Char]
showsPrec :: Int -> MonoSize -> ShowS
$cshowsPrec :: Int -> MonoSize -> ShowS
Show)

-- We treat all MonoAnon as identical.
instance Eq MonoSize where
  MonoKnown Int
x == :: MonoSize -> MonoSize -> Bool
== MonoKnown Int
y = Int
x forall a. Eq a => a -> a -> Bool
== Int
y
  MonoSize
MonoAnon == MonoSize
MonoAnon = Bool
True
  MonoSize
_ == MonoSize
_ = Bool
False

instance Pretty MonoSize where
  pretty :: forall ann. MonoSize -> Doc ann
pretty (MonoKnown Int
i) = Doc ann
"?" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Int
i
  pretty MonoSize
MonoAnon = Doc ann
"?"

instance Pretty (Shape MonoSize) where
  pretty :: forall ann. Shape MonoSize -> Doc ann
pretty (Shape [MonoSize]
ds) = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall ann. Doc ann -> Doc ann
brackets forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a ann. Pretty a => a -> Doc ann
pretty) [MonoSize]
ds)

-- The kind of type relative to which we monomorphise.  What is most
-- important to us is not the specific dimensions, but merely whether
-- they are known or anonymous/local.
type MonoType = TypeBase MonoSize NoUniqueness

monoType :: TypeBase Size als -> MonoType
monoType :: forall als. TypeBase Exp als -> MonoType
monoType = forall u. TypeBase MonoSize u -> TypeBase MonoSize u
noExts forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s a. State s a -> s -> a
`evalState` (Int
0, forall a. Monoid a => a
mempty)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) fdim tdim als.
Applicative f =>
(Set VName -> DimPos -> fdim -> f tdim)
-> TypeBase fdim als -> f (TypeBase tdim als)
traverseDims forall {f :: * -> *} {p}.
MonadState (Int, Map Exp Int) f =>
Set VName -> p -> Exp -> f MonoSize
onDim forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct
  where
    -- Remove exts from return types because we don't use them anymore.
    noExts :: TypeBase MonoSize u -> TypeBase MonoSize u
    noExts :: forall u. TypeBase MonoSize u -> TypeBase MonoSize u
noExts (Array u
u Shape MonoSize
shape ScalarTypeBase MonoSize NoUniqueness
t) = forall dim u.
u -> Shape dim -> ScalarTypeBase dim NoUniqueness -> TypeBase dim u
Array u
u Shape MonoSize
shape forall a b. (a -> b) -> a -> b
$ forall {u}. ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u
noExtsScalar ScalarTypeBase MonoSize NoUniqueness
t
    noExts (Scalar ScalarTypeBase MonoSize u
t) = forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall a b. (a -> b) -> a -> b
$ forall {u}. ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u
noExtsScalar ScalarTypeBase MonoSize u
t
    noExtsScalar :: ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u
noExtsScalar (Record Map Name (TypeBase MonoSize u)
fs) = forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall u. TypeBase MonoSize u -> TypeBase MonoSize u
noExts Map Name (TypeBase MonoSize u)
fs
    noExtsScalar (Sum Map Name [TypeBase MonoSize u]
fs) = forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase MonoSize u -> TypeBase MonoSize u
noExts) Map Name [TypeBase MonoSize u]
fs
    noExtsScalar (Arrow u
as PName
p Diet
d MonoType
t1 (RetType [VName]
_ TypeBase MonoSize Uniqueness
t2)) =
      forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow u
as PName
p Diet
d (forall u. TypeBase MonoSize u -> TypeBase MonoSize u
noExts MonoType
t1) (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (forall u. TypeBase MonoSize u -> TypeBase MonoSize u
noExts TypeBase MonoSize Uniqueness
t2))
    noExtsScalar ScalarTypeBase MonoSize u
t = ScalarTypeBase MonoSize u
t
    onDim :: Set VName -> p -> Exp -> f MonoSize
onDim Set VName
bound p
_ Exp
e
      -- A locally bound size.
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound) forall a b. (a -> b) -> a -> b
$ FV -> Set VName
fvVars forall a b. (a -> b) -> a -> b
$ Exp -> FV
freeInExp Exp
e =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure MonoSize
MonoAnon
    onDim Set VName
_ p
_ Exp
d = do
      (Int
i, Map Exp Int
m) <- forall s (m :: * -> *). MonadState s m => m s
get
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Exp
d Map Exp Int
m of
        Just Int
prev ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int -> MonoSize
MonoKnown Int
prev
        Maybe Int
Nothing -> do
          forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int
i forall a. Num a => a -> a -> a
+ Int
1, forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Exp
d Int
i Map Exp Int
m)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int -> MonoSize
MonoKnown Int
i

-- Mapping from function name and instance list to a new function name in case
-- the function has already been instantiated with those concrete types.
type Lifts = [((VName, MonoType), (VName, InferSizeArgs))]

getLifts :: MonoM Lifts
getLifts :: MonoM Lifts
getLifts = forall a.
RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
-> MonoM a
MonoM forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadState s m => m s
get

modifyLifts :: (Lifts -> Lifts) -> MonoM ()
modifyLifts :: (Lifts -> Lifts) -> MonoM ()
modifyLifts = forall a.
RWST
  Env
  (Seq (VName, ValBind))
  (ExpReplacements, VNameSource)
  (State Lifts)
  a
-> MonoM a
MonoM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify

addLifted :: VName -> MonoType -> (VName, InferSizeArgs) -> MonoM ()
addLifted :: VName -> MonoType -> (VName, InferSizeArgs) -> MonoM ()
addLifted VName
fname MonoType
il (VName, InferSizeArgs)
liftf =
  (Lifts -> Lifts) -> MonoM ()
modifyLifts (((VName
fname, MonoType
il), (VName, InferSizeArgs)
liftf) :)

lookupLifted :: VName -> MonoType -> MonoM (Maybe (VName, InferSizeArgs))
lookupLifted :: VName -> MonoType -> MonoM (Maybe (VName, InferSizeArgs))
lookupLifted VName
fname MonoType
t = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (VName
fname, MonoType
t) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM Lifts
getLifts

-- | Creates a new expression replacement if needed, this always produces normalised sizes.
-- (e.g. single variable or constant)
replaceExp :: Exp -> MonoM Exp
replaceExp :: Exp -> MonoM Exp
replaceExp Exp
e =
  case Exp -> Maybe Exp
maybeNormalisedSize Exp
e of
    Just Exp
e' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e'
    Maybe Exp
Nothing -> do
      let e' :: ReplacedExp
e' = Exp -> ReplacedExp
ReplacedExp Exp
e
      Maybe VName
prev <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ReplacedExp
e'
      Maybe VName
prev_param <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ReplacedExp
e' forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> ExpReplacements
envParametrized
      case (Maybe VName
prev_param, Maybe VName
prev) of
        (Just VName
vn, Maybe VName
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
vn) (forall a. Located a => a -> SrcLoc
srclocOf Exp
e)
        (Maybe VName
Nothing, Just VName
vn) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
vn) (forall a. Located a => a -> SrcLoc
srclocOf Exp
e)
        (Maybe VName
Nothing, Maybe VName
Nothing) -> do
          VName
vn <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString forall a b. (a -> b) -> a -> b
$ [Char]
"d<{" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (Exp -> ExpBase NoInfo VName
bareExp Exp
e) forall a. [a] -> [a] -> [a]
++ [Char]
"}>"
          forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ReplacedExp
e', VName
vn) :)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
vn) (forall a. Located a => a -> SrcLoc
srclocOf Exp
e)
  where
    -- Avoid replacing of some 'already normalised' sizes that are just surounded by some parentheses.
    maybeNormalisedSize :: Exp -> Maybe Exp
maybeNormalisedSize Exp
e'
      | Just Exp
e'' <- Exp -> Maybe Exp
stripExp Exp
e' = Exp -> Maybe Exp
maybeNormalisedSize Exp
e''
    maybeNormalisedSize (Var QualName VName
qn Info StructType
_ SrcLoc
loc) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName QualName VName
qn SrcLoc
loc
    maybeNormalisedSize (IntLit Integer
v Info StructType
_ SrcLoc
loc) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
Integer -> f StructType -> SrcLoc -> ExpBase f vn
IntLit Integer
v (forall a. a -> Info a
Info forall dim als. TypeBase dim als
i64) SrcLoc
loc
    maybeNormalisedSize Exp
_ = forall a. Maybe a
Nothing

transformFName :: SrcLoc -> QualName VName -> StructType -> MonoM Exp
transformFName :: SrcLoc -> QualName VName -> StructType -> MonoM Exp
transformFName SrcLoc
loc QualName VName
fname StructType
t = do
  StructType
t' <- StructType -> MonoM StructType
removeTypeVariablesInType StructType
t
  StructType
t'' <- forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType StructType
t'
  let mono_t :: MonoType
mono_t = forall als. TypeBase Exp als -> MonoType
monoType StructType
t'
  if VName -> Int
baseTag (forall vn. QualName vn -> vn
qualLeaf QualName VName
fname) forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {vn}. QualName vn -> StructType -> ExpBase Info vn
var QualName VName
fname StructType
t''
    else do
      Maybe (VName, InferSizeArgs)
maybe_fname <- VName -> MonoType -> MonoM (Maybe (VName, InferSizeArgs))
lookupLifted (forall vn. QualName vn -> vn
qualLeaf QualName VName
fname) MonoType
mono_t
      Maybe PolyBinding
maybe_funbind <- VName -> MonoM (Maybe PolyBinding)
lookupFun forall a b. (a -> b) -> a -> b
$ forall vn. QualName vn -> vn
qualLeaf QualName VName
fname
      case (Maybe (VName, InferSizeArgs)
maybe_fname, Maybe PolyBinding
maybe_funbind) of
        -- The function has already been monomorphised.
        (Just (VName
fname', InferSizeArgs
infer), Maybe PolyBinding
_) ->
          forall {vn}.
vn
-> TypeBase Exp Uniqueness -> [ExpBase Info vn] -> ExpBase Info vn
applySizeArgs VName
fname' (forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique StructType
t'') forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InferSizeArgs
infer StructType
t''
        -- An intrinsic function.
        (Maybe (VName, InferSizeArgs)
Nothing, Maybe PolyBinding
Nothing) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {vn}. QualName vn -> StructType -> ExpBase Info vn
var QualName VName
fname StructType
t''
        -- A polymorphic function.
        (Maybe (VName, InferSizeArgs)
Nothing, Just PolyBinding
funbind) -> do
          (VName
fname', InferSizeArgs
infer, ValBind
funbind') <- Bool
-> PolyBinding -> MonoType -> MonoM (VName, InferSizeArgs, ValBind)
monomorphiseBinding Bool
False PolyBinding
funbind MonoType
mono_t
          forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall a. a -> Seq a
Seq.singleton (forall vn. QualName vn -> vn
qualLeaf QualName VName
fname, ValBind
funbind')
          VName -> MonoType -> (VName, InferSizeArgs) -> MonoM ()
addLifted (forall vn. QualName vn -> vn
qualLeaf QualName VName
fname) MonoType
mono_t (VName
fname', InferSizeArgs
infer)
          forall {vn}.
vn
-> TypeBase Exp Uniqueness -> [ExpBase Info vn] -> ExpBase Info vn
applySizeArgs VName
fname' (forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique StructType
t'') forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InferSizeArgs
infer StructType
t''
  where
    var :: QualName vn -> StructType -> ExpBase Info vn
var QualName vn
fname' StructType
t'' = forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName vn
fname' (forall a. a -> Info a
Info StructType
t'') SrcLoc
loc

    applySizeArg :: TypeBase Exp Uniqueness
-> (Int, ExpBase Info vn)
-> ExpBase Info vn
-> (Int, ExpBase Info vn)
applySizeArg TypeBase Exp Uniqueness
t' (Int
i, ExpBase Info vn
f) ExpBase Info vn
size_arg =
      ( Int
i forall a. Num a => a -> a -> a
- Int
1,
        forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply
          ExpBase Info vn
f
          [(Diet
Observe, forall a. Maybe a
Nothing, ExpBase Info vn
size_arg)]
          (StructType -> [VName] -> AppRes
AppRes ([ParamType] -> ResRetType -> StructType
foldFunType (forall a. Int -> a -> [a]
replicate Int
i forall dim als. TypeBase dim als
i64) (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] TypeBase Exp Uniqueness
t')) [])
      )

    applySizeArgs :: vn
-> TypeBase Exp Uniqueness -> [ExpBase Info vn] -> ExpBase Info vn
applySizeArgs vn
fname' TypeBase Exp Uniqueness
t' [ExpBase Info vn]
size_args =
      forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
          (forall {vn}.
TypeBase Exp Uniqueness
-> (Int, ExpBase Info vn)
-> ExpBase Info vn
-> (Int, ExpBase Info vn)
applySizeArg TypeBase Exp Uniqueness
t')
          ( forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExpBase Info vn]
size_args forall a. Num a => a -> a -> a
- Int
1,
            forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var
              (forall v. v -> QualName v
qualName vn
fname')
              (forall a. a -> Info a
Info ([ParamType] -> ResRetType -> StructType
foldFunType (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall dim als. TypeBase dim als
i64) [ExpBase Info vn]
size_args) (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] TypeBase Exp Uniqueness
t')))
              SrcLoc
loc
          )
          [ExpBase Info vn]
size_args

transformType :: TypeBase Size u -> MonoM (TypeBase Size u)
transformType :: forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType TypeBase Exp u
typ =
  case TypeBase Exp u
typ of
    Scalar ScalarTypeBase Exp u
scalar -> forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u. ScalarTypeBase Exp u -> MonoM (ScalarTypeBase Exp u)
transformScalarSizes ScalarTypeBase Exp u
scalar
    Array u
u Shape Exp
shape ScalarTypeBase Exp NoUniqueness
scalar -> forall dim u.
u -> Shape dim -> ScalarTypeBase dim NoUniqueness -> TypeBase dim u
Array u
u forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> MonoM Exp
onDim Shape Exp
shape forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall u. ScalarTypeBase Exp u -> MonoM (ScalarTypeBase Exp u)
transformScalarSizes ScalarTypeBase Exp NoUniqueness
scalar
  where
    transformScalarSizes :: ScalarTypeBase Size u -> MonoM (ScalarTypeBase Size u)
    transformScalarSizes :: forall u. ScalarTypeBase Exp u -> MonoM (ScalarTypeBase Exp u)
transformScalarSizes (Record Map Name (TypeBase Exp u)
fs) =
      forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Map Name (TypeBase Exp u)
fs
    transformScalarSizes (Sum Map Name [TypeBase Exp u]
cs) =
      forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse) forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Map Name [TypeBase Exp u]
cs
    transformScalarSizes (Arrow u
as PName
argName Diet
d StructType
argT ResRetType
retT) = do
      ResRetType
retT' <- forall as.
Set VName -> RetTypeBase Exp as -> MonoM (RetTypeBase Exp as)
transformRetTypeSizes Set VName
argset ResRetType
retT
      forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow u
as PName
argName Diet
d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType StructType
argT forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ResRetType
retT'
      where
        argset :: Set VName
argset =
          FV -> Set VName
fvVars (forall u. TypeBase Exp u -> FV
freeInType StructType
argT)
            forall a. Semigroup a => a -> a -> a
<> case PName
argName of
              PName
Unnamed -> forall a. Monoid a => a
mempty
              Named VName
vn -> forall a. a -> Set a
S.singleton VName
vn
    transformScalarSizes (TypeVar u
u QualName VName
qn [TypeArg Exp]
args) =
      forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar u
u QualName VName
qn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeArg Exp -> MonoM (TypeArg Exp)
onArg [TypeArg Exp]
args
      where
        onArg :: TypeArg Exp -> MonoM (TypeArg Exp)
onArg (TypeArgDim Exp
dim) = forall dim. dim -> TypeArg dim
TypeArgDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
onDim Exp
dim
        onArg (TypeArgType StructType
ty) = forall dim. TypeBase dim NoUniqueness -> TypeArg dim
TypeArgType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType StructType
ty
    transformScalarSizes ty :: ScalarTypeBase Exp u
ty@Prim {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ScalarTypeBase Exp u
ty

    onDim :: Exp -> MonoM Exp
onDim Exp
e
      | Exp
e forall a. Eq a => a -> a -> Bool
== Exp
anySize = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
      | Bool
otherwise = Exp -> MonoM Exp
replaceExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> MonoM Exp
transformExp Exp
e

transformRetTypeSizes :: S.Set VName -> RetTypeBase Size as -> MonoM (RetTypeBase Size as)
transformRetTypeSizes :: forall as.
Set VName -> RetTypeBase Exp as -> MonoM (RetTypeBase Exp as)
transformRetTypeSizes Set VName
argset (RetType [VName]
dims TypeBase Exp as
ty) = do
  TypeBase Exp as
ty' <- forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
argset forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType TypeBase Exp as
ty
  ExpReplacements
rl <- Set VName -> MonoM ExpReplacements
parametrizing Set VName
argset
  let dims' :: [VName]
dims' = [VName]
dims forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd ExpReplacements
rl
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims' TypeBase Exp as
ty'

sizesForPat :: MonadFreshNames m => Pat ParamType -> m ([VName], Pat ParamType)
sizesForPat :: forall (m :: * -> *).
MonadFreshNames m =>
Pat ParamType -> m ([VName], Pat ParamType)
sizesForPat Pat ParamType
pat = do
  (Pat ParamType
params', [VName]
sizes) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse forall {t :: (* -> *) -> * -> *} {m :: * -> *}.
(MonadTrans t, MonadFreshNames m, MonadState [VName] (t m)) =>
Exp -> t m Exp
onDim forall (f :: * -> *) a. Applicative f => a -> f a
pure) Pat ParamType
pat) []
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
sizes, Pat ParamType
params')
  where
    onDim :: Exp -> t m Exp
onDim Exp
d
      | Exp
d forall a. Eq a => a -> a -> Bool
== Exp
anySize = do
          VName
v <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"size"
          forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (VName
v :)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
v) forall a. Monoid a => a
mempty
      | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
d

transformAppRes :: AppRes -> MonoM AppRes
transformAppRes :: AppRes -> MonoM AppRes
transformAppRes (AppRes StructType
t [VName]
ext) =
  StructType -> [VName] -> AppRes
AppRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType StructType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
ext

transformAppExp :: AppExp -> AppRes -> MonoM Exp
transformAppExp :: AppExpBase Info VName -> AppRes -> MonoM Exp
transformAppExp (Range Exp
e1 Maybe Exp
me Inclusiveness Exp
incl SrcLoc
loc) AppRes
res = do
  Exp
e1' <- Exp -> MonoM Exp
transformExp Exp
e1
  Maybe Exp
me' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> MonoM Exp
transformExp Maybe Exp
me
  Inclusiveness Exp
incl' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> MonoM Exp
transformExp Inclusiveness Exp
incl
  AppRes
res' <- AppRes -> MonoM AppRes
transformAppRes AppRes
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn
-> Maybe (ExpBase f vn)
-> Inclusiveness (ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Range Exp
e1' Maybe Exp
me' Inclusiveness Exp
incl' SrcLoc
loc) (forall a. a -> Info a
Info AppRes
res')
transformAppExp (LetPat [SizeBinder VName]
sizes PatBase Info VName StructType
pat Exp
e Exp
body SrcLoc
loc) AppRes
res = do
  Exp
e' <- Exp -> MonoM Exp
transformExp Exp
e
  let dimArgs :: Set VName
dimArgs = forall a. Ord a => [a] -> Set a
S.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall vn. SizeBinder vn -> vn
sizeName [SizeBinder VName]
sizes)
  Set VName
implicitDims <- forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
dimArgs forall a b. (a -> b) -> a -> b
$ Set VName -> MonoM (Set VName)
askIntros forall a b. (a -> b) -> a -> b
$ FV -> Set VName
fvVars forall a b. (a -> b) -> a -> b
$ forall u. Pat (TypeBase Exp u) -> FV
freeInPat PatBase Info VName StructType
pat
  let dimArgs' :: Set VName
dimArgs' = Set VName
dimArgs forall a. Semigroup a => a -> a -> a
<> Set VName
implicitDims
      letArgs :: Set VName
letArgs = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall t. Pat t -> [VName]
patNames PatBase Info VName StructType
pat
      argset :: Set VName
argset = Set VName
dimArgs' forall a. Ord a => Set a -> Set a -> Set a
`S.union` Set VName
letArgs
  PatBase Info VName StructType
pat' <- forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
dimArgs' forall a b. (a -> b) -> a -> b
$ forall u. Pat (TypeBase Exp u) -> MonoM (Pat (TypeBase Exp u))
transformPat PatBase Info VName StructType
pat
  ExpReplacements
params <- Set VName -> MonoM ExpReplacements
parametrizing Set VName
dimArgs'
  let sizes' :: [SizeBinder VName]
sizes' = [SizeBinder VName]
sizes forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map (forall vn. vn -> SrcLoc -> SizeBinder vn
`SizeBinder` forall a. Monoid a => a
mempty) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd ExpReplacements
params forall a. Semigroup a => a -> a -> a
<> forall a. Set a -> [a]
S.toList Set VName
implicitDims)
  Exp
body' <- forall a. ExpReplacements -> MonoM a -> MonoM a
withParams ExpReplacements
params forall a b. (a -> b) -> a -> b
$ Set VName -> MonoM Exp -> MonoM Exp
scoping Set VName
argset forall a b. (a -> b) -> a -> b
$ Exp -> MonoM Exp
transformExp Exp
body
  AppRes
res' <- AppRes -> MonoM AppRes
transformAppRes AppRes
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat [SizeBinder VName]
sizes' PatBase Info VName StructType
pat' Exp
e' Exp
body' SrcLoc
loc) (forall a. a -> Info a
Info AppRes
res')
transformAppExp (LetFun VName
fname ([TypeParamBase VName]
tparams, [Pat ParamType]
params, Maybe (TypeExp Info VName)
retdecl, Info ResRetType
ret, Exp
body) Exp
e SrcLoc
loc) AppRes
res
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TypeParamBase VName]
tparams = do
      -- Retrieve the lifted monomorphic function bindings that are
      -- produced, filter those that are monomorphic versions of the
      -- current let-bound function and insert them at this point, and
      -- propagate the rest.
      let funbind :: PolyBinding
funbind = (VName, [TypeParamBase VName], [Pat ParamType], ResRetType, Exp,
 [AttrInfo VName], SrcLoc)
-> PolyBinding
PolyBinding (VName
fname, [TypeParamBase VName]
tparams, [Pat ParamType]
params, ResRetType
ret, Exp
body, forall a. Monoid a => a
mempty, SrcLoc
loc)
      forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass forall a b. (a -> b) -> a -> b
$ do
        (Exp
e', Seq (VName, ValBind)
bs) <- forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen forall a b. (a -> b) -> a -> b
$ forall a. VName -> PolyBinding -> MonoM a -> MonoM a
extendEnv VName
fname PolyBinding
funbind forall a b. (a -> b) -> a -> b
$ Set VName -> MonoM Exp -> MonoM Exp
scoping (forall a. a -> Set a
S.singleton VName
fname) forall a b. (a -> b) -> a -> b
$ Exp -> MonoM Exp
transformExp Exp
e
        -- Do not remember this one for next time we monomorphise this
        -- function.
        (Lifts -> Lifts) -> MonoM ()
modifyLifts forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= VName
fname) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
        let (Seq (VName, ValBind)
bs_local, Seq (VName, ValBind)
bs_prop) = forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition ((forall a. Eq a => a -> a -> Bool
== VName
fname) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) Seq (VName, ValBind)
bs
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ValBind] -> Exp -> Exp
unfoldLetFuns (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Seq (VName, ValBind)
bs_local) Exp
e', forall a b. a -> b -> a
const Seq (VName, ValBind)
bs_prop)
  | Bool
otherwise = do
      Exp
body' <- Set VName -> MonoM Exp -> MonoM Exp
scoping (forall a. Ord a => [a] -> Set a
S.fromList (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall t. Pat t -> [VName]
patNames [Pat ParamType]
params)) forall a b. (a -> b) -> a -> b
$ Exp -> MonoM Exp
transformExp Exp
body
      ResRetType
ret' <- forall as.
Set VName -> RetTypeBase Exp as -> MonoM (RetTypeBase Exp as)
transformRetTypeSizes (forall a. Ord a => [a] -> Set a
S.fromList (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall t. Pat t -> [VName]
patNames [Pat ParamType]
params)) ResRetType
ret
      forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall (f :: * -> *) vn.
vn
-> ([TypeParamBase vn], [PatBase f vn ParamType],
    Maybe (TypeExp f vn), f ResRetType, ExpBase f vn)
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetFun VName
fname ([TypeParamBase VName]
tparams, [Pat ParamType]
params, Maybe (TypeExp Info VName)
retdecl, forall a. a -> Info a
Info ResRetType
ret', Exp
body')
                forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set VName -> MonoM Exp -> MonoM Exp
scoping (forall a. a -> Set a
S.singleton VName
fname) (Exp -> MonoM Exp
transformExp Exp
e)
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
            )
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. a -> Info a
Info forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AppRes -> MonoM AppRes
transformAppRes AppRes
res)
transformAppExp (If Exp
e1 Exp
e2 Exp
e3 SrcLoc
loc) AppRes
res =
  forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn -> ExpBase f vn -> SrcLoc -> AppExpBase f vn
If forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> MonoM Exp
transformExp Exp
e2 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> MonoM Exp
transformExp Exp
e3 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. a -> Info a
Info forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AppRes -> MonoM AppRes
transformAppRes AppRes
res)
transformAppExp (Apply Exp
fe NonEmpty (Info (Diet, Maybe VName), Exp)
args SrcLoc
_) AppRes
res =
  forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
fe
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {t} {t}. (Info (t, t), Exp) -> MonoM (t, t, Exp)
onArg (forall a. NonEmpty a -> [a]
NE.toList NonEmpty (Info (Diet, Maybe VName), Exp)
args)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> AppRes -> MonoM AppRes
transformAppRes AppRes
res
  where
    onArg :: (Info (t, t), Exp) -> MonoM (t, t, Exp)
onArg (Info (t
d, t
ext), Exp
e) = (t
d,t
ext,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e
transformAppExp (DoLoop [VName]
sparams Pat ParamType
pat Exp
e1 LoopFormBase Info VName
form Exp
body SrcLoc
loc) AppRes
res = do
  Exp
e1' <- Exp -> MonoM Exp
transformExp Exp
e1

  let dimArgs :: Set VName
dimArgs = forall a. Ord a => [a] -> Set a
S.fromList [VName]
sparams
  Pat ParamType
pat' <- forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
dimArgs forall a b. (a -> b) -> a -> b
$ forall u. Pat (TypeBase Exp u) -> MonoM (Pat (TypeBase Exp u))
transformPat Pat ParamType
pat
  ExpReplacements
params <- Set VName -> MonoM ExpReplacements
parametrizing Set VName
dimArgs
  let sparams' :: [VName]
sparams' = [VName]
sparams forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd ExpReplacements
params
      mergeArgs :: Set VName
mergeArgs = Set VName
dimArgs forall a. Ord a => Set a -> Set a -> Set a
`S.union` forall a. Ord a => [a] -> Set a
S.fromList (forall t. Pat t -> [VName]
patNames Pat ParamType
pat)

  (LoopFormBase Info VName
form', Set VName
formArgs) <- case LoopFormBase Info VName
form of
    For IdentBase Info VName StructType
ident Exp
e2 -> (,forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall {k} (f :: k -> *) vn (t :: k). IdentBase f vn t -> vn
identName IdentBase Info VName StructType
ident) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) vn.
IdentBase f vn StructType -> ExpBase f vn -> LoopFormBase f vn
For IdentBase Info VName StructType
ident forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e2
    ForIn PatBase Info VName StructType
pat2 Exp
e2 -> do
      PatBase Info VName StructType
pat2' <- forall u. Pat (TypeBase Exp u) -> MonoM (Pat (TypeBase Exp u))
transformPat PatBase Info VName StructType
pat2
      (,forall a. Ord a => [a] -> Set a
S.fromList (forall t. Pat t -> [VName]
patNames PatBase Info VName StructType
pat2)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) vn.
PatBase f vn StructType -> ExpBase f vn -> LoopFormBase f vn
ForIn PatBase Info VName StructType
pat2' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e2
    While Exp
e2 ->
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,forall a. Monoid a => a
mempty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While) forall a b. (a -> b) -> a -> b
$
        forall a. ExpReplacements -> MonoM a -> MonoM a
withParams ExpReplacements
params forall a b. (a -> b) -> a -> b
$
          Set VName -> MonoM Exp -> MonoM Exp
scoping Set VName
mergeArgs forall a b. (a -> b) -> a -> b
$
            Exp -> MonoM Exp
transformExp Exp
e2
  let argset :: Set VName
argset = Set VName
mergeArgs forall a. Ord a => Set a -> Set a -> Set a
`S.union` Set VName
formArgs

  Exp
body' <- forall a. ExpReplacements -> MonoM a -> MonoM a
withParams ExpReplacements
params forall a b. (a -> b) -> a -> b
$ Set VName -> MonoM Exp -> MonoM Exp
scoping Set VName
argset forall a b. (a -> b) -> a -> b
$ Exp -> MonoM Exp
transformExp Exp
body
  -- Maybe monomorphisation introduced new arrays to the loop, and
  -- maybe they have AnySize sizes.  This is not allowed.  Invent some
  -- sizes for them.
  ([VName]
pat_sizes, Pat ParamType
pat'') <- forall (m :: * -> *).
MonadFreshNames m =>
Pat ParamType -> m ([VName], Pat ParamType)
sizesForPat Pat ParamType
pat'
  AppRes
res' <- AppRes -> MonoM AppRes
transformAppRes AppRes
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
[VName]
-> PatBase f vn ParamType
-> ExpBase f vn
-> LoopFormBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
DoLoop ([VName]
sparams' forall a. [a] -> [a] -> [a]
++ [VName]
pat_sizes) Pat ParamType
pat'' Exp
e1' LoopFormBase Info VName
form' Exp
body' SrcLoc
loc) (forall a. a -> Info a
Info AppRes
res')
transformAppExp (BinOp (QualName VName
fname, SrcLoc
_) (Info StructType
t) (Exp
e1, Info (Maybe VName)
d1) (Exp
e2, Info (Maybe VName)
d2) SrcLoc
loc) AppRes
res = do
  (AppRes StructType
ret [VName]
ext) <- AppRes -> MonoM AppRes
transformAppRes AppRes
res
  Exp
fname' <- SrcLoc -> QualName VName -> StructType -> MonoM Exp
transformFName SrcLoc
loc QualName VName
fname (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
t)
  Exp
e1' <- Exp -> MonoM Exp
transformExp Exp
e1
  Exp
e2' <- Exp -> MonoM Exp
transformExp Exp
e2
  if forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> StructType
typeOf Exp
e1') Bool -> Bool -> Bool
&& forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> StructType
typeOf Exp
e2')
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {vn}.
StructType
-> [VName]
-> ExpBase Info vn
-> ExpBase Info vn
-> ExpBase Info vn
-> ExpBase Info vn
applyOp StructType
ret [VName]
ext Exp
fname' Exp
e1' Exp
e2'
    else do
      -- We have to flip the arguments to the function, because
      -- operator application is left-to-right, while function
      -- application is outside-in.  This matters when the arguments
      -- produce existential sizes.  There are later places in the
      -- compiler where we transform BinOp to Apply, but anything that
      -- involves existential sizes will necessarily go through here.
      (Exp
x_param_e, PatBase Info VName StructType
x_param) <- forall {m :: * -> *}.
MonadFreshNames m =>
Exp -> m (Exp, PatBase Info VName StructType)
makeVarParam Exp
e1'
      (Exp
y_param_e, PatBase Info VName StructType
y_param) <- forall {m :: * -> *}.
MonadFreshNames m =>
Exp -> m (Exp, PatBase Info VName StructType)
makeVarParam Exp
e2'
      -- XXX: the type annotations here are wrong, but hopefully it
      -- doesn't matter as there will be an outer AppExp to handle
      -- them.
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
          ( forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat
              []
              PatBase Info VName StructType
x_param
              Exp
e1'
              ( forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
                  (forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat [] PatBase Info VName StructType
y_param Exp
e2' (forall {vn}.
StructType
-> [VName]
-> ExpBase Info vn
-> ExpBase Info vn
-> ExpBase Info vn
-> ExpBase Info vn
applyOp StructType
ret [VName]
ext Exp
fname' Exp
x_param_e Exp
y_param_e) SrcLoc
loc)
                  (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes StructType
ret forall a. Monoid a => a
mempty)
              )
              forall a. Monoid a => a
mempty
          )
          (forall a. a -> Info a
Info (StructType -> [VName] -> AppRes
AppRes StructType
ret forall a. Monoid a => a
mempty))
  where
    applyOp :: StructType
-> [VName]
-> ExpBase Info vn
-> ExpBase Info vn
-> ExpBase Info vn
-> ExpBase Info vn
applyOp StructType
ret [VName]
ext ExpBase Info vn
fname' ExpBase Info vn
x ExpBase Info vn
y =
      forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply
        (forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply ExpBase Info vn
fname' [(Diet
Observe, forall a. Info a -> a
unInfo Info (Maybe VName)
d1, ExpBase Info vn
x)] (StructType -> [VName] -> AppRes
AppRes StructType
ret forall a. Monoid a => a
mempty))
        [(Diet
Observe, forall a. Info a -> a
unInfo Info (Maybe VName)
d2, ExpBase Info vn
y)]
        (StructType -> [VName] -> AppRes
AppRes StructType
ret [VName]
ext)

    makeVarParam :: Exp -> m (Exp, PatBase Info VName StructType)
makeVarParam Exp
arg = do
      let argtype :: StructType
argtype = Exp -> StructType
typeOf Exp
arg
      VName
x <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString [Char]
"binop_p"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
x) (forall a. a -> Info a
Info StructType
argtype) forall a. Monoid a => a
mempty,
          forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
x (forall a. a -> Info a
Info StructType
argtype) forall a. Monoid a => a
mempty
        )
transformAppExp (LetWith IdentBase Info VName StructType
id1 IdentBase Info VName StructType
id2 SliceBase Info VName
idxs Exp
e1 Exp
body SrcLoc
loc) AppRes
res = do
  IdentBase Info VName StructType
id1' <- forall {f :: * -> *} {vn} {u}.
Traversable f =>
IdentBase f vn (TypeBase Exp u)
-> MonoM (IdentBase f vn (TypeBase Exp u))
transformIdent IdentBase Info VName StructType
id1
  IdentBase Info VName StructType
id2' <- forall {f :: * -> *} {vn} {u}.
Traversable f =>
IdentBase f vn (TypeBase Exp u)
-> MonoM (IdentBase f vn (TypeBase Exp u))
transformIdent IdentBase Info VName StructType
id2
  SliceBase Info VName
idxs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> MonoM (DimIndexBase Info VName)
transformDimIndex SliceBase Info VName
idxs
  Exp
e1' <- Exp -> MonoM Exp
transformExp Exp
e1
  Exp
body' <- Set VName -> MonoM Exp -> MonoM Exp
scoping (forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall {k} (f :: k -> *) vn (t :: k). IdentBase f vn t -> vn
identName IdentBase Info VName StructType
id1') forall a b. (a -> b) -> a -> b
$ Exp -> MonoM Exp
transformExp Exp
body
  AppRes
res' <- AppRes -> MonoM AppRes
transformAppRes AppRes
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
IdentBase f vn StructType
-> IdentBase f vn StructType
-> SliceBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetWith IdentBase Info VName StructType
id1' IdentBase Info VName StructType
id2' SliceBase Info VName
idxs' Exp
e1' Exp
body' SrcLoc
loc) (forall a. a -> Info a
Info AppRes
res')
  where
    transformIdent :: IdentBase f vn (TypeBase Exp u)
-> MonoM (IdentBase f vn (TypeBase Exp u))
transformIdent (Ident vn
v f (TypeBase Exp u)
t SrcLoc
vloc) =
      forall {k} (f :: k -> *) vn (t :: k).
vn -> f t -> SrcLoc -> IdentBase f vn t
Ident vn
v forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType f (TypeBase Exp u)
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
vloc
transformAppExp (Index Exp
e0 SliceBase Info VName
idxs SrcLoc
loc) AppRes
res =
  forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (f :: * -> *) vn.
ExpBase f vn -> SliceBase f vn -> SrcLoc -> AppExpBase f vn
Index forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e0 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> MonoM (DimIndexBase Info VName)
transformDimIndex SliceBase Info VName
idxs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. a -> Info a
Info forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AppRes -> MonoM AppRes
transformAppRes AppRes
res)
transformAppExp (Match Exp
e NonEmpty (CaseBase Info VName)
cs SrcLoc
loc) AppRes
res = do
  Set VName
implicitDims <- Set VName -> MonoM (Set VName)
askIntros forall a b. (a -> b) -> a -> b
$ FV -> Set VName
fvVars forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Exp u -> FV
freeInType forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e
  Exp
e' <- Exp -> MonoM Exp
transformExp Exp
e
  NonEmpty (CaseBase Info VName)
cs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set VName -> CaseBase Info VName -> MonoM (CaseBase Info VName)
transformCase Set VName
implicitDims) NonEmpty (CaseBase Info VName)
cs
  AppRes
res' <- AppRes -> MonoM AppRes
transformAppRes AppRes
res
  if forall a. Set a -> Bool
S.null Set VName
implicitDims
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn) -> SrcLoc -> AppExpBase f vn
Match Exp
e' NonEmpty (CaseBase Info VName)
cs' SrcLoc
loc) (forall a. a -> Info a
Info AppRes
res')
    else do
      VName
tmpVar <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString [Char]
"matched_variable"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
          ( forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat
              (forall a b. (a -> b) -> [a] -> [b]
map (forall vn. vn -> SrcLoc -> SizeBinder vn
`SizeBinder` forall a. Monoid a => a
mempty) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set VName
implicitDims)
              (forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
tmpVar (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e') forall a. Monoid a => a
mempty)
              Exp
e'
              ( forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
                  (forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn) -> SrcLoc -> AppExpBase f vn
Match (forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
tmpVar) (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e') forall a. Monoid a => a
mempty) NonEmpty (CaseBase Info VName)
cs' SrcLoc
loc)
                  (forall a. a -> Info a
Info AppRes
res)
              )
              forall a. Monoid a => a
mempty
          )
          (forall a. a -> Info a
Info AppRes
res')

-- Monomorphization of expressions.
transformExp :: Exp -> MonoM Exp
transformExp :: Exp -> MonoM Exp
transformExp e :: Exp
e@Literal {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
transformExp e :: Exp
e@IntLit {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
transformExp e :: Exp
e@FloatLit {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
transformExp e :: Exp
e@StringLit {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
transformExp (Parens Exp
e SrcLoc
loc) =
  forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (QualParens (QualName VName, SrcLoc)
qn Exp
e SrcLoc
loc) =
  forall (f :: * -> *) vn.
(QualName vn, SrcLoc) -> ExpBase f vn -> SrcLoc -> ExpBase f vn
QualParens (QualName VName, SrcLoc)
qn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (TupLit [Exp]
es SrcLoc
loc) =
  forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> MonoM Exp
transformExp [Exp]
es forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (RecordLit [FieldBase Info VName]
fs SrcLoc
loc) =
  forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FieldBase Info VName -> MonoM (FieldBase Info VName)
transformField [FieldBase Info VName]
fs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
  where
    transformField :: FieldBase Info VName -> MonoM (FieldBase Info VName)
transformField (RecordFieldExplicit Name
name Exp
e SrcLoc
loc') =
      forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc'
    transformField (RecordFieldImplicit VName
v Info StructType
t SrcLoc
_) = do
      Info StructType
t' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Info StructType
t
      FieldBase Info VName -> MonoM (FieldBase Info VName)
transformField forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit
          (VName -> Name
baseName VName
v)
          (forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
v) Info StructType
t' SrcLoc
loc)
          SrcLoc
loc
transformExp (ArrayLit [Exp]
es Info StructType
t SrcLoc
loc) =
  forall (f :: * -> *) vn.
[ExpBase f vn] -> f StructType -> SrcLoc -> ExpBase f vn
ArrayLit forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> MonoM Exp
transformExp [Exp]
es forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Info StructType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (AppExp AppExpBase Info VName
e Info AppRes
res) =
  AppExpBase Info VName -> AppRes -> MonoM Exp
transformAppExp AppExpBase Info VName
e (forall a. Info a -> a
unInfo Info AppRes
res)
transformExp (Var QualName VName
fname (Info StructType
t) SrcLoc
loc) =
  SrcLoc -> QualName VName -> StructType -> MonoM Exp
transformFName SrcLoc
loc QualName VName
fname (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
t)
transformExp (Hole Info StructType
t SrcLoc
loc) =
  forall (f :: * -> *) vn. f StructType -> SrcLoc -> ExpBase f vn
Hole forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Info StructType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Ascript Exp
e TypeExp Info VName
tp SrcLoc
loc) =
  forall (f :: * -> *) vn.
ExpBase f vn -> TypeExp f vn -> SrcLoc -> ExpBase f vn
Ascript forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeExp Info VName
tp forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Coerce Exp
e TypeExp Info VName
te Info StructType
t SrcLoc
loc) =
  forall (f :: * -> *) vn.
ExpBase f vn
-> TypeExp f vn -> f StructType -> SrcLoc -> ExpBase f vn
Coerce forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeExp Info VName
te forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Info StructType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Negate Exp
e SrcLoc
loc) =
  forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Negate forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Not Exp
e SrcLoc
loc) =
  forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Not forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Lambda [Pat ParamType]
params Exp
e0 Maybe (TypeExp Info VName)
decl Info ResRetType
tp SrcLoc
loc) = do
  let patArgs :: Set VName
patArgs = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall t. Pat t -> [VName]
patNames [Pat ParamType]
params
  Set VName
dimArgs <- forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
patArgs forall a b. (a -> b) -> a -> b
$ Set VName -> MonoM (Set VName)
askIntros (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (FV -> Set VName
fvVars forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. Pat (TypeBase Exp u) -> FV
freeInPat) [Pat ParamType]
params)
  let argset :: Set VName
argset = Set VName
dimArgs forall a. Ord a => Set a -> Set a -> Set a
`S.union` Set VName
patArgs
  [Pat ParamType]
params' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall u. Pat (TypeBase Exp u) -> MonoM (Pat (TypeBase Exp u))
transformPat [Pat ParamType]
params
  ExpReplacements
paramed <- Set VName -> MonoM ExpReplacements
parametrizing Set VName
argset
  forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda [Pat ParamType]
params'
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. ExpReplacements -> MonoM a -> MonoM a
withParams ExpReplacements
paramed (Set VName -> MonoM Exp -> MonoM Exp
scoping Set VName
argset forall a b. (a -> b) -> a -> b
$ Exp -> MonoM Exp
transformExp Exp
e0)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (TypeExp Info VName)
decl
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. RetTypeBase Exp u -> MonoM (RetTypeBase Exp u)
transformRetType Info ResRetType
tp
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (OpSection QualName VName
qn Info StructType
t SrcLoc
loc) =
  Exp -> MonoM Exp
transformExp forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn Info StructType
t SrcLoc
loc
transformExp (OpSectionLeft QualName VName
fname (Info StructType
t) Exp
e (Info (PName, ParamType, Maybe VName), Info (PName, ParamType))
arg (Info ResRetType
rettype, Info [VName]
retext) SrcLoc
loc) = do
  let (Info (PName
xp, ParamType
xtype, Maybe VName
xargext), Info (PName
yp, ParamType
ytype)) = (Info (PName, ParamType, Maybe VName), Info (PName, ParamType))
arg
  Exp
e' <- Exp -> MonoM Exp
transformExp Exp
e
  QualName VName
-> Maybe Exp
-> Maybe Exp
-> StructType
-> (PName, ParamType, Maybe VName)
-> (PName, ParamType, Maybe VName)
-> (ResRetType, [VName])
-> SrcLoc
-> MonoM Exp
desugarBinOpSection
    QualName VName
fname
    (forall a. a -> Maybe a
Just Exp
e')
    forall a. Maybe a
Nothing
    StructType
t
    (PName
xp, ParamType
xtype, Maybe VName
xargext)
    (PName
yp, ParamType
ytype, forall a. Maybe a
Nothing)
    (ResRetType
rettype, [VName]
retext)
    SrcLoc
loc
transformExp (OpSectionRight QualName VName
fname (Info StructType
t) Exp
e (Info (PName, ParamType), Info (PName, ParamType, Maybe VName))
arg (Info ResRetType
rettype) SrcLoc
loc) = do
  let (Info (PName
xp, ParamType
xtype), Info (PName
yp, ParamType
ytype, Maybe VName
yargext)) = (Info (PName, ParamType), Info (PName, ParamType, Maybe VName))
arg
  Exp
e' <- Exp -> MonoM Exp
transformExp Exp
e
  QualName VName
-> Maybe Exp
-> Maybe Exp
-> StructType
-> (PName, ParamType, Maybe VName)
-> (PName, ParamType, Maybe VName)
-> (ResRetType, [VName])
-> SrcLoc
-> MonoM Exp
desugarBinOpSection
    QualName VName
fname
    forall a. Maybe a
Nothing
    (forall a. a -> Maybe a
Just Exp
e')
    StructType
t
    (PName
xp, ParamType
xtype, forall a. Maybe a
Nothing)
    (PName
yp, ParamType
ytype, Maybe VName
yargext)
    (ResRetType
rettype, [])
    SrcLoc
loc
transformExp (ProjectSection [Name]
fields (Info StructType
t) SrcLoc
loc) = do
  StructType
t' <- forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType StructType
t
  [Name] -> StructType -> SrcLoc -> MonoM Exp
desugarProjectSection [Name]
fields StructType
t' SrcLoc
loc
transformExp (IndexSection SliceBase Info VName
idxs (Info StructType
t) SrcLoc
loc) = do
  SliceBase Info VName
idxs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> MonoM (DimIndexBase Info VName)
transformDimIndex SliceBase Info VName
idxs
  SliceBase Info VName -> StructType -> SrcLoc -> MonoM Exp
desugarIndexSection SliceBase Info VName
idxs' StructType
t SrcLoc
loc
transformExp (Project Name
n Exp
e Info StructType
tp SrcLoc
loc) = do
  Info StructType
tp' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Info StructType
tp
  Exp
e' <- Exp -> MonoM Exp
transformExp Exp
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f StructType -> SrcLoc -> ExpBase f vn
Project Name
n Exp
e' Info StructType
tp' SrcLoc
loc
transformExp (Update Exp
e1 SliceBase Info VName
idxs Exp
e2 SrcLoc
loc) =
  forall (f :: * -> *) vn.
ExpBase f vn
-> SliceBase f vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Update
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e1
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> MonoM (DimIndexBase Info VName)
transformDimIndex SliceBase Info VName
idxs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> MonoM Exp
transformExp Exp
e2
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (RecordUpdate Exp
e1 [Name]
fs Exp
e2 Info StructType
t SrcLoc
loc) =
  forall (f :: * -> *) vn.
ExpBase f vn
-> [Name] -> ExpBase f vn -> f StructType -> SrcLoc -> ExpBase f vn
RecordUpdate
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e1
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Name]
fs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> MonoM Exp
transformExp Exp
e2
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Info StructType
t
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Assert Exp
e1 Exp
e2 Info Text
desc SrcLoc
loc) =
  forall (f :: * -> *) vn.
ExpBase f vn -> ExpBase f vn -> f Text -> SrcLoc -> ExpBase f vn
Assert forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> MonoM Exp
transformExp Exp
e2 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Info Text
desc forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Constr Name
name [Exp]
all_es Info StructType
t SrcLoc
loc) =
  forall (f :: * -> *) vn.
Name -> [ExpBase f vn] -> f StructType -> SrcLoc -> ExpBase f vn
Constr Name
name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> MonoM Exp
transformExp [Exp]
all_es forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType Info StructType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
transformExp (Attr AttrInfo VName
info Exp
e SrcLoc
loc) =
  forall (f :: * -> *) vn.
AttrInfo vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Attr AttrInfo VName
info forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc

transformCase :: S.Set VName -> Case -> MonoM Case
transformCase :: Set VName -> CaseBase Info VName -> MonoM (CaseBase Info VName)
transformCase Set VName
implicitDims (CasePat PatBase Info VName StructType
p Exp
e SrcLoc
loc) = do
  PatBase Info VName StructType
p' <- forall u. Pat (TypeBase Exp u) -> MonoM (Pat (TypeBase Exp u))
transformPat PatBase Info VName StructType
p
  forall (f :: * -> *) vn.
PatBase f vn StructType -> ExpBase f vn -> SrcLoc -> CaseBase f vn
CasePat PatBase Info VName StructType
p' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set VName -> MonoM Exp -> MonoM Exp
scoping (forall a. Ord a => [a] -> Set a
S.fromList (forall t. Pat t -> [VName]
patNames PatBase Info VName StructType
p) forall a. Ord a => Set a -> Set a -> Set a
`S.union` Set VName
implicitDims) (Exp -> MonoM Exp
transformExp Exp
e) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc

transformDimIndex :: DimIndexBase Info VName -> MonoM (DimIndexBase Info VName)
transformDimIndex :: DimIndexBase Info VName -> MonoM (DimIndexBase Info VName)
transformDimIndex (DimFix Exp
e) = forall (f :: * -> *) vn. ExpBase f vn -> DimIndexBase f vn
DimFix forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> MonoM Exp
transformExp Exp
e
transformDimIndex (DimSlice Maybe Exp
me1 Maybe Exp
me2 Maybe Exp
me3) =
  forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp -> MonoM (Maybe Exp)
trans Maybe Exp
me1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> MonoM (Maybe Exp)
trans Maybe Exp
me2 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> MonoM (Maybe Exp)
trans Maybe Exp
me3
  where
    trans :: Maybe Exp -> MonoM (Maybe Exp)
trans = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> MonoM Exp
transformExp

-- Transform an operator section into a lambda.
desugarBinOpSection ::
  QualName VName ->
  Maybe Exp ->
  Maybe Exp ->
  StructType ->
  (PName, ParamType, Maybe VName) ->
  (PName, ParamType, Maybe VName) ->
  (ResRetType, [VName]) ->
  SrcLoc ->
  MonoM Exp
desugarBinOpSection :: QualName VName
-> Maybe Exp
-> Maybe Exp
-> StructType
-> (PName, ParamType, Maybe VName)
-> (PName, ParamType, Maybe VName)
-> (ResRetType, [VName])
-> SrcLoc
-> MonoM Exp
desugarBinOpSection QualName VName
fname Maybe Exp
e_left Maybe Exp
e_right StructType
t (PName
xp, ParamType
xtype, Maybe VName
xext) (PName
yp, ParamType
ytype, Maybe VName
yext) (RetType [VName]
dims TypeBase Exp Uniqueness
rettype, [VName]
retext) SrcLoc
loc = do
  StructType
t' <- forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType StructType
t
  Exp
op <- SrcLoc -> QualName VName -> StructType -> MonoM Exp
transformFName SrcLoc
loc QualName VName
fname forall a b. (a -> b) -> a -> b
$ forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
t
  (VName
v1, Exp -> Exp
wrap_left, Exp
e1, [Pat ParamType]
p1) <- forall {m :: * -> *} {u}.
MonadFreshNames m =>
Maybe Exp
-> TypeBase Exp u
-> m (VName, Exp -> Exp, Exp,
      [PatBase Info VName (TypeBase Exp u)])
makeVarParam Maybe Exp
e_left forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType ParamType
xtype
  (VName
v2, Exp -> Exp
wrap_right, Exp
e2, [Pat ParamType]
p2) <- forall {m :: * -> *} {u}.
MonadFreshNames m =>
Maybe Exp
-> TypeBase Exp u
-> m (VName, Exp -> Exp, Exp,
      [PatBase Info VName (TypeBase Exp u)])
makeVarParam Maybe Exp
e_right forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType ParamType
ytype
  let apply_left :: Exp
apply_left =
        forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply
          Exp
op
          [(Diet
Observe, Maybe VName
xext, Exp
e1)]
          (StructType -> [VName] -> AppRes
AppRes (forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall a b. (a -> b) -> a -> b
$ forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow forall a. Monoid a => a
mempty PName
yp (forall shape. TypeBase shape Diet -> Diet
diet ParamType
ytype) (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct ParamType
ytype) (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] forall a b. (a -> b) -> a -> b
$ forall u. Uniqueness -> TypeBase Exp u -> TypeBase Exp Uniqueness
toRes Uniqueness
Nonunique StructType
t')) [])
      onDim :: ExpBase f VName -> ExpBase f VName
onDim (Var QualName VName
d f StructType
typ SrcLoc
_)
        | Named VName
p <- PName
xp, forall vn. QualName vn -> vn
qualLeaf QualName VName
d forall a. Eq a => a -> a -> Bool
== VName
p = forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
v1) f StructType
typ SrcLoc
loc
        | Named VName
p <- PName
yp, forall vn. QualName vn -> vn
qualLeaf QualName VName
d forall a. Eq a => a -> a -> Bool
== VName
p = forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
v2) f StructType
typ SrcLoc
loc
      onDim ExpBase f VName
d = ExpBase f VName
d
      rettype' :: TypeBase Exp Uniqueness
rettype' = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall {f :: * -> *}. ExpBase f VName -> ExpBase f VName
onDim TypeBase Exp Uniqueness
rettype
  Exp
body <-
    Set VName -> MonoM Exp -> MonoM Exp
scoping (forall a. Ord a => [a] -> Set a
S.fromList [VName
v1, VName
v2]) forall a b. (a -> b) -> a -> b
$
      forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
apply_left [(Diet
Observe, Maybe VName
yext, Exp
e2)]
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AppRes -> MonoM AppRes
transformAppRes (StructType -> [VName] -> AppRes
AppRes (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase Exp Uniqueness
rettype') [VName]
retext)
  ResRetType
rettype'' <- forall as.
Set VName -> RetTypeBase Exp as -> MonoM (RetTypeBase Exp as)
transformRetTypeSizes (forall a. Ord a => [a] -> Set a
S.fromList [VName
v1, VName
v2]) forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims TypeBase Exp Uniqueness
rettype'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
wrap_left forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
wrap_right forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda ([Pat ParamType]
p1 forall a. [a] -> [a] -> [a]
++ [Pat ParamType]
p2) Exp
body forall a. Maybe a
Nothing (forall a. a -> Info a
Info ResRetType
rettype'') SrcLoc
loc
  where
    patAndVar :: TypeBase Exp u
-> m (VName, PatBase Info VName (TypeBase Exp u), Exp)
patAndVar TypeBase Exp u
argtype = do
      VName
x <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString [Char]
"x"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( VName
x,
          forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
x (forall a. a -> Info a
Info TypeBase Exp u
argtype) forall a. Monoid a => a
mempty,
          forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
x) (forall a. a -> Info a
Info (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase Exp u
argtype)) forall a. Monoid a => a
mempty
        )

    makeVarParam :: Maybe Exp
-> TypeBase Exp u
-> m (VName, Exp -> Exp, Exp,
      [PatBase Info VName (TypeBase Exp u)])
makeVarParam (Just Exp
e) TypeBase Exp u
argtype = do
      (VName
v, PatBase Info VName (TypeBase Exp u)
pat, Exp
var_e) <- forall {m :: * -> *} {u}.
MonadFreshNames m =>
TypeBase Exp u
-> m (VName, PatBase Info VName (TypeBase Exp u), Exp)
patAndVar TypeBase Exp u
argtype
      let wrap :: Exp -> Exp
wrap Exp
body =
            forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn StructType
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat [] (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct PatBase Info VName (TypeBase Exp u)
pat) Exp
e Exp
body forall a. Monoid a => a
mempty) (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes (Exp -> StructType
typeOf Exp
body) forall a. Monoid a => a
mempty)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
v, Exp -> Exp
wrap, Exp
var_e, [])
    makeVarParam Maybe Exp
Nothing TypeBase Exp u
argtype = do
      (VName
v, PatBase Info VName (TypeBase Exp u)
pat, Exp
var_e) <- forall {m :: * -> *} {u}.
MonadFreshNames m =>
TypeBase Exp u
-> m (VName, PatBase Info VName (TypeBase Exp u), Exp)
patAndVar TypeBase Exp u
argtype
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
v, forall a. a -> a
id, Exp
var_e, [PatBase Info VName (TypeBase Exp u)
pat])

desugarProjectSection :: [Name] -> StructType -> SrcLoc -> MonoM Exp
desugarProjectSection :: [Name] -> StructType -> SrcLoc -> MonoM Exp
desugarProjectSection [Name]
fields (Scalar (Arrow NoUniqueness
_ PName
_ Diet
_ StructType
t1 (RetType [VName]
dims TypeBase Exp Uniqueness
t2))) SrcLoc
loc = do
  VName
p <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"project_p"
  let body :: Exp
body = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Name -> Exp
project (forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
p) (forall a. a -> Info a
Info StructType
t1) forall a. Monoid a => a
mempty) [Name]
fields
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda
      [forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
p (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
t1) forall a. Monoid a => a
mempty]
      Exp
body
      forall a. Maybe a
Nothing
      (forall a. a -> Info a
Info (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims TypeBase Exp Uniqueness
t2))
      SrcLoc
loc
  where
    project :: Exp -> Name -> Exp
project Exp
e Name
field =
      case Exp -> StructType
typeOf Exp
e of
        Scalar (Record Map Name StructType
fs)
          | Just StructType
t <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
field Map Name StructType
fs ->
              forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f StructType -> SrcLoc -> ExpBase f vn
Project Name
field Exp
e (forall a. a -> Info a
Info StructType
t) forall a. Monoid a => a
mempty
        StructType
t ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
            [Char]
"desugarOpSection: type "
              forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString StructType
t
              forall a. [a] -> [a] -> [a]
++ [Char]
" does not have field "
              forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Name
field
desugarProjectSection [Name]
_ StructType
t SrcLoc
_ = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"desugarOpSection: not a function type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString StructType
t

desugarIndexSection :: [DimIndex] -> StructType -> SrcLoc -> MonoM Exp
desugarIndexSection :: SliceBase Info VName -> StructType -> SrcLoc -> MonoM Exp
desugarIndexSection SliceBase Info VName
idxs (Scalar (Arrow NoUniqueness
_ PName
_ Diet
_ StructType
t1 (RetType [VName]
dims TypeBase Exp Uniqueness
t2))) SrcLoc
loc = do
  VName
p <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"index_i"
  StructType
t1' <- forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType StructType
t1
  TypeBase Exp Uniqueness
t2' <- forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType TypeBase Exp Uniqueness
t2
  let body :: Exp
body = forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn -> SliceBase f vn -> SrcLoc -> AppExpBase f vn
Index (forall (f :: * -> *) vn.
QualName vn -> f StructType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
p) (forall a. a -> Info a
Info StructType
t1') SrcLoc
loc) SliceBase Info VName
idxs SrcLoc
loc) (forall a. a -> Info a
Info (StructType -> [VName] -> AppRes
AppRes (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase Exp Uniqueness
t2') []))
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) vn.
[PatBase f vn ParamType]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f ResRetType
-> SrcLoc
-> ExpBase f vn
Lambda
      [forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
p (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe StructType
t1') forall a. Monoid a => a
mempty]
      Exp
body
      forall a. Maybe a
Nothing
      (forall a. a -> Info a
Info (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims TypeBase Exp Uniqueness
t2'))
      SrcLoc
loc
desugarIndexSection SliceBase Info VName
_ StructType
t SrcLoc
_ = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"desugarIndexSection: not a function type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString StructType
t

-- Convert a collection of 'ValBind's to a nested sequence of let-bound,
-- monomorphic functions with the given expression at the bottom.
unfoldLetFuns :: [ValBind] -> Exp -> Exp
unfoldLetFuns :: [ValBind] -> Exp -> Exp
unfoldLetFuns [] Exp
e = Exp
e
unfoldLetFuns (ValBind Maybe (Info EntryPoint)
_ VName
fname Maybe (TypeExp Info VName)
_ (Info ResRetType
rettype) [TypeParamBase VName]
dim_params [Pat ParamType]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
_ SrcLoc
loc : [ValBind]
rest) Exp
e =
  forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
vn
-> ([TypeParamBase vn], [PatBase f vn ParamType],
    Maybe (TypeExp f vn), f ResRetType, ExpBase f vn)
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetFun VName
fname ([TypeParamBase VName]
dim_params, [Pat ParamType]
params, forall a. Maybe a
Nothing, forall a. a -> Info a
Info ResRetType
rettype, Exp
body) Exp
e' SrcLoc
loc) (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StructType -> [VName] -> AppRes
AppRes StructType
e_t forall a. Monoid a => a
mempty)
  where
    e' :: Exp
e' = [ValBind] -> Exp -> Exp
unfoldLetFuns [ValBind]
rest Exp
e
    e_t :: StructType
e_t = Exp -> StructType
typeOf Exp
e'

transformPat :: Pat (TypeBase Size u) -> MonoM (Pat (TypeBase Size u))
transformPat :: forall u. Pat (TypeBase Exp u) -> MonoM (Pat (TypeBase Exp u))
transformPat = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType

type DimInst = M.Map VName Size

dimMapping ::
  Monoid a =>
  TypeBase Size a ->
  TypeBase Size a ->
  ExpReplacements ->
  ExpReplacements ->
  DimInst
dimMapping :: forall a.
Monoid a =>
TypeBase Exp a
-> TypeBase Exp a -> ExpReplacements -> ExpReplacements -> DimInst
dimMapping TypeBase Exp a
t1 TypeBase Exp a
t2 ExpReplacements
r1 ExpReplacements
r2 = forall s a. State s a -> s -> s
execState (forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
([VName] -> d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims forall {m :: * -> *} {t :: * -> *}.
(Foldable t, MonadState DimInst m) =>
t VName -> Exp -> Exp -> m Exp
onDims TypeBase Exp a
t1 TypeBase Exp a
t2) forall a. Monoid a => a
mempty
  where
    revMap :: [(b, a)] -> [(a, b)]
revMap = forall a b. (a -> b) -> [a] -> [b]
map (\(b
k, a
v) -> (a
v, b
k))
    named1 :: [(VName, ReplacedExp)]
named1 = forall {b} {a}. [(b, a)] -> [(a, b)]
revMap ExpReplacements
r1
    named2 :: [(VName, ReplacedExp)]
named2 = forall {b} {a}. [(b, a)] -> [(a, b)]
revMap ExpReplacements
r2

    onDims :: t VName -> Exp -> Exp -> m Exp
onDims t VName
bound Exp
e1 Exp
e2 = do
      forall {m :: * -> *} {t :: * -> *}.
(Foldable t, MonadState DimInst m) =>
t VName -> Exp -> Exp -> m ()
onExps t VName
bound Exp
e1 Exp
e2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e1

    onExps :: t VName -> Exp -> Exp -> m ()
onExps t VName
bound (Var QualName VName
v Info StructType
_ SrcLoc
_) Exp
e = do
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t VName
bound) forall a b. (a -> b) -> a -> b
$ Exp -> Set VName
freeVarsInExp Exp
e) forall a b. (a -> b) -> a -> b
$
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Exp
e)
      case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) [(VName, ReplacedExp)]
named1 of
        Just ReplacedExp
rexp -> t VName -> Exp -> Exp -> m ()
onExps t VName
bound (ReplacedExp -> Exp
unReplaced ReplacedExp
rexp) Exp
e
        Maybe ReplacedExp
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    onExps t VName
bound Exp
e (Var QualName VName
v Info StructType
_ SrcLoc
_)
      | Just ReplacedExp
rexp <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) [(VName, ReplacedExp)]
named2 = t VName -> Exp -> Exp -> m ()
onExps t VName
bound Exp
e (ReplacedExp -> Exp
unReplaced ReplacedExp
rexp)
    onExps t VName
bound Exp
e1 Exp
e2
      | Just [(Exp, Exp)]
es <- Exp -> Exp -> Maybe [(Exp, Exp)]
similarExps Exp
e1 Exp
e2 =
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ t VName -> Exp -> Exp -> m ()
onExps t VName
bound) [(Exp, Exp)]
es
    onExps t VName
_ Exp
_ Exp
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

    freeVarsInExp :: Exp -> Set VName
freeVarsInExp = FV -> Set VName
fvVars forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> FV
freeInExp

inferSizeArgs :: [TypeParam] -> StructType -> ExpReplacements -> StructType -> MonoM [Exp]
inferSizeArgs :: [TypeParamBase VName]
-> StructType -> ExpReplacements -> InferSizeArgs
inferSizeArgs [TypeParamBase VName]
tparams StructType
bind_t ExpReplacements
bind_r StructType
t = do
  ExpReplacements
r <- forall s (m :: * -> *). MonadState s m => m s
get
  let dinst :: DimInst
dinst = forall a.
Monoid a =>
TypeBase Exp a
-> TypeBase Exp a -> ExpReplacements -> ExpReplacements -> DimInst
dimMapping StructType
bind_t StructType
t ExpReplacements
bind_r ExpReplacements
r
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k}. Ord k => Map k Exp -> TypeParamBase k -> MonoM Exp
tparamArg DimInst
dinst) [TypeParamBase VName]
tparams
  where
    tparamArg :: Map k Exp -> TypeParamBase k -> MonoM Exp
tparamArg Map k Exp
dinst TypeParamBase k
tp =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall vn. TypeParamBase vn -> vn
typeParamName TypeParamBase k
tp) Map k Exp
dinst of
        Just Exp
e ->
          Exp -> MonoM Exp
replaceExp Exp
e
        Maybe Exp
Nothing ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Integer -> SrcLoc -> Exp
sizeFromInteger Integer
0 forall a. Monoid a => a
mempty

-- Monomorphising higher-order functions can result in function types
-- where the same named parameter occurs in multiple spots.  When
-- monomorphising we don't really need those parameter names anymore,
-- and the defunctionaliser can be confused if there are duplicates
-- (it doesn't handle shadowing), so let's just remove all parameter
-- names here.  This is safe because a MonoType does not contain sizes
-- anyway.
noNamedParams :: MonoType -> MonoType
noNamedParams :: MonoType -> MonoType
noNamedParams = forall u. TypeBase MonoSize u -> TypeBase MonoSize u
f
  where
    f :: TypeBase MonoSize u -> TypeBase MonoSize u
    f :: forall u. TypeBase MonoSize u -> TypeBase MonoSize u
f (Array u
u Shape MonoSize
shape ScalarTypeBase MonoSize NoUniqueness
t) = forall dim u.
u -> Shape dim -> ScalarTypeBase dim NoUniqueness -> TypeBase dim u
Array u
u Shape MonoSize
shape (forall {u}. ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u
f' ScalarTypeBase MonoSize NoUniqueness
t)
    f (Scalar ScalarTypeBase MonoSize u
t) = forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall a b. (a -> b) -> a -> b
$ forall {u}. ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u
f' ScalarTypeBase MonoSize u
t
    f' :: ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u
    f' :: forall {u}. ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u
f' (Record Map Name (TypeBase MonoSize u)
fs) = forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall u. TypeBase MonoSize u -> TypeBase MonoSize u
f Map Name (TypeBase MonoSize u)
fs
    f' (Sum Map Name [TypeBase MonoSize u]
cs) = forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase MonoSize u -> TypeBase MonoSize u
f) Map Name [TypeBase MonoSize u]
cs
    f' (Arrow u
u PName
_ Diet
d1 MonoType
t1 (RetType [VName]
dims TypeBase MonoSize Uniqueness
t2)) =
      forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow u
u PName
Unnamed Diet
d1 (forall u. TypeBase MonoSize u -> TypeBase MonoSize u
f MonoType
t1) (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims (forall u. TypeBase MonoSize u -> TypeBase MonoSize u
f TypeBase MonoSize Uniqueness
t2))
    f' ScalarTypeBase MonoSize u
t = ScalarTypeBase MonoSize u
t

transformRetType :: RetTypeBase Size u -> MonoM (RetTypeBase Size u)
transformRetType :: forall u. RetTypeBase Exp u -> MonoM (RetTypeBase Exp u)
transformRetType (RetType [VName]
ext TypeBase Exp u
t) = forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ext forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType TypeBase Exp u
t

-- | arrowArg takes a return type and returns it
-- with the existentials bound moved at the right of arrows.
-- It also gives the new set of parameters to consider.
arrowArg ::
  S.Set VName -> -- scope
  S.Set VName -> -- set of argument
  [VName] -> -- size parameters
  RetTypeBase Size as ->
  (RetTypeBase Size as, S.Set VName)
arrowArg :: forall as.
Set VName
-> Set VName
-> [VName]
-> RetTypeBase Exp as
-> (RetTypeBase Exp as, Set VName)
arrowArg Set VName
scope Set VName
argset [VName]
args_params RetTypeBase Exp as
rety =
  let (RetTypeBase Exp as
rety', (Set VName
funArgs, Set VName
_)) = forall w a. Writer w a -> (a, w)
runWriter (forall as'.
(Set VName, [VName])
-> Set VName
-> RetTypeBase Exp as'
-> Writer (Set VName, Set VName) (RetTypeBase Exp as')
arrowArgRetType (Set VName
scope, forall a. Monoid a => a
mempty) Set VName
argset RetTypeBase Exp as
rety)
      new_params :: Set VName
new_params = Set VName
funArgs forall a. Ord a => Set a -> Set a -> Set a
`S.union` forall a. Ord a => [a] -> Set a
S.fromList [VName]
args_params
   in (forall as. Set VName -> RetTypeBase Exp as -> RetTypeBase Exp as
arrowCleanRetType Set VName
new_params RetTypeBase Exp as
rety', Set VName
new_params)
  where
    -- \| takes a type (or return type) and returns it
    -- with the existentials bound moved at the right of arrows.
    -- It also gives (through writer monad) size variables used in arrow arguments
    -- and variables that are constructively used.
    -- The returned type should be cleanned, as too many existentials are introduced.
    arrowArgRetType ::
      (S.Set VName, [VName]) ->
      S.Set VName ->
      RetTypeBase Size as' ->
      Writer (S.Set VName, S.Set VName) (RetTypeBase Size as')
    arrowArgRetType :: forall as'.
(Set VName, [VName])
-> Set VName
-> RetTypeBase Exp as'
-> Writer (Set VName, Set VName) (RetTypeBase Exp as')
arrowArgRetType (Set VName
scope', [VName]
dimsToPush) Set VName
argset' (RetType [VName]
dims TypeBase Exp as'
ty) = forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass forall a b. (a -> b) -> a -> b
$ do
      let dims' :: [VName]
dims' = [VName]
dims forall a. Semigroup a => a -> a -> a
<> [VName]
dimsToPush
      (TypeBase Exp as'
ty', (Set VName
_, Set VName
canExt)) <- forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen forall a b. (a -> b) -> a -> b
$ forall as'.
(Set VName, [VName])
-> TypeBase Exp as'
-> Writer (Set VName, Set VName) (TypeBase Exp as')
arrowArgType (Set VName
argset' forall a. Ord a => Set a -> Set a -> Set a
`S.union` Set VName
scope', [VName]
dims') TypeBase Exp as'
ty
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType (forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
canExt) [VName]
dims') TypeBase Exp as'
ty', forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Set VName
canExt))

    arrowArgScalar :: (Set VName, [VName])
-> ScalarTypeBase Exp u
-> WriterT (Set VName, Set VName) Identity (ScalarTypeBase Exp u)
arrowArgScalar (Set VName, [VName])
env (Record Map Name (TypeBase Exp u)
fs) =
      forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall as'.
(Set VName, [VName])
-> TypeBase Exp as'
-> Writer (Set VName, Set VName) (TypeBase Exp as')
arrowArgType (Set VName, [VName])
env) Map Name (TypeBase Exp u)
fs
    arrowArgScalar (Set VName, [VName])
env (Sum Map Name [TypeBase Exp u]
cs) =
      forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse) (forall as'.
(Set VName, [VName])
-> TypeBase Exp as'
-> Writer (Set VName, Set VName) (TypeBase Exp as')
arrowArgType (Set VName, [VName])
env) Map Name [TypeBase Exp u]
cs
    arrowArgScalar (Set VName
scope', [VName]
dimsToPush) (Arrow u
as PName
argName Diet
d StructType
argT ResRetType
retT) =
      forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass forall a b. (a -> b) -> a -> b
$ do
        let intros :: Set VName
intros = forall a. (a -> Bool) -> Set a -> Set a
S.filter VName -> Bool
notIntrisic Set VName
argset' forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Set VName
scope'
        ResRetType
retT' <- forall as'.
(Set VName, [VName])
-> Set VName
-> RetTypeBase Exp as'
-> Writer (Set VName, Set VName) (RetTypeBase Exp as')
arrowArgRetType (Set VName
scope', forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
intros) [VName]
dimsToPush) Set VName
fullArgset ResRetType
retT
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow u
as PName
argName Diet
d StructType
argT ResRetType
retT', forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (Set VName
intros `S.union`) (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty))
      where
        notIntrisic :: VName -> Bool
notIntrisic VName
vn = VName -> Int
baseTag VName
vn forall a. Ord a => a -> a -> Bool
> Int
maxIntrinsicTag
        argset' :: Set VName
argset' = FV -> Set VName
fvVars forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Exp u -> FV
freeInType StructType
argT
        fullArgset :: Set VName
fullArgset =
          Set VName
argset'
            forall a. Semigroup a => a -> a -> a
<> case PName
argName of
              PName
Unnamed -> forall a. Monoid a => a
mempty
              Named VName
vn -> forall a. a -> Set a
S.singleton VName
vn
    arrowArgScalar (Set VName, [VName])
env (TypeVar u
u QualName VName
qn [TypeArg Exp]
args) =
      forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar u
u QualName VName
qn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeArg Exp
-> WriterT (Set VName, Set VName) Identity (TypeArg Exp)
arrowArgArg [TypeArg Exp]
args
      where
        arrowArgArg :: TypeArg Exp
-> WriterT (Set VName, Set VName) Identity (TypeArg Exp)
arrowArgArg (TypeArgDim Exp
dim) = forall dim. dim -> TypeArg dim
TypeArgDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {a} {a} {m :: * -> *} {f :: * -> *}.
(MonadWriter (a, Set a) m, Monoid a) =>
ExpBase f a -> m (ExpBase f a)
arrowArgSize Exp
dim
        arrowArgArg (TypeArgType StructType
ty) = forall dim. TypeBase dim NoUniqueness -> TypeArg dim
TypeArgType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall as'.
(Set VName, [VName])
-> TypeBase Exp as'
-> Writer (Set VName, Set VName) (TypeBase Exp as')
arrowArgType (Set VName, [VName])
env StructType
ty
    arrowArgScalar (Set VName, [VName])
_ ScalarTypeBase Exp u
ty = forall (f :: * -> *) a. Applicative f => a -> f a
pure ScalarTypeBase Exp u
ty

    arrowArgType ::
      (S.Set VName, [VName]) ->
      TypeBase Size as' ->
      Writer (S.Set VName, S.Set VName) (TypeBase Size as')
    arrowArgType :: forall as'.
(Set VName, [VName])
-> TypeBase Exp as'
-> Writer (Set VName, Set VName) (TypeBase Exp as')
arrowArgType (Set VName, [VName])
env (Array as'
u Shape Exp
shape ScalarTypeBase Exp NoUniqueness
scalar) =
      forall dim u.
u -> Shape dim -> ScalarTypeBase dim NoUniqueness -> TypeBase dim u
Array as'
u forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {a} {a} {m :: * -> *} {f :: * -> *}.
(MonadWriter (a, Set a) m, Monoid a) =>
ExpBase f a -> m (ExpBase f a)
arrowArgSize Shape Exp
shape forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {u}.
(Set VName, [VName])
-> ScalarTypeBase Exp u
-> WriterT (Set VName, Set VName) Identity (ScalarTypeBase Exp u)
arrowArgScalar (Set VName, [VName])
env ScalarTypeBase Exp NoUniqueness
scalar
    arrowArgType (Set VName, [VName])
env (Scalar ScalarTypeBase Exp as'
ty) =
      forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {u}.
(Set VName, [VName])
-> ScalarTypeBase Exp u
-> WriterT (Set VName, Set VName) Identity (ScalarTypeBase Exp u)
arrowArgScalar (Set VName, [VName])
env ScalarTypeBase Exp as'
ty

    arrowArgSize :: ExpBase f a -> m (ExpBase f a)
arrowArgSize s :: ExpBase f a
s@(Var QualName a
qn f StructType
_ SrcLoc
_) = forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (ExpBase f a
s, (forall a. Monoid a => a
mempty, forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall vn. QualName vn -> vn
qualLeaf QualName a
qn))
    arrowArgSize ExpBase f a
s = forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpBase f a
s

    -- \| arrowClean cleans the mess in the type
    arrowCleanRetType :: S.Set VName -> RetTypeBase Size as -> RetTypeBase Size as
    arrowCleanRetType :: forall as. Set VName -> RetTypeBase Exp as -> RetTypeBase Exp as
arrowCleanRetType Set VName
paramed (RetType [VName]
dims TypeBase Exp as
ty) =
      forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType (forall a. Ord a => [a] -> [a]
nubOrd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
paramed) [VName]
dims) (forall as. Set VName -> TypeBase Exp as -> TypeBase Exp as
arrowCleanType (Set VName
paramed forall a. Ord a => Set a -> Set a -> Set a
`S.union` forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims) TypeBase Exp as
ty)

    arrowCleanScalar :: S.Set VName -> ScalarTypeBase Size as -> ScalarTypeBase Size as
    arrowCleanScalar :: forall as.
Set VName -> ScalarTypeBase Exp as -> ScalarTypeBase Exp as
arrowCleanScalar Set VName
paramed (Record Map Name (TypeBase Exp as)
fs) =
      forall dim u. Map Name (TypeBase dim u) -> ScalarTypeBase dim u
Record forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall as. Set VName -> TypeBase Exp as -> TypeBase Exp as
arrowCleanType Set VName
paramed) Map Name (TypeBase Exp as)
fs
    arrowCleanScalar Set VName
paramed (Sum Map Name [TypeBase Exp as]
cs) =
      forall dim u. Map Name [TypeBase dim u] -> ScalarTypeBase dim u
Sum forall a b. (a -> b) -> a -> b
$ (forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map) (forall as. Set VName -> TypeBase Exp as -> TypeBase Exp as
arrowCleanType Set VName
paramed) Map Name [TypeBase Exp as]
cs
    arrowCleanScalar Set VName
paramed (Arrow as
as PName
argName Diet
d StructType
argT ResRetType
retT) =
      forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow as
as PName
argName Diet
d StructType
argT (forall as. Set VName -> RetTypeBase Exp as -> RetTypeBase Exp as
arrowCleanRetType Set VName
paramed ResRetType
retT)
    arrowCleanScalar Set VName
paramed (TypeVar as
u QualName VName
qn [TypeArg Exp]
args) =
      forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar as
u QualName VName
qn forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map TypeArg Exp -> TypeArg Exp
arrowCleanArg [TypeArg Exp]
args
      where
        arrowCleanArg :: TypeArg Exp -> TypeArg Exp
arrowCleanArg (TypeArgDim Exp
dim) = forall dim. dim -> TypeArg dim
TypeArgDim Exp
dim
        arrowCleanArg (TypeArgType StructType
ty) = forall dim. TypeBase dim NoUniqueness -> TypeArg dim
TypeArgType forall a b. (a -> b) -> a -> b
$ forall as. Set VName -> TypeBase Exp as -> TypeBase Exp as
arrowCleanType Set VName
paramed StructType
ty
    arrowCleanScalar Set VName
_ ScalarTypeBase Exp as
ty = ScalarTypeBase Exp as
ty

    arrowCleanType :: S.Set VName -> TypeBase Size as -> TypeBase Size as
    arrowCleanType :: forall as. Set VName -> TypeBase Exp as -> TypeBase Exp as
arrowCleanType Set VName
paramed (Array as
u Shape Exp
shape ScalarTypeBase Exp NoUniqueness
scalar) =
      forall dim u.
u -> Shape dim -> ScalarTypeBase dim NoUniqueness -> TypeBase dim u
Array as
u Shape Exp
shape forall a b. (a -> b) -> a -> b
$ forall as.
Set VName -> ScalarTypeBase Exp as -> ScalarTypeBase Exp as
arrowCleanScalar Set VName
paramed ScalarTypeBase Exp NoUniqueness
scalar
    arrowCleanType Set VName
paramed (Scalar ScalarTypeBase Exp as
ty) =
      forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall a b. (a -> b) -> a -> b
$ forall as.
Set VName -> ScalarTypeBase Exp as -> ScalarTypeBase Exp as
arrowCleanScalar Set VName
paramed ScalarTypeBase Exp as
ty

-- Monomorphise a polymorphic function at the types given in the instance
-- list. Monomorphises the body of the function as well. Returns the fresh name
-- of the generated monomorphic function and its 'ValBind' representation.
monomorphiseBinding ::
  Bool ->
  PolyBinding ->
  MonoType ->
  MonoM (VName, InferSizeArgs, ValBind)
monomorphiseBinding :: Bool
-> PolyBinding -> MonoType -> MonoM (VName, InferSizeArgs, ValBind)
monomorphiseBinding Bool
entry (PolyBinding (VName
name, [TypeParamBase VName]
tparams, [Pat ParamType]
params, ResRetType
rettype, Exp
body, [AttrInfo VName]
attrs, SrcLoc
loc)) MonoType
inst_t = do
  Bool
letFun <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Bool
S.member VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Set VName
envScope
  let paramGetClean :: Set VName -> MonoM ExpReplacements
paramGetClean Set VName
argset =
        if Bool
letFun
          then Set VName -> MonoM ExpReplacements
parametrizing Set VName
argset
          else do
            ExpReplacements
ret <- forall s (m :: * -> *). MonadState s m => m s
get
            forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a. Monoid a => a
mempty
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReplacements
ret
  (if Bool
letFun then forall a. a -> a
id else forall a. MonoM a -> MonoM a
isolateNormalisation) forall a b. (a -> b) -> a -> b
$ do
    let bind_t :: StructType
bind_t = [Pat ParamType] -> ResRetType -> StructType
funType [Pat ParamType]
params ResRetType
rettype
    (Map VName StructRetType
substs, [TypeParamBase VName]
t_shape_params) <-
      forall (m :: * -> *).
MonadFreshNames m =>
SrcLoc
-> TypeBase () NoUniqueness
-> MonoType
-> m (Map VName StructRetType, [TypeParamBase VName])
typeSubstsM SrcLoc
loc (forall as. TypeBase Exp as -> TypeBase () as
noSizes StructType
bind_t) forall a b. (a -> b) -> a -> b
$ MonoType -> MonoType
noNamedParams MonoType
inst_t
    let shape_names :: Set VName
shape_names = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall vn. TypeParamBase vn -> vn
typeParamName forall a b. (a -> b) -> a -> b
$ [TypeParamBase VName]
shape_params forall a. [a] -> [a] -> [a]
++ [TypeParamBase VName]
t_shape_params
        substs' :: Map VName (Subst StructRetType)
substs' = forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall t. [TypeParamBase VName] -> t -> Subst t
Subst []) Map VName StructRetType
substs
        substStructType :: ParamType -> ParamType
substStructType =
          forall as.
Monoid as =>
(VName -> Maybe (Subst (RetTypeBase Exp as)))
-> TypeBase Exp as -> TypeBase Exp as
substTypesAny (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty))) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
substs'))
        params' :: [Pat ParamType]
params' = forall a b. (a -> b) -> [a] -> [b]
map (forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
entry ParamType -> ParamType
substStructType) [Pat ParamType]
params
    [Pat ParamType]
params'' <- forall a. Set VName -> MonoM a -> MonoM a
withArgs Set VName
shape_names forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall u. Pat (TypeBase Exp u) -> MonoM (Pat (TypeBase Exp u))
transformPat [Pat ParamType]
params'
    ExpReplacements
exp_naming <- Set VName -> MonoM ExpReplacements
paramGetClean Set VName
shape_names

    let args :: Set VName
args = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall t. Pat t -> [VName]
patNames [Pat ParamType]
params
        arg_params :: [VName]
arg_params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd ExpReplacements
exp_naming

    ResRetType
rettype' <-
      forall a. ExpReplacements -> MonoM a -> MonoM a
withParams ExpReplacements
exp_naming forall a b. (a -> b) -> a -> b
$
        forall a. Set VName -> MonoM a -> MonoM a
withArgs (Set VName
args forall a. Semigroup a => a -> a -> a
<> Set VName
shape_names) forall a b. (a -> b) -> a -> b
$
          forall u. RetTypeBase Exp u -> MonoM (RetTypeBase Exp u)
hardTransformRetType (forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
substs') ResRetType
rettype)
    ExpReplacements
extNaming <- Set VName -> MonoM ExpReplacements
paramGetClean (Set VName
args forall a. Semigroup a => a -> a -> a
<> Set VName
shape_names)
    Set VName
scope <- forall a. Ord a => Set a -> Set a -> Set a
S.union Set VName
shape_names forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM (Set VName)
askScope'
    let (ResRetType
rettype'', Set VName
new_params) = forall as.
Set VName
-> Set VName
-> [VName]
-> RetTypeBase Exp as
-> (RetTypeBase Exp as, Set VName)
arrowArg Set VName
scope Set VName
args [VName]
arg_params ResRetType
rettype'
        bind_t' :: StructType
bind_t' = forall as.
Monoid as =>
(VName -> Maybe (Subst (RetTypeBase Exp as)))
-> TypeBase Exp as -> TypeBase Exp as
substTypesAny (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
substs') StructType
bind_t
        ([TypeParamBase VName]
shape_params_explicit, [TypeParamBase VName]
shape_params_implicit) =
          forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((forall a. Ord a => a -> Set a -> Bool
`S.member` (StructType -> Set VName
mustBeExplicitInBinding StructType
bind_t'' forall a. Ord a => Set a -> Set a -> Set a
`S.union` StructType -> Set VName
mustBeExplicitInBinding StructType
bind_t')) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall vn. TypeParamBase vn -> vn
typeParamName) forall a b. (a -> b) -> a -> b
$
            [TypeParamBase VName]
shape_params forall a. [a] -> [a] -> [a]
++ [TypeParamBase VName]
t_shape_params forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` forall a. Monoid a => a
mempty) (forall a. Set a -> [a]
S.toList Set VName
new_params)
        exp_naming' :: ExpReplacements
exp_naming' = forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
new_params) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (ExpReplacements
extNaming forall a. Semigroup a => a -> a -> a
<> ExpReplacements
exp_naming)

        bind_t'' :: StructType
bind_t'' = [Pat ParamType] -> ResRetType -> StructType
funType [Pat ParamType]
params'' ResRetType
rettype''
        bind_r :: ExpReplacements
bind_r = ExpReplacements
exp_naming forall a. Semigroup a => a -> a -> a
<> ExpReplacements
extNaming
    Exp
body' <- forall {m :: * -> *}. Monad m => TypeSubs -> Exp -> m Exp
updateExpTypes (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
substs') Exp
body
    Exp
body'' <- forall a. ExpReplacements -> MonoM a -> MonoM a
withParams ExpReplacements
exp_naming' forall a b. (a -> b) -> a -> b
$ forall a. Set VName -> MonoM a -> MonoM a
withArgs (Set VName
shape_names forall a. Semigroup a => a -> a -> a
<> Set VName
args) forall a b. (a -> b) -> a -> b
$ Exp -> MonoM Exp
transformExp Exp
body'
    Set VName
scope' <- forall a. Ord a => Set a -> Set a -> Set a
S.union (Set VName
shape_names forall a. Semigroup a => a -> a -> a
<> Set VName
args) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM (Set VName)
askScope'
    Exp
body''' <-
      if Bool
letFun
        then Set VName -> Exp -> MonoM Exp
unscoping (Set VName
shape_names forall a. Semigroup a => a -> a -> a
<> Set VName
args) Exp
body''
        else ExpReplacements -> Exp -> Exp
expReplace ExpReplacements
exp_naming' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> ExpReplacements -> MonoM Exp
calculateDims Exp
body'' forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set VName -> ExpReplacements -> ExpReplacements
canCalculate Set VName
scope' forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s (m :: * -> *). MonadState s m => m s
get)

    Bool
seen_before <- forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM Lifts
getLifts
    VName
name' <-
      if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TypeParamBase VName]
tparams Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
entry Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
seen_before
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name
        else forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
name

    forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( VName
name',
        [TypeParamBase VName]
-> StructType -> ExpReplacements -> InferSizeArgs
inferSizeArgs [TypeParamBase VName]
shape_params_explicit StructType
bind_t'' ExpReplacements
bind_r,
        if Bool
entry
          then
            VName
-> [TypeParamBase VName]
-> [Pat ParamType]
-> ResRetType
-> Exp
-> ValBind
toValBinding
              VName
name'
              ([TypeParamBase VName]
shape_params_explicit forall a. [a] -> [a] -> [a]
++ [TypeParamBase VName]
shape_params_implicit)
              [Pat ParamType]
params''
              ResRetType
rettype''
              (ExpReplacements -> Exp -> Exp
entryAssert ExpReplacements
exp_naming Exp
body''')
          else
            VName
-> [TypeParamBase VName]
-> [Pat ParamType]
-> ResRetType
-> Exp
-> ValBind
toValBinding
              VName
name'
              [TypeParamBase VName]
shape_params_implicit
              (forall a b. (a -> b) -> [a] -> [b]
map forall {vn} {dim} {als}.
TypeParamBase vn -> PatBase Info vn (TypeBase dim als)
shapeParam [TypeParamBase VName]
shape_params_explicit forall a. [a] -> [a] -> [a]
++ [Pat ParamType]
params'')
              ResRetType
rettype''
              Exp
body'''
      )
  where
    askScope' :: MonoM (Set VName)
askScope' = forall a. (a -> Bool) -> Set a -> Set a
S.filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` forall dim as. RetTypeBase dim as -> [VName]
retDims ResRetType
rettype) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MonoM (Set VName)
askScope

    shape_params :: [TypeParamBase VName]
shape_params = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall vn. TypeParamBase vn -> Bool
isTypeParam) [TypeParamBase VName]
tparams

    updateExpTypes :: TypeSubs -> Exp -> m Exp
updateExpTypes TypeSubs
substs = forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap (TypeSubs -> ASTMapper m
mapper TypeSubs
substs)

    hardTransformRetType :: RetTypeBase Exp as -> MonoM (RetTypeBase Exp as)
hardTransformRetType (RetType [VName]
dims TypeBase Exp as
ty) = do
      TypeBase Exp as
ty' <- forall u. TypeBase Exp u -> MonoM (TypeBase Exp u)
transformType TypeBase Exp as
ty
      Set VName
unbounded <- Set VName -> MonoM (Set VName)
askIntros forall a b. (a -> b) -> a -> b
$ FV -> Set VName
fvVars forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Exp u -> FV
freeInType TypeBase Exp as
ty'
      let dims' :: [VName]
dims' = forall a. Set a -> [a]
S.toList Set VName
unbounded
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType ([VName]
dims' forall a. Semigroup a => a -> a -> a
<> [VName]
dims) TypeBase Exp as
ty'

    mapper :: TypeSubs -> ASTMapper m
mapper TypeSubs
substs =
      ASTMapper
        { mapOnExp :: Exp -> m Exp
mapOnExp = TypeSubs -> Exp -> m Exp
updateExpTypes TypeSubs
substs,
          mapOnName :: VName -> m VName
mapOnName = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnStructType :: StructType -> m StructType
mapOnStructType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitutable a => TypeSubs -> a -> a
applySubst TypeSubs
substs,
          mapOnParamType :: ParamType -> m ParamType
mapOnParamType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitutable a => TypeSubs -> a -> a
applySubst TypeSubs
substs,
          mapOnResRetType :: ResRetType -> m ResRetType
mapOnResRetType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitutable a => TypeSubs -> a -> a
applySubst TypeSubs
substs
        }

    shapeParam :: TypeParamBase vn -> PatBase Info vn (TypeBase dim als)
shapeParam TypeParamBase vn
tp = forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id (forall vn. TypeParamBase vn -> vn
typeParamName TypeParamBase vn
tp) (forall a. a -> Info a
Info forall dim als. TypeBase dim als
i64) forall a b. (a -> b) -> a -> b
$ forall a. Located a => a -> SrcLoc
srclocOf TypeParamBase vn
tp

    toValBinding :: VName
-> [TypeParamBase VName]
-> [Pat ParamType]
-> ResRetType
-> Exp
-> ValBind
toValBinding VName
name' [TypeParamBase VName]
tparams' [Pat ParamType]
params'' ResRetType
rettype' Exp
body'' =
      ValBind
        { valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = forall a. Maybe a
Nothing,
          valBindName :: VName
valBindName = VName
name',
          valBindRetType :: Info ResRetType
valBindRetType = forall a. a -> Info a
Info ResRetType
rettype',
          valBindRetDecl :: Maybe (TypeExp Info VName)
valBindRetDecl = forall a. Maybe a
Nothing,
          valBindTypeParams :: [TypeParamBase VName]
valBindTypeParams = [TypeParamBase VName]
tparams',
          valBindParams :: [Pat ParamType]
valBindParams = [Pat ParamType]
params'',
          valBindBody :: Exp
valBindBody = Exp
body'',
          valBindDoc :: Maybe DocComment
valBindDoc = forall a. Maybe a
Nothing,
          valBindAttrs :: [AttrInfo VName]
valBindAttrs = [AttrInfo VName]
attrs,
          valBindLocation :: SrcLoc
valBindLocation = SrcLoc
loc
        }

typeSubstsM ::
  MonadFreshNames m =>
  SrcLoc ->
  TypeBase () NoUniqueness ->
  MonoType ->
  m (M.Map VName StructRetType, [TypeParam])
typeSubstsM :: forall (m :: * -> *).
MonadFreshNames m =>
SrcLoc
-> TypeBase () NoUniqueness
-> MonoType
-> m (Map VName StructRetType, [TypeParamBase VName])
typeSubstsM SrcLoc
loc TypeBase () NoUniqueness
orig_t1 MonoType
orig_t2 =
  forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (forall {t :: (* -> *) -> * -> *} {t :: (* -> *) -> * -> *}
       {m :: * -> *} {dim}.
(MonadState (Map VName StructRetType, Map Int VName) (t (t m)),
 MonadTrans t, MonadTrans t, MonadFreshNames m,
 MonadWriter [TypeParamBase VName] (t (t m)), Pretty (Shape dim),
 Monad (t m)) =>
TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub TypeBase () NoUniqueness
orig_t1 MonoType
orig_t2) (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)
  where
    subRet :: TypeBase dim NoUniqueness
-> RetTypeBase MonoSize NoUniqueness -> t (t m) ()
subRet (Scalar (TypeVar NoUniqueness
_ QualName VName
v [TypeArg dim]
_)) RetTypeBase MonoSize NoUniqueness
rt =
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName -> Int
baseTag (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag) forall a b. (a -> b) -> a -> b
$
        forall {k} {as} {t :: (* -> *) -> * -> *} {t :: (* -> *) -> * -> *}
       {m :: * -> *}.
(Ord k,
 MonadState (Map k (RetTypeBase Exp as), Map Int VName) (t (t m)),
 MonadTrans t, MonadTrans t, MonadFreshNames m,
 MonadWriter [TypeParamBase VName] (t (t m)), Monad (t m)) =>
QualName k -> RetTypeBase MonoSize as -> t (t m) ()
addSubst QualName VName
v RetTypeBase MonoSize NoUniqueness
rt
    subRet TypeBase dim NoUniqueness
t1 (RetType [VName]
_ MonoType
t2) =
      TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub TypeBase dim NoUniqueness
t1 MonoType
t2

    sub :: TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub t1 :: TypeBase dim NoUniqueness
t1@Array {} t2 :: MonoType
t2@Array {}
      | Just TypeBase dim NoUniqueness
t1' <- forall dim u. Int -> TypeBase dim u -> Maybe (TypeBase dim u)
peelArray (forall dim as. TypeBase dim as -> Int
arrayRank TypeBase dim NoUniqueness
t1) TypeBase dim NoUniqueness
t1,
        Just MonoType
t2' <- forall dim u. Int -> TypeBase dim u -> Maybe (TypeBase dim u)
peelArray (forall dim as. TypeBase dim as -> Int
arrayRank TypeBase dim NoUniqueness
t1) MonoType
t2 =
          TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub TypeBase dim NoUniqueness
t1' MonoType
t2'
    sub (Scalar (TypeVar NoUniqueness
_ QualName VName
v [TypeArg dim]
_)) MonoType
t =
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName -> Int
baseTag (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag) forall a b. (a -> b) -> a -> b
$
        forall {k} {as} {t :: (* -> *) -> * -> *} {t :: (* -> *) -> * -> *}
       {m :: * -> *}.
(Ord k,
 MonadState (Map k (RetTypeBase Exp as), Map Int VName) (t (t m)),
 MonadTrans t, MonadTrans t, MonadFreshNames m,
 MonadWriter [TypeParamBase VName] (t (t m)), Monad (t m)) =>
QualName k -> RetTypeBase MonoSize as -> t (t m) ()
addSubst QualName VName
v forall a b. (a -> b) -> a -> b
$
          forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] MonoType
t
    sub (Scalar (Record Map Name (TypeBase dim NoUniqueness)
fields1)) (Scalar (Record Map Name MonoType
fields2)) =
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
        TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub
        (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. Map Name a -> [(Name, a)]
sortFields Map Name (TypeBase dim NoUniqueness)
fields1)
        (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. Map Name a -> [(Name, a)]
sortFields Map Name MonoType
fields2)
    sub (Scalar Prim {}) (Scalar Prim {}) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    sub (Scalar (Arrow NoUniqueness
_ PName
_ Diet
_ TypeBase dim NoUniqueness
t1a (RetType [VName]
_ TypeBase dim Uniqueness
t1b))) (Scalar (Arrow NoUniqueness
_ PName
_ Diet
_ MonoType
t2a RetTypeBase MonoSize Uniqueness
t2b)) = do
      TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub TypeBase dim NoUniqueness
t1a MonoType
t2a
      TypeBase dim NoUniqueness
-> RetTypeBase MonoSize NoUniqueness -> t (t m) ()
subRet (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase dim Uniqueness
t1b) (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall a b. a -> b -> a
const NoUniqueness
NoUniqueness) RetTypeBase MonoSize Uniqueness
t2b)
    sub (Scalar (Sum Map Name [TypeBase dim NoUniqueness]
cs1)) (Scalar (Sum Map Name [MonoType]
cs2)) =
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {a} {a}.
(a, [TypeBase dim NoUniqueness]) -> (a, [MonoType]) -> t (t m) [()]
typeSubstClause (forall a. Map Name a -> [(Name, a)]
sortConstrs Map Name [TypeBase dim NoUniqueness]
cs1) (forall a. Map Name a -> [(Name, a)]
sortConstrs Map Name [MonoType]
cs2)
      where
        typeSubstClause :: (a, [TypeBase dim NoUniqueness]) -> (a, [MonoType]) -> t (t m) [()]
typeSubstClause (a
_, [TypeBase dim NoUniqueness]
ts1) (a
_, [MonoType]
ts2) = forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub [TypeBase dim NoUniqueness]
ts1 [MonoType]
ts2
    sub t1 :: TypeBase dim NoUniqueness
t1@(Scalar Sum {}) MonoType
t2 = TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub TypeBase dim NoUniqueness
t1 MonoType
t2
    sub TypeBase dim NoUniqueness
t1 t2 :: MonoType
t2@(Scalar Sum {}) = TypeBase dim NoUniqueness -> MonoType -> t (t m) ()
sub TypeBase dim NoUniqueness
t1 MonoType
t2
    sub TypeBase dim NoUniqueness
t1 MonoType
t2 = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines [[Char]
"typeSubstsM: mismatched types:", forall a. Pretty a => a -> [Char]
prettyString TypeBase dim NoUniqueness
t1, forall a. Pretty a => a -> [Char]
prettyString MonoType
t2]

    addSubst :: QualName k -> RetTypeBase MonoSize as -> t (t m) ()
addSubst (QualName [k]
_ k
v) (RetType [VName]
ext TypeBase MonoSize as
t) = do
      (Map k (RetTypeBase Exp as)
ts, Map Int VName
sizes) <- forall s (m :: * -> *). MonadState s m => m s
get
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (k
v forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map k (RetTypeBase Exp as)
ts) forall a b. (a -> b) -> a -> b
$ do
        TypeBase Exp as
t' <- forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse forall {a} {t :: (* -> *) -> * -> *} {t :: (* -> *) -> * -> *}
       {m :: * -> *}.
(MonadState (a, Map Int VName) (t (t m)), MonadTrans t,
 MonadTrans t, Monad (t m), MonadFreshNames m,
 MonadWriter [TypeParamBase VName] (t (t m))) =>
MonoSize -> t (t m) Exp
onDim forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeBase MonoSize as
t
        forall s (m :: * -> *). MonadState s m => s -> m ()
put (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ext TypeBase Exp as
t') Map k (RetTypeBase Exp as)
ts, Map Int VName
sizes)

    onDim :: MonoSize -> t (t m) Exp
onDim (MonoKnown Int
i) = do
      (a
ts, Map Int VName
sizes) <- forall s (m :: * -> *). MonadState s m => m s
get
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
i Map Int VName
sizes of
        Maybe VName
Nothing -> do
          VName
d <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"d"
          forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [forall vn. vn -> SrcLoc -> TypeParamBase vn
TypeParamDim VName
d SrcLoc
loc]
          forall s (m :: * -> *). MonadState s m => s -> m ()
put (a
ts, forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
i VName
d Map Int VName
sizes)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
d) forall a. Monoid a => a
mempty
        Just VName
d ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
d) forall a. Monoid a => a
mempty
    onDim MonoSize
MonoAnon = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
anySize

-- Perform a given substitution on the types in a pattern.
substPat :: Bool -> (t -> t) -> Pat t -> Pat t
substPat :: forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
entry t -> t
f Pat t
pat = case Pat t
pat of
  TuplePat [Pat t]
pats SrcLoc
loc -> forall (f :: * -> *) vn t.
[PatBase f vn t] -> SrcLoc -> PatBase f vn t
TuplePat (forall a b. (a -> b) -> [a] -> [b]
map (forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
entry t -> t
f) [Pat t]
pats) SrcLoc
loc
  RecordPat [(Name, Pat t)]
fs SrcLoc
loc -> forall (f :: * -> *) vn t.
[(Name, PatBase f vn t)] -> SrcLoc -> PatBase f vn t
RecordPat (forall a b. (a -> b) -> [a] -> [b]
map forall {a}. (a, Pat t) -> (a, Pat t)
substField [(Name, Pat t)]
fs) SrcLoc
loc
    where
      substField :: (a, Pat t) -> (a, Pat t)
substField (a
n, Pat t
p) = (a
n, forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
entry t -> t
f Pat t
p)
  PatParens Pat t
p SrcLoc
loc -> forall (f :: * -> *) vn t.
PatBase f vn t -> SrcLoc -> PatBase f vn t
PatParens (forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
entry t -> t
f Pat t
p) SrcLoc
loc
  PatAttr AttrInfo VName
attr Pat t
p SrcLoc
loc -> forall (f :: * -> *) vn t.
AttrInfo vn -> PatBase f vn t -> SrcLoc -> PatBase f vn t
PatAttr AttrInfo VName
attr (forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
entry t -> t
f Pat t
p) SrcLoc
loc
  Id VName
vn (Info t
tp) SrcLoc
loc -> forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
vn (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) SrcLoc
loc
  Wildcard (Info t
tp) SrcLoc
loc -> forall (f :: * -> *) vn t. f t -> SrcLoc -> PatBase f vn t
Wildcard (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) SrcLoc
loc
  PatAscription Pat t
p TypeExp Info VName
td SrcLoc
loc
    | Bool
entry -> forall (f :: * -> *) vn t.
PatBase f vn t -> TypeExp f vn -> SrcLoc -> PatBase f vn t
PatAscription (forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
False t -> t
f Pat t
p) TypeExp Info VName
td SrcLoc
loc
    | Bool
otherwise -> forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
False t -> t
f Pat t
p
  PatLit PatLit
e (Info t
tp) SrcLoc
loc -> forall (f :: * -> *) vn t.
PatLit -> f t -> SrcLoc -> PatBase f vn t
PatLit PatLit
e (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) SrcLoc
loc
  PatConstr Name
n (Info t
tp) [Pat t]
ps SrcLoc
loc -> forall (f :: * -> *) vn t.
Name -> f t -> [PatBase f vn t] -> SrcLoc -> PatBase f vn t
PatConstr Name
n (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) [Pat t]
ps SrcLoc
loc

toPolyBinding :: ValBind -> PolyBinding
toPolyBinding :: ValBind -> PolyBinding
toPolyBinding (ValBind Maybe (Info EntryPoint)
_ VName
name Maybe (TypeExp Info VName)
_ (Info ResRetType
rettype) [TypeParamBase VName]
tparams [Pat ParamType]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
attrs SrcLoc
loc) =
  (VName, [TypeParamBase VName], [Pat ParamType], ResRetType, Exp,
 [AttrInfo VName], SrcLoc)
-> PolyBinding
PolyBinding (VName
name, [TypeParamBase VName]
tparams, [Pat ParamType]
params, ResRetType
rettype, Exp
body, [AttrInfo VName]
attrs, SrcLoc
loc)

-- Remove all type variables and type abbreviations from a value binding.
removeTypeVariables :: Bool -> ValBind -> MonoM ValBind
removeTypeVariables :: Bool -> ValBind -> MonoM ValBind
removeTypeVariables Bool
entry ValBind
valbind = do
  let (ValBind Maybe (Info EntryPoint)
_ VName
_ Maybe (TypeExp Info VName)
_ (Info (RetType [VName]
dims TypeBase Exp Uniqueness
rettype)) [TypeParamBase VName]
_ [Pat ParamType]
pats Exp
body Maybe DocComment
_ [AttrInfo VName]
_ SrcLoc
_) = ValBind
valbind
  Map VName (Subst StructRetType)
subs <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map TypeBinding -> Subst StructRetType
substFromAbbr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map VName TypeBinding
envTypeBindings
  let mapper :: ASTMapper MonoM
mapper =
        ASTMapper
          { mapOnExp :: Exp -> MonoM Exp
mapOnExp = Exp -> MonoM Exp
onExp,
            mapOnName :: VName -> MonoM VName
mapOnName = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
            mapOnStructType :: StructType -> MonoM StructType
mapOnStructType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
subs),
            mapOnParamType :: ParamType -> MonoM ParamType
mapOnParamType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
subs),
            mapOnResRetType :: ResRetType -> MonoM ResRetType
mapOnResRetType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
subs)
          }

      onExp :: Exp -> MonoM Exp
onExp = forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper MonoM
mapper

  Exp
body' <- Exp -> MonoM Exp
onExp Exp
body

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ValBind
valbind
      { valBindRetType :: Info ResRetType
valBindRetType = forall a. a -> Info a
Info (forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
subs) forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims TypeBase Exp Uniqueness
rettype),
        valBindParams :: [Pat ParamType]
valBindParams = forall a b. (a -> b) -> [a] -> [b]
map (forall t. Bool -> (t -> t) -> Pat t -> Pat t
substPat Bool
entry forall a b. (a -> b) -> a -> b
$ forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
subs)) [Pat ParamType]
pats,
        valBindBody :: Exp
valBindBody = Exp
body'
      }

removeTypeVariablesInType :: StructType -> MonoM StructType
removeTypeVariablesInType :: StructType -> MonoM StructType
removeTypeVariablesInType StructType
t = do
  Map VName (Subst StructRetType)
subs <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map TypeBinding -> Subst StructRetType
substFromAbbr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map VName TypeBinding
envTypeBindings
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
subs) StructType
t

transformEntryPoint :: EntryPoint -> MonoM EntryPoint
transformEntryPoint :: EntryPoint -> MonoM EntryPoint
transformEntryPoint (EntryPoint [EntryParam]
params EntryType
ret) =
  [EntryParam] -> EntryType -> EntryPoint
EntryPoint forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM EntryParam -> MonoM EntryParam
onEntryParam [EntryParam]
params forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> EntryType -> MonoM EntryType
onEntryType EntryType
ret
  where
    onEntryParam :: EntryParam -> MonoM EntryParam
onEntryParam (EntryParam Name
v EntryType
t) =
      Name -> EntryType -> EntryParam
EntryParam Name
v forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EntryType -> MonoM EntryType
onEntryType EntryType
t
    onEntryType :: EntryType -> MonoM EntryType
onEntryType (EntryType StructType
t Maybe (TypeExp Info VName)
te) =
      StructType -> Maybe (TypeExp Info VName) -> EntryType
EntryType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StructType -> MonoM StructType
removeTypeVariablesInType StructType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (TypeExp Info VName)
te

transformValBind :: ValBind -> MonoM Env
transformValBind :: ValBind -> MonoM Env
transformValBind ValBind
valbind = do
  PolyBinding
valbind' <-
    ValBind -> PolyBinding
toPolyBinding
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> ValBind -> MonoM ValBind
removeTypeVariables (forall a. Maybe a -> Bool
isJust (forall (f :: * -> *) vn. ValBindBase f vn -> Maybe (f EntryPoint)
valBindEntryPoint ValBind
valbind)) ValBind
valbind

  case forall (f :: * -> *) vn. ValBindBase f vn -> Maybe (f EntryPoint)
valBindEntryPoint ValBind
valbind of
    Maybe (Info EntryPoint)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just (Info EntryPoint
entry) -> do
      StructType
t <-
        StructType -> MonoM StructType
removeTypeVariablesInType forall a b. (a -> b) -> a -> b
$
          [Pat ParamType] -> ResRetType -> StructType
funType (forall (f :: * -> *) vn.
ValBindBase f vn -> [PatBase f vn ParamType]
valBindParams ValBind
valbind) forall a b. (a -> b) -> a -> b
$
            forall a. Info a -> a
unInfo forall a b. (a -> b) -> a -> b
$
              forall (f :: * -> *) vn. ValBindBase f vn -> f ResRetType
valBindRetType ValBind
valbind
      (VName
name, InferSizeArgs
infer, ValBind
valbind'') <- Bool
-> PolyBinding -> MonoType -> MonoM (VName, InferSizeArgs, ValBind)
monomorphiseBinding Bool
True PolyBinding
valbind' forall a b. (a -> b) -> a -> b
$ forall als. TypeBase Exp als -> MonoType
monoType StructType
t
      EntryPoint
entry' <- EntryPoint -> MonoM EntryPoint
transformEntryPoint EntryPoint
entry
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall a. a -> Seq a
Seq.singleton (VName
name, ValBind
valbind'' {valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. a -> Info a
Info EntryPoint
entry'})
      VName -> MonoType -> (VName, InferSizeArgs) -> MonoM ()
addLifted (forall (f :: * -> *) vn. ValBindBase f vn -> vn
valBindName ValBind
valbind) (forall als. TypeBase Exp als -> MonoType
monoType StructType
t) (VName
name, InferSizeArgs
infer)

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    forall a. Monoid a => a
mempty
      { envPolyBindings :: Map VName PolyBinding
envPolyBindings = forall k a. k -> a -> Map k a
M.singleton (forall (f :: * -> *) vn. ValBindBase f vn -> vn
valBindName ValBind
valbind) PolyBinding
valbind',
        envGlobalScope :: Set VName
envGlobalScope =
          if forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall (f :: * -> *) vn.
ValBindBase f vn -> [PatBase f vn ParamType]
valBindParams ValBind
valbind)
            then forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall dim as. RetTypeBase dim as -> [VName]
retDims forall a b. (a -> b) -> a -> b
$ forall a. Info a -> a
unInfo forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. ValBindBase f vn -> f ResRetType
valBindRetType ValBind
valbind
            else forall a. Monoid a => a
mempty
      }

transformTypeBind :: TypeBind -> MonoM Env
transformTypeBind :: TypeBind -> MonoM Env
transformTypeBind (TypeBind VName
name Liftedness
l [TypeParamBase VName]
tparams TypeExp Info VName
_ (Info (RetType [VName]
dims StructType
t)) Maybe DocComment
_ SrcLoc
_) = do
  Map VName (Subst StructRetType)
subs <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map TypeBinding -> Subst StructRetType
substFromAbbr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map VName TypeBinding
envTypeBindings
  let tbinding :: TypeBinding
tbinding = Liftedness -> [TypeParamBase VName] -> StructRetType -> TypeBinding
TypeAbbr Liftedness
l [TypeParamBase VName]
tparams forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims forall a b. (a -> b) -> a -> b
$ forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
subs) StructType
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty {envTypeBindings :: Map VName TypeBinding
envTypeBindings = forall k a. k -> a -> Map k a
M.singleton VName
name TypeBinding
tbinding}

transformDecs :: [Dec] -> MonoM ()
transformDecs :: [Dec] -> MonoM ()
transformDecs [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
transformDecs (ValDec ValBind
valbind : [Dec]
ds) = do
  Env
env <- ValBind -> MonoM Env
transformValBind ValBind
valbind
  forall a. Env -> MonoM a -> MonoM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ [Dec] -> MonoM ()
transformDecs [Dec]
ds
transformDecs (TypeDec TypeBind
typebind : [Dec]
ds) = do
  Env
env <- TypeBind -> MonoM Env
transformTypeBind TypeBind
typebind
  forall a. Env -> MonoM a -> MonoM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ [Dec] -> MonoM ()
transformDecs [Dec]
ds
transformDecs (Dec
dec : [Dec]
_) =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
"The monomorphization module expects a module-free "
      forall a. [a] -> [a] -> [a]
++ [Char]
"input program, but received: "
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Dec
dec

-- | Monomorphise a list of top-level declarations. A module-free input program
-- is expected, so only value declarations and type declaration are accepted.
transformProg :: MonadFreshNames m => [Dec] -> m [ValBind]
transformProg :: forall (m :: * -> *). MonadFreshNames m => [Dec] -> m [ValBind]
transformProg [Dec]
decs =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
      forall a.
VNameSource -> MonoM a -> ((a, Seq (VName, ValBind)), VNameSource)
runMonoM VNameSource
namesrc forall a b. (a -> b) -> a -> b
$ [Dec] -> MonoM ()
transformDecs [Dec]
decs