{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016-2017, Myrtle Software Ltd,
                    2017-2018, Google Inc.,
                    2021-2022, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  The X-optimization transformation.
-}

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Normalize.Transformations.XOptimize
  ( xOptimize
  ) where

import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import qualified Data.List.Extra as List
import qualified Data.Text.Extra as Text (showt)
import GHC.Stack (HasCallStack)

import Clash.XException (errorX)

import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.DataCon (DataCon)
import Clash.Core.HasType
import Clash.Core.Term
  ( Alt, IsMultiPrim(..), LetBinding, Pat(..), PrimInfo(..), Term(..)
  , WorkInfo(..), collectArgs, PrimUnfolding(..))
import Clash.Core.Type (TyVar, Type)
import Clash.Core.Util (mkInternalVar)
import Clash.Core.Var (Id)
import Clash.Core.VarEnv (InScopeSet)
import Clash.Netlist.BlackBox.Types (Element(Err))
import Clash.Netlist.Types (BlackBox(..))
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Primitives.Types (Primitive(..))
import Clash.Rewrite.Types
  (TransformContext(..), aggressiveXOpt, tcCache, primitives)
import Clash.Rewrite.Util (changed)
import Clash.Util (MonadUnique, curLoc)

-- | Remove all undefined alternatives from case expressions, replacing them
-- with the value of another defined alternative. If there is one defined
-- alternative, the entire expression is replaced with that alternative. If
-- there are no defined alternatives, the entire expression is replaced with
-- a call to 'errorX'.
--
-- e.g. It converts
--
--     case x of
--       D1 a -> f a
--       D2   -> undefined
--       D3   -> undefined
--
-- to
--
--     let subj = x
--         a    = case subj of
--                  D1 a -> field0
--      in f a
--
-- where fieldN is an internal variable referring to the nth argument of a
-- data constructor.
--
xOptimize :: HasCallStack => NormRewrite
xOptimize :: NormRewrite
xOptimize (TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Case Term
subj Type
ty [Alt]
alts) = do
  Bool
runXOpt <- Getting Bool RewriteEnv Bool -> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting Bool RewriteEnv Bool
Getter RewriteEnv Bool
aggressiveXOpt

  if Bool
runXOpt then do
    ([Alt], [Alt])
defPart <- (Alt -> RewriteMonad NormalizeState Bool)
-> [Alt] -> RewriteMonad NormalizeState ([Alt], [Alt])
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
List.partitionM (Term -> RewriteMonad NormalizeState Bool
isPrimError (Term -> RewriteMonad NormalizeState Bool)
-> (Alt -> Term) -> Alt -> RewriteMonad NormalizeState Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Term
forall a b. (a, b) -> b
snd) [Alt]
alts

    case ([Alt], [Alt])
defPart of
      ([], [Alt]
_)    -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
      ([Alt]
_, [])    -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (PrimInfo -> Term
Prim (Text
-> Type -> WorkInfo -> IsMultiPrim -> PrimUnfolding -> PrimInfo
PrimInfo (Name -> Text
forall a. Show a => a -> Text
Text.showt 'errorX) Type
ty WorkInfo
WorkConstant IsMultiPrim
SingleResult PrimUnfolding
NoUnfolding))
      ([Alt]
_, [Alt
alt]) -> InScopeSet -> Term -> Alt -> RewriteMonad NormalizeState Term
xOptimizeSingle InScopeSet
is0 Term
subj Alt
alt
      ([Alt]
_, [Alt]
defs)  -> HasCallStack =>
InScopeSet
-> Term -> Type -> [Alt] -> RewriteMonad NormalizeState Term
InScopeSet
-> Term -> Type -> [Alt] -> RewriteMonad NormalizeState Term
xOptimizeMany InScopeSet
is0 Term
subj Type
ty [Alt]
defs
  else
    Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

xOptimize TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC xOptimize #-}

-- Return an expression equivalent to the alternative given. When only one
-- alternative is defined the result of this function is used to replace the
-- case expression.
--
xOptimizeSingle :: InScopeSet -> Term -> Alt -> NormalizeSession Term
xOptimizeSingle :: InScopeSet -> Term -> Alt -> RewriteMonad NormalizeState Term
xOptimizeSingle InScopeSet
is Term
subj (DataPat DataCon
dc [TyVar]
tvs [Id]
vars, Term
expr) = do
  TyConMap
tcm    <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  Id
subjId <- InScopeSet -> Text -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Text -> Type -> m Id
mkInternalVar InScopeSet
is Text
"subj" (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
subj)

  let fieldTys :: [Type]
fieldTys = (Id -> Type) -> [Id] -> [Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Id -> Type
forall a. HasType a => a -> Type
coreTypeOf [Id]
vars
  [LetBinding]
lets <- (Id -> Int -> RewriteMonad NormalizeState LetBinding)
-> [Id] -> [Int] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM (InScopeSet
-> Id
-> DataCon
-> [TyVar]
-> [Type]
-> Id
-> Int
-> RewriteMonad NormalizeState LetBinding
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet
-> Id -> DataCon -> [TyVar] -> [Type] -> Id -> Int -> m LetBinding
mkFieldSelector InScopeSet
is Id
subjId DataCon
dc [TyVar]
tvs [Type]
fieldTys) [Id]
vars [Int
0..]

  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec ((Id
subjId, Term
subj) LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
: [LetBinding]
lets) Term
expr)

xOptimizeSingle InScopeSet
_ Term
_ (Pat
_, Term
expr) = Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
expr

-- Given a list of alternatives which are defined, create a new case
-- expression which only ever returns a defined value.
--
xOptimizeMany
  :: HasCallStack
  => InScopeSet
  -> Term
  -> Type
  -> [Alt]
  -> NormalizeSession Term
xOptimizeMany :: InScopeSet
-> Term -> Type -> [Alt] -> RewriteMonad NormalizeState Term
xOptimizeMany InScopeSet
is Term
subj Type
ty defs :: [Alt]
defs@(Alt
d:[Alt]
ds)
  | [Alt] -> Bool
isAnyDefault [Alt]
defs = Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty [Alt]
defs)
  | Bool
otherwise = do
      Term
newAlt <- InScopeSet -> Term -> Alt -> RewriteMonad NormalizeState Term
xOptimizeSingle InScopeSet
is Term
subj Alt
d
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty ([Alt] -> Term) -> [Alt] -> Term
forall a b. (a -> b) -> a -> b
$ [Alt]
ds [Alt] -> [Alt] -> [Alt]
forall a. Semigroup a => a -> a -> a
<> [(Pat
DefaultPat, Term
newAlt)])
 where
  isAnyDefault :: [Alt] -> Bool
  isAnyDefault :: [Alt] -> Bool
isAnyDefault = (Alt -> Bool) -> [Alt] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any ((Pat -> Pat -> Bool
forall a. Eq a => a -> a -> Bool
== Pat
DefaultPat) (Pat -> Bool) -> (Alt -> Pat) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst)

xOptimizeMany InScopeSet
_ Term
_ Type
_ [] =
  [Char] -> RewriteMonad NormalizeState Term
forall a. HasCallStack => [Char] -> a
error ([Char] -> RewriteMonad NormalizeState Term)
-> [Char] -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ $([Char]
curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"Report as bug: xOptimizeMany error: No defined alternatives"

mkFieldSelector
  :: MonadUnique m
  => InScopeSet
  -> Id
  -- ^ subject id
  -> DataCon
  -> [TyVar]
  -> [Type]
  -- ^ concrete types of fields
  -> Id
  -> Int
  -> m LetBinding
mkFieldSelector :: InScopeSet
-> Id -> DataCon -> [TyVar] -> [Type] -> Id -> Int -> m LetBinding
mkFieldSelector InScopeSet
is0 Id
subj DataCon
dc [TyVar]
tvs [Type]
fieldTys Id
nm Int
index = do
  [Id]
fields <- (Type -> m Id) -> [Type] -> m [Id]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
ty -> InScopeSet -> Text -> Type -> m Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Text -> Type -> m Id
mkInternalVar InScopeSet
is0 Text
"field" Type
ty) [Type]
fieldTys
  let alt :: Alt
alt = (DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc [TyVar]
tvs [Id]
fields, Id -> Term
Var (Id -> Term) -> Id -> Term
forall a b. (a -> b) -> a -> b
$ [Id]
fields [Id] -> Int -> Id
forall a. [a] -> Int -> a
!! Int
index)
  LetBinding -> m LetBinding
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
nm, Term -> Type -> [Alt] -> Term
Case (Id -> Term
Var Id
subj) ([Type]
fieldTys [Type] -> Int -> Type
forall a. [a] -> Int -> a
!! Int
index) [Alt
alt])

-- Check whether a term is really a black box primitive representing an error.
-- Such values are undefined and are removed in X Optimization.
--
isPrimError :: Term -> NormalizeSession Bool
isPrimError :: Term -> RewriteMonad NormalizeState Bool
isPrimError (Term -> (Term, [Either Term Type])
collectArgs -> (Prim PrimInfo
pInfo, [Either Term Type]
_)) = do
  Maybe
  (PrimitiveGuard
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim <- Getting
  (Maybe
     (PrimitiveGuard
        (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
  RewriteEnv
  (Maybe
     (PrimitiveGuard
        (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
-> RewriteMonad
     NormalizeState
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view ((CompiledPrimMap
 -> Const
      (Maybe
         (PrimitiveGuard
            (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
      CompiledPrimMap)
-> RewriteEnv
-> Const
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
     RewriteEnv
Getter RewriteEnv CompiledPrimMap
primitives ((CompiledPrimMap
  -> Const
       (Maybe
          (PrimitiveGuard
             (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
       CompiledPrimMap)
 -> RewriteEnv
 -> Const
      (Maybe
         (PrimitiveGuard
            (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
      RewriteEnv)
-> ((Maybe
       (PrimitiveGuard
          (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
     -> Const
          (Maybe
             (PrimitiveGuard
                (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
          (Maybe
             (PrimitiveGuard
                (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))))
    -> CompiledPrimMap
    -> Const
         (Maybe
            (PrimitiveGuard
               (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
         CompiledPrimMap)
-> Getting
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
     RewriteEnv
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index CompiledPrimMap
-> Lens' CompiledPrimMap (Maybe (IxValue CompiledPrimMap))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
Lens.at (PrimInfo -> Text
primName PrimInfo
pInfo))

  case Maybe
  (PrimitiveGuard
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim Maybe
  (PrimitiveGuard
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> (PrimitiveGuard
      (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
    -> Maybe
         (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> Maybe
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= PrimitiveGuard
  (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
-> Maybe
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall a. PrimitiveGuard a -> Maybe a
extractPrim of
    Just Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)
p  -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)
-> Bool
forall a c d. Primitive a BlackBox c d -> Bool
isErr Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)
p)
    Maybe
  (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
Nothing -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False
 where
  isErr :: Primitive a BlackBox c d -> Bool
isErr BlackBox{template :: forall a b c d. Primitive a b c d -> b
template=(BBTemplate [Err Maybe Int
_])} = Bool
True
  isErr Primitive a BlackBox c d
_ = Bool
False

isPrimError Term
_ = Bool -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False