{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Core.Type
( Type (..)
, TypeView (..)
, ConstTy (..)
, LitTy (..)
, Kind
, KindOrType
, KiName
, TyName
, TyVar
, tyView
, coreView
, coreView1
, typeKind
, mkTyConTy
, mkFunTy
, mkPolyFunTy
, mkTyConApp
, splitFunTy
, splitFunTys
, splitFunForallTy
, splitCoreFunForallTy
, splitTyConAppM
, isPolyFunTy
, isPolyFunCoreTy
, isPolyTy
, isTypeFamilyApplication
, isFunTy
, isClassTy
, applyFunTy
, findFunSubst
, reduceTypeFamily
, undefinedTy
, isIntegerTy
, normalizeType
, varAttrs
, typeAttrs
)
where
import Control.DeepSeq as DS
import Data.Binary (Binary)
import Data.Coerce (coerce)
import Data.Hashable (Hashable)
import Data.List (foldl')
import Data.List.Extra (splitAtList)
import Data.Maybe (isJust, mapMaybe)
import GHC.Base (isTrue#,(==#))
import GHC.Generics (Generic(..))
import GHC.Integer (smallInteger)
import GHC.Integer.Logarithms (integerLogBase#)
#if __GLASGOW_HASKELL__ >= 808
import PrelNames
(ordLTDataConKey, ordEQDataConKey, ordGTDataConKey)
#else
import Unique (Unique)
import PrelNames
(ltDataConKey, eqDataConKey, gtDataConKey)
#endif
import PrelNames
(integerTyConKey, typeNatAddTyFamNameKey, typeNatExpTyFamNameKey,
typeNatLeqTyFamNameKey, typeNatMulTyFamNameKey, typeNatSubTyFamNameKey,
typeNatCmpTyFamNameKey,
typeSymbolAppendFamNameKey, typeSymbolCmpTyFamNameKey)
import SrcLoc (wiredInSrcSpan)
import Unique (getKey)
import Clash.Core.DataCon
import Clash.Core.Name
import {-# SOURCE #-} Clash.Core.Subst
import Clash.Core.TyCon
import Clash.Core.TysPrim
import Clash.Core.Var
import Clash.Unique
import Clash.Util
#if __GLASGOW_HASKELL__ <= 806
ordLTDataConKey, ordEQDataConKey, ordGTDataConKey :: Unique.Unique
ordLTDataConKey = ltDataConKey
ordEQDataConKey = eqDataConKey
ordGTDataConKey = gtDataConKey
#endif
varAttrs :: Var a -> [Attr']
varAttrs t@(TyVar {}) =
error $ $(curLoc) ++ "Unexpected argument: " ++ show t
varAttrs (Id _ _ ty _) =
case ty of
AnnType attrs _typ -> attrs
_ -> []
data Type
= VarTy !TyVar
| ConstTy !ConstTy
| ForAllTy !TyVar !Type
| AppTy !Type !Type
| LitTy !LitTy
| AnnType [Attr'] !Type
deriving (Show,Generic,NFData,Hashable,Binary)
data TypeView
= FunTy !Type !Type
| TyConApp !TyConName [Type]
| OtherType !Type
deriving Show
data ConstTy
= TyCon !TyConName
| Arrow
deriving (Eq,Ord,Show,Generic,NFData,Hashable,Binary)
data LitTy
= NumTy !Integer
| SymTy !String
deriving (Eq,Ord,Show,Generic,NFData,Hashable,Binary)
type Kind = Type
type KindOrType = Type
type TyName = Name Type
type KiName = Name Kind
tyView :: Type -> TypeView
tyView tOrig = case tOrig of
ConstTy c -> case c of
TyCon tc -> TyConApp tc []
_ -> OtherType tOrig
AppTy l0 res -> case l0 of
ConstTy (TyCon tc) -> TyConApp tc [res]
AppTy l1 arg -> case l1 of
ConstTy Arrow -> FunTy arg res
ConstTy (TyCon tc) -> TyConApp tc [arg,res]
AppTy l2 resK -> case l2 of
ConstTy (TyCon tc) -> TyConApp tc [resK,arg,res]
AppTy l3 argK -> case l3 of
ConstTy (TyCon tc) -> TyConApp tc [argK,resK,arg,res]
ConstTy Arrow -> FunTy arg res
_ -> case go [argK,resK,arg,res] l3 of
(ConstTy (TyCon tc),args)
-> TyConApp tc args
_ -> OtherType tOrig
_ -> OtherType tOrig
_ -> OtherType tOrig
_ -> OtherType tOrig
_ -> OtherType tOrig
where
go args (AppTy ty1 ty2) = go (ty2:args) ty1
go args t1 = (t1,args)
coreView :: TyConMap -> Type -> Type
coreView tcm ty =
case coreView1 tcm ty of
Nothing -> ty
Just ty' -> coreView tcm ty'
coreView1 :: TyConMap -> Type -> Maybe Type
coreView1 tcMap ty = case tyView ty of
TyConApp tcNm args
| nameOcc tcNm == "Clash.Signal.BiSignal.BiSignalIn"
, [_,_,_,elTy] <- args
-> Just elTy
| nameOcc tcNm == "Clash.Signal.BiSignal.BiSignalOut"
, [_,_,_,elTy] <- args
-> Just elTy
| nameOcc tcNm == "Clash.Signal.Internal.Signal"
, [_,elTy] <- args
-> Just elTy
| otherwise
-> case tcMap `lookupUniqMap'` tcNm of
AlgTyCon {algTcRhs = (NewTyCon _ nt)}
-> newTyConInstRhs nt args
_ -> reduceTypeFamily tcMap ty
_ -> Nothing
newTyConInstRhs :: ([TyVar],Type) -> [Type] -> Maybe Type
newTyConInstRhs (tvs,ty) tys
| length tvs <= length tys
= Just (foldl' AppTy (substTyWith tvs tys1 ty) tys2)
| otherwise
= Nothing
where
(tys1, tys2) = splitAtList tvs tys
mkFunTy :: Type -> Type -> Type
mkFunTy t1 = AppTy (AppTy (ConstTy Arrow) t1)
mkTyConApp :: TyConName -> [Type] -> Type
mkTyConApp tc = foldl AppTy (ConstTy $ TyCon tc)
mkTyConTy :: TyConName -> Type
mkTyConTy ty = ConstTy $ TyCon ty
splitTyConAppM :: Type
-> Maybe (TyConName,[Type])
splitTyConAppM (tyView -> TyConApp tc args) = Just (tc,args)
splitTyConAppM _ = Nothing
isSuperKind :: TyConMap -> Type -> Bool
isSuperKind tcMap (ConstTy (TyCon ((tcMap `lookupUniqMap'`) -> SuperKindTyCon {}))) = True
isSuperKind _ _ = False
typeKind :: TyConMap -> Type -> Kind
typeKind _ (VarTy k) = varType k
typeKind m (ForAllTy _ ty) = typeKind m ty
typeKind _ (LitTy (NumTy _)) = typeNatKind
typeKind _ (LitTy (SymTy _)) = typeSymbolKind
typeKind m (AnnType _ann typ) = typeKind m typ
typeKind m (tyView -> FunTy _arg res)
| isSuperKind m k = k
| otherwise = liftedTypeKind
where k = typeKind m res
typeKind m (tyView -> TyConApp tc args) =
foldl' kindFunResult (tyConKind (m `lookupUniqMap'` tc)) args
typeKind m (AppTy fun arg) = kindFunResult (typeKind m fun) arg
typeKind _ (ConstTy ct) = error $ $(curLoc) ++ "typeKind: naked ConstTy: " ++ show ct
kindFunResult :: Kind -> KindOrType -> Kind
kindFunResult (tyView -> FunTy _ res) _ = res
kindFunResult (ForAllTy kv ki) arg =
substTyWith [kv] [arg] ki
kindFunResult k tys =
error $ $(curLoc) ++ "kindFunResult: " ++ show (k,tys)
isPolyTy :: Type -> Bool
isPolyTy (ForAllTy _ _) = True
isPolyTy (tyView -> FunTy _ res) = isPolyTy res
isPolyTy _ = False
splitFunTy :: TyConMap
-> Type
-> Maybe (Type, Type)
splitFunTy m (coreView1 m -> Just ty) = splitFunTy m ty
splitFunTy _ (tyView -> FunTy arg res) = Just (arg,res)
splitFunTy _ _ = Nothing
splitFunTys :: TyConMap
-> Type
-> ([Type],Type)
splitFunTys m ty = go [] ty ty
where
go args orig_ty (coreView1 m -> Just ty') = go args orig_ty ty'
go args _ (tyView -> FunTy arg res) = go (arg:args) res res
go args orig_ty _ = (reverse args, orig_ty)
splitFunForallTy :: Type
-> ([Either TyVar Type],Type)
splitFunForallTy = go []
where
go args (ForAllTy tv ty) = go (Left tv:args) ty
go args (tyView -> FunTy arg res) = go (Right arg:args) res
go args ty = (reverse args,ty)
mkPolyFunTy
:: Type
-> [Either TyVar Type]
-> Type
mkPolyFunTy = foldr (either ForAllTy mkFunTy)
splitCoreFunForallTy :: TyConMap
-> Type
-> ([Either TyVar Type], Type)
splitCoreFunForallTy tcm ty = go [] ty ty
where
go args orig_ty (coreView1 tcm -> Just ty') = go args orig_ty ty'
go args _ (ForAllTy tv res) = go (Left tv:args) res res
go args _ (tyView -> FunTy arg res) = go (Right arg:args) res res
go args orig_ty _ = (reverse args,orig_ty)
isPolyFunTy :: Type
-> Bool
isPolyFunTy = not . null . fst . splitFunForallTy
isPolyFunCoreTy :: TyConMap
-> Type
-> Bool
isPolyFunCoreTy m (coreView1 m -> Just ty) = isPolyFunCoreTy m ty
isPolyFunCoreTy _ ty = case tyView ty of
FunTy _ _ -> True
OtherType (ForAllTy _ _) -> True
_ -> False
typeAttrs
:: Type
-> [Attr']
typeAttrs (AnnType attrs _typ) = attrs
typeAttrs _ = []
isFunTy :: TyConMap
-> Type
-> Bool
isFunTy m = isJust . splitFunTy m
applyFunTy :: TyConMap
-> Type
-> Type
-> Type
applyFunTy m (coreView1 m -> Just ty) arg = applyFunTy m ty arg
applyFunTy _ (tyView -> FunTy _ resTy) _ = resTy
applyFunTy _ _ _ = error $ $(curLoc) ++ "Report as bug: not a FunTy"
findFunSubst :: TyConMap -> [([Type],Type)] -> [Type] -> Maybe Type
findFunSubst _ [] _ = Nothing
findFunSubst tcm (tcSubst:rest) args = case funSubsts tcm tcSubst args of
Just ty -> Just ty
Nothing -> findFunSubst tcm rest args
funSubsts :: TyConMap -> ([Type],Type) -> [Type] -> Maybe Type
funSubsts tcm (tcSubstLhs,tcSubstRhs) args = do
tySubts <- foldl' (funSubst tcm) (Just []) (zip tcSubstLhs args)
let tyRhs = uncurry substTyWith (unzip tySubts) tcSubstRhs
case drop (length tcSubstLhs) args of
[] -> return tyRhs
args' -> return (foldl' AppTy tyRhs args')
funSubst
:: TyConMap
-> Maybe [(TyVar,Type)]
-> (Type,Type)
-> Maybe [(TyVar,Type)]
funSubst _ Nothing = const Nothing
funSubst tcm (Just s) = uncurry go
where
go (VarTy nmF) ty = case lookup nmF s of
Nothing -> Just ((nmF,ty):s)
Just ty' | ty' `aeqType` ty -> Just s
_ -> Nothing
go ty1 (reduceTypeFamily tcm -> Just ty2) = go ty1 ty2
go (AppTy a1 r1) (AppTy a2 r2) = do
s1 <- funSubst tcm (Just s) (a1, a2)
funSubst tcm (Just s1) (r1, r2)
go ty1@(ConstTy _) ty2 =
if ty1 `aeqType` ty2 then Just s else Nothing
go ty1@(LitTy _) ty2 =
if ty1 `aeqType` ty2 then Just s else Nothing
go _ _ = Nothing
reduceTypeFamily :: TyConMap -> Type -> Maybe Type
reduceTypeFamily tcm (tyView -> TyConApp tc tys)
| nameUniq tc == getKey typeNatAddTyFamNameKey
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 + i2)))
| nameUniq tc == getKey typeNatMulTyFamNameKey
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 * i2)))
| nameUniq tc == getKey typeNatExpTyFamNameKey
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 ^ i2)))
| nameUniq tc == getKey typeNatSubTyFamNameKey
, [i1, i2] <- mapMaybe (litView tcm) tys
, let z = i1 - i2
, z >= 0
= Just (LitTy (NumTy z))
| nameUniq tc == getKey typeNatLeqTyFamNameKey
, [i1, i2] <- mapMaybe (litView tcm) tys
, Just (FunTyCon {tyConKind = tck}) <- lookupUniqMap tc tcm
, (_,tyView -> TyConApp boolTcNm []) <- splitFunTys tcm tck
, Just boolTc <- lookupUniqMap boolTcNm tcm
= let [falseTc,trueTc] = map (coerce . dcName) (tyConDataCons boolTc)
in if i1 <= i2 then Just (mkTyConApp trueTc [] )
else Just (mkTyConApp falseTc [])
| nameUniq tc == getKey typeNatCmpTyFamNameKey
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just $ ConstTy $ TyCon $
case compare i1 i2 of
LT -> Name User "GHC.Types.LT" (getKey ordLTDataConKey) wiredInSrcSpan
EQ -> Name User "GHC.Types.EQ" (getKey ordEQDataConKey) wiredInSrcSpan
GT -> Name User "GHC.Types.GT" (getKey ordGTDataConKey) wiredInSrcSpan
| nameUniq tc == getKey typeSymbolCmpTyFamNameKey
, [s1, s2] <- mapMaybe (symLitView tcm) tys
= Just $ ConstTy $ TyCon $
case compare s1 s2 of
LT -> Name User "GHC.Types.LT" (getKey ordLTDataConKey) wiredInSrcSpan
EQ -> Name User "GHC.Types.EQ" (getKey ordEQDataConKey) wiredInSrcSpan
GT -> Name User "GHC.Types.GT" (getKey ordGTDataConKey) wiredInSrcSpan
| nameUniq tc == getKey typeSymbolAppendFamNameKey
, [s1, s2] <- mapMaybe (symLitView tcm) tys
= Just (LitTy (SymTy (s1 ++ s2)))
| nameOcc tc `elem` ["GHC.TypeLits.Extra.FLog", "GHC.TypeNats.FLog"]
, [i1, i2] <- mapMaybe (litView tcm) tys
, i1 > 1
, i2 > 0
= Just (LitTy (NumTy (smallInteger (integerLogBase# i1 i2))))
| nameOcc tc `elem` ["GHC.TypeLits.Extra.CLog", "GHC.TypeNats.CLog"]
, [i1, i2] <- mapMaybe (litView tcm) tys
, Just k <- clogBase i1 i2
= Just (LitTy (NumTy (toInteger k)))
| nameOcc tc `elem` ["GHC.TypeLits.Extra.Log", "GHC.TypeNats.Log"]
, [i1, i2] <- mapMaybe (litView tcm) tys
, i1 > 1
, i2 > 0
= if i2 == 1
then Just (LitTy (NumTy 0))
else let z1 = integerLogBase# i1 i2
z2 = integerLogBase# i1 (i2-1)
in if isTrue# (z1 ==# z2)
then Nothing
else Just (LitTy (NumTy (smallInteger z1)))
| nameOcc tc `elem` ["GHC.TypeLits.Extra.GCD", "GHC.TypeNats.GCD"]
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 `gcd` i2)))
| nameOcc tc `elem` ["GHC.TypeLits.Extra.LCM", "GHC.TypeNats.LCM"]
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 `lcm` i2)))
| nameOcc tc `elem` ["GHC.TypeLits.Extra.Div", "GHC.TypeNats.Div"]
, [i1, i2] <- mapMaybe (litView tcm) tys
, i2 > 0
= Just (LitTy (NumTy (i1 `div` i2)))
| nameOcc tc `elem` ["GHC.TypeLits.Extra.Mod", "GHC.TypeNats.Mod"]
, [i1, i2] <- mapMaybe (litView tcm) tys
, i2 > 0
= Just (LitTy (NumTy (i1 `mod` i2)))
| Just (FunTyCon {tyConSubst = tcSubst}) <- lookupUniqMap tc tcm
= findFunSubst tcm tcSubst tys
reduceTypeFamily _ _ = Nothing
isTypeFamilyApplication :: TyConMap -> Type -> Bool
isTypeFamilyApplication tcm (tyView -> TyConApp tcNm _args)
| Just (FunTyCon {}) <- lookupUniqMap tcNm tcm = True
isTypeFamilyApplication _tcm _type = False
litView :: TyConMap -> Type -> Maybe Integer
litView _ (LitTy (NumTy i)) = Just i
litView m (reduceTypeFamily m -> Just ty') = litView m ty'
litView _ _ = Nothing
symLitView :: TyConMap -> Type -> Maybe String
symLitView _ (LitTy (SymTy s)) = Just s
symLitView m (reduceTypeFamily m -> Just ty') = symLitView m ty'
symLitView _ _ = Nothing
undefinedTy ::Type
undefinedTy =
let aNm = mkUnsafeSystemName "a" 0
aTv = (TyVar aNm 0 liftedTypeKind)
in ForAllTy aTv (VarTy aTv)
isIntegerTy :: Type -> Bool
isIntegerTy (ConstTy (TyCon nm)) = nameUniq nm == getKey integerTyConKey
isIntegerTy _ = False
normalizeType :: TyConMap -> Type -> Type
normalizeType tcMap = go
where
go ty = case tyView ty of
TyConApp tcNm args
| nameOcc tcNm == "Clash.Signal.Internal.Signal"
, [_,elTy] <- args
-> go elTy
| nameOcc tcNm == "Clash.Sized.Internal.BitVector.Bit" ||
nameOcc tcNm == "Clash.Sized.Internal.BitVector.BitVector" ||
nameOcc tcNm == "Clash.Sized.Internal.Index.Index" ||
nameOcc tcNm == "Clash.Sized.Internal.Signed.Signed" ||
nameOcc tcNm == "Clash.Sized.Internal.Unsigned.Unsigned"
-> mkTyConApp tcNm (map go args)
| otherwise
-> case lookupUniqMap' tcMap tcNm of
AlgTyCon {algTcRhs = (NewTyCon _ nt)}
-> case newTyConInstRhs nt args of
Just ty' -> go ty'
Nothing -> ty
_ ->
let args' = map go args
ty' = mkTyConApp tcNm args'
in case reduceTypeFamily tcMap ty' of
Just ty'' -> ty''
Nothing -> ty'
FunTy ty1 ty2 -> mkFunTy (go ty1) (go ty2)
OtherType (ForAllTy tyvar ty')
-> ForAllTy tyvar (go ty')
_ -> ty
isClassTy
:: TyConMap
-> Type
-> Bool
isClassTy tcm (tyView -> TyConApp tcNm _) =
case lookupUniqMap tcNm tcm of
Just (AlgTyCon {isClassTc}) -> isClassTc
_ -> False
isClassTy _ _ = False