{-# LANGUAGE TypeSynonymInstances, FlexibleInstances, RankNTypes,
TemplateHaskell, GeneralizedNewtypeDeriving,
MultiParamTypeClasses, StandaloneDeriving,
UndecidableInstances, MagicHash, UnboxedTuples,
LambdaCase, NoMonomorphismRestriction #-}
module Data.Singletons.Util where
import Prelude hiding ( exp, foldl, concat, mapM, any, pred )
import Language.Haskell.TH.Syntax hiding ( lift )
import Language.Haskell.TH.Desugar
import Data.Char
import Control.Monad hiding ( mapM )
import Control.Monad.Writer hiding ( mapM )
import Control.Monad.Reader hiding ( mapM )
import qualified Data.Map as Map
import Data.List.NonEmpty (NonEmpty(..))
import Data.Map ( Map )
import qualified Data.Monoid as Monoid
import Data.Semigroup as Semigroup
import Data.Foldable
import Data.Functor.Identity
import Data.Traversable
import Data.Generics
import Data.Maybe
import Data.Void
import Control.Monad.Fail ( MonadFail )
basicTypes :: [Name]
basicTypes = [ ''Maybe
, ''[]
, ''Either
, ''NonEmpty
, ''Void
] ++ boundedBasicTypes
boundedBasicTypes :: [Name]
boundedBasicTypes =
[ ''(,)
, ''(,,)
, ''(,,,)
, ''(,,,,)
, ''(,,,,,)
, ''(,,,,,,)
, ''Identity
] ++ enumBasicTypes
enumBasicTypes :: [Name]
enumBasicTypes = [ ''Bool, ''Ordering, ''() ]
semigroupBasicTypes :: [Name]
semigroupBasicTypes
= [ ''Dual
, ''All
, ''Any
, ''Sum
, ''Product
, ''Min
, ''Max
, ''Semigroup.First
, ''Semigroup.Last
, ''WrappedMonoid
]
monoidBasicTypes :: [Name]
monoidBasicTypes
= [ ''Monoid.First
, ''Monoid.Last
]
qReportWarning :: Quasi q => String -> q ()
qReportWarning = qReport False
qReportError :: Quasi q => String -> q ()
qReportError = qReport True
qNewUnique :: DsMonad q => q Int
qNewUnique = do
Name _ flav <- qNewName "x"
case flav of
NameU n -> return n
_ -> error "Internal error: `qNewName` didn't return a NameU"
checkForRep :: Quasi q => [Name] -> q ()
checkForRep names =
when (any ((== "Rep") . nameBase) names)
(fail $ "A data type named <<Rep>> is a special case.\n" ++
"Promoting it will not work as expected.\n" ++
"Please choose another name for your data type.")
checkForRepInDecls :: Quasi q => [DDec] -> q ()
checkForRepInDecls decls =
checkForRep (allNamesIn decls)
tysOfConFields :: DConFields -> [DType]
tysOfConFields (DNormalC _ stys) = map snd stys
tysOfConFields (DRecC vstys) = map (\(_,_,ty) -> ty) vstys
extractNameArgs :: DCon -> (Name, Int)
extractNameArgs = liftSnd length . extractNameTypes
extractNameTypes :: DCon -> (Name, [DType])
extractNameTypes (DCon _ _ n fields _) = (n, tysOfConFields fields)
extractName :: DCon -> Name
extractName (DCon _ _ n _ _) = n
isInfixDataCon :: String -> Bool
isInfixDataCon (':':_) = True
isInfixDataCon _ = False
isDataConName :: Name -> Bool
isDataConName n = let first = head (nameBase n) in isUpper first || first == ':'
isUpcase :: Name -> Bool
isUpcase n = let first = head (nameBase n) in isUpper first
upcase :: Name -> Name
upcase = mkName . toUpcaseStr noPrefix
toUpcaseStr :: (String, String)
-> Name -> String
toUpcaseStr (alpha, symb) n
| isHsLetter first
= upcase_alpha
| otherwise
= upcase_symb
where
str = nameBase n
first = head str
upcase_alpha = alpha ++ (toUpper first) : tail str
upcase_symb = symb ++ str
noPrefix :: (String, String)
noPrefix = ("", "")
prefixConName :: String -> String -> Name -> Name
prefixConName pre tyPre n = case (nameBase n) of
(':' : rest) -> mkName (':' : tyPre ++ rest)
alpha -> mkName (pre ++ alpha)
prefixName :: String -> String -> Name -> Name
prefixName pre tyPre n =
let str = nameBase n
first = head str in
if isHsLetter first
then mkName (pre ++ str)
else mkName (tyPre ++ str)
suffixName :: String -> String -> Name -> Name
suffixName ident symb n =
let str = nameBase n
first = head str in
if isHsLetter first
then mkName (str ++ ident)
else mkName (str ++ symb)
uniquePrefixes :: String
-> String
-> Int
-> (String, String)
uniquePrefixes alpha symb n = (alpha ++ n_str, symb ++ convert n_str)
where
n_str = show n
convert [] = []
convert (d : ds) =
let d' = case d of
'0' -> '!'
'1' -> '#'
'2' -> '$'
'3' -> '%'
'4' -> '&'
'5' -> '*'
'6' -> '+'
'7' -> '.'
'8' -> '/'
'9' -> '>'
_ -> error "non-digit in show #"
in d' : convert ds
extractTvbKind :: DTyVarBndr -> Maybe DKind
extractTvbKind (DPlainTV _) = Nothing
extractTvbKind (DKindedTV _ k) = Just k
extractTvbName :: DTyVarBndr -> Name
extractTvbName (DPlainTV n) = n
extractTvbName (DKindedTV n _) = n
tvbToType :: DTyVarBndr -> DType
tvbToType = DVarT . extractTvbName
inferMaybeKindTV :: Name -> Maybe DKind -> DTyVarBndr
inferMaybeKindTV n Nothing = DPlainTV n
inferMaybeKindTV n (Just k) = DKindedTV n k
resultSigToMaybeKind :: DFamilyResultSig -> Maybe DKind
resultSigToMaybeKind DNoSig = Nothing
resultSigToMaybeKind (DKindSig k) = Just k
resultSigToMaybeKind (DTyVarSig (DPlainTV _)) = Nothing
resultSigToMaybeKind (DTyVarSig (DKindedTV _ k)) = Just k
ravel :: [DType] -> DType -> DType
ravel [] res = res
ravel (h:t) res = DAppT (DAppT DArrowT h) (ravel t res)
predToType :: DPred -> DType
predToType (DForallPr tvbs cxt p) = DForallT tvbs cxt (predToType p)
predToType (DAppPr p t) = DAppT (predToType p) t
predToType (DSigPr p k) = DSigT (predToType p) k
predToType (DVarPr n) = DVarT n
predToType (DConPr n) = DConT n
predToType DWildCardPr = DWildCardT
countArgs :: DType -> Int
countArgs ty = length args
where (_, _, args, _) = unravel ty
noExactTyVars :: Data a => a -> a
noExactTyVars = everywhere go
where
go :: Data a => a -> a
go = mkT fix_tvb `extT` fix_ty `extT` fix_inj_ann
no_exact_name :: Name -> Name
no_exact_name (Name (OccName occ) (NameU unique)) = mkName (occ ++ show unique)
no_exact_name n = n
fix_tvb (DPlainTV n) = DPlainTV (no_exact_name n)
fix_tvb (DKindedTV n k) = DKindedTV (no_exact_name n) k
fix_ty (DVarT n) = DVarT (no_exact_name n)
fix_ty ty = ty
fix_inj_ann (InjectivityAnn lhs rhs)
= InjectivityAnn (no_exact_name lhs) (map no_exact_name rhs)
substKind :: Map Name DKind -> DKind -> DKind
substKind = substType
substType :: Map Name DType -> DType -> DType
substType subst ty | Map.null subst = ty
substType subst (DForallT tvbs cxt inner_ty)
= DForallT tvbs' cxt' inner_ty'
where
(subst', tvbs') = mapAccumL subst_tvb subst tvbs
cxt' = map (substPred subst') cxt
inner_ty' = substType subst' inner_ty
substType subst (DAppT ty1 ty2) = substType subst ty1 `DAppT` substType subst ty2
substType subst (DSigT ty ki) = substType subst ty `DSigT` substType subst ki
substType subst (DVarT n) =
case Map.lookup n subst of
Just ki -> ki
Nothing -> DVarT n
substType _ ty@(DConT {}) = ty
substType _ ty@(DArrowT) = ty
substType _ ty@(DLitT {}) = ty
substType _ ty@DWildCardT = ty
substPred :: Map Name DType -> DPred -> DPred
substPred subst pred | Map.null subst = pred
substPred subst (DForallPr tvbs cxt inner_pred)
= DForallPr tvbs' cxt' inner_pred'
where
(subst', tvbs') = mapAccumL subst_tvb subst tvbs
cxt' = map (substPred subst') cxt
inner_pred' = substPred subst' inner_pred
substPred subst (DAppPr pred ty) =
DAppPr (substPred subst pred) (substType subst ty)
substPred subst (DSigPr pred ki) =
DSigPr (substPred subst pred) (substKind subst ki)
substPred _ pred@(DVarPr {}) = pred
substPred _ pred@(DConPr {}) = pred
substPred _ pred@DWildCardPr = pred
subst_tvb :: Map Name DKind -> DTyVarBndr -> (Map Name DKind, DTyVarBndr)
subst_tvb s tvb@(DPlainTV n) = (Map.delete n s, tvb)
subst_tvb s (DKindedTV n k) = (Map.delete n s, DKindedTV n (substKind s k))
cuskify :: DTyVarBndr -> DTyVarBndr
cuskify (DPlainTV tvname) = DKindedTV tvname $ DConT typeKindName
cuskify tvb = tvb
foldType :: DType -> [DType] -> DType
foldType = foldl DAppT
foldTypeTvbs :: DType -> [DTyVarBndr] -> DType
foldTypeTvbs ty = foldType ty . map tvbToType
foldPred :: DPred -> [DType] -> DPred
foldPred = foldl DAppPr
foldPredTvbs :: DPred -> [DTyVarBndr] -> DPred
foldPredTvbs pr = foldPred pr . map tvbToType
unfoldType :: DType -> NonEmpty DType
unfoldType = go []
where
go :: [DType] -> DType -> NonEmpty DType
go acc (DAppT t1 t2) = go (t2:acc) t1
go acc (DSigT t _) = go acc t
go acc (DForallT _ _ t) = go acc t
go acc t = t :| acc
buildDataDTvbs :: DsMonad q => [DTyVarBndr] -> Maybe DKind -> q [DTyVarBndr]
buildDataDTvbs tvbs mk = do
extra_tvbs <- mkExtraDKindBinders $ fromMaybe (DConT typeKindName) mk
pure $ tvbs ++ extra_tvbs
foldExp :: DExp -> [DExp] -> DExp
foldExp = foldl DAppE
isFunTy :: DType -> Bool
isFunTy (DAppT (DAppT DArrowT _) _) = True
isFunTy (DForallT _ _ _) = True
isFunTy _ = False
orIfEmpty :: [a] -> [a] -> [a]
orIfEmpty [] x = x
orIfEmpty x _ = x
multiCase :: [DExp] -> [DPat] -> DExp -> DExp
multiCase [] [] body = body
multiCase scruts pats body =
DCaseE (mkTupleDExp scruts) [DMatch (mkTupleDPat pats) body]
wrapDesugar :: (Desugar th ds, DsMonad q) => (th -> ds -> q ds) -> th -> q th
wrapDesugar f th = do
ds <- desugar th
fmap sweeten $ f th ds
newtype QWithAux m q a = QWA { runQWA :: WriterT m q a }
deriving ( Functor, Applicative, Monad, MonadTrans
, MonadWriter m, MonadReader r
, MonadFail, MonadIO )
instance (Quasi q, Monoid m) => Quasi (QWithAux m q) where
qNewName = lift `comp1` qNewName
qReport = lift `comp2` qReport
qLookupName = lift `comp2` qLookupName
qReify = lift `comp1` qReify
qReifyInstances = lift `comp2` qReifyInstances
qLocation = lift qLocation
qRunIO = lift `comp1` qRunIO
qAddDependentFile = lift `comp1` qAddDependentFile
qReifyRoles = lift `comp1` qReifyRoles
qReifyAnnotations = lift `comp1` qReifyAnnotations
qReifyModule = lift `comp1` qReifyModule
qAddTopDecls = lift `comp1` qAddTopDecls
qAddModFinalizer = lift `comp1` qAddModFinalizer
qGetQ = lift qGetQ
qPutQ = lift `comp1` qPutQ
qReifyFixity = lift `comp1` qReifyFixity
qReifyConStrictness = lift `comp1` qReifyConStrictness
qIsExtEnabled = lift `comp1` qIsExtEnabled
qExtsEnabled = lift qExtsEnabled
qAddForeignFilePath = lift `comp2` qAddForeignFilePath
qAddTempFile = lift `comp1` qAddTempFile
qAddCorePlugin = lift `comp1` qAddCorePlugin
qRecover exp handler = do
(result, aux) <- lift $ qRecover (evalForPair exp) (evalForPair handler)
tell aux
return result
instance (DsMonad q, Monoid m) => DsMonad (QWithAux m q) where
localDeclarations = lift localDeclarations
comp1 :: (b -> c) -> (a -> b) -> a -> c
comp1 = (.)
comp2 :: (c -> d) -> (a -> b -> c) -> a -> b -> d
comp2 f g a b = f (g a b)
evalWithoutAux :: Quasi q => QWithAux m q a -> q a
evalWithoutAux = liftM fst . runWriterT . runQWA
evalForAux :: Quasi q => QWithAux m q a -> q m
evalForAux = execWriterT . runQWA
evalForPair :: QWithAux m q a -> q (a, m)
evalForPair = runWriterT . runQWA
addBinding :: (Quasi q, Ord k) => k -> v -> QWithAux (Map.Map k v) q ()
addBinding k v = tell (Map.singleton k v)
addElement :: Quasi q => elt -> QWithAux [elt] q ()
addElement elt = tell [elt]
dsReifyTypeNameInfo :: DsMonad q => Name -> q (Maybe DInfo)
dsReifyTypeNameInfo ty_name = do
mb_name <- lookupTypeNameWithLocals (nameBase ty_name)
case mb_name of
Just n -> dsReify n
Nothing -> pure Nothing
concatMapM :: (Monad monad, Monoid monoid, Traversable t)
=> (a -> monad monoid) -> t a -> monad monoid
concatMapM fn list = do
bss <- mapM fn list
return $ fold bss
listify :: a -> [a]
listify = (:[])
fstOf3 :: (a,b,c) -> a
fstOf3 (a,_,_) = a
liftFst :: (a -> b) -> (a, c) -> (b, c)
liftFst f (a, c) = (f a, c)
liftSnd :: (a -> b) -> (c, a) -> (c, b)
liftSnd f (c, a) = (c, f a)
snocView :: [a] -> ([a], a)
snocView [] = error "snocView nil"
snocView [x] = ([], x)
snocView (x : xs) = liftFst (x:) (snocView xs)
partitionWith :: (a -> Either b c) -> [a] -> ([b], [c])
partitionWith f = go [] []
where go bs cs [] = (reverse bs, reverse cs)
go bs cs (a:as) =
case f a of
Left b -> go (b:bs) cs as
Right c -> go bs (c:cs) as
partitionWithM :: Monad m => (a -> m (Either b c)) -> [a] -> m ([b], [c])
partitionWithM f = go [] []
where go bs cs [] = return (reverse bs, reverse cs)
go bs cs (a:as) = do
fa <- f a
case fa of
Left b -> go (b:bs) cs as
Right c -> go bs (c:cs) as
partitionLetDecs :: [DDec] -> ([DLetDec], [DDec])
partitionLetDecs = partitionWith (\case DLetDec ld -> Left ld
dec -> Right dec)
{-# INLINEABLE zipWith3M #-}
zipWith3M :: Monad m => (a -> b -> m c) -> [a] -> [b] -> m [c]
zipWith3M f (a:as) (b:bs) = (:) <$> f a b <*> zipWith3M f as bs
zipWith3M _ _ _ = return []
mapAndUnzip3M :: Monad m => (a -> m (b,c,d)) -> [a] -> m ([b],[c],[d])
mapAndUnzip3M _ [] = return ([],[],[])
mapAndUnzip3M f (x:xs) = do
(r1, r2, r3) <- f x
(rs1, rs2, rs3) <- mapAndUnzip3M f xs
return (r1:rs1, r2:rs2, r3:rs3)
isHsLetter :: Char -> Bool
isHsLetter c = isLetter c || c == '_'