{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Rewrite.Types where
import Control.Concurrent.Supply (Supply, freshId)
import Control.Lens (use, (.=), (<<%=))
import Control.Monad
import Control.Monad.Fix (MonadFix (..), fix)
import Control.Monad.Reader (MonadReader (..))
import Control.Monad.State (MonadState (..))
import Control.Monad.Writer (MonadWriter (..))
import Data.HashMap.Strict (HashMap)
import Data.HashSet (HashSet)
import Data.IntMap.Strict (IntMap)
import Data.Monoid (Any)
import Unbound.Generics.LocallyNameless (Fresh (..))
import Unbound.Generics.LocallyNameless.Name (Name (..))
import SrcLoc (SrcSpan)
import Clash.Core.Evaluator (GlobalHeap, PrimEvaluator)
import Clash.Core.Term (Term, TmName, TmOccName)
import Clash.Core.Type (Type)
import Clash.Core.TyCon (TyCon, TyConName, TyConOccName)
import Clash.Core.Var (Id, TyVar)
import Clash.Driver.Types (BindingMap, DebugLevel)
import Clash.Netlist.Types (HWType)
import Clash.Util
data CoreContext
= AppFun
| AppArg
| TyAppC
| LetBinding Id [Id]
| LetBody [Id]
| LamBody Id
| TyLamBody TyVar
| CaseAlt [TyVar] [Id]
| CaseScrut
| CastBody
deriving (Eq,Show)
data RewriteState extra
= RewriteState
{ _transformCounter :: {-# UNPACK #-} !Int
, _bindings :: !BindingMap
, _uniqSupply :: !Supply
, _curFun :: (TmName,SrcSpan)
, _nameCounter :: {-# UNPACK #-} !Int
, _globalHeap :: GlobalHeap
, _extra :: !extra
}
makeLenses ''RewriteState
data RewriteEnv
= RewriteEnv
{ _dbgLevel :: DebugLevel
, _typeTranslator :: HashMap TyConOccName TyCon -> Bool -> Type
-> Maybe (Either String HWType)
, _tcCache :: HashMap TyConOccName TyCon
, _tupleTcCache :: IntMap TyConName
, _evaluator :: PrimEvaluator
, _allowZero :: Bool
, _topEntities :: HashSet TmOccName
}
makeLenses ''RewriteEnv
newtype RewriteMonad extra a = R
{ runR :: RewriteEnv -> RewriteState extra -> (a,RewriteState extra,Any) }
instance Functor (RewriteMonad extra) where
fmap f m = R (\r s -> case runR m r s of (a,s',w) -> (f a,s',w))
instance Applicative (RewriteMonad extra) where
pure = return
(<*>) = ap
instance Monad (RewriteMonad extra) where
return a = R (\_ s -> (a, s, mempty))
m >>= k = R (\r s -> case runR m r s of
(a,s',w) -> case runR (k a) r s' of
(b,s'',w') -> let w'' = mappend w w'
in seq w'' (b,s'',w''))
instance MonadState (RewriteState extra) (RewriteMonad extra) where
get = R (\_ s -> (s,s,mempty))
put s = R (\_ _ -> ((),s,mempty))
state f = R (\_ s -> case f s of (a,s') -> (a,s',mempty))
instance Fresh (RewriteMonad extra) where
fresh (Fn s _) = do
n <- nameCounter <<%= (+1)
let n' = toInteger n
n' `seq` return (Fn s n')
fresh nm@(Bn {}) = return nm
instance MonadUnique (RewriteMonad extra) where
getUniqueM = do
sup <- use uniqSupply
let (a,sup') = freshId sup
uniqSupply .= sup'
a `seq` return a
instance MonadWriter Any (RewriteMonad extra) where
writer (a,w) = R (\_ s -> (a,s,w))
tell w = R (\_ s -> ((),s,w))
listen m = R (\r s -> case runR m r s of (a,s',w) -> ((a,w),s',w))
pass m = R (\r s -> case runR m r s of ((a,f),s',w) -> (a, s', f w))
instance MonadReader RewriteEnv (RewriteMonad extra) where
ask = R (\r s -> (r,s,mempty))
local f m = R (\r s -> runR m (f r) s)
reader f = R (\r s -> (f r,s,mempty))
instance MonadFix (RewriteMonad extra) where
mfix f = R (\r s -> fix $ \ ~(a,_,_) -> runR (f a) r s)
type Transform m = [CoreContext] -> Term -> m Term
type Rewrite extra = Transform (RewriteMonad extra)