{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Rewrite.Types where
import Control.Concurrent.Supply (Supply, freshId)
import Control.DeepSeq (NFData)
import Control.Lens (use, (.=))
#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail(fail))
#endif
import Control.Monad.Fix (MonadFix (..), fix)
import Control.Monad.Reader (MonadReader (..))
import Control.Monad.State (MonadState (..))
import Control.Monad.State.Strict (State)
import Control.Monad.Writer (MonadWriter (..))
import Data.Binary (Binary)
import Data.Hashable (Hashable)
import Data.IntMap.Strict (IntMap)
import Data.Monoid (Any)
import qualified Data.Set as Set
import GHC.Generics
import Clash.Core.Evaluator.Types (PrimHeap, PrimStep, PrimUnwind)
import Clash.Core.Term (Term, Context)
import Clash.Core.Type (Type)
import Clash.Core.TyCon (TyConName, TyConMap)
import Clash.Core.Var (Id)
import Clash.Core.VarEnv (InScopeSet, VarSet, VarEnv)
import Clash.Driver.Types (BindingMap, DebugLevel)
import Clash.Netlist.Types (FilteredHWType, HWMap)
import Clash.Util
import Clash.Annotations.BitRepresentation.Internal (CustomReprs)
data RewriteStep
= RewriteStep
{ t_ctx :: Context
, t_name :: String
, t_bndrS :: String
, t_before :: Term
, t_after :: Term
} deriving (Show, Generic, NFData, Hashable, Binary)
data RewriteState extra
= RewriteState
{ _transformCounter :: {-# UNPACK #-} !Int
, _bindings :: !BindingMap
, _uniqSupply :: !Supply
, _curFun :: (Id,SrcSpan)
, _nameCounter :: {-# UNPACK #-} !Int
, _globalHeap :: PrimHeap
, _workFreeBinders :: VarEnv Bool
, _extra :: !extra
}
makeLenses ''RewriteState
data RewriteEnv
= RewriteEnv
{ _dbgLevel :: DebugLevel
, _dbgTransformations :: Set.Set String
, _dbgTransformationsFrom :: Int
, _dbgTransformationsLimit :: Int
, _aggressiveXOpt :: Bool
, _typeTranslator :: CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
, _tcCache :: TyConMap
, _tupleTcCache :: IntMap TyConName
, _evaluator :: (PrimStep, PrimUnwind)
, _topEntities :: VarSet
, _customReprs :: CustomReprs
}
makeLenses ''RewriteEnv
newtype RewriteMonad extra a = R
{ unR :: RewriteEnv -> RewriteState extra -> Any -> (a,RewriteState extra,Any) }
runR
:: RewriteMonad extra a
-> RewriteEnv
-> RewriteState extra
-> (a, RewriteState extra, Any)
runR m r s = unR m r s mempty
instance MonadFail (RewriteMonad extra) where
fail err = error ("RewriteMonad.fail: " ++ err)
instance Functor (RewriteMonad extra) where
fmap f m = R $ \ r s w -> case unR m r s w of (a, s', w') -> (f a, s', w')
{-# INLINE fmap #-}
instance Applicative (RewriteMonad extra) where
pure a = R $ \ _ s w -> (a, s, w)
{-# INLINE pure #-}
R mf <*> R mx = R $ \ r s w -> case mf r s w of
(f,s',w') -> case mx r s' w' of
(x,s'',w'') -> (f x, s'', w'')
{-# INLINE (<*>) #-}
instance Monad (RewriteMonad extra) where
return a = R $ \ _ s w -> (a, s, w)
{-# INLINE return #-}
m >>= k =
R $ \ r s w -> case unR m r s w of
(a,s',w') -> unR (k a) r s' w'
{-# INLINE (>>=) #-}
instance MonadState (RewriteState extra) (RewriteMonad extra) where
get = R $ \_ s w -> (s,s,w)
{-# INLINE get #-}
put s = R $ \_ _ w -> ((),s,w)
{-# INLINE put #-}
state f = R $ \_ s w -> case f s of (a,s') -> (a,s',w)
{-# INLINE state #-}
instance MonadUnique (RewriteMonad extra) where
getUniqueM = do
sup <- use uniqSupply
let (a,sup') = freshId sup
uniqSupply .= sup'
a `seq` return a
instance MonadWriter Any (RewriteMonad extra) where
writer (a,w') = R $ \_ s w -> let wt = w `mappend` w' in wt `seq` (a,s,wt)
{-# INLINE writer #-}
tell w' = R $ \_ s w -> let wt = w `mappend` w' in wt `seq` ((),s,wt)
{-# INLINE tell #-}
listen m = R $ \r s w -> case runR m r s of
(a,s',w') -> let wt = w `mappend` w' in wt `seq` ((a,w'),s',wt)
{-# INLINE listen #-}
pass m = R $ \r s w -> case runR m r s of
((a,f),s',w') -> let wt = w `mappend` f w' in wt `seq` (a, s', wt)
{-# INLINE pass #-}
censor :: (Any -> Any) -> RewriteMonad extra a -> RewriteMonad extra a
censor f m = R $ \r s w -> case runR m r s of
(a,s',w') -> let wt = w `mappend` f w' in wt `seq` (a, s', wt)
{-# INLINE censor #-}
instance MonadReader RewriteEnv (RewriteMonad extra) where
ask = R $ \r s w -> (r,s,w)
{-# INLINE ask #-}
local f m = R $ \r s w -> unR m (f r) s w
{-# INLINE local #-}
reader f = R $ \r s w -> (f r,s,w)
{-# INLINE reader #-}
instance MonadFix (RewriteMonad extra) where
mfix f = R $ \r s w -> fix $ \ ~(a,_,_) -> unR (f a) r s w
{-# INLINE mfix #-}
data TransformContext
= TransformContext
{ tfInScope :: !InScopeSet
, tfContext :: Context
}
type Transform m = TransformContext -> Term -> m Term
type Rewrite extra = Transform (RewriteMonad extra)