{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Internalise.Monad
  ( InternaliseM,
    runInternaliseM,
    throwError,
    VarSubstitutions,
    InternaliseEnv (..),
    FunInfo,
    substitutingVars,
    lookupSubst,
    addFunDef,
    lookupFunction,
    lookupFunction',
    lookupConst,
    bindFunction,
    bindConstant,
    localConstsScope,
    assert,

    -- * Convenient reexports
    module Futhark.Tools,
  )
where

import Control.Monad.Except
import Control.Monad.RWS
import qualified Data.Map.Strict as M
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util (takeLast)

type FunInfo =
  ( [VName],
    [DeclType],
    [FParam],
    [(SubExp, Type)] -> Maybe [DeclExtType]
  )

type FunTable = M.Map VName FunInfo

-- | A mapping from external variable names to the corresponding
-- internalised subexpressions.
type VarSubstitutions = M.Map VName [SubExp]

data InternaliseEnv = InternaliseEnv
  { InternaliseEnv -> VarSubstitutions
envSubsts :: VarSubstitutions,
    InternaliseEnv -> Bool
envDoBoundsChecks :: Bool,
    InternaliseEnv -> Bool
envSafe :: Bool,
    InternaliseEnv -> Attrs
envAttrs :: Attrs
  }

data InternaliseState = InternaliseState
  { InternaliseState -> VNameSource
stateNameSource :: VNameSource,
    InternaliseState -> FunTable
stateFunTable :: FunTable,
    InternaliseState -> VarSubstitutions
stateConstSubsts :: VarSubstitutions,
    InternaliseState -> Scope SOACS
stateConstScope :: Scope SOACS
  }

data InternaliseResult = InternaliseResult (Stms SOACS) [FunDef SOACS]

instance Semigroup InternaliseResult where
  InternaliseResult Stms SOACS
xs1 [FunDef SOACS]
ys1 <> :: InternaliseResult -> InternaliseResult -> InternaliseResult
<> InternaliseResult Stms SOACS
xs2 [FunDef SOACS]
ys2 =
    Stms SOACS -> [FunDef SOACS] -> InternaliseResult
InternaliseResult (Stms SOACS
xs1 Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS
xs2) ([FunDef SOACS]
ys1 [FunDef SOACS] -> [FunDef SOACS] -> [FunDef SOACS]
forall a. Semigroup a => a -> a -> a
<> [FunDef SOACS]
ys2)

instance Monoid InternaliseResult where
  mempty :: InternaliseResult
mempty = Stms SOACS -> [FunDef SOACS] -> InternaliseResult
InternaliseResult Stms SOACS
forall a. Monoid a => a
mempty [FunDef SOACS]
forall a. Monoid a => a
mempty

newtype InternaliseM a
  = InternaliseM
      ( BinderT
          SOACS
          ( RWS
              InternaliseEnv
              InternaliseResult
              InternaliseState
          )
          a
      )
  deriving
    ( (forall a b. (a -> b) -> InternaliseM a -> InternaliseM b)
-> (forall a b. a -> InternaliseM b -> InternaliseM a)
-> Functor InternaliseM
forall a b. a -> InternaliseM b -> InternaliseM a
forall a b. (a -> b) -> InternaliseM a -> InternaliseM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> InternaliseM b -> InternaliseM a
$c<$ :: forall a b. a -> InternaliseM b -> InternaliseM a
fmap :: forall a b. (a -> b) -> InternaliseM a -> InternaliseM b
$cfmap :: forall a b. (a -> b) -> InternaliseM a -> InternaliseM b
Functor,
      Functor InternaliseM
Functor InternaliseM
-> (forall a. a -> InternaliseM a)
-> (forall a b.
    InternaliseM (a -> b) -> InternaliseM a -> InternaliseM b)
-> (forall a b c.
    (a -> b -> c)
    -> InternaliseM a -> InternaliseM b -> InternaliseM c)
-> (forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b)
-> (forall a b. InternaliseM a -> InternaliseM b -> InternaliseM a)
-> Applicative InternaliseM
forall a. a -> InternaliseM a
forall a b. InternaliseM a -> InternaliseM b -> InternaliseM a
forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b
forall a b.
InternaliseM (a -> b) -> InternaliseM a -> InternaliseM b
forall a b c.
(a -> b -> c) -> InternaliseM a -> InternaliseM b -> InternaliseM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. InternaliseM a -> InternaliseM b -> InternaliseM a
$c<* :: forall a b. InternaliseM a -> InternaliseM b -> InternaliseM a
*> :: forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b
$c*> :: forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b
liftA2 :: forall a b c.
(a -> b -> c) -> InternaliseM a -> InternaliseM b -> InternaliseM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> InternaliseM a -> InternaliseM b -> InternaliseM c
<*> :: forall a b.
InternaliseM (a -> b) -> InternaliseM a -> InternaliseM b
$c<*> :: forall a b.
InternaliseM (a -> b) -> InternaliseM a -> InternaliseM b
pure :: forall a. a -> InternaliseM a
$cpure :: forall a. a -> InternaliseM a
Applicative,
      Applicative InternaliseM
Applicative InternaliseM
-> (forall a b.
    InternaliseM a -> (a -> InternaliseM b) -> InternaliseM b)
-> (forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b)
-> (forall a. a -> InternaliseM a)
-> Monad InternaliseM
forall a. a -> InternaliseM a
forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b
forall a b.
InternaliseM a -> (a -> InternaliseM b) -> InternaliseM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> InternaliseM a
$creturn :: forall a. a -> InternaliseM a
>> :: forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b
$c>> :: forall a b. InternaliseM a -> InternaliseM b -> InternaliseM b
>>= :: forall a b.
InternaliseM a -> (a -> InternaliseM b) -> InternaliseM b
$c>>= :: forall a b.
InternaliseM a -> (a -> InternaliseM b) -> InternaliseM b
Monad,
      MonadReader InternaliseEnv,
      MonadState InternaliseState,
      Monad InternaliseM
Applicative InternaliseM
InternaliseM VNameSource
Applicative InternaliseM
-> Monad InternaliseM
-> InternaliseM VNameSource
-> (VNameSource -> InternaliseM ())
-> MonadFreshNames InternaliseM
VNameSource -> InternaliseM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> InternaliseM ()
$cputNameSource :: VNameSource -> InternaliseM ()
getNameSource :: InternaliseM VNameSource
$cgetNameSource :: InternaliseM VNameSource
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

instance (Monoid w, Monad m) => MonadFreshNames (RWST r w InternaliseState m) where
  getNameSource :: RWST r w InternaliseState m VNameSource
getNameSource = (InternaliseState -> VNameSource)
-> RWST r w InternaliseState m VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets InternaliseState -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> RWST r w InternaliseState m ()
putNameSource VNameSource
src = (InternaliseState -> InternaliseState)
-> RWST r w InternaliseState m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InternaliseState -> InternaliseState)
 -> RWST r w InternaliseState m ())
-> (InternaliseState -> InternaliseState)
-> RWST r w InternaliseState m ()
forall a b. (a -> b) -> a -> b
$ \InternaliseState
s -> InternaliseState
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

instance MonadBinder InternaliseM where
  type Lore InternaliseM = SOACS
  mkExpDecM :: Pattern (Lore InternaliseM)
-> Exp (Lore InternaliseM)
-> InternaliseM (ExpDec (Lore InternaliseM))
mkExpDecM Pattern (Lore InternaliseM)
pat Exp (Lore InternaliseM)
e = BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
-> InternaliseM ()
forall a.
BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
-> InternaliseM a
InternaliseM (BinderT
   SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
 -> InternaliseM ())
-> BinderT
     SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
-> InternaliseM ()
forall a b. (a -> b) -> a -> b
$ Pattern
  (Lore
     (BinderT
        SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
-> Exp
     (Lore
        (BinderT
           SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
-> BinderT
     SOACS
     (RWS InternaliseEnv InternaliseResult InternaliseState)
     (ExpDec
        (Lore
           (BinderT
              SOACS (RWS InternaliseEnv InternaliseResult InternaliseState))))
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m (ExpDec (Lore m))
mkExpDecM Pattern
  (Lore
     (BinderT
        SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
Pattern (Lore InternaliseM)
pat Exp
  (Lore
     (BinderT
        SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
Exp (Lore InternaliseM)
e
  mkBodyM :: Stms (Lore InternaliseM)
-> Result -> InternaliseM (Body (Lore InternaliseM))
mkBodyM Stms (Lore InternaliseM)
bnds Result
res = BinderT
  SOACS
  (RWS InternaliseEnv InternaliseResult InternaliseState)
  (Body SOACS)
-> InternaliseM (Body SOACS)
forall a.
BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
-> InternaliseM a
InternaliseM (BinderT
   SOACS
   (RWS InternaliseEnv InternaliseResult InternaliseState)
   (Body SOACS)
 -> InternaliseM (Body SOACS))
-> BinderT
     SOACS
     (RWS InternaliseEnv InternaliseResult InternaliseState)
     (Body SOACS)
-> InternaliseM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Stms
  (Lore
     (BinderT
        SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
-> Result
-> BinderT
     SOACS
     (RWS InternaliseEnv InternaliseResult InternaliseState)
     (Body
        (Lore
           (BinderT
              SOACS (RWS InternaliseEnv InternaliseResult InternaliseState))))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> Result -> m (Body (Lore m))
mkBodyM Stms
  (Lore
     (BinderT
        SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
Stms (Lore InternaliseM)
bnds Result
res
  mkLetNamesM :: [VName]
-> Exp (Lore InternaliseM)
-> InternaliseM (Stm (Lore InternaliseM))
mkLetNamesM [VName]
pat Exp (Lore InternaliseM)
e = BinderT
  SOACS
  (RWS InternaliseEnv InternaliseResult InternaliseState)
  (Stm SOACS)
-> InternaliseM (Stm SOACS)
forall a.
BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
-> InternaliseM a
InternaliseM (BinderT
   SOACS
   (RWS InternaliseEnv InternaliseResult InternaliseState)
   (Stm SOACS)
 -> InternaliseM (Stm SOACS))
-> BinderT
     SOACS
     (RWS InternaliseEnv InternaliseResult InternaliseState)
     (Stm SOACS)
-> InternaliseM (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp
     (Lore
        (BinderT
           SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
-> BinderT
     SOACS
     (RWS InternaliseEnv InternaliseResult InternaliseState)
     (Stm
        (Lore
           (BinderT
              SOACS (RWS InternaliseEnv InternaliseResult InternaliseState))))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesM [VName]
pat Exp
  (Lore
     (BinderT
        SOACS (RWS InternaliseEnv InternaliseResult InternaliseState)))
Exp (Lore InternaliseM)
e

  addStms :: Stms (Lore InternaliseM) -> InternaliseM ()
addStms = BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
-> InternaliseM ()
forall a.
BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
-> InternaliseM a
InternaliseM (BinderT
   SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
 -> InternaliseM ())
-> (Stms SOACS
    -> BinderT
         SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ())
-> Stms SOACS
-> InternaliseM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS
-> BinderT
     SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms
  collectStms :: forall a.
InternaliseM a -> InternaliseM (a, Stms (Lore InternaliseM))
collectStms (InternaliseM BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
m) = BinderT
  SOACS
  (RWS InternaliseEnv InternaliseResult InternaliseState)
  (a, Stms SOACS)
-> InternaliseM (a, Stms SOACS)
forall a.
BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
-> InternaliseM a
InternaliseM (BinderT
   SOACS
   (RWS InternaliseEnv InternaliseResult InternaliseState)
   (a, Stms SOACS)
 -> InternaliseM (a, Stms SOACS))
-> BinderT
     SOACS
     (RWS InternaliseEnv InternaliseResult InternaliseState)
     (a, Stms SOACS)
-> InternaliseM (a, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
-> BinderT
     SOACS
     (RWS InternaliseEnv InternaliseResult InternaliseState)
     (a,
      Stms
        (Lore
           (BinderT
              SOACS (RWS InternaliseEnv InternaliseResult InternaliseState))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
m

runInternaliseM ::
  MonadFreshNames m =>
  Bool ->
  InternaliseM () ->
  m (Stms SOACS, [FunDef SOACS])
runInternaliseM :: forall (m :: * -> *).
MonadFreshNames m =>
Bool -> InternaliseM () -> m (Stms SOACS, [FunDef SOACS])
runInternaliseM Bool
safe (InternaliseM BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
m) =
  (VNameSource -> ((Stms SOACS, [FunDef SOACS]), VNameSource))
-> m (Stms SOACS, [FunDef SOACS])
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Stms SOACS, [FunDef SOACS]), VNameSource))
 -> m (Stms SOACS, [FunDef SOACS]))
-> (VNameSource -> ((Stms SOACS, [FunDef SOACS]), VNameSource))
-> m (Stms SOACS, [FunDef SOACS])
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ((()
_, Stms SOACS
consts), InternaliseState
s, InternaliseResult Stms SOACS
_ [FunDef SOACS]
funs) =
          RWS
  InternaliseEnv InternaliseResult InternaliseState ((), Stms SOACS)
-> InternaliseEnv
-> InternaliseState
-> (((), Stms SOACS), InternaliseState, InternaliseResult)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
-> Scope SOACS
-> RWS
     InternaliseEnv InternaliseResult InternaliseState ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
m Scope SOACS
forall a. Monoid a => a
mempty) InternaliseEnv
newEnv (VNameSource -> InternaliseState
newState VNameSource
src)
     in ((Stms SOACS
consts, [FunDef SOACS]
funs), InternaliseState -> VNameSource
stateNameSource InternaliseState
s)
  where
    newEnv :: InternaliseEnv
newEnv =
      InternaliseEnv :: VarSubstitutions -> Bool -> Bool -> Attrs -> InternaliseEnv
InternaliseEnv
        { envSubsts :: VarSubstitutions
envSubsts = VarSubstitutions
forall a. Monoid a => a
mempty,
          envDoBoundsChecks :: Bool
envDoBoundsChecks = Bool
True,
          envSafe :: Bool
envSafe = Bool
safe,
          envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
        }
    newState :: VNameSource -> InternaliseState
newState VNameSource
src =
      InternaliseState :: VNameSource
-> FunTable -> VarSubstitutions -> Scope SOACS -> InternaliseState
InternaliseState
        { stateNameSource :: VNameSource
stateNameSource = VNameSource
src,
          stateFunTable :: FunTable
stateFunTable = FunTable
forall a. Monoid a => a
mempty,
          stateConstSubsts :: VarSubstitutions
stateConstSubsts = VarSubstitutions
forall a. Monoid a => a
mempty,
          stateConstScope :: Scope SOACS
stateConstScope = Scope SOACS
forall a. Monoid a => a
mempty
        }

substitutingVars :: VarSubstitutions -> InternaliseM a -> InternaliseM a
substitutingVars :: forall a. VarSubstitutions -> InternaliseM a -> InternaliseM a
substitutingVars VarSubstitutions
substs = (InternaliseEnv -> InternaliseEnv)
-> InternaliseM a -> InternaliseM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((InternaliseEnv -> InternaliseEnv)
 -> InternaliseM a -> InternaliseM a)
-> (InternaliseEnv -> InternaliseEnv)
-> InternaliseM a
-> InternaliseM a
forall a b. (a -> b) -> a -> b
$ \InternaliseEnv
env -> InternaliseEnv
env {envSubsts :: VarSubstitutions
envSubsts = VarSubstitutions
substs VarSubstitutions -> VarSubstitutions -> VarSubstitutions
forall a. Semigroup a => a -> a -> a
<> InternaliseEnv -> VarSubstitutions
envSubsts InternaliseEnv
env}

lookupSubst :: VName -> InternaliseM (Maybe [SubExp])
lookupSubst :: VName -> InternaliseM (Maybe Result)
lookupSubst VName
v = do
  Maybe Result
env_substs <- (InternaliseEnv -> Maybe Result) -> InternaliseM (Maybe Result)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((InternaliseEnv -> Maybe Result) -> InternaliseM (Maybe Result))
-> (InternaliseEnv -> Maybe Result) -> InternaliseM (Maybe Result)
forall a b. (a -> b) -> a -> b
$ VName -> VarSubstitutions -> Maybe Result
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (VarSubstitutions -> Maybe Result)
-> (InternaliseEnv -> VarSubstitutions)
-> InternaliseEnv
-> Maybe Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InternaliseEnv -> VarSubstitutions
envSubsts
  Maybe Result
const_substs <- (InternaliseState -> Maybe Result) -> InternaliseM (Maybe Result)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((InternaliseState -> Maybe Result) -> InternaliseM (Maybe Result))
-> (InternaliseState -> Maybe Result)
-> InternaliseM (Maybe Result)
forall a b. (a -> b) -> a -> b
$ VName -> VarSubstitutions -> Maybe Result
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (VarSubstitutions -> Maybe Result)
-> (InternaliseState -> VarSubstitutions)
-> InternaliseState
-> Maybe Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InternaliseState -> VarSubstitutions
stateConstSubsts
  Maybe Result -> InternaliseM (Maybe Result)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Result -> InternaliseM (Maybe Result))
-> Maybe Result -> InternaliseM (Maybe Result)
forall a b. (a -> b) -> a -> b
$ Maybe Result
env_substs Maybe Result -> Maybe Result -> Maybe Result
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Maybe Result
const_substs

-- | Add a function definition to the program being constructed.
addFunDef :: FunDef SOACS -> InternaliseM ()
addFunDef :: FunDef SOACS -> InternaliseM ()
addFunDef FunDef SOACS
fd =
  BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
-> InternaliseM ()
forall a.
BinderT
  SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) a
-> InternaliseM a
InternaliseM (BinderT
   SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
 -> InternaliseM ())
-> BinderT
     SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
-> InternaliseM ()
forall a b. (a -> b) -> a -> b
$ RWS InternaliseEnv InternaliseResult InternaliseState ()
-> BinderT
     SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RWS InternaliseEnv InternaliseResult InternaliseState ()
 -> BinderT
      SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ())
-> RWS InternaliseEnv InternaliseResult InternaliseState ()
-> BinderT
     SOACS (RWS InternaliseEnv InternaliseResult InternaliseState) ()
forall a b. (a -> b) -> a -> b
$ InternaliseResult
-> RWS InternaliseEnv InternaliseResult InternaliseState ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (InternaliseResult
 -> RWS InternaliseEnv InternaliseResult InternaliseState ())
-> InternaliseResult
-> RWS InternaliseEnv InternaliseResult InternaliseState ()
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [FunDef SOACS] -> InternaliseResult
InternaliseResult Stms SOACS
forall a. Monoid a => a
mempty [FunDef SOACS
fd]

lookupFunction' :: VName -> InternaliseM (Maybe FunInfo)
lookupFunction' :: VName -> InternaliseM (Maybe FunInfo)
lookupFunction' VName
fname = (InternaliseState
 -> Maybe
      ([VName], [DeclType], [Param DeclType],
       [(SubExp, Type)] -> Maybe [DeclExtType]))
-> InternaliseM
     (Maybe
        ([VName], [DeclType], [Param DeclType],
         [(SubExp, Type)] -> Maybe [DeclExtType]))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((InternaliseState
  -> Maybe
       ([VName], [DeclType], [Param DeclType],
        [(SubExp, Type)] -> Maybe [DeclExtType]))
 -> InternaliseM
      (Maybe
         ([VName], [DeclType], [Param DeclType],
          [(SubExp, Type)] -> Maybe [DeclExtType])))
-> (InternaliseState
    -> Maybe
         ([VName], [DeclType], [Param DeclType],
          [(SubExp, Type)] -> Maybe [DeclExtType]))
-> InternaliseM
     (Maybe
        ([VName], [DeclType], [Param DeclType],
         [(SubExp, Type)] -> Maybe [DeclExtType]))
forall a b. (a -> b) -> a -> b
$ VName
-> Map
     VName
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
-> Maybe
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
fname (Map
   VName
   ([VName], [DeclType], [Param DeclType],
    [(SubExp, Type)] -> Maybe [DeclExtType])
 -> Maybe
      ([VName], [DeclType], [Param DeclType],
       [(SubExp, Type)] -> Maybe [DeclExtType]))
-> (InternaliseState
    -> Map
         VName
         ([VName], [DeclType], [Param DeclType],
          [(SubExp, Type)] -> Maybe [DeclExtType]))
-> InternaliseState
-> Maybe
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InternaliseState
-> Map
     VName
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
InternaliseState -> FunTable
stateFunTable

lookupFunction :: VName -> InternaliseM FunInfo
lookupFunction :: VName -> InternaliseM FunInfo
lookupFunction VName
fname = InternaliseM
  ([VName], [DeclType], [Param DeclType],
   [(SubExp, Type)] -> Maybe [DeclExtType])
-> (([VName], [DeclType], [Param DeclType],
     [(SubExp, Type)] -> Maybe [DeclExtType])
    -> InternaliseM
         ([VName], [DeclType], [Param DeclType],
          [(SubExp, Type)] -> Maybe [DeclExtType]))
-> Maybe
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
-> InternaliseM
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall b a. b -> (a -> b) -> Maybe a -> b
maybe InternaliseM
  ([VName], [DeclType], [Param DeclType],
   [(SubExp, Type)] -> Maybe [DeclExtType])
bad ([VName], [DeclType], [Param DeclType],
 [(SubExp, Type)] -> Maybe [DeclExtType])
-> InternaliseM
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe
   ([VName], [DeclType], [Param DeclType],
    [(SubExp, Type)] -> Maybe [DeclExtType])
 -> InternaliseM
      ([VName], [DeclType], [Param DeclType],
       [(SubExp, Type)] -> Maybe [DeclExtType]))
-> InternaliseM
     (Maybe
        ([VName], [DeclType], [Param DeclType],
         [(SubExp, Type)] -> Maybe [DeclExtType]))
-> InternaliseM
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> InternaliseM (Maybe FunInfo)
lookupFunction' VName
fname
  where
    bad :: InternaliseM
  ([VName], [DeclType], [Param DeclType],
   [(SubExp, Type)] -> Maybe [DeclExtType])
bad = [Char]
-> InternaliseM
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> InternaliseM
      ([VName], [DeclType], [Param DeclType],
       [(SubExp, Type)] -> Maybe [DeclExtType]))
-> [Char]
-> InternaliseM
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall a b. (a -> b) -> a -> b
$ [Char]
"Internalise.lookupFunction: Function '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
fname [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' not found."

lookupConst :: VName -> InternaliseM (Maybe [SubExp])
lookupConst :: VName -> InternaliseM (Maybe Result)
lookupConst VName
fname = (InternaliseState -> Maybe Result) -> InternaliseM (Maybe Result)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((InternaliseState -> Maybe Result) -> InternaliseM (Maybe Result))
-> (InternaliseState -> Maybe Result)
-> InternaliseM (Maybe Result)
forall a b. (a -> b) -> a -> b
$ VName -> VarSubstitutions -> Maybe Result
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
fname (VarSubstitutions -> Maybe Result)
-> (InternaliseState -> VarSubstitutions)
-> InternaliseState
-> Maybe Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InternaliseState -> VarSubstitutions
stateConstSubsts

bindFunction :: VName -> FunDef SOACS -> FunInfo -> InternaliseM ()
bindFunction :: VName -> FunDef SOACS -> FunInfo -> InternaliseM ()
bindFunction VName
fname FunDef SOACS
fd FunInfo
info = do
  FunDef SOACS -> InternaliseM ()
addFunDef FunDef SOACS
fd
  (InternaliseState -> InternaliseState) -> InternaliseM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InternaliseState -> InternaliseState) -> InternaliseM ())
-> (InternaliseState -> InternaliseState) -> InternaliseM ()
forall a b. (a -> b) -> a -> b
$ \InternaliseState
s -> InternaliseState
s {stateFunTable :: FunTable
stateFunTable = VName
-> ([VName], [DeclType], [Param DeclType],
    [(SubExp, Type)] -> Maybe [DeclExtType])
-> Map
     VName
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
-> Map
     VName
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
fname ([VName], [DeclType], [Param DeclType],
 [(SubExp, Type)] -> Maybe [DeclExtType])
FunInfo
info (Map
   VName
   ([VName], [DeclType], [Param DeclType],
    [(SubExp, Type)] -> Maybe [DeclExtType])
 -> Map
      VName
      ([VName], [DeclType], [Param DeclType],
       [(SubExp, Type)] -> Maybe [DeclExtType]))
-> Map
     VName
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
-> Map
     VName
     ([VName], [DeclType], [Param DeclType],
      [(SubExp, Type)] -> Maybe [DeclExtType])
forall a b. (a -> b) -> a -> b
$ InternaliseState -> FunTable
stateFunTable InternaliseState
s}

bindConstant :: VName -> FunDef SOACS -> InternaliseM ()
bindConstant :: VName -> FunDef SOACS -> InternaliseM ()
bindConstant VName
cname FunDef SOACS
fd = do
  let stms :: Stms SOACS
stms = Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> Body SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fd
      substs :: Result
substs =
        Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([DeclExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (FunDef SOACS -> [RetType SOACS]
forall lore. FunDef lore -> [RetType lore]
funDefRetType FunDef SOACS
fd)) (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$
          Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> Body SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fd
  Stms (Lore InternaliseM) -> InternaliseM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore InternaliseM)
Stms SOACS
stms
  (InternaliseState -> InternaliseState) -> InternaliseM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InternaliseState -> InternaliseState) -> InternaliseM ())
-> (InternaliseState -> InternaliseState) -> InternaliseM ()
forall a b. (a -> b) -> a -> b
$ \InternaliseState
s ->
    InternaliseState
s
      { stateConstSubsts :: VarSubstitutions
stateConstSubsts = VName -> Result -> VarSubstitutions -> VarSubstitutions
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
cname Result
substs (VarSubstitutions -> VarSubstitutions)
-> VarSubstitutions -> VarSubstitutions
forall a b. (a -> b) -> a -> b
$ InternaliseState -> VarSubstitutions
stateConstSubsts InternaliseState
s,
        stateConstScope :: Scope SOACS
stateConstScope = Stms SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms SOACS
stms Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> InternaliseState -> Scope SOACS
stateConstScope InternaliseState
s
      }

localConstsScope :: InternaliseM a -> InternaliseM a
localConstsScope :: forall a. InternaliseM a -> InternaliseM a
localConstsScope InternaliseM a
m = do
  Scope SOACS
scope <- (InternaliseState -> Scope SOACS) -> InternaliseM (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets InternaliseState -> Scope SOACS
stateConstScope
  Scope SOACS -> InternaliseM a -> InternaliseM a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope SOACS
scope InternaliseM a
m

-- | Construct an 'Assert' statement, but taking attributes into
-- account.  Always use this function, and never construct 'Assert'
-- directly in the internaliser!
assert ::
  String ->
  SubExp ->
  ErrorMsg SubExp ->
  SrcLoc ->
  InternaliseM Certificates
assert :: [Char]
-> SubExp -> ErrorMsg SubExp -> SrcLoc -> InternaliseM Certificates
assert [Char]
desc SubExp
se ErrorMsg SubExp
msg SrcLoc
loc = InternaliseM VName -> InternaliseM Certificates
assertingOne (InternaliseM VName -> InternaliseM Certificates)
-> InternaliseM VName -> InternaliseM Certificates
forall a b. (a -> b) -> a -> b
$ do
  Attrs
attrs <- (InternaliseEnv -> Attrs) -> InternaliseM Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((InternaliseEnv -> Attrs) -> InternaliseM Attrs)
-> (InternaliseEnv -> Attrs) -> InternaliseM Attrs
forall a b. (a -> b) -> a -> b
$ Attrs -> Attrs
attrsForAssert (Attrs -> Attrs)
-> (InternaliseEnv -> Attrs) -> InternaliseEnv -> Attrs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InternaliseEnv -> Attrs
envAttrs
  Attrs -> InternaliseM VName -> InternaliseM VName
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (InternaliseM VName -> InternaliseM VName)
-> InternaliseM VName -> InternaliseM VName
forall a b. (a -> b) -> a -> b
$
    [Char] -> Exp (Lore InternaliseM) -> InternaliseM VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
desc (Exp (Lore InternaliseM) -> InternaliseM VName)
-> Exp (Lore InternaliseM) -> InternaliseM VName
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp
Assert SubExp
se ErrorMsg SubExp
msg (SrcLoc
loc, [SrcLoc]
forall a. Monoid a => a
mempty)

-- | Execute the given action if 'envDoBoundsChecks' is true, otherwise
-- just return an empty list.
asserting ::
  InternaliseM Certificates ->
  InternaliseM Certificates
asserting :: InternaliseM Certificates -> InternaliseM Certificates
asserting InternaliseM Certificates
m = do
  Bool
doBoundsChecks <- (InternaliseEnv -> Bool) -> InternaliseM Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks InternaliseEnv -> Bool
envDoBoundsChecks
  if Bool
doBoundsChecks
    then InternaliseM Certificates
m
    else Certificates -> InternaliseM Certificates
forall (m :: * -> *) a. Monad m => a -> m a
return Certificates
forall a. Monoid a => a
mempty

-- | Execute the given action if 'envDoBoundsChecks' is true, otherwise
-- just return an empty list.
assertingOne ::
  InternaliseM VName ->
  InternaliseM Certificates
assertingOne :: InternaliseM VName -> InternaliseM Certificates
assertingOne InternaliseM VName
m = InternaliseM Certificates -> InternaliseM Certificates
asserting (InternaliseM Certificates -> InternaliseM Certificates)
-> InternaliseM Certificates -> InternaliseM Certificates
forall a b. (a -> b) -> a -> b
$ [VName] -> Certificates
Certificates ([VName] -> Certificates)
-> (VName -> [VName]) -> VName -> Certificates
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Certificates)
-> InternaliseM VName -> InternaliseM Certificates
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InternaliseM VName
m