{-# LANGUAGE FlexibleContexts, GeneralizedNewtypeDeriving, TypeFamilies, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
{-# LANGUAGE ConstraintKinds #-}
-- | This module defines a convenience monad/typeclass for creating
-- normalised programs.
module Futhark.Binder
  ( -- * A concrete @MonadBinder@ monad.
    BinderT
  , runBinderT
  , BinderOps (..)
  , bindableMkExpAttrB
  , bindableMkBodyB
  , bindableMkLetNamesB
  , Binder
  , runBinder
  , runBinder_
  , joinBinder
  , runBodyBinder
  -- * Non-class interface
  , addBinderStms
  , collectBinderStms
  , certifyingBinder
  -- * The 'MonadBinder' typeclass
  , module Futhark.Binder.Class
  )
where

import Control.Arrow (second)
import Control.Monad.Writer
import Control.Monad.State.Strict
import Control.Monad.Reader
import Control.Monad.Error.Class
import qualified Data.Map.Strict as M
import qualified Control.Monad.Fail as Fail

import Futhark.Binder.Class
import Futhark.Representation.AST

class Attributes lore => BinderOps lore where
  mkExpAttrB :: (MonadBinder m, Lore m ~ lore) =>
                Pattern lore -> Exp lore -> m (ExpAttr lore)
  mkBodyB :: (MonadBinder m, Lore m ~ lore) =>
             Stms lore -> Result -> m (Body lore)
  mkLetNamesB :: (MonadBinder m, Lore m ~ lore) =>
                 [VName] -> Exp lore -> m (Stm lore)

bindableMkExpAttrB :: (MonadBinder m, Bindable (Lore m)) =>
                      Pattern (Lore m) -> Exp (Lore m) -> m (ExpAttr (Lore m))
bindableMkExpAttrB pat e = return $ mkExpAttr pat e

bindableMkBodyB :: (MonadBinder m, Bindable (Lore m)) =>
                   Stms (Lore m) -> Result -> m (Body (Lore m))
bindableMkBodyB stms res = return $ mkBody stms res

bindableMkLetNamesB :: (MonadBinder m, Bindable (Lore m)) =>
                       [VName] -> Exp (Lore m) -> m (Stm (Lore m))
bindableMkLetNamesB = mkLetNames

newtype BinderT lore m a = BinderT (StateT (Stms lore, Scope lore) m a)
  deriving (Functor, Monad, Applicative)

instance MonadTrans (BinderT lore) where
  lift = BinderT . lift

instance Monad m => Fail.MonadFail (BinderT lore m) where
  fail = error . ("BinderT.fail: "++)

type Binder lore = BinderT lore (State VNameSource)

instance MonadFreshNames m => MonadFreshNames (BinderT lore m) where
  getNameSource = lift getNameSource
  putNameSource = lift . putNameSource

instance (Attributes lore, Monad m) =>
         HasScope lore (BinderT lore m) where
  lookupType name = do
    t <- BinderT $ gets $ M.lookup name . snd
    case t of
      Nothing -> fail $ "BinderT.lookupType: unknown variable " ++ pretty name
      Just t' -> return $ typeOf t'
  askScope = BinderT $ gets snd

instance (Attributes lore, Monad m) =>
         LocalScope lore (BinderT lore m) where
  localScope types (BinderT m) = BinderT $ do
    modify $ second (M.union types)
    x <- m
    modify $ second (`M.difference` types)
    return x

instance (Attributes lore, MonadFreshNames m, BinderOps lore) =>
         MonadBinder (BinderT lore m) where
  type Lore (BinderT lore m) = lore
  mkExpAttrM = mkExpAttrB
  mkBodyM = mkBodyB
  mkLetNamesM = mkLetNamesB

  addStms     = addBinderStms
  collectStms = collectBinderStms

  certifying = certifyingBinder

runBinderT :: MonadFreshNames m =>
              BinderT lore m a
           -> Scope lore
           -> m (a, Stms lore)
runBinderT (BinderT m) scope = do
  (x, (stms, _)) <- runStateT m (mempty, scope)
  return (x, stms)

runBinder :: (MonadFreshNames m,
              HasScope somelore m, SameScope somelore lore) =>
              Binder lore a
           -> m (a, Stms lore)
runBinder m = do
  types <- askScope
  modifyNameSource $ runState $ runBinderT m $ castScope types

-- | Like 'runBinder', but throw away the result and just return the
-- added bindings.
runBinder_ :: (MonadFreshNames m,
               HasScope somelore m, SameScope somelore lore) =>
              Binder lore a
           -> m (Stms lore)
runBinder_ = fmap snd . runBinder

-- | As 'runBinder', but uses 'addStm' to add the returned
-- bindings to the surrounding monad.
joinBinder :: MonadBinder m => Binder (Lore m) a -> m a
joinBinder m = do (x, bnds) <- runBinder m
                  addStms bnds
                  return x

runBodyBinder :: (Bindable lore, MonadFreshNames m,
                  HasScope somelore m, SameScope somelore lore) =>
                 Binder lore (Body lore) -> m (Body lore)
runBodyBinder = fmap (uncurry $ flip insertStms) . runBinder

addBinderStms :: Monad m =>
                 Stms lore -> BinderT lore m ()
addBinderStms stms = BinderT $
  modify $ \(cur_stms,scope) -> (cur_stms<>stms,
                                 scope `M.union` scopeOf stms)

collectBinderStms :: Monad m =>
                     BinderT lore m a
                  -> BinderT lore m (a, Stms lore)
collectBinderStms m = do
  (old_stms, old_scope) <- BinderT get
  BinderT $ put (mempty, old_scope)
  x <- m
  (new_stms, _) <- BinderT get
  BinderT $ put (old_stms, old_scope)
  return (x, new_stms)

certifyingBinder :: (MonadFreshNames m, BinderOps lore) =>
                    Certificates -> BinderT lore m a
                 -> BinderT lore m a
certifyingBinder cs m = do
  (x, stms) <- collectStms m
  addStms $ certify cs <$> stms
  return x

-- Utility instance defintions for MTL classes.  These require
-- UndecidableInstances, but save on typing elsewhere.

mapInner :: Monad m =>
            (m (a, (Stms lore, Scope lore))
             -> m (b, (Stms lore, Scope lore)))
         -> BinderT lore m a -> BinderT lore m b
mapInner f (BinderT m) = BinderT $ do
  s <- get
  (x, s') <- lift $ f $ runStateT m s
  put s'
  return x

instance MonadReader r m => MonadReader r (BinderT lore m) where
  ask = BinderT $ lift ask
  local f = mapInner $ local f

instance MonadState s m => MonadState s (BinderT lore m) where
  get = BinderT $ lift get
  put = BinderT . lift . put

instance MonadWriter w m => MonadWriter w (BinderT lore m) where
  tell = BinderT . lift . tell
  pass = mapInner $ \m -> pass $ do
    ((x, f), s) <- m
    return ((x, s), f)
  listen = mapInner $ \m -> do
    ((x, s), y) <- listen m
    return ((x, y), s)

instance MonadError e m => MonadError e (BinderT lore m) where
  throwError = lift . throwError
  catchError (BinderT m) f =
    BinderT $ catchError m $ unBinder . f
    where unBinder (BinderT m') = m'