{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Core.DataCon
( DataCon (..)
, DcName
, ConTag
, dataConInstArgTys
)
where
#ifndef MIN_VERSION_unbound_generics
#define MIN_VERSION_unbound_generics(x,y,z)(1)
#endif
import Control.DeepSeq (NFData(..))
import Data.Hashable (Hashable)
import GHC.Generics (Generic)
import Unbound.Generics.LocallyNameless (Alpha(..),Subst(..))
import Unbound.Generics.LocallyNameless.Extra ()
#if MIN_VERSION_unbound_generics(0,3,0)
import Data.Monoid (All (..))
import Unbound.Generics.LocallyNameless (NthPatFind (..),
NamePatFind (..))
#endif
import Clash.Core.Name (Name (..))
import {-# SOURCE #-} Clash.Core.Type (TyName, Type)
import Clash.Util
data DataCon
= MkData
{ dcName :: !DcName
, dcTag :: !ConTag
, dcType :: !Type
, dcUnivTyVars :: [TyName]
, dcExtTyVars :: [TyName]
, dcArgTys :: [Type]
} deriving (Generic,NFData,Hashable)
instance Show DataCon where
show = show . dcName
instance Eq DataCon where
(==) = (==) `on` dcName
instance Ord DataCon where
compare = compare `on` dcName
type ConTag = Int
type DcName = Name DataCon
instance Alpha DataCon where
aeq' c dc1 dc2 = aeq' c (dcName dc1) (dcName dc2)
fvAny' _ _ dc = pure dc
close _ _ dc = dc
open _ _ dc = dc
isPat _ = mempty
#if MIN_VERSION_unbound_generics(0,3,0)
isTerm _ = All True
nthPatFind _ = NthPatFind Left
namePatFind _ = NamePatFind (const (Left 0))
#else
isTerm _ = True
nthPatFind _ = Left
namePatFind _ _ = Left 0
#endif
swaps' _ _ dc = dc
lfreshen' _ dc cont = cont dc mempty
freshen' _ dc = return (dc,mempty)
acompare' c dc1 dc2 = acompare' c (dcName dc1) (dcName dc2)
instance Subst a DataCon where
subst _ _ dc = dc
substs _ dc = dc
dataConInstArgTys :: DataCon -> [Type] -> Maybe [Type]
dataConInstArgTys (MkData { dcArgTys = arg_tys
, dcUnivTyVars = univ_tvs
, dcExtTyVars = ex_tvs
})
inst_tys
| length tyvars == length inst_tys
= Just (map (substs (zip tyvars inst_tys)) arg_tys)
| otherwise
= Nothing
where
tyvars = map nameOcc (univ_tvs ++ ex_tvs)