{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016-2017, Myrtle Software Ltd,
                    2017-2018, Google Inc.
                    2022     , QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  Transformations on primitives with multiple results.
-}

{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}

module Clash.Normalize.Transformations.MultiPrim
  ( setupMultiResultPrim
  ) where

import qualified Control.Lens as Lens
import qualified Data.Either as Either
import Data.Text.Extra (showt)
import GHC.Stack (HasCallStack)

import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.Name (mkUnsafeInternalName)
import Clash.Core.Term
  ( IsMultiPrim(..), MultiPrimInfo(..), PrimInfo(..), Term(..), WorkInfo(..)
  , mkAbstraction, mkApps, mkTmApps, mkTyApps, PrimUnfolding(..))
import Clash.Core.TermInfo (multiPrimInfo')
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (Type(..), mkPolyFunTy, splitFunForallTy)
import Clash.Core.Util (listToLets)
import Clash.Core.Var (mkLocalId)
import Clash.Normalize.Types (NormRewrite)
import Clash.Primitives.Types (Primitive(..))
import Clash.Rewrite.Types (tcCache, primitives)
import Clash.Rewrite.Util (changed)

-- Note [MultiResult type]
--
-- A multi result primitive assigns its results to multiple result variables
-- instead of one. Besides producing nicer HDL it works around issues with
-- synthesis tooling described in:
--
--   https://github.com/clash-lang/clash-compiler/issues/1555
--
-- This transformation rewrites primitives indicating they can assign their
-- results to multiple signals, such that netlist can easily render it. This
-- involves inserting additional arguments for each of the result values, and
-- then using the c$multiPrimSelect primitive to select individual results.
--
-- Example:
--
-- @
-- prim :: forall a b c. a -> (b, c)
-- @
--
-- will be rewritten to:
--
-- @
--   \(x :: a) ->
--         let
--            r  = prim @a @b @c x r0 r1 -- With 'Clash.Core.Term.MultiPrim'
--            r0 = c$multiPrimSelect r0 r
--            r1 = c$multiPrimSelect r1 r
--          in
--            (r0, r1)
-- @
--
-- Netlist will not render any @multiPrimSelect@ primitives. Similar to
-- primitives having a /void/ return type, /r/ is not rendered either.
--
-- This transformation is currently hardcoded to recognize tuples as return
-- types, not any product type. It will error if it sees a multi result primitive
-- with a non-tuple return type.
--
setupMultiResultPrim :: HasCallStack => NormRewrite
setupMultiResultPrim :: NormRewrite
setupMultiResultPrim TransformContext
_ctx e :: Term
e@(Prim pInfo :: PrimInfo
pInfo@PrimInfo{primMultiResult :: PrimInfo -> IsMultiPrim
primMultiResult=IsMultiPrim
SingleResult}) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  Maybe
  (PrimitiveGuard
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim <- Getting
  (Maybe
     (PrimitiveGuard
        (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
  RewriteEnv
  (Maybe
     (PrimitiveGuard
        (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
-> RewriteMonad
     NormalizeState
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view ((CompiledPrimMap
 -> Const
      (Maybe
         (PrimitiveGuard
            (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
      CompiledPrimMap)
-> RewriteEnv
-> Const
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
     RewriteEnv
Getter RewriteEnv CompiledPrimMap
primitives ((CompiledPrimMap
  -> Const
       (Maybe
          (PrimitiveGuard
             (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
       CompiledPrimMap)
 -> RewriteEnv
 -> Const
      (Maybe
         (PrimitiveGuard
            (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
      RewriteEnv)
-> ((Maybe
       (PrimitiveGuard
          (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
     -> Const
          (Maybe
             (PrimitiveGuard
                (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
          (Maybe
             (PrimitiveGuard
                (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))))
    -> CompiledPrimMap
    -> Const
         (Maybe
            (PrimitiveGuard
               (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
         CompiledPrimMap)
-> Getting
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
     RewriteEnv
     (Maybe
        (PrimitiveGuard
           (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index CompiledPrimMap
-> Lens' CompiledPrimMap (Maybe (IxValue CompiledPrimMap))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
Lens.at (PrimInfo -> Text
primName PrimInfo
pInfo))

  case Maybe
  (PrimitiveGuard
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim Maybe
  (PrimitiveGuard
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> (PrimitiveGuard
      (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
    -> Maybe
         (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> Maybe
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= PrimitiveGuard
  (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
-> Maybe
     (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall a. PrimitiveGuard a -> Maybe a
extractPrim of
    Just (BlackBoxHaskell{multiResult :: forall a b c d. Primitive a b c d -> Bool
multiResult=Bool
True}) ->
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => TyConMap -> PrimInfo -> Term
TyConMap -> PrimInfo -> Term
setupMultiResultPrim' TyConMap
tcm PrimInfo
pInfo)
    Just (BlackBox{multiResult :: forall a b c d. Primitive a b c d -> Bool
multiResult=Bool
True}) ->
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => TyConMap -> PrimInfo -> Term
TyConMap -> PrimInfo -> Term
setupMultiResultPrim' TyConMap
tcm PrimInfo
pInfo)
    Maybe
  (Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
_ ->
      Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

setupMultiResultPrim TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

setupMultiResultPrim' :: HasCallStack => TyConMap -> PrimInfo -> Term
setupMultiResultPrim' :: TyConMap -> PrimInfo -> Term
setupMultiResultPrim' TyConMap
tcm primInfo :: PrimInfo
primInfo@PrimInfo{Type
primType :: PrimInfo -> Type
primType :: Type
primType} =
  Term -> [Either Id TyVar] -> Term
mkAbstraction Term
letTerm ((TyVar -> Either Id TyVar) -> [TyVar] -> [Either Id TyVar]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right [TyVar]
typeVars [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. Semigroup a => a -> a -> a
<> (Id -> Either Id TyVar) -> [Id] -> [Either Id TyVar]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Either Id TyVar
forall a b. a -> Either a b
Left [Id]
argIds)
 where
  typeVars :: [TyVar]
typeVars = [Either TyVar Type] -> [TyVar]
forall a b. [Either a b] -> [a]
Either.lefts [Either TyVar Type]
pArgs

  internalNm :: Text -> Int -> Name a
internalNm Text
prefix Int
n = Text -> Int -> Name a
forall a. Text -> Int -> Name a
mkUnsafeInternalName (Text
prefix Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showt Int
n) Int
n
  internalId :: Text -> Type -> Int -> Id
internalId Text
prefix Type
typ Int
n = Type -> TmName -> Id
mkLocalId Type
typ (Text -> Int -> TmName
forall a. Text -> Int -> Name a
internalNm Text
prefix Int
n)

  nTermArgs :: Int
nTermArgs = [Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ([Either TyVar Type] -> [Type]
forall a b. [Either a b] -> [b]
Either.rights [Either TyVar Type]
pArgs)
  argIds :: [Id]
argIds = (Type -> Int -> Id) -> [Type] -> [Int] -> [Id]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Text -> Type -> Int -> Id
internalId Text
"a") ([Either TyVar Type] -> [Type]
forall a b. [Either a b] -> [b]
Either.rights [Either TyVar Type]
pArgs) [Int
1..Int
nTermArgs]
  resIds :: [Id]
resIds = (Type -> Int -> Id) -> [Type] -> [Int] -> [Id]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Text -> Type -> Int -> Id
internalId Text
"r") [Type]
resTypes [Int
nTermArgsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nTermArgsInt -> Int -> Int
forall a. Num a => a -> a -> a
+[Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Type]
resTypes]
  resId :: Id
resId = Type -> TmName -> Id
mkLocalId Type
pResTy (Text -> Int -> TmName
forall a. Text -> Int -> Name a
mkUnsafeInternalName Text
"r" (Int
nTermArgsInt -> Int -> Int
forall a. Num a => a -> a -> a
+[Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Type]
resTypesInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))

  ([Either TyVar Type]
pArgs, Type
pResTy) = Type -> ([Either TyVar Type], Type)
splitFunForallTy Type
primType
  MultiPrimInfo{mpi_resultDc :: MultiPrimInfo -> DataCon
mpi_resultDc=DataCon
tupTc, mpi_resultTypes :: MultiPrimInfo -> [Type]
mpi_resultTypes=[Type]
resTypes} =
    HasCallStack => TyConMap -> PrimInfo -> MultiPrimInfo
TyConMap -> PrimInfo -> MultiPrimInfo
multiPrimInfo' TyConMap
tcm PrimInfo
primInfo

  multiPrimSelect :: Id -> Type -> (Id, Term)
multiPrimSelect Id
r Type
t = (Id
r, Term -> [Term] -> Term
mkTmApps (PrimInfo -> Term
Prim (Type -> PrimInfo
multiPrimSelectInfo Type
t)) [Id -> Term
Var Id
r, Id -> Term
Var Id
resId])
  multiPrimSelectBinds :: [(Id, Term)]
multiPrimSelectBinds = (Id -> Type -> (Id, Term)) -> [Id] -> [Type] -> [(Id, Term)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Id -> Type -> (Id, Term)
multiPrimSelect  [Id]
resIds [Type]
resTypes
  multiPrimTermArgs :: [Either Term b]
multiPrimTermArgs = (Id -> Either Term b) -> [Id] -> [Either Term b]
forall a b. (a -> b) -> [a] -> [b]
map (Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> (Id -> Term) -> Id -> Either Term b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Term
Var) ([Id]
argIds [Id] -> [Id] -> [Id]
forall a. Semigroup a => a -> a -> a
<> [Id]
resIds)
  multiPrimTypeArgs :: [Either a Type]
multiPrimTypeArgs = (TyVar -> Either a Type) -> [TyVar] -> [Either a Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Either a Type
forall a b. b -> Either a b
Right (Type -> Either a Type)
-> (TyVar -> Type) -> TyVar -> Either a Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Type
VarTy) [TyVar]
typeVars
  multiPrimBind :: Term
multiPrimBind =
    Term -> [Either Term Type] -> Term
mkApps
      (PrimInfo -> Term
Prim PrimInfo
primInfo{primMultiResult :: IsMultiPrim
primMultiResult=IsMultiPrim
MultiResult})
      ([Either Term Type]
forall a. [Either a Type]
multiPrimTypeArgs [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. Semigroup a => a -> a -> a
<> [Either Term Type]
forall b. [Either Term b]
multiPrimTermArgs)

  multiPrimSelectInfo :: Type -> PrimInfo
multiPrimSelectInfo Type
t = PrimInfo :: Text
-> Type -> WorkInfo -> IsMultiPrim -> PrimUnfolding -> PrimInfo
PrimInfo
    { primName :: Text
primName = Text
"c$multiPrimSelect"
    , primType :: Type
primType = Type -> [Either TyVar Type] -> Type
mkPolyFunTy Type
pResTy [Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
pResTy, Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
t]
    , primWorkInfo :: WorkInfo
primWorkInfo = WorkInfo
WorkAlways
    , primMultiResult :: IsMultiPrim
primMultiResult = IsMultiPrim
SingleResult
    , primUnfolding :: PrimUnfolding
primUnfolding = PrimUnfolding
NoUnfolding
    }

  letTerm :: Term
letTerm =
    [(Id, Term)] -> Term -> Term
listToLets
      ((Id
resId,Term
multiPrimBind)(Id, Term) -> [(Id, Term)] -> [(Id, Term)]
forall a. a -> [a] -> [a]
:[(Id, Term)]
multiPrimSelectBinds)
      (Term -> [Term] -> Term
mkTmApps (Term -> [Type] -> Term
mkTyApps (DataCon -> Term
Data DataCon
tupTc) [Type]
resTypes) ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
resIds))