{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module defines a convenience monad/typeclass for creating
-- normalised programs.  The fundamental building block is 'BinderT'
-- and its execution functions, but it is usually easier to use
-- 'Binder'.
--
-- See "Futhark.Construct" for a high-level description.
module Futhark.Binder
  ( -- * A concrete @MonadBinder@ monad.
    BinderT,
    runBinderT,
    runBinderT_,
    runBinderT',
    runBinderT'_,
    BinderOps (..),
    Binder,
    runBinder,
    runBinder_,
    runBodyBinder,

    -- * The 'MonadBinder' typeclass
    module Futhark.Binder.Class,
  )
where

import Control.Arrow (second)
import Control.Monad.Error.Class
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import qualified Data.Map.Strict as M
import Futhark.Binder.Class
import Futhark.IR

-- | A 'BinderT' (and by extension, a 'Binder') is only an instance of
-- 'MonadBinder' for lores that implement this type class, which
-- contains methods for constructing statements.
class ASTLore lore => BinderOps lore where
  mkExpDecB ::
    (MonadBinder m, Lore m ~ lore) =>
    Pattern lore ->
    Exp lore ->
    m (ExpDec lore)
  mkBodyB ::
    (MonadBinder m, Lore m ~ lore) =>
    Stms lore ->
    Result ->
    m (Body lore)
  mkLetNamesB ::
    (MonadBinder m, Lore m ~ lore) =>
    [VName] ->
    Exp lore ->
    m (Stm lore)

  default mkExpDecB ::
    (MonadBinder m, Bindable lore) =>
    Pattern lore ->
    Exp lore ->
    m (ExpDec lore)
  mkExpDecB Pattern lore
pat Exp lore
e = ExpDec lore -> m (ExpDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpDec lore -> m (ExpDec lore)) -> ExpDec lore -> m (ExpDec lore)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ExpDec lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpDec lore
mkExpDec Pattern lore
pat Exp lore
e

  default mkBodyB ::
    (MonadBinder m, Bindable lore) =>
    Stms lore ->
    Result ->
    m (Body lore)
  mkBodyB Stms lore
stms Result
res = Body lore -> m (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> m (Body lore)) -> Body lore -> m (Body lore)
forall a b. (a -> b) -> a -> b
$ Stms lore -> Result -> Body lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms lore
stms Result
res

  default mkLetNamesB ::
    (MonadBinder m, Lore m ~ lore, Bindable lore) =>
    [VName] ->
    Exp lore ->
    m (Stm lore)
  mkLetNamesB = [VName] -> Exp lore -> m (Stm lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, HasScope lore m) =>
[VName] -> Exp lore -> m (Stm lore)
mkLetNames

-- | A monad transformer that tracks statements and provides a
-- 'MonadBinder' instance, assuming that the underlying monad provides
-- a name source.  In almost all cases, this is what you will use for
-- constructing statements (possibly as part of a larger monad stack).
-- If you find yourself needing to implement 'MonadBinder' from
-- scratch, then it is likely that you are making a mistake.
newtype BinderT lore m a = BinderT (StateT (Stms lore, Scope lore) m a)
  deriving ((forall a b. (a -> b) -> BinderT lore m a -> BinderT lore m b)
-> (forall a b. a -> BinderT lore m b -> BinderT lore m a)
-> Functor (BinderT lore m)
forall a b. a -> BinderT lore m b -> BinderT lore m a
forall a b. (a -> b) -> BinderT lore m a -> BinderT lore m b
forall lore (m :: * -> *) a b.
Functor m =>
a -> BinderT lore m b -> BinderT lore m a
forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> BinderT lore m a -> BinderT lore m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> BinderT lore m b -> BinderT lore m a
$c<$ :: forall lore (m :: * -> *) a b.
Functor m =>
a -> BinderT lore m b -> BinderT lore m a
fmap :: forall a b. (a -> b) -> BinderT lore m a -> BinderT lore m b
$cfmap :: forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> BinderT lore m a -> BinderT lore m b
Functor, Applicative (BinderT lore m)
Applicative (BinderT lore m)
-> (forall a b.
    BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b)
-> (forall a b.
    BinderT lore m a -> BinderT lore m b -> BinderT lore m b)
-> (forall a. a -> BinderT lore m a)
-> Monad (BinderT lore m)
forall a. a -> BinderT lore m a
forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall a b.
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
forall {lore} {m :: * -> *}.
Monad m =>
Applicative (BinderT lore m)
forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> BinderT lore m a
$creturn :: forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
>> :: forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
$c>> :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
>>= :: forall a b.
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
$c>>= :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
Monad, Functor (BinderT lore m)
Functor (BinderT lore m)
-> (forall a. a -> BinderT lore m a)
-> (forall a b.
    BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b)
-> (forall a b c.
    (a -> b -> c)
    -> BinderT lore m a -> BinderT lore m b -> BinderT lore m c)
-> (forall a b.
    BinderT lore m a -> BinderT lore m b -> BinderT lore m b)
-> (forall a b.
    BinderT lore m a -> BinderT lore m b -> BinderT lore m a)
-> Applicative (BinderT lore m)
forall a. a -> BinderT lore m a
forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall a b.
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
forall a b c.
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
forall {lore} {m :: * -> *}. Monad m => Functor (BinderT lore m)
forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
forall lore (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
$c<* :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
*> :: forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
$c*> :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
liftA2 :: forall a b c.
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
$cliftA2 :: forall lore (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
<*> :: forall a b.
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
$c<*> :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
pure :: forall a. a -> BinderT lore m a
$cpure :: forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
Applicative)

instance MonadTrans (BinderT lore) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> BinderT lore m a
lift = StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m a -> BinderT lore m a)
-> (m a -> StateT (Stms lore, Scope lore) m a)
-> m a
-> BinderT lore m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT (Stms lore, Scope lore) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | The most commonly used binder monad.
type Binder lore = BinderT lore (State VNameSource)

instance MonadFreshNames m => MonadFreshNames (BinderT lore m) where
  getNameSource :: BinderT lore m VNameSource
getNameSource = m VNameSource -> BinderT lore m VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> BinderT lore m ()
putNameSource = m () -> BinderT lore m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> BinderT lore m ())
-> (VNameSource -> m ()) -> VNameSource -> BinderT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance
  (ASTLore lore, Monad m) =>
  HasScope lore (BinderT lore m)
  where
  lookupType :: VName -> BinderT lore m Type
lookupType VName
name = do
    Maybe (NameInfo lore)
t <- StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
-> BinderT lore m (Maybe (NameInfo lore))
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
 -> BinderT lore m (Maybe (NameInfo lore)))
-> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
-> BinderT lore m (Maybe (NameInfo lore))
forall a b. (a -> b) -> a -> b
$ ((Stms lore, Scope lore) -> Maybe (NameInfo lore))
-> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((Stms lore, Scope lore) -> Maybe (NameInfo lore))
 -> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore)))
-> ((Stms lore, Scope lore) -> Maybe (NameInfo lore))
-> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
forall a b. (a -> b) -> a -> b
$ VName -> Scope lore -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Scope lore -> Maybe (NameInfo lore))
-> ((Stms lore, Scope lore) -> Scope lore)
-> (Stms lore, Scope lore)
-> Maybe (NameInfo lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms lore, Scope lore) -> Scope lore
forall a b. (a, b) -> b
snd
    case Maybe (NameInfo lore)
t of
      Maybe (NameInfo lore)
Nothing -> [Char] -> BinderT lore m Type
forall a. HasCallStack => [Char] -> a
error ([Char] -> BinderT lore m Type) -> [Char] -> BinderT lore m Type
forall a b. (a -> b) -> a -> b
$ [Char]
"BinderT.lookupType: unknown variable " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
      Just NameInfo lore
t' -> Type -> BinderT lore m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> BinderT lore m Type) -> Type -> BinderT lore m Type
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> Type
forall t. Typed t => t -> Type
typeOf NameInfo lore
t'
  askScope :: BinderT lore m (Scope lore)
askScope = StateT (Stms lore, Scope lore) m (Scope lore)
-> BinderT lore m (Scope lore)
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m (Scope lore)
 -> BinderT lore m (Scope lore))
-> StateT (Stms lore, Scope lore) m (Scope lore)
-> BinderT lore m (Scope lore)
forall a b. (a -> b) -> a -> b
$ ((Stms lore, Scope lore) -> Scope lore)
-> StateT (Stms lore, Scope lore) m (Scope lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Stms lore, Scope lore) -> Scope lore
forall a b. (a, b) -> b
snd

instance
  (ASTLore lore, Monad m) =>
  LocalScope lore (BinderT lore m)
  where
  localScope :: forall a. Scope lore -> BinderT lore m a -> BinderT lore m a
localScope Scope lore
types (BinderT StateT (Stms lore, Scope lore) m a
m) = StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m a -> BinderT lore m a)
-> StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ do
    ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Stms lore, Scope lore) -> (Stms lore, Scope lore))
 -> StateT (Stms lore, Scope lore) m ())
-> ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall a b. (a -> b) -> a -> b
$ (Scope lore -> Scope lore)
-> (Stms lore, Scope lore) -> (Stms lore, Scope lore)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Scope lore -> Scope lore -> Scope lore
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Scope lore
types)
    a
x <- StateT (Stms lore, Scope lore) m a
m
    ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Stms lore, Scope lore) -> (Stms lore, Scope lore))
 -> StateT (Stms lore, Scope lore) m ())
-> ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall a b. (a -> b) -> a -> b
$ (Scope lore -> Scope lore)
-> (Stms lore, Scope lore) -> (Stms lore, Scope lore)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Scope lore -> Scope lore -> Scope lore
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Scope lore
types)
    a -> StateT (Stms lore, Scope lore) m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

instance
  (ASTLore lore, MonadFreshNames m, BinderOps lore) =>
  MonadBinder (BinderT lore m)
  where
  type Lore (BinderT lore m) = lore
  mkExpDecM :: Pattern (Lore (BinderT lore m))
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (ExpDec (Lore (BinderT lore m)))
mkExpDecM = Pattern (Lore (BinderT lore m))
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (ExpDec (Lore (BinderT lore m)))
forall lore (m :: * -> *).
(BinderOps lore, MonadBinder m, Lore m ~ lore) =>
Pattern lore -> Exp lore -> m (ExpDec lore)
mkExpDecB
  mkBodyM :: Stms (Lore (BinderT lore m))
-> Result -> BinderT lore m (Body (Lore (BinderT lore m)))
mkBodyM = Stms (Lore (BinderT lore m))
-> Result -> BinderT lore m (Body (Lore (BinderT lore m)))
forall lore (m :: * -> *).
(BinderOps lore, MonadBinder m, Lore m ~ lore) =>
Stms lore -> Result -> m (Body lore)
mkBodyB
  mkLetNamesM :: [VName]
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (Stm (Lore (BinderT lore m)))
mkLetNamesM = [VName]
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (Stm (Lore (BinderT lore m)))
forall lore (m :: * -> *).
(BinderOps lore, MonadBinder m, Lore m ~ lore) =>
[VName] -> Exp lore -> m (Stm lore)
mkLetNamesB

  addStms :: Stms (Lore (BinderT lore m)) -> BinderT lore m ()
addStms Stms (Lore (BinderT lore m))
stms =
    StateT (Stms lore, Map VName (NameInfo lore)) m ()
-> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Map VName (NameInfo lore)) m ()
 -> BinderT lore m ())
-> StateT (Stms lore, Map VName (NameInfo lore)) m ()
-> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$
      ((Stms lore, Map VName (NameInfo lore))
 -> (Stms lore, Map VName (NameInfo lore)))
-> StateT (Stms lore, Map VName (NameInfo lore)) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Stms lore, Map VName (NameInfo lore))
  -> (Stms lore, Map VName (NameInfo lore)))
 -> StateT (Stms lore, Map VName (NameInfo lore)) m ())
-> ((Stms lore, Map VName (NameInfo lore))
    -> (Stms lore, Map VName (NameInfo lore)))
-> StateT (Stms lore, Map VName (NameInfo lore)) m ()
forall a b. (a -> b) -> a -> b
$ \(Stms lore
cur_stms, Map VName (NameInfo lore)
scope) ->
        (Stms lore
cur_stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
Stms (Lore (BinderT lore m))
stms, Map VName (NameInfo lore)
scope Map VName (NameInfo lore)
-> Map VName (NameInfo lore) -> Map VName (NameInfo lore)
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
Stms (Lore (BinderT lore m))
stms)

  collectStms :: forall a.
BinderT lore m a
-> BinderT lore m (a, Stms (Lore (BinderT lore m)))
collectStms BinderT lore m a
m = do
    (Stms lore
old_stms, Map VName (NameInfo lore)
old_scope) <- StateT
  (Stms lore, Map VName (NameInfo lore))
  m
  (Stms lore, Map VName (NameInfo lore))
-> BinderT lore m (Stms lore, Map VName (NameInfo lore))
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT StateT
  (Stms lore, Map VName (NameInfo lore))
  m
  (Stms lore, Map VName (NameInfo lore))
forall s (m :: * -> *). MonadState s m => m s
get
    StateT (Stms lore, Map VName (NameInfo lore)) m ()
-> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Map VName (NameInfo lore)) m ()
 -> BinderT lore m ())
-> StateT (Stms lore, Map VName (NameInfo lore)) m ()
-> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stms lore, Map VName (NameInfo lore))
-> StateT (Stms lore, Map VName (NameInfo lore)) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Stms lore
forall a. Monoid a => a
mempty, Map VName (NameInfo lore)
old_scope)
    a
x <- BinderT lore m a
m
    (Stms lore
new_stms, Map VName (NameInfo lore)
_) <- StateT
  (Stms lore, Map VName (NameInfo lore))
  m
  (Stms lore, Map VName (NameInfo lore))
-> BinderT lore m (Stms lore, Map VName (NameInfo lore))
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT StateT
  (Stms lore, Map VName (NameInfo lore))
  m
  (Stms lore, Map VName (NameInfo lore))
forall s (m :: * -> *). MonadState s m => m s
get
    StateT (Stms lore, Map VName (NameInfo lore)) m ()
-> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Map VName (NameInfo lore)) m ()
 -> BinderT lore m ())
-> StateT (Stms lore, Map VName (NameInfo lore)) m ()
-> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stms lore, Map VName (NameInfo lore))
-> StateT (Stms lore, Map VName (NameInfo lore)) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Stms lore
old_stms, Map VName (NameInfo lore)
old_scope)
    (a, Stms lore) -> BinderT lore m (a, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stms lore
new_stms)

-- | Run a binder action given an initial scope, returning a value and
-- the statements added ('addStm') during the action.
runBinderT ::
  MonadFreshNames m =>
  BinderT lore m a ->
  Scope lore ->
  m (a, Stms lore)
runBinderT :: forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (BinderT StateT (Stms lore, Scope lore) m a
m) Scope lore
scope = do
  (a
x, (Stms lore
stms, Scope lore
_)) <- StateT (Stms lore, Scope lore) m a
-> (Stms lore, Scope lore) -> m (a, (Stms lore, Scope lore))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Stms lore, Scope lore) m a
m (Stms lore
forall a. Monoid a => a
mempty, Scope lore
scope)
  (a, Stms lore) -> m (a, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stms lore
stms)

-- | Like 'runBinderT', but return only the statements.
runBinderT_ ::
  MonadFreshNames m =>
  BinderT lore m () ->
  Scope lore ->
  m (Stms lore)
runBinderT_ :: forall (m :: * -> *) lore.
MonadFreshNames m =>
BinderT lore m () -> Scope lore -> m (Stms lore)
runBinderT_ BinderT lore m ()
m = (((), Stms lore) -> Stms lore)
-> m ((), Stms lore) -> m (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Stms lore) -> Stms lore
forall a b. (a, b) -> b
snd (m ((), Stms lore) -> m (Stms lore))
-> (Scope lore -> m ((), Stms lore)) -> Scope lore -> m (Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderT lore m () -> Scope lore -> m ((), Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT lore m ()
m

-- | Like 'runBinderT', but get the initial scope from the current
-- monad.
runBinderT' ::
  (MonadFreshNames m, HasScope somelore m, SameScope somelore lore) =>
  BinderT lore m a ->
  m (a, Stms lore)
runBinderT' :: forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' BinderT lore m a
m = do
  Scope somelore
scope <- m (Scope somelore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  BinderT lore m a -> Scope lore -> m (a, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT lore m a
m (Scope lore -> m (a, Stms lore)) -> Scope lore -> m (a, Stms lore)
forall a b. (a -> b) -> a -> b
$ Scope somelore -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope Scope somelore
scope

-- | Like 'runBinderT_', but get the initial scope from the current
-- monad.
runBinderT'_ ::
  (MonadFreshNames m, HasScope somelore m, SameScope somelore lore) =>
  BinderT lore m a ->
  m (Stms lore)
runBinderT'_ :: forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ = ((a, Stms lore) -> Stms lore) -> m (a, Stms lore) -> m (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms lore) -> Stms lore
forall a b. (a, b) -> b
snd (m (a, Stms lore) -> m (Stms lore))
-> (BinderT lore m a -> m (a, Stms lore))
-> BinderT lore m a
-> m (Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderT lore m a -> m (a, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT'

-- | Run a binder action, returning a value and the statements added
-- ('addStm') during the action.  Assumes that the current monad
-- provides initial scope and name source.
runBinder ::
  ( MonadFreshNames m,
    HasScope somelore m,
    SameScope somelore lore
  ) =>
  Binder lore a ->
  m (a, Stms lore)
runBinder :: forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder Binder lore a
m = do
  Scope somelore
types <- m (Scope somelore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (VNameSource -> ((a, Stms lore), VNameSource)) -> m (a, Stms lore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms lore), VNameSource))
 -> m (a, Stms lore))
-> (VNameSource -> ((a, Stms lore), VNameSource))
-> m (a, Stms lore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms lore)
-> VNameSource -> ((a, Stms lore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms lore)
 -> VNameSource -> ((a, Stms lore), VNameSource))
-> State VNameSource (a, Stms lore)
-> VNameSource
-> ((a, Stms lore), VNameSource)
forall a b. (a -> b) -> a -> b
$ Binder lore a -> Scope lore -> State VNameSource (a, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Binder lore a
m (Scope lore -> State VNameSource (a, Stms lore))
-> Scope lore -> State VNameSource (a, Stms lore)
forall a b. (a -> b) -> a -> b
$ Scope somelore -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope Scope somelore
types

-- | Like 'runBinder', but throw away the result and just return the
-- added statements.
runBinder_ ::
  ( MonadFreshNames m,
    HasScope somelore m,
    SameScope somelore lore
  ) =>
  Binder lore a ->
  m (Stms lore)
runBinder_ :: forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ = ((a, Stms lore) -> Stms lore) -> m (a, Stms lore) -> m (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms lore) -> Stms lore
forall a b. (a, b) -> b
snd (m (a, Stms lore) -> m (Stms lore))
-> (Binder lore a -> m (a, Stms lore))
-> Binder lore a
-> m (Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore a -> m (a, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder

-- | Run a binder that produces a t'Body', and prefix that t'Body' by
-- the statements produced during execution of the action.
runBodyBinder ::
  ( Bindable lore,
    MonadFreshNames m,
    HasScope somelore m,
    SameScope somelore lore
  ) =>
  Binder lore (Body lore) ->
  m (Body lore)
runBodyBinder :: forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder = ((Body lore, Stms lore) -> Body lore)
-> m (Body lore, Stms lore) -> m (Body lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Body lore -> Stms lore -> Body lore)
-> (Body lore, Stms lore) -> Body lore
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Body lore -> Stms lore -> Body lore)
 -> (Body lore, Stms lore) -> Body lore)
-> (Body lore -> Stms lore -> Body lore)
-> (Body lore, Stms lore)
-> Body lore
forall a b. (a -> b) -> a -> b
$ (Stms lore -> Body lore -> Body lore)
-> Body lore -> Stms lore -> Body lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stms lore -> Body lore -> Body lore
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms) (m (Body lore, Stms lore) -> m (Body lore))
-> (Binder lore (Body lore) -> m (Body lore, Stms lore))
-> Binder lore (Body lore)
-> m (Body lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore (Body lore) -> m (Body lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder

-- Utility instance defintions for MTL classes.  These require
-- UndecidableInstances, but save on typing elsewhere.

mapInner ::
  Monad m =>
  ( m (a, (Stms lore, Scope lore)) ->
    m (b, (Stms lore, Scope lore))
  ) ->
  BinderT lore m a ->
  BinderT lore m b
mapInner :: forall (m :: * -> *) a lore b.
Monad m =>
(m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore))
f (BinderT StateT (Stms lore, Scope lore) m a
m) = StateT (Stms lore, Scope lore) m b -> BinderT lore m b
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m b -> BinderT lore m b)
-> StateT (Stms lore, Scope lore) m b -> BinderT lore m b
forall a b. (a -> b) -> a -> b
$ do
  (Stms lore, Scope lore)
s <- StateT (Stms lore, Scope lore) m (Stms lore, Scope lore)
forall s (m :: * -> *). MonadState s m => m s
get
  (b
x, (Stms lore, Scope lore)
s') <- m (b, (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m (b, (Stms lore, Scope lore))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (b, (Stms lore, Scope lore))
 -> StateT (Stms lore, Scope lore) m (b, (Stms lore, Scope lore)))
-> m (b, (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m (b, (Stms lore, Scope lore))
forall a b. (a -> b) -> a -> b
$ m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore))
f (m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore))
forall a b. (a -> b) -> a -> b
$ StateT (Stms lore, Scope lore) m a
-> (Stms lore, Scope lore) -> m (a, (Stms lore, Scope lore))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Stms lore, Scope lore) m a
m (Stms lore, Scope lore)
s
  (Stms lore, Scope lore) -> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Stms lore, Scope lore)
s'
  b -> StateT (Stms lore, Scope lore) m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x

instance MonadReader r m => MonadReader r (BinderT lore m) where
  ask :: BinderT lore m r
ask = StateT (Stms lore, Scope lore) m r -> BinderT lore m r
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m r -> BinderT lore m r)
-> StateT (Stms lore, Scope lore) m r -> BinderT lore m r
forall a b. (a -> b) -> a -> b
$ m r -> StateT (Stms lore, Scope lore) m r
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m r
forall r (m :: * -> *). MonadReader r m => m r
ask
  local :: forall a. (r -> r) -> BinderT lore m a -> BinderT lore m a
local r -> r
f = (m (a, (Stms lore, Scope lore)) -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m a
forall (m :: * -> *) a lore b.
Monad m =>
(m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner ((m (a, (Stms lore, Scope lore)) -> m (a, (Stms lore, Scope lore)))
 -> BinderT lore m a -> BinderT lore m a)
-> (m (a, (Stms lore, Scope lore))
    -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m a
-> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ (r -> r)
-> m (a, (Stms lore, Scope lore)) -> m (a, (Stms lore, Scope lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local r -> r
f

instance MonadState s m => MonadState s (BinderT lore m) where
  get :: BinderT lore m s
get = StateT (Stms lore, Scope lore) m s -> BinderT lore m s
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m s -> BinderT lore m s)
-> StateT (Stms lore, Scope lore) m s -> BinderT lore m s
forall a b. (a -> b) -> a -> b
$ m s -> StateT (Stms lore, Scope lore) m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> BinderT lore m ()
put = StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m () -> BinderT lore m ())
-> (s -> StateT (Stms lore, Scope lore) m ())
-> s
-> BinderT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> StateT (Stms lore, Scope lore) m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT (Stms lore, Scope lore) m ())
-> (s -> m ()) -> s -> StateT (Stms lore, Scope lore) m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance MonadWriter w m => MonadWriter w (BinderT lore m) where
  tell :: w -> BinderT lore m ()
tell = StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m () -> BinderT lore m ())
-> (w -> StateT (Stms lore, Scope lore) m ())
-> w
-> BinderT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> StateT (Stms lore, Scope lore) m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT (Stms lore, Scope lore) m ())
-> (w -> m ()) -> w -> StateT (Stms lore, Scope lore) m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. w -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
  pass :: forall a. BinderT lore m (a, w -> w) -> BinderT lore m a
pass = (m ((a, w -> w), (Stms lore, Scope lore))
 -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m (a, w -> w) -> BinderT lore m a
forall (m :: * -> *) a lore b.
Monad m =>
(m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner ((m ((a, w -> w), (Stms lore, Scope lore))
  -> m (a, (Stms lore, Scope lore)))
 -> BinderT lore m (a, w -> w) -> BinderT lore m a)
-> (m ((a, w -> w), (Stms lore, Scope lore))
    -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m (a, w -> w)
-> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ \m ((a, w -> w), (Stms lore, Scope lore))
m -> m ((a, (Stms lore, Scope lore)), w -> w)
-> m (a, (Stms lore, Scope lore))
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (m ((a, (Stms lore, Scope lore)), w -> w)
 -> m (a, (Stms lore, Scope lore)))
-> m ((a, (Stms lore, Scope lore)), w -> w)
-> m (a, (Stms lore, Scope lore))
forall a b. (a -> b) -> a -> b
$ do
    ((a
x, w -> w
f), (Stms lore, Scope lore)
s) <- m ((a, w -> w), (Stms lore, Scope lore))
m
    ((a, (Stms lore, Scope lore)), w -> w)
-> m ((a, (Stms lore, Scope lore)), w -> w)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, (Stms lore, Scope lore)
s), w -> w
f)
  listen :: forall a. BinderT lore m a -> BinderT lore m (a, w)
listen = (m (a, (Stms lore, Scope lore))
 -> m ((a, w), (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m (a, w)
forall (m :: * -> *) a lore b.
Monad m =>
(m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner ((m (a, (Stms lore, Scope lore))
  -> m ((a, w), (Stms lore, Scope lore)))
 -> BinderT lore m a -> BinderT lore m (a, w))
-> (m (a, (Stms lore, Scope lore))
    -> m ((a, w), (Stms lore, Scope lore)))
-> BinderT lore m a
-> BinderT lore m (a, w)
forall a b. (a -> b) -> a -> b
$ \m (a, (Stms lore, Scope lore))
m -> do
    ((a
x, (Stms lore, Scope lore)
s), w
y) <- m (a, (Stms lore, Scope lore))
-> m ((a, (Stms lore, Scope lore)), w)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen m (a, (Stms lore, Scope lore))
m
    ((a, w), (Stms lore, Scope lore))
-> m ((a, w), (Stms lore, Scope lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, w
y), (Stms lore, Scope lore)
s)

instance MonadError e m => MonadError e (BinderT lore m) where
  throwError :: forall a. e -> BinderT lore m a
throwError = m a -> BinderT lore m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> BinderT lore m a) -> (e -> m a) -> e -> BinderT lore m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: forall a.
BinderT lore m a -> (e -> BinderT lore m a) -> BinderT lore m a
catchError (BinderT StateT (Stms lore, Scope lore) m a
m) e -> BinderT lore m a
f =
    StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m a -> BinderT lore m a)
-> StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ StateT (Stms lore, Scope lore) m a
-> (e -> StateT (Stms lore, Scope lore) m a)
-> StateT (Stms lore, Scope lore) m a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError StateT (Stms lore, Scope lore) m a
m ((e -> StateT (Stms lore, Scope lore) m a)
 -> StateT (Stms lore, Scope lore) m a)
-> (e -> StateT (Stms lore, Scope lore) m a)
-> StateT (Stms lore, Scope lore) m a
forall a b. (a -> b) -> a -> b
$ BinderT lore m a -> StateT (Stms lore, Scope lore) m a
forall {lore} {m :: * -> *} {a}.
BinderT lore m a -> StateT (Stms lore, Scope lore) m a
unBinder (BinderT lore m a -> StateT (Stms lore, Scope lore) m a)
-> (e -> BinderT lore m a)
-> e
-> StateT (Stms lore, Scope lore) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> BinderT lore m a
f
    where
      unBinder :: BinderT lore m a -> StateT (Stms lore, Scope lore) m a
unBinder (BinderT StateT (Stms lore, Scope lore) m a
m') = StateT (Stms lore, Scope lore) m a
m'