{-# LANGUAGE TemplateHaskell #-}
module Data.Singletons.Promote.Defun where
import Language.Haskell.TH.Desugar
import Data.Singletons.Promote.Monad
import Data.Singletons.Promote.Type
import Data.Singletons.Names
import Language.Haskell.TH.Syntax
import Data.Singletons.Util
import Control.Monad
defunInfo :: DInfo -> PrM [DDec]
defunInfo (DTyConI dec _instances) = buildDefunSyms dec
defunInfo (DPrimTyConI _name _numArgs _unlifted) =
fail $ "Building defunctionalization symbols of primitive " ++
"type constructors not supported"
defunInfo (DVarI _name _ty _mdec) =
fail "Building defunctionalization symbols of values not supported"
defunInfo (DTyVarI _name _ty) =
fail "Building defunctionalization symbols of type variables not supported"
defunInfo (DPatSynI {}) =
fail "Building defunctionalization symbols of pattern synonyms not supported"
buildDefunSyms :: DDec -> PrM [DDec]
buildDefunSyms (DDataD _new_or_data _cxt tyName tvbs ctors _derivings) =
buildDefunSymsDataD tyName tvbs ctors
buildDefunSyms (DClosedTypeFamilyD (DTypeFamilyHead name tvbs result_sig _) _) = do
let arg_m_kinds = map extractTvbKind tvbs
defunctionalize name arg_m_kinds (resultSigToMaybeKind result_sig)
buildDefunSyms (DOpenTypeFamilyD (DTypeFamilyHead name tvbs result_sig _)) = do
let arg_kinds = map (default_to_star . extractTvbKind) tvbs
res_kind = default_to_star (resultSigToMaybeKind result_sig)
default_to_star Nothing = Just DStarT
default_to_star (Just k) = Just k
defunctionalize name arg_kinds res_kind
buildDefunSyms (DTySynD name tvbs _type) = do
let arg_m_kinds = map extractTvbKind tvbs
defunctionalize name arg_m_kinds Nothing
buildDefunSyms (DClassD _cxt name tvbs _fundeps _members) = do
let arg_m_kinds = map extractTvbKind tvbs
defunctionalize name arg_m_kinds (Just (DConT constraintName))
buildDefunSyms _ = fail $ "Defunctionalization symbols can only be built for " ++
"type families and data declarations"
buildDefunSymsDataD :: Name -> [DTyVarBndr] -> [DCon] -> PrM [DDec]
buildDefunSymsDataD tyName tvbs ctors = do
let res_ty = foldType (DConT tyName) (map tvbToType tvbs)
res_ki <- promoteType res_ty
concatMapM (promoteCtor res_ki) ctors
where
promoteCtor :: DKind -> DCon -> PrM [DDec]
promoteCtor promotedKind ctor = do
let (name, arg_tys) = extractNameTypes ctor
arg_kis <- mapM promoteType arg_tys
defunctionalize name (map Just arg_kis) (Just promotedKind)
defunctionalize :: Name -> [Maybe DKind] -> Maybe DKind -> PrM [DDec]
defunctionalize name m_arg_kinds' m_res_kind' = do
let (m_arg_kinds, m_res_kind) = eta_expand (noExactTyVars m_arg_kinds')
(noExactTyVars m_res_kind')
num_args = length m_arg_kinds
sat_name = promoteTySym name num_args
tvbNames <- replicateM num_args $ qNewName "t"
let mk_rhs ns = foldType (DConT name) (map DVarT ns)
sat_dec = DTySynD sat_name (zipWith mk_tvb tvbNames m_arg_kinds) (mk_rhs tvbNames)
other_decs <- go (num_args - 1) (reverse m_arg_kinds) m_res_kind mk_rhs
return $ sat_dec : other_decs
where
mk_tvb :: Name -> Maybe DKind -> DTyVarBndr
mk_tvb tvb_name Nothing = DPlainTV tvb_name
mk_tvb tvb_name (Just k) = DKindedTV tvb_name k
eta_expand :: [Maybe DKind] -> Maybe DKind -> ([Maybe DKind], Maybe DKind)
eta_expand m_arg_kinds Nothing = (m_arg_kinds, Nothing)
eta_expand m_arg_kinds (Just res_kind) =
let (_, _, argKs, resultK) = unravel res_kind
in (m_arg_kinds ++ (map Just argKs), Just resultK)
go :: Int -> [Maybe DKind] -> Maybe DKind
-> ([Name] -> DType)
-> PrM [DDec]
go _ [] _ _ = return []
go n (m_arg : m_args) m_result mk_rhs = do
fst_name : rest_names <- replicateM (n + 1) (qNewName "l")
extra_name <- qNewName "arg"
let data_name = promoteTySym name n
next_name = promoteTySym name (n+1)
con_name = prefixName "" ":" $ suffixName "KindInference" "###" data_name
m_tyfun = buildTyFun_maybe m_arg m_result
arg_params = zipWith mk_tvb rest_names (reverse m_args)
tyfun_param = mk_tvb fst_name m_tyfun
arg_names = map extractTvbName arg_params
params = arg_params ++ [tyfun_param]
con_eq_ct = DConPr sameKindName `DAppPr` lhs `DAppPr` rhs
where
lhs = foldType (DConT data_name) (map DVarT arg_names) `apply` (DVarT extra_name)
rhs = foldType (DConT next_name) (map DVarT (arg_names ++ [extra_name]))
con_decl = DCon [DPlainTV extra_name]
[con_eq_ct]
con_name
(DNormalC False [])
Nothing
data_decl = DDataD Data [] data_name params [con_decl] []
app_eqn = DTySynEqn [ foldType (DConT data_name)
(map DVarT rest_names)
, DVarT fst_name ]
(mk_rhs (rest_names ++ [fst_name]))
app_decl = DTySynInstD applyName app_eqn
suppress = DInstanceD Nothing []
(DConT suppressClassName `DAppT` DConT data_name)
[DLetDec $ DFunD suppressMethodName
[DClause []
((DVarE 'snd) `DAppE`
mkTupleDExp [DConE con_name,
mkTupleDExp []])]]
mk_rhs' ns = foldType (DConT data_name) (map DVarT ns)
decls <- go (n - 1) m_args (addStar_maybe (buildTyFun_maybe m_arg m_result)) mk_rhs'
return $ suppress : data_decl : app_decl : decls
buildTyFun :: DKind -> DKind -> DKind
buildTyFun k1 k2 = DConT tyFunName `DAppT` k1 `DAppT` k2
buildTyFun_maybe :: Maybe DKind -> Maybe DKind -> Maybe DKind
buildTyFun_maybe m_k1 m_k2 = do
k1 <- m_k1
k2 <- m_k2
return $ DConT tyFunName `DAppT` k1 `DAppT` k2
tyFunArity :: DKind -> Int
tyFunArity (DArrowT `DAppT` (DConT tyFunNm `DAppT` _ `DAppT` b) `DAppT` DStarT)
| tyFunName == tyFunNm
= 1 + tyFunArity b
tyFunArity _ = 0
isTyFun :: DKind -> Bool
isTyFun (DArrowT `DAppT` (DConT tyFunNm `DAppT` _ `DAppT` _) `DAppT` DStarT)
| tyFunName == tyFunNm
= True
isTyFun _ = False
ravelTyFun :: [DKind] -> DKind
ravelTyFun [] = error "Internal error: TyFun raveling nil"
ravelTyFun [k] = k
ravelTyFun kinds = go tailK (buildTyFun k2 k1)
where (k1 : k2 : tailK) = reverse kinds
go [] acc = addStar acc
go (k:ks) acc = go ks (buildTyFun k (addStar acc))