{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
module Clash.Annotations.TH
(
makeTopEntity
, makeTopEntityWithName
, makeTopEntityWithName'
, buildTopEntity
, maybeBuildTopEntity
, getNameBinding
)
where
import Data.Foldable ( fold)
import qualified Data.Set as Set
import qualified Data.Map as Map
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup as Semigroup
#endif
import Language.Haskell.TH
import Data.Functor.Foldable ( para )
import Data.Functor.Foldable.TH
import Control.Lens ( (%~), (&), (.~)
, _1, _2, _3, view
)
import Control.Monad (mfilter, liftM2)
import Control.Monad.Trans.Reader (ReaderT(..), asks, local)
import Control.Monad.Trans.Class (lift)
import Language.Haskell.TH.Instances ( )
import Language.Haskell.TH.Datatype
import Clash.Annotations.TopEntity ( PortName(..)
, TopEntity(..)
)
import Clash.NamedTypes ((:::))
import Clash.Signal ( HiddenClockResetEnable
, HiddenClock, HiddenReset, HiddenEnable
, Signal)
import Clash.Signal.Delayed (DSignal)
$(makeBaseFunctor ''Type)
data Naming a = Complete a | HasFail String | BackTrack (Set.Set Name)
deriving Functor
instance Semigroup a => Semigroup (Naming a) where
Complete a <> Complete b = Complete $ a <> b
BackTrack n1 <> BackTrack n2 = BackTrack $ n1 <> n2
BackTrack n <> _ = BackTrack n
_ <> BackTrack n = BackTrack n
HasFail e1 <> HasFail e2 = HasFail $ e1 ++ "\n" ++ e2
_ <> HasFail e = HasFail e
HasFail e <> _ = HasFail e
instance (Semigroup a, Monoid a) => Monoid (Naming a) where
mempty = Complete mempty
#if !(MIN_VERSION_base(4,11,0))
mappend = (Semigroup.<>)
#endif
type ErrorContext = String
type TrackData = (Set.Set Name, ErrorContext)
type Tracked m a = ReaderT TrackData m a
pattern ArrowTy :: Type -> Type -> Type
pattern ArrowTy a b = AppT (AppT ArrowT a) b
unapp :: Type -> [Type]
unapp (AppT l r) = unapp l ++ [r]
unapp t = [t]
unarrow :: Type -> [Type]
unarrow (ArrowTy x y) = x : unarrow y
unarrow _ = []
collapseNames :: [PortName] -> [PortName]
collapseNames [] = []
collapseNames [x] = [x]
collapseNames xs = [PortProduct "" xs]
failMsg :: String -> String
failMsg s = "TopEntity generation error: " ++ s
errorContext :: Tracked Q String
errorContext = asks snd
failMsgWithContext :: String -> Tracked Q String
failMsgWithContext s = (++) (failMsg s) <$> errorContext
visit :: (Show b) => Name -> b -> Tracked m a -> Tracked m a
visit name a = local (\t -> t & _1 %~ Set.insert name
& _2 .~ show a)
datatypeVars' :: DatatypeInfo -> [Name]
#if MIN_VERSION_th_abstraction(0,3,0)
datatypeVars' d = tvName <$> datatypeVars d
#else
datatypeVars' d = name <$> datatypeVars d
where
name (VarT n) = n
name (SigT n _) = name n
name e = error $ "Unexpected datatype variable name of type " ++ show e
#endif
tryReifyDatatype :: a -> (DatatypeInfo -> a) -> Name -> Tracked Q a
tryReifyDatatype a f name = lift (recover (pure a) $ f <$> reifyDatatype name)
portsFromTypes
:: [Type]
-> Tracked Q (Naming [PortName])
portsFromTypes xs = do
(mconcat <$> mapM f xs)
>>= \case
Complete names | length names > 0 && length names /= length xs ->
HasFail <$> failMsgWithContext "Partially named constructor arguments!\n"
x -> return x
where
f = fmap (fmap collapseNames) . gatherNames
handleNamesInSum
:: [ConstructorInfo]
-> Tracked Q (Naming [PortName])
handleNamesInSum xs =
(fold <$> mapM portsFromTypes (constructorFields <$> xs)) >>= \case
Complete [] -> return $ Complete []
x ->
mappend x . HasFail <$> failMsgWithContext "Annotated sum types not supported!\n"
constructorToPorts :: Con -> Map.Map Name Type -> Tracked Q (Naming [PortName])
constructorToPorts c m = do
let xs = applySubstitution m (ctys c)
portsFromTypes xs
where
ctys (NormalC _ (fmap snd -> tys)) = tys
ctys (RecC _ (fmap (view _3) -> tys)) = tys
ctys (InfixC _ _ (snd -> ty)) = [ty]
ctys (ForallC _ _ c') = ctys c'
ctys (GadtC _ (fmap snd -> tys) _) = tys
ctys (RecGadtC _ (fmap (view _3) -> tys) _) = tys
datatypeNameToPorts
:: Name
-> Tracked Q (Naming [PortName])
datatypeNameToPorts name = do
constructors <- tryReifyDatatype [] datatypeCons name
names <- case constructors of
[] -> return $ Complete []
[x] -> portsFromTypes (constructorFields x)
xs -> handleNamesInSum xs
case names of
BackTrack ns | Set.member name ns -> do
lift $ reportWarning $ "Make sure HDL port names are correct:\n"
++ "Backtracked when constructing " ++ pprint name
++ "\n(Type appears recursive)"
return $ case (Set.delete name ns) of
e | e == Set.empty -> Complete []
xs -> BackTrack xs
_ -> return names
typeTreeToPorts
:: TypeF (Type, Tracked Q (Naming [PortName]))
-> Tracked Q (Naming [PortName])
typeTreeToPorts (AppTF (AppT (ConT split) (LitT (StrTyLit name)), _) (_,c))
| split == ''(:::)
= c >>= \case
Complete [] -> return $ Complete [PortName name]
Complete [PortName n2] -> return $ Complete [PortName (name ++ "_" ++ n2)]
Complete xs -> return $ Complete [PortProduct name xs]
x -> return x
typeTreeToPorts (ConTF name) = do
seen <- asks fst
if Set.member name seen
then return $ BackTrack $ Set.singleton name
else visit name name $ do
info <- lift $ reify name
case info of
PrimTyConI _ _ _ -> return $ Complete []
TyConI (TySynD _ _ t) -> gatherNames t
_ -> datatypeNameToPorts name
typeTreeToPorts f@(AppTF (a,a') (b,b')) = do
case unapp (AppT a b) of
(ConT x : _ : _ : []) | x == ''Clash.Signal.Signal -> b'
(ConT x : _ : _ : _ : []) | x == ''Clash.Signal.Delayed.DSignal -> b'
(ConT x : xs) -> do
info <- lift $ reify x
case info of
(TyConI (TySynD _ synvars def)) -> do
gatherNames $ applyContext xs (tvName <$> synvars) def
FamilyI (ClosedTypeFamilyD (TypeFamilyHead _ bds _ _) eqs) _
| length bds == length xs ->
case filter ((==) xs . applyFamilyBindings xs info . tySynArgs) eqs of
#if MIN_VERSION_template_haskell(2,15,0)
[TySynEqn _ _ r] ->
#else
[TySynEqn _ r] ->
#endif
gatherNames (applyFamilyBindings xs info r)
_ -> return $ Complete []
_ | familyArity info == Just (length xs) -> do
(lift $ reifyInstances x xs) >>= \case
#if MIN_VERSION_template_haskell(2,15,0)
[TySynInstD (TySynEqn _ _ r)] ->
#else
[TySynInstD _ (TySynEqn _ r)] ->
#endif
gatherNames (applyFamilyBindings xs info r)
[NewtypeInstD _ _ _ _ c _] -> constructorToPorts c (familyTyMap xs info)
[DataInstD _ _ _ _ cs _] -> do
case cs of
[c] -> constructorToPorts c (familyTyMap xs info)
_ -> return $ Complete []
y -> fail $ failMsg "Encountered unexpected type during family application!"
++ pprint y
_ -> do
dataTy <- tryReifyDatatype Nothing Just x
let
hasAllArgs = \vs -> length xs == length (datatypeVars vs)
constructors = applyDatatypeContext xs <$> mfilter hasAllArgs dataTy
getSingleConstructor cs = do [c] <- cs; return c
constructor = getSingleConstructor constructors
maybe a' (visit x (ppr x) . portsFromTypes . constructorFields) constructor
(ListT:_) -> fold <$> mapM snd f
(TupleT _:_) -> fold <$> mapM snd f
_ -> do
lift $ reportWarning $ "Make sure HDL port names are correct:\n"
++ "Type application with non ConT head:\n:("
++ pprint (AppT a b)
f' <- mapM snd f
return $ fold f'
where
tyMap ctx holes = Map.fromList $ zip holes ctx
familyTyMap ctx (familyBindings -> Just holes) = tyMap ctx (tvName <$> holes)
familyTyMap _ _ = error "familyTyMap called with non family argument!"
applyContext ctx holes = applySubstitution (tyMap ctx holes)
applyDatatypeContext ctx d = applyContext ctx (datatypeVars' d) <$> datatypeCons d
applyFamilyBindings ctx (familyBindings -> Just holes) t
= applyContext ctx (tvName <$> holes) t
applyFamilyBindings _ _ _ = error "familyTyMap called with non family argument!"
#if MIN_VERSION_template_haskell(2,15,0)
tySynArgs (TySynEqn _ args _) = tail (unapp args)
#else
tySynArgs (TySynEqn args _) = args
#endif
familyBindings (FamilyI (ClosedTypeFamilyD (TypeFamilyHead _ xs _ _) _) _) = Just xs
familyBindings (FamilyI (OpenTypeFamilyD (TypeFamilyHead _ xs _ _)) _) = Just xs
familyBindings (FamilyI (DataFamilyD _ xs _) _) = Just xs
familyBindings _ = Nothing
familyArity = fmap length . familyBindings
typeTreeToPorts f = do
f' <- mapM snd f
return $ fold f'
gatherNames
:: Type
-> Tracked Q (Naming [PortName])
gatherNames =
para typeTreeToPorts
buildPorts
:: Type
-> Q [PortName]
buildPorts x = do
flip runReaderT (Set.empty, "") $ gatherNames x
>>= \case
Complete xs -> return xs
HasFail err -> fail err
BackTrack n -> fail $ failMsg "Encountered recursive type at entry! " ++ show n
toReturnName :: Type -> Q PortName
toReturnName (ArrowTy _ b) = toReturnName b
toReturnName b =
buildPorts b
>>= \case
[] -> fail $ failMsg "No return name specified!"
[x] -> return x
xs -> return $ PortProduct "" xs
toArgNames :: Type -> Q [PortName]
toArgNames ty = traverse build (unarrow ty)
where
build x = buildPorts x >>= check x
check x [] = fail $ failMsg "Unnamed argument " ++ pprint x
check _ [a] = return a
check _ xs = return $ PortProduct "" xs
data ClockType = None | SingleClockResetEnable | Other
deriving Eq
handleConstraints :: Type -> ClockType -> Q (Type, ClockType)
handleConstraints (ForallT [] [] x) clk = handleConstraints x clk
handleConstraints (ForallT xs@(_:_) _ _) _ =
fail $ failMsg "Free type variables!\n"
++ pprint xs
handleConstraints (ForallT _ c x) clk = handleConstraints x hiddenClocks
where
hiddenClocks = foldl findHiddenClocks clk c
findHiddenClocks a (AppT (ConT b) _)
| b == ''Clash.Signal.HiddenClockResetEnable && a == None
= SingleClockResetEnable
| b == ''Clash.Signal.HiddenClockResetEnable && a /= None
= Other
| b == ''Clash.Signal.HiddenClock
|| b == ''Clash.Signal.HiddenReset
|| b == ''Clash.Signal.HiddenEnable
= Other
findHiddenClocks a _ = a
handleConstraints x clk = return (x, clk)
clockToPorts :: ClockType -> Q [PortName]
clockToPorts None = return []
clockToPorts (SingleClockResetEnable) =
return [PortProduct "" [ PortName "clk" , PortName "rst" , PortName "en" ]]
clockToPorts Other =
fail $ failMsg "TH generation for multiple hidden clocks and"
++ " HiddenClock/HiddenReset/HiddenEnable currently unsupported!"
buildTopEntity :: Maybe String -> (Name, Type) -> TExpQ TopEntity
buildTopEntity topName (name, ty) = do
(ty', clock) <- handleConstraints ty None
ins <- liftM2 (<>) (clockToPorts clock) (toArgNames ty')
out <- toReturnName ty'
let outName = case topName of
Just name' -> name'
Nothing -> nameBase name
[|| Synthesize
{ t_name = outName
, t_inputs = ins
, t_output = out
} ||]
maybeBuildTopEntity :: Maybe String -> Name -> Q (TExp (Maybe TopEntity))
maybeBuildTopEntity topName name = do
recover ([|| Nothing ||]) $ do
let expr = getNameBinding name >>= buildTopEntity topName
[|| Just ($$expr) ||]
getNameBinding :: Name -> Q (Name, Type)
getNameBinding n = reify n >>= \case
VarI name ty _ -> return (name, ty)
_ -> fail "getNameBinding: Invalid Name, must be a top-level binding!"
makeTopEntityWithName' :: Name -> Maybe String -> DecQ
makeTopEntityWithName' n topName = do
(name,ty) <- getNameBinding n
topEntity <- buildTopEntity topName (name,ty)
let prag t = PragmaD (AnnP (valueAnnotation name) t)
return $ prag $ unType topEntity
makeTopEntityWithName :: Name -> String -> DecsQ
makeTopEntityWithName nam top = pure <$> makeTopEntityWithName' nam (Just top)
makeTopEntity :: Name -> DecsQ
makeTopEntity nam = pure <$> makeTopEntityWithName' nam Nothing