{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Core.Type
( Type (..)
, TypeView (..)
, ConstTy (..)
, LitTy (..)
, Kind
, KindOrType
, KiName
, KiOccName
, TyName
, TyOccName
, TyVar
, tyView
, coreView
, typeKind
, mkTyConTy
, mkFunTy
, mkTyConApp
, splitFunTy
, splitFunTys
, splitFunForallTy
, splitCoreFunForallTy
, splitTyConAppM
, isPolyFunTy
, isPolyFunCoreTy
, isPolyTy
, isFunTy
, applyFunTy
, applyTy
, findFunSubst
, reduceTypeFamily
, undefinedTy
, isIntegerTy
, normalizeType
)
where
import Control.DeepSeq as DS
import Data.Hashable (Hashable)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.List (foldl',isPrefixOf)
import Data.Maybe (isJust, mapMaybe)
import GHC.Base (isTrue#,(==#))
import GHC.Generics (Generic(..))
import GHC.Integer (smallInteger)
import GHC.Integer.Logarithms (integerLogBase#)
import Unbound.Generics.LocallyNameless (Alpha(..),Bind,Fresh,
Subst(..),SubstName(..),
acompare,aeq,bind,embed,
gacompare,gaeq,gfvAny,
runFreshM,unbind)
import Unbound.Generics.LocallyNameless.Extra ()
import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind)
import Clash.Core.DataCon
import Clash.Core.Name
import Clash.Core.Subst
import {-# SOURCE #-} Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.TysPrim
import Clash.Core.Var
import Clash.Util
data Type
= VarTy !Kind !TyName
| ConstTy !ConstTy
| ForAllTy !(Bind TyVar Type)
| AppTy !Type !Type
| LitTy !LitTy
deriving (Show,Generic,NFData,Hashable)
data TypeView
= FunTy !Type !Type
| TyConApp !TyConName [Type]
| OtherType !Type
deriving Show
data ConstTy
= TyCon !TyConName
| Arrow
deriving (Show,Generic,NFData,Alpha,Hashable)
data LitTy
= NumTy !Integer
| SymTy !String
deriving (Show,Generic,NFData,Alpha,Hashable)
type Kind = Type
type KindOrType = Type
type TyName = Name Type
type TyOccName = OccName Type
type KiName = Name Kind
type KiOccName = OccName Kind
instance Alpha Type where
fvAny' c nfn (VarTy t n) = fmap (VarTy t) $ fvAny' c nfn n
fvAny' c nfn t = fmap to . gfvAny c nfn $ from t
aeq' c (VarTy _ n) (VarTy _ m) = aeq' c n m
aeq' c t1 t2 = gaeq c (from t1) (from t2)
acompare' c (VarTy _ n) (VarTy _ m) = acompare' c n m
acompare' c t1 t2 = gacompare c (from t1) (from t2)
instance Subst a LitTy where
subst _ _ lt = lt
substs _ lt = lt
instance Subst a ConstTy where
subst _ _ ct = ct
substs _ ct = ct
instance Subst Term Type
instance Subst Type Type where
isvar (VarTy _ v) = Just (SubstName (nameOcc v))
isvar _ = Nothing
instance Eq Type where
(==) = aeq
instance Ord Type where
compare = acompare
tyView :: Type -> TypeView
tyView ty@(AppTy _ _) = case splitTyAppM ty of
Just (ConstTy Arrow, [ty1,ty2]) -> FunTy ty1 ty2
Just (ConstTy (TyCon tc), args) -> TyConApp tc args
_ -> OtherType ty
tyView (ConstTy (TyCon tc)) = TyConApp tc []
tyView t = OtherType t
coreView :: HashMap TyConOccName TyCon -> Type -> Maybe Type
coreView tcMap ty = case tyView ty of
TyConApp tcNm args
| name2String tcNm == "Clash.Signal.Internal.Signal"
, [_,elTy] <- args
-> Just elTy
| otherwise
-> case tcMap HashMap.! nameOcc tcNm of
AlgTyCon {algTcRhs = (NewTyCon _ nt)}
-> newTyConInstRhs nt args
_ -> reduceTypeFamily tcMap ty
_ -> Nothing
newTyConInstRhs :: ([TyName],Type) -> [Type] -> Maybe Type
newTyConInstRhs (tvs,ty) tys
| length tvs <= length tys
= Just (foldl AppTy (substTys (zip tvs' tys1) ty) tys2)
| otherwise
= Nothing
where
(tys1, tys2) = splitAtList tvs tys
tvs' = map nameOcc tvs
mkFunTy :: Type -> Type -> Type
mkFunTy t1 = AppTy (AppTy (ConstTy Arrow) t1)
mkTyConApp :: TyConName -> [Type] -> Type
mkTyConApp tc = foldl AppTy (ConstTy $ TyCon tc)
mkTyConTy :: TyConName -> Type
mkTyConTy ty = ConstTy $ TyCon ty
splitTyConAppM :: Type
-> Maybe (TyConName,[Type])
splitTyConAppM (tyView -> TyConApp tc args) = Just (tc,args)
splitTyConAppM _ = Nothing
isSuperKind :: HashMap TyConOccName TyCon -> Type -> Bool
isSuperKind tcMap (ConstTy (TyCon (((tcMap HashMap.!) . nameOcc) -> SuperKindTyCon {}))) = True
isSuperKind _ _ = False
typeKind :: HashMap TyConOccName TyCon -> Type -> Kind
typeKind _ (VarTy k _) = k
typeKind m (ForAllTy b) = let (_,ty) = runFreshM $ unbind b
in typeKind m ty
typeKind _ (LitTy (NumTy _)) = typeNatKind
typeKind _ (LitTy (SymTy _)) = typeSymbolKind
typeKind m (tyView -> FunTy _arg res)
| isSuperKind m k = k
| otherwise = liftedTypeKind
where k = typeKind m res
typeKind m (tyView -> TyConApp tc args) = foldl' kindFunResult (tyConKind (m HashMap.! nameOcc tc)) args
typeKind m (AppTy fun arg) = kindFunResult (typeKind m fun) arg
typeKind _ (ConstTy ct) = error $ $(curLoc) ++ "typeKind: naked ConstTy: " ++ show ct
kindFunResult :: Kind -> KindOrType -> Kind
kindFunResult (tyView -> FunTy _ res) _ = res
kindFunResult (ForAllTy b) arg =
let (kv,ki) = runFreshM . unbind $ b
in substKindWith (zip [nameOcc (varName kv)] [arg]) ki
kindFunResult k tys =
error $ $(curLoc) ++ "kindFunResult: " ++ show (k,tys)
isPolyTy :: Type -> Bool
isPolyTy (ForAllTy _) = True
isPolyTy (tyView -> FunTy _ res) = isPolyTy res
isPolyTy _ = False
splitFunTy :: HashMap TyConOccName TyCon
-> Type
-> Maybe (Type, Type)
splitFunTy m (coreView m -> Just ty) = splitFunTy m ty
splitFunTy _ (tyView -> FunTy arg res) = Just (arg,res)
splitFunTy _ _ = Nothing
splitFunTys :: HashMap TyConOccName TyCon
-> Type
-> ([Type],Type)
splitFunTys m ty = go [] ty ty
where
go args orig_ty (coreView m -> Just ty') = go args orig_ty ty'
go args _ (tyView -> FunTy arg res) = go (arg:args) res res
go args orig_ty _ = (reverse args, orig_ty)
splitFunForallTy :: Type
-> ([Either TyVar Type],Type)
splitFunForallTy = go []
where
go args (ForAllTy b) = let (tv,ty) = runFreshM $ unbind b
in go (Left tv:args) ty
go args (tyView -> FunTy arg res) = go (Right arg:args) res
go args ty = (reverse args,ty)
splitCoreFunForallTy :: HashMap TyConOccName TyCon
-> Type
-> ([Either TyVar Type], Type)
splitCoreFunForallTy tcm ty = go [] ty ty
where
go args orig_ty (coreView tcm -> Just ty') = go args orig_ty ty'
go args _ (ForAllTy b) = let (tv,res) = runFreshM $ unbind b
in go (Left tv:args) res res
go args _ (tyView -> FunTy arg res) = go (Right arg:args) res res
go args orig_ty _ = (reverse args,orig_ty)
isPolyFunTy :: Type
-> Bool
isPolyFunTy = not . null . fst . splitFunForallTy
isPolyFunCoreTy :: HashMap TyConOccName TyCon
-> Type
-> Bool
isPolyFunCoreTy m (coreView m -> Just ty) = isPolyFunCoreTy m ty
isPolyFunCoreTy _ ty = case tyView ty of
FunTy _ _ -> True
OtherType (ForAllTy _) -> True
_ -> False
isFunTy :: HashMap TyConOccName TyCon
-> Type
-> Bool
isFunTy m = isJust . splitFunTy m
applyFunTy :: HashMap TyConOccName TyCon
-> Type
-> Type
-> Type
applyFunTy m (coreView m -> Just ty) arg = applyFunTy m ty arg
applyFunTy _ (tyView -> FunTy _ resTy) _ = resTy
applyFunTy _ _ _ = error $ $(curLoc) ++ "Report as bug: not a FunTy"
applyTy :: Fresh m
=> HashMap TyConOccName TyCon
-> Type
-> KindOrType
-> m Type
applyTy tcm (coreView tcm -> Just ty) arg = applyTy tcm ty arg
applyTy _ (ForAllTy b) arg = do
(tv,ty) <- unbind b
return (substTy (nameOcc (varName tv)) arg ty)
applyTy _ ty arg = error ($(curLoc) ++ "applyTy: not a forall type:\n" ++ show ty ++ "\nArg:\n" ++ show arg)
splitTyAppM :: Type
-> Maybe (Type, [Type])
splitTyAppM = fmap (second reverse) . go []
where
go args (AppTy ty1 ty2) =
case go args ty1 of
Nothing -> Just (ty1,ty2:args)
Just (ty1',ty1args) -> Just (ty1',ty2:ty1args )
go _ _ = Nothing
findFunSubst :: HashMap TyConOccName TyCon -> [([Type],Type)] -> [Type] -> Maybe Type
findFunSubst _ [] _ = Nothing
findFunSubst tcm (tcSubst:rest) args = case funSubsts tcm tcSubst args of
Just ty -> Just ty
Nothing -> findFunSubst tcm rest args
funSubsts :: HashMap TyConOccName TyCon -> ([Type],Type) -> [Type] -> Maybe Type
funSubsts tcm (tcSubstLhs,tcSubstRhs) args = do
tySubts <- foldl' (funSubst tcm) (Just []) (zip tcSubstLhs args)
let tyRhs = substTys tySubts tcSubstRhs
case drop (length tcSubstLhs) args of
[] -> return tyRhs
args' -> return (foldl' AppTy tyRhs args')
funSubst
:: HashMap TyConOccName TyCon
-> Maybe [(TyOccName,Type)]
-> (Type,Type)
-> Maybe [(TyOccName,Type)]
funSubst _ Nothing = const Nothing
funSubst tcm (Just s) = uncurry go
where
go (VarTy _ (nameOcc -> nmF)) ty = case lookup nmF s of
Nothing -> Just ((nmF,ty):s)
Just ty' | ty' == ty -> Just s
_ -> Nothing
go ty1 (reduceTypeFamily tcm -> Just ty2) = go ty1 ty2
go ty1@(LitTy _) ty2 = if ty1 == ty2 then Just s else Nothing
go (tyView -> TyConApp tc argTys) (tyView -> TyConApp tc' argTys')
| tc == tc'
= foldl' (funSubst tcm) (Just s) (zip argTys argTys')
go _ _ = Nothing
reduceTypeFamily :: HashMap TyConOccName TyCon -> Type -> Maybe Type
reduceTypeFamily tcm (tyView -> TyConApp tc tys)
#if MIN_VERSION_ghc(8,2,0)
| name2String tc == "GHC.TypeNats.+"
#else
| name2String tc == "GHC.TypeLits.+"
#endif
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 + i2)))
#if MIN_VERSION_ghc(8,2,0)
| name2String tc == "GHC.TypeNats.*"
#else
| name2String tc == "GHC.TypeLits.*"
#endif
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 * i2)))
#if MIN_VERSION_ghc(8,2,0)
| name2String tc == "GHC.TypeNats.^"
#else
| name2String tc == "GHC.TypeLits.^"
#endif
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 ^ i2)))
#if MIN_VERSION_ghc(8,2,0)
| name2String tc == "GHC.TypeNats.-"
#else
| name2String tc == "GHC.TypeLits.-"
#endif
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 - i2)))
#if MIN_VERSION_ghc(8,2,0)
| name2String tc == "GHC.TypeNats.<=?"
#else
| name2String tc == "GHC.TypeLits.<=?"
#endif
, [i1, i2] <- mapMaybe (litView tcm) tys
, Just (FunTyCon {tyConKind = tck}) <- HashMap.lookup (nameOcc tc) tcm
, (_,tyView -> TyConApp boolTcNm []) <- splitFunTys tcm tck
, Just boolTc <- HashMap.lookup (nameOcc boolTcNm) tcm
= let [falseTc,trueTc] = map (coerceName . dcName) (tyConDataCons boolTc)
in if i1 <= i2 then Just (mkTyConApp trueTc [] )
else Just (mkTyConApp falseTc [])
| name2String tc == "GHC.TypeLits.Extra.FLog"
, [i1, i2] <- mapMaybe (litView tcm) tys
, i1 > 1
, i2 > 0
= Just (LitTy (NumTy (smallInteger (integerLogBase# i1 i2))))
| name2String tc == "GHC.TypeLits.Extra.CLog"
, [i1, i2] <- mapMaybe (litView tcm) tys
, Just k <- clogBase i1 i2
= Just (LitTy (NumTy (toInteger k)))
| name2String tc == "GHC.TypeLits.Extra.Log"
, [i1, i2] <- mapMaybe (litView tcm) tys
, i1 > 1
, i2 > 0
= if i2 == 1
then Just (LitTy (NumTy 0))
else let z1 = integerLogBase# i1 i2
z2 = integerLogBase# i1 (i2-1)
in if isTrue# (z1 ==# z2)
then Nothing
else Just (LitTy (NumTy (smallInteger z1)))
| name2String tc == "GHC.TypeLits.Extra.GCD"
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 `gcd` i2)))
| name2String tc == "GHC.TypeLits.Extra.LCM"
, [i1, i2] <- mapMaybe (litView tcm) tys
= Just (LitTy (NumTy (i1 `lcm` i2)))
| name2String tc == "GHC.TypeLits.Extra.Div"
, [i1, i2] <- mapMaybe (litView tcm) tys
, i2 > 0
= Just (LitTy (NumTy (i1 `div` i2)))
| name2String tc == "GHC.TypeLits.Extra.Mod"
, [i1, i2] <- mapMaybe (litView tcm) tys
, i2 > 0
= Just (LitTy (NumTy (i1 `mod` i2)))
| Just (FunTyCon {tyConSubst = tcSubst}) <- HashMap.lookup (nameOcc tc) tcm
= findFunSubst tcm tcSubst tys
reduceTypeFamily _ _ = Nothing
litView :: HashMap TyConOccName TyCon -> Type -> Maybe Integer
litView _ (LitTy (NumTy i)) = Just i
litView m (reduceTypeFamily m -> Just ty') = litView m ty'
litView _ _ = Nothing
undefinedTy :: Type
undefinedTy =
let aNm = string2SystemName "a"
in ForAllTy (bind (TyVar aNm (embed liftedTypeKind)) (VarTy liftedTypeKind aNm))
isIntegerTy :: Type -> Bool
isIntegerTy (ConstTy (TyCon (nm)))
| "GHC.Integer.Type.Integer" `isPrefixOf` (name2String nm) = True
isIntegerTy _ = False
normalizeType :: HashMap TyConOccName TyCon -> Type -> Type
normalizeType tcMap = go
where
go ty = case tyView ty of
TyConApp tcNm args
| name2String tcNm == "Clash.Signal.Internal.Signal"
, [_,elTy] <- args
-> go elTy
| name2String tcNm == "Clash.Sized.Internal.BitVector.Bit" ||
name2String tcNm == "Clash.Sized.Internal.BitVector.BitVector" ||
name2String tcNm == "Clash.Sized.Internal.Index.Index" ||
name2String tcNm == "Clash.Sized.Internal.Signed.Signed" ||
name2String tcNm == "Clash.Sized.Internal.Unsigned.Unsigned"
-> mkTyConApp tcNm (map go args)
| otherwise
-> case tcMap HashMap.! nameOcc tcNm of
AlgTyCon {algTcRhs = (NewTyCon _ nt)}
-> case newTyConInstRhs nt args of
Just ty' -> go ty'
Nothing -> ty
_ -> let args' = map go args
ty' = mkTyConApp tcNm args'
in case reduceTypeFamily tcMap ty' of
Just ty'' -> ty''
Nothing -> ty'
FunTy ty1 ty2 -> mkFunTy (go ty1) (go ty2)
(OtherType (ForAllTy (unsafeUnbind -> (tyvar,ty'))))
-> ForAllTy (bind tyvar (go ty'))
_ -> ty