{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations.Cast
( argCastSpec
, caseCast
, elimCastCast
, letCast
, splitCastWork
) where
import Control.Exception (throw)
import qualified Control.Lens as Lens
import Control.Monad.Writer (listen)
import qualified Data.Monoid as Monoid (Any(..))
import GHC.Stack (HasCallStack)
import Clash.Core.Name (nameOcc)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term (LetBinding, Term(..), collectArgs, stripTicks)
import Clash.Core.TermInfo (isCast)
import Clash.Core.Type (normalizeType)
import Clash.Core.Var (isGlobalId, varName)
import Clash.Core.VarEnv (InScopeSet)
import Clash.Debug (trace)
import Clash.Normalize.Transformations.Specialize (specialize)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Types
(TransformContext(..), bindings, curFun, tcCache, workFreeBinders)
import Clash.Rewrite.Util (changed, mkDerivedName, mkTmBinderFor)
import Clash.Rewrite.WorkFree (isWorkFree)
import Clash.Util (ClashException(..), curLoc)
argCastSpec :: HasCallStack => NormRewrite
argCastSpec :: NormRewrite
argCastSpec TransformContext
ctx e :: Term
e@(App Term
f (Term -> Term
stripTicks -> Cast Term
e' Type
_ Type
_))
| Bool -> Bool
not (Term -> Bool
isCast Term
e')
, (Var Id
g, [Either Term Type]
_) <- Term -> (Term, [Either Term Type])
collectArgs Term
f
, Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
g = do
BindingMap
bndrs <- Getting BindingMap (RewriteState NormalizeState) BindingMap
-> RewriteMonad NormalizeState BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
Lens' (RewriteState NormalizeState) (VarEnv Bool)
-> BindingMap -> Term -> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree forall extra. Lens' (RewriteState extra) (VarEnv Bool)
Lens' (RewriteState NormalizeState) (VarEnv Bool)
workFreeBinders BindingMap
bndrs Term
e' RewriteMonad NormalizeState Bool
-> (Bool -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> RewriteMonad NormalizeState Term
go
Bool
False -> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. a -> a
warn RewriteMonad NormalizeState Term
go
where
go :: RewriteMonad NormalizeState Term
go = NormRewrite
specialize TransformContext
ctx Term
e
warn :: a -> a
warn = String -> a -> a
forall a. String -> a -> a
trace ([String] -> String
unwords
[ String
"WARNING:", $(String
curLoc), String
"specializing a function on a non work-free"
, String
"cast. Generated HDL implementation might contain duplicate work."
, String
"Please report this as a bug.", String
"\n\nExpression where this occured:"
, String
"\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
e
])
argCastSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC argCastSpec #-}
caseCast :: HasCallStack => NormRewrite
caseCast :: NormRewrite
caseCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Case Term
subj Type
ty [Alt]
alts) Type
ty1 Type
ty2) = do
let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map (\(Pat
p,Term
e) -> (Pat
p, Term -> Type -> Type -> Term
Cast Term
e Type
ty1 Type
ty2)) [Alt]
alts
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty [Alt]
alts')
caseCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseCast #-}
elimCastCast :: HasCallStack => NormRewrite
elimCastCast :: NormRewrite
elimCastCast TransformContext
_ c :: Term
c@(Cast (Term -> Term
stripTicks -> Cast Term
e Type
tyA Type
tyB) Type
tyB' Type
tyC) = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
let ntyA :: Type
ntyA = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyA
ntyB :: Type
ntyB = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB
ntyB' :: Type
ntyB' = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB'
ntyC :: Type
ntyC = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyC
if Type
ntyB Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ntyB' Bool -> Bool -> Bool
&& Type
ntyA Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ntyC then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
else RewriteMonad NormalizeState Term
forall b. RewriteMonad NormalizeState b
throwError
where throwError :: RewriteMonad NormalizeState b
throwError = do
(Id
nm,SrcSpan
sp) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
ClashException -> RewriteMonad NormalizeState b
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
sp ($(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
nm
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": Found 2 nested casts whose types don't line up:\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c)
Maybe String
forall a. Maybe a
Nothing)
elimCastCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC elimCastCast #-}
letCast :: HasCallStack => NormRewrite
letCast :: NormRewrite
letCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Let Bind Term
binds Term
body) Type
ty1 Type
ty2) =
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Bind Term -> Term -> Term
Let Bind Term
binds (Term -> Type -> Type -> Term
Cast Term
body Type
ty1 Type
ty2)
letCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC letCast #-}
splitCastWork :: HasCallStack => NormRewrite
splitCastWork :: NormRewrite
splitCastWork ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) unchanged :: Term
unchanged@(Letrec [LetBinding]
vs Term
e') = do
([[LetBinding]]
vss', Any -> Bool
Monoid.getAny -> Bool
hasChanged) <- RewriteMonad NormalizeState [[LetBinding]]
-> RewriteMonad NormalizeState ([[LetBinding]], Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen ((LetBinding -> RewriteMonad NormalizeState [LetBinding])
-> [LetBinding] -> RewriteMonad NormalizeState [[LetBinding]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
is0) [LetBinding]
vs)
let vs' :: [LetBinding]
vs' = [[LetBinding]] -> [LetBinding]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[LetBinding]]
vss'
if Bool
hasChanged then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
vs' Term
e')
else Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
unchanged
where
splitCastLetBinding
:: InScopeSet
-> LetBinding
-> NormalizeSession [LetBinding]
splitCastLetBinding :: InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
isN x :: LetBinding
x@(Id
nm, Term
e) = case Term -> Term
stripTicks Term
e of
Cast (Var {}) Type
_ Type
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]
Cast (Cast {}) Type
_ Type
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]
Cast Term
e0 Type
ty1 Type
ty2 -> do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
Id
nm' <- InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
isN TyConMap
tcm (TransformContext -> OccName -> Name Term
mkDerivedName TransformContext
ctx (Name Term -> OccName
forall a. Name a -> OccName
nameOcc (Name Term -> OccName) -> Name Term -> OccName
forall a b. (a -> b) -> a -> b
$ Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)) Term
e0
[LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed [(Id
nm',Term
e0)
,(Id
nm, Term -> Type -> Type -> Term
Cast (Id -> Term
Var Id
nm') Type
ty1 Type
ty2)
]
Term
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]
splitCastWork TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC splitCastWork #-}