{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Clash.Primitives.Util
( generatePrimMap
, hashCompiledPrimMap
, constantArgs
, decodeOrErr
, getFunctionPlurality
) where
import Control.DeepSeq (force)
import Control.Monad (join)
import Data.Aeson.Extra (decodeOrErr)
import qualified Data.ByteString.Lazy as LZ
import qualified Data.HashMap.Lazy as HashMap
import qualified Data.HashMap.Strict as HashMapStrict
import qualified Data.Set as Set
import Data.Hashable (hash)
import Data.List (isSuffixOf, sort, find)
import Data.Maybe (fromMaybe)
import qualified Data.Text as TS
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy.IO as T
import GHC.Stack (HasCallStack)
import qualified System.Directory as Directory
import qualified System.FilePath as FilePath
import System.IO.Error (tryIOError)
import Clash.Annotations.Primitive
( PrimitiveGuard(HasBlackBox, WarnNonSynthesizable, WarnAlways, DontTranslate)
, extractPrim)
import Clash.Core.Term (Term)
import Clash.Core.Type (Type)
import Clash.Primitives.Types
( Primitive(BlackBox), CompiledPrimitive, ResolvedPrimitive, ResolvedPrimMap
, includes, template, TemplateSource(TFile, TInline), Primitive(..)
, UnresolvedPrimitive, CompiledPrimMap, GuardedResolvedPrimitive)
import Clash.Netlist.Types (BlackBox(..), NetlistMonad)
import Clash.Netlist.Util (preserveState)
import Clash.Netlist.BlackBox.Util
(walkElement)
import Clash.Netlist.BlackBox.Types
(Element(Const, Lit), BlackBoxMeta(..))
hashCompiledPrimitive :: CompiledPrimitive -> Int
hashCompiledPrimitive (Primitive {name, primSort}) = hash (name, primSort)
hashCompiledPrimitive (BlackBoxHaskell {function}) = fst function
hashCompiledPrimitive (BlackBox {name, kind, outputReg, libraries, imports, includes, template}) =
hash (name, kind, outputReg, libraries, imports, includes', hashBlackbox template)
where
includes' = map (\(nms, bb) -> (nms, hashBlackbox bb)) includes
hashBlackbox (BBTemplate bbTemplate) = hash bbTemplate
hashBlackbox (BBFunction bbName bbHash _bbFunc) = hash (bbName, bbHash)
hashCompiledPrimMap :: CompiledPrimMap -> Int
hashCompiledPrimMap cpm = hash (map (fmap hashCompiledPrimitive) orderedValues)
where
orderedKeys = sort (HashMap.keys cpm)
orderedValues = map (cpm HashMapStrict.!) orderedKeys
resolveTemplateSource
:: HasCallStack
=> FilePath
-> TemplateSource
-> IO Text
resolveTemplateSource _metaPath (TInline text) =
return text
resolveTemplateSource metaPath (TFile path) =
let path' = FilePath.replaceFileName metaPath path in
either (error . show) id <$> (tryIOError $ T.readFile path')
resolvePrimitive'
:: HasCallStack
=> FilePath
-> UnresolvedPrimitive
-> IO (TS.Text, GuardedResolvedPrimitive)
resolvePrimitive' _metaPath (Primitive name wf primType) =
return (name, HasBlackBox (Primitive name wf primType))
resolvePrimitive' metaPath BlackBox{template=t, includes=i, resultName=r, resultInit=ri, ..} = do
let resolveSourceM = traverse (traverse (resolveTemplateSource metaPath))
bb <- BlackBox name workInfo renderVoid kind () outputReg libraries imports functionPlurality
<$> mapM (traverse resolveSourceM) i
<*> traverse resolveSourceM r
<*> traverse resolveSourceM ri
<*> resolveSourceM t
case warning of
Just w -> pure (name, WarnNonSynthesizable (TS.unpack w) bb)
Nothing -> pure (name, HasBlackBox bb)
resolvePrimitive' metaPath (BlackBoxHaskell bbName wf usedArgs funcName t) =
(bbName,) . HasBlackBox . BlackBoxHaskell bbName wf usedArgs funcName <$>
(mapM (resolveTemplateSource metaPath) t)
resolvePrimitive
:: HasCallStack
=> FilePath
-> IO [(TS.Text, GuardedResolvedPrimitive)]
resolvePrimitive fileName = do
prims <- decodeOrErr fileName <$> LZ.readFile fileName
mapM (resolvePrimitive' fileName) prims
addGuards
:: ResolvedPrimMap
-> [(TS.Text, PrimitiveGuard ())]
-> ResolvedPrimMap
addGuards = foldl go
where
lookupPrim :: TS.Text -> ResolvedPrimMap -> Maybe ResolvedPrimitive
lookupPrim nm primMap = join (extractPrim <$> HashMapStrict.lookup nm primMap)
go primMap (nm, guard) =
HashMapStrict.insert
nm
(case (lookupPrim nm primMap, guard) of
(Nothing, HasBlackBox _) ->
error $ "No BlackBox definition for '" ++ TS.unpack nm ++ "' even"
++ " though this value was annotated with 'HasBlackBox'."
(Nothing, WarnNonSynthesizable _ _) ->
error $ "No BlackBox definition for '" ++ TS.unpack nm ++ "' even"
++ " though this value was annotated with 'WarnNonSynthesizable'"
++ ", implying it has a BlackBox."
(Nothing, WarnAlways _ _) ->
error $ "No BlackBox definition for '" ++ TS.unpack nm ++ "' even"
++ " though this value was annotated with 'WarnAlways'"
++ ", implying it has a BlackBox."
(Just _, DontTranslate) ->
error (TS.unpack nm ++ " was annotated with DontTranslate, but a "
++ "BlackBox definition was found anyway.")
(Nothing, DontTranslate) -> DontTranslate
(Just p, g) -> fmap (const p) g)
primMap
generatePrimMap
:: HasCallStack
=> [UnresolvedPrimitive]
-> [(TS.Text, PrimitiveGuard ())]
-> [FilePath]
-> IO ResolvedPrimMap
generatePrimMap unresolvedPrims primGuards filePaths = do
primitiveFiles <- fmap concat $ mapM
(\filePath -> do
fpExists <- Directory.doesDirectoryExist filePath
if fpExists
then
fmap ( map (FilePath.combine filePath)
. filter (isSuffixOf ".json")
) (Directory.getDirectoryContents filePath)
else
return []
) filePaths
primitives0 <- concat <$> mapM resolvePrimitive primitiveFiles
let metapaths = map (TS.unpack . name) unresolvedPrims
primitives1 <- sequence $ zipWith resolvePrimitive' metapaths unresolvedPrims
let primMap = HashMap.fromList (primitives0 ++ primitives1)
return (force (addGuards primMap primGuards))
{-# SCC generatePrimMap #-}
constantArgs :: TS.Text -> CompiledPrimitive -> Set.Set Int
constantArgs nm BlackBox {template = templ@(BBTemplate _), resultInit = tRIM} =
Set.fromList (concat [ fromIntForce
, maybe [] walkTemplate tRIM
, walkTemplate templ
])
where
walkTemplate (BBTemplate t) = concatMap (walkElement getConstant) t
walkTemplate _ = []
getConstant (Lit i) = Just i
getConstant (Const i) = Just i
getConstant _ = Nothing
fromIntForce
| nm == "Clash.Sized.Internal.BitVector.fromInteger#" = [2]
| nm == "Clash.Sized.Internal.BitVector.fromInteger##" = [0,1]
| nm == "Clash.Sized.Internal.Index.fromInteger#" = [1]
| nm == "Clash.Sized.Internal.Signed.fromInteger#" = [1]
| nm == "Clash.Sized.Internal.Unsigned.fromInteger#" = [1]
| nm == "Clash.Sized.Vector.index_int" = [1,2]
| nm == "Clash.Sized.Vector.replace_int" = [1,2]
| otherwise = []
constantArgs _ _ = Set.empty
getFunctionPlurality' :: [(Int, Int)] -> Int -> Int
getFunctionPlurality' functionPlurality n =
fromMaybe 1 (snd <$> (find ((== n) . fst) functionPlurality))
getFunctionPlurality
:: HasCallStack
=> CompiledPrimitive
-> [Either Term Type]
-> Type
-> Int
-> NetlistMonad Int
getFunctionPlurality (Primitive {}) _args _resTy _n = pure 1
getFunctionPlurality (BlackBoxHaskell {name, function, functionName}) args resTy n = do
errOrMeta <- preserveState ((snd function) False name args resTy)
case errOrMeta of
Left err ->
error $ concat [ "Tried to determine function plurality for "
, TS.unpack name, " by quering ", show functionName
, ". Function returned an error message instead:\n\n"
, err ]
Right (BlackBoxMeta {bbFunctionPlurality}, _bb) ->
pure (getFunctionPlurality' bbFunctionPlurality n)
getFunctionPlurality (BlackBox {functionPlurality}) _args _resTy n =
pure (getFunctionPlurality' functionPlurality n)