{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations.XOptimize
( xOptimize
) where
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import qualified Data.List.Extra as List
import qualified Data.Text.Extra as Text (showt)
import GHC.Stack (HasCallStack)
import Clash.XException (errorX)
import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.DataCon (DataCon)
import Clash.Core.HasType
import Clash.Core.Term
( Alt, IsMultiPrim(..), LetBinding, Pat(..), PrimInfo(..), Term(..)
, WorkInfo(..), collectArgs, PrimUnfolding(..))
import Clash.Core.Type (TyVar, Type)
import Clash.Core.Util (mkInternalVar)
import Clash.Core.Var (Id)
import Clash.Core.VarEnv (InScopeSet)
import Clash.Netlist.BlackBox.Types (Element(Err))
import Clash.Netlist.Types (BlackBox(..))
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Primitives.Types (Primitive(..))
import Clash.Rewrite.Types
(TransformContext(..), aggressiveXOpt, tcCache, primitives)
import Clash.Rewrite.Util (changed)
import Clash.Util (MonadUnique, curLoc)
xOptimize :: HasCallStack => NormRewrite
xOptimize :: NormRewrite
xOptimize (TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Case Term
subj Type
ty [Alt]
alts) = do
Bool
runXOpt <- Getting Bool RewriteEnv Bool -> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting Bool RewriteEnv Bool
Getter RewriteEnv Bool
aggressiveXOpt
if Bool
runXOpt then do
([Alt], [Alt])
defPart <- (Alt -> RewriteMonad NormalizeState Bool)
-> [Alt] -> RewriteMonad NormalizeState ([Alt], [Alt])
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
List.partitionM (Term -> RewriteMonad NormalizeState Bool
isPrimError (Term -> RewriteMonad NormalizeState Bool)
-> (Alt -> Term) -> Alt -> RewriteMonad NormalizeState Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Term
forall a b. (a, b) -> b
snd) [Alt]
alts
case ([Alt], [Alt])
defPart of
([], [Alt]
_) -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
([Alt]
_, []) -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (PrimInfo -> Term
Prim (Text
-> Type -> WorkInfo -> IsMultiPrim -> PrimUnfolding -> PrimInfo
PrimInfo (Name -> Text
forall a. Show a => a -> Text
Text.showt 'errorX) Type
ty WorkInfo
WorkConstant IsMultiPrim
SingleResult PrimUnfolding
NoUnfolding))
([Alt]
_, [Alt
alt]) -> InScopeSet -> Term -> Alt -> RewriteMonad NormalizeState Term
xOptimizeSingle InScopeSet
is0 Term
subj Alt
alt
([Alt]
_, [Alt]
defs) -> HasCallStack =>
InScopeSet
-> Term -> Type -> [Alt] -> RewriteMonad NormalizeState Term
InScopeSet
-> Term -> Type -> [Alt] -> RewriteMonad NormalizeState Term
xOptimizeMany InScopeSet
is0 Term
subj Type
ty [Alt]
defs
else
Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
xOptimize TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC xOptimize #-}
xOptimizeSingle :: InScopeSet -> Term -> Alt -> NormalizeSession Term
xOptimizeSingle :: InScopeSet -> Term -> Alt -> RewriteMonad NormalizeState Term
xOptimizeSingle InScopeSet
is Term
subj (DataPat DataCon
dc [TyVar]
tvs [Id]
vars, Term
expr) = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
Id
subjId <- InScopeSet -> Text -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Text -> Type -> m Id
mkInternalVar InScopeSet
is Text
"subj" (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
subj)
let fieldTys :: [Type]
fieldTys = (Id -> Type) -> [Id] -> [Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Id -> Type
forall a. HasType a => a -> Type
coreTypeOf [Id]
vars
[LetBinding]
lets <- (Id -> Int -> RewriteMonad NormalizeState LetBinding)
-> [Id] -> [Int] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM (InScopeSet
-> Id
-> DataCon
-> [TyVar]
-> [Type]
-> Id
-> Int
-> RewriteMonad NormalizeState LetBinding
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet
-> Id -> DataCon -> [TyVar] -> [Type] -> Id -> Int -> m LetBinding
mkFieldSelector InScopeSet
is Id
subjId DataCon
dc [TyVar]
tvs [Type]
fieldTys) [Id]
vars [Int
0..]
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec ((Id
subjId, Term
subj) LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
: [LetBinding]
lets) Term
expr)
xOptimizeSingle InScopeSet
_ Term
_ (Pat
_, Term
expr) = Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
expr
xOptimizeMany
:: HasCallStack
=> InScopeSet
-> Term
-> Type
-> [Alt]
-> NormalizeSession Term
xOptimizeMany :: InScopeSet
-> Term -> Type -> [Alt] -> RewriteMonad NormalizeState Term
xOptimizeMany InScopeSet
is Term
subj Type
ty defs :: [Alt]
defs@(Alt
d:[Alt]
ds)
| [Alt] -> Bool
isAnyDefault [Alt]
defs = Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty [Alt]
defs)
| Bool
otherwise = do
Term
newAlt <- InScopeSet -> Term -> Alt -> RewriteMonad NormalizeState Term
xOptimizeSingle InScopeSet
is Term
subj Alt
d
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty ([Alt] -> Term) -> [Alt] -> Term
forall a b. (a -> b) -> a -> b
$ [Alt]
ds [Alt] -> [Alt] -> [Alt]
forall a. Semigroup a => a -> a -> a
<> [(Pat
DefaultPat, Term
newAlt)])
where
isAnyDefault :: [Alt] -> Bool
isAnyDefault :: [Alt] -> Bool
isAnyDefault = (Alt -> Bool) -> [Alt] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any ((Pat -> Pat -> Bool
forall a. Eq a => a -> a -> Bool
== Pat
DefaultPat) (Pat -> Bool) -> (Alt -> Pat) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst)
xOptimizeMany InScopeSet
_ Term
_ Type
_ [] =
[Char] -> RewriteMonad NormalizeState Term
forall a. HasCallStack => [Char] -> a
error ([Char] -> RewriteMonad NormalizeState Term)
-> [Char] -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ $([Char]
curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"Report as bug: xOptimizeMany error: No defined alternatives"
mkFieldSelector
:: MonadUnique m
=> InScopeSet
-> Id
-> DataCon
-> [TyVar]
-> [Type]
-> Id
-> Int
-> m LetBinding
mkFieldSelector :: InScopeSet
-> Id -> DataCon -> [TyVar] -> [Type] -> Id -> Int -> m LetBinding
mkFieldSelector InScopeSet
is0 Id
subj DataCon
dc [TyVar]
tvs [Type]
fieldTys Id
nm Int
index = do
[Id]
fields <- (Type -> m Id) -> [Type] -> m [Id]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
ty -> InScopeSet -> Text -> Type -> m Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Text -> Type -> m Id
mkInternalVar InScopeSet
is0 Text
"field" Type
ty) [Type]
fieldTys
let alt :: Alt
alt = (DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc [TyVar]
tvs [Id]
fields, Id -> Term
Var (Id -> Term) -> Id -> Term
forall a b. (a -> b) -> a -> b
$ [Id]
fields [Id] -> Int -> Id
forall a. [a] -> Int -> a
!! Int
index)
LetBinding -> m LetBinding
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
nm, Term -> Type -> [Alt] -> Term
Case (Id -> Term
Var Id
subj) ([Type]
fieldTys [Type] -> Int -> Type
forall a. [a] -> Int -> a
!! Int
index) [Alt
alt])
isPrimError :: Term -> NormalizeSession Bool
isPrimError :: Term -> RewriteMonad NormalizeState Bool
isPrimError (Term -> (Term, [Either Term Type])
collectArgs -> (Prim PrimInfo
pInfo, [Either Term Type]
_)) = do
Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim <- Getting
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
-> RewriteMonad
NormalizeState
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view ((CompiledPrimMap
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
CompiledPrimMap)
-> RewriteEnv
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv
Getter RewriteEnv CompiledPrimMap
primitives ((CompiledPrimMap
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
CompiledPrimMap)
-> RewriteEnv
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv)
-> ((Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))))
-> CompiledPrimMap
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
CompiledPrimMap)
-> Getting
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index CompiledPrimMap
-> Lens' CompiledPrimMap (Maybe (IxValue CompiledPrimMap))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
Lens.at (PrimInfo -> Text
primName PrimInfo
pInfo))
case Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> (PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
-> Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
-> Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall a. PrimitiveGuard a -> Maybe a
extractPrim of
Just Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)
p -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)
-> Bool
forall a c d. Primitive a BlackBox c d -> Bool
isErr Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)
p)
Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
Nothing -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False
where
isErr :: Primitive a BlackBox c d -> Bool
isErr BlackBox{template :: forall a b c d. Primitive a b c d -> b
template=(BBTemplate [Err Maybe Int
_])} = Bool
True
isErr Primitive a BlackBox c d
_ = Bool
False
isPrimError Term
_ = Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False