{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Pattern.TH (
mkPattern,
mkPatterns,
) where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Type
import Control.Monad
import Data.Bits
import Data.Char
import Data.List ( (\\), foldl' )
import Language.Haskell.TH hiding ( Exp, Match, match, tupP, tupE )
import Language.Haskell.TH.Extra
import Numeric
import Text.Printf
import qualified Language.Haskell.TH as TH
import GHC.Stack
mkPatterns :: [Name] -> DecsQ
mkPatterns nms = concat <$> mapM mkPattern nms
mkPattern :: Name -> DecsQ
mkPattern nm = do
info <- reify nm
case info of
TyConI dec -> mkDec dec
_ -> fail "mkPatterns: expected the name of a newtype or datatype"
mkDec :: Dec -> DecsQ
mkDec dec =
case dec of
DataD _ nm tv _ cs _ -> mkDataD nm tv cs
NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c
_ -> fail "mkPatterns: expected the name of a newtype or datatype"
mkNewtypeD :: Name -> [TyVarBndr] -> Con -> DecsQ
mkNewtypeD tn tvs c = mkDataD tn tvs [c]
mkDataD :: Name -> [TyVarBndr] -> [Con] -> DecsQ
mkDataD tn tvs cs = do
(pats, decs) <- unzip <$> go cs
comp <- pragCompleteD pats Nothing
return $ comp : concat decs
where
go [] = fail "mkPatterns: empty data declarations not supported"
go [c] = return <$> mkConP tn tvs c
go _ = go' [] (map fieldTys cs) ctags cs
go' prev (this:next) (tag:tags) (con:cons) = do
r <- mkConS tn tvs prev next tag con
rs <- go' (this:prev) next tags cons
return (r : rs)
go' _ [] [] [] = return []
go' _ _ _ _ = fail "mkPatterns: unexpected error"
fieldTys (NormalC _ fs) = map snd fs
fieldTys (RecC _ fs) = map (\(_,_,t) -> t) fs
fieldTys (InfixC a _ b) = [snd a, snd b]
fieldTys _ = fail "mkPatterns: only constructors for \"vanilla\" syntax are supported"
ctags =
let n = length cs
m = n `quot` 2
l = take m (iterate (True:) [False])
r = take (n-m) (iterate (True:) [True])
bitsToTag = foldl' f 0
where
f i False = i `shiftL` 1
f i True = setBit (i `shiftL` 1) 0
in
map bitsToTag (l ++ r)
mkConP :: Name -> [TyVarBndr] -> Con -> Q (Name, [Dec])
mkConP tn' tvs' con' = do
checkExts [ PatternSynonyms ]
case con' of
NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs)
RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs)
InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b]
_ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported"
where
mkNormalC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec])
mkNormalC tn cn tvs fs = do
xs <- replicateM (length fs) (newName "_x")
r <- sequence [ patSynSigD pat sig
, patSynD pat
(prefixPatSyn xs)
implBidir
[p| Pattern $(tupP (map varP xs)) |]
]
return (pat, r)
where
pat = rename cn
sig = forallT
(map plainTV tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkRecC :: Name -> Name -> [Name] -> [Name] -> [Type] -> Q (Name, [Dec])
mkRecC tn cn tvs xs fs = do
r <- sequence [ patSynSigD pat sig
, patSynD pat
(recordPatSyn xs)
implBidir
[p| Pattern $(tupP (map varP xs)) |]
]
return (pat, r)
where
pat = rename cn
sig = forallT
(map plainTV tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkInfixC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec])
mkInfixC tn cn tvs fs = do
mf <- reifyFixity cn
_a <- newName "_a"
_b <- newName "_b"
r <- sequence [ patSynSigD pat sig
, patSynD pat
(infixPatSyn _a _b)
implBidir
[p| Pattern $(tupP [varP _a, varP _b]) |]
]
r' <- case mf of
Nothing -> return r
Just f -> return (InfixD f pat : r)
return (pat, r')
where
pat = mkName (':' : nameBase cn)
sig = forallT
(map plainTV tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkConS :: Name -> [TyVarBndr] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec])
mkConS tn' tvs' prev' next' tag' con' = do
checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns]
case con' of
NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next'
RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next'
InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next'
_ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported"
where
mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkNormalC tn cn tag tvs ps fs ns = do
let pat = rename cn
(fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns
(fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns
dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match
return $ (pat, concat [dec_pat, dec_build, dec_match])
mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkRecC tn cn tag tvs xs ps fs ns = do
let pat = rename cn
(fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns
(fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns
dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match
return $ (pat, concat [dec_pat, dec_build, dec_match])
mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkInfixC tn cn tag tvs ps fs ns = do
let pat = mkName (':' : nameBase cn)
(fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns
(fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns
dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match
return $ (pat, concat [dec_pat, dec_build, dec_match])
mkNormalC_pattern :: Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec]
mkNormalC_pattern tn pat tvs fs build match = do
xs <- replicateM (length fs) (newName "_x")
r <- sequence [ patSynSigD pat sig
, patSynD pat
(prefixPatSyn xs)
(explBidir [clause [] (normalB (varE build)) []])
(parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |])
]
return r
where
sig = forallT
(map plainTV tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkRecC_pattern :: Name -> Name -> [Name] -> [Name] -> [Type] -> Name -> Name -> Q [Dec]
mkRecC_pattern tn pat tvs xs fs build match = do
r <- sequence [ patSynSigD pat sig
, patSynD pat
(recordPatSyn xs)
(explBidir [clause [] (normalB (varE build)) []])
(parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |])
]
return r
where
sig = forallT
(map plainTV tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkInfixC_pattern :: Name -> Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec]
mkInfixC_pattern tn cn pat tvs fs build match = do
mf <- reifyFixity cn
_a <- newName "_a"
_b <- newName "_b"
r <- sequence [ patSynSigD pat sig
, patSynD pat
(infixPatSyn _a _b)
(explBidir [clause [] (normalB (varE build)) []])
(parensP $ viewP (varE match) [p| Just $(tupP [varP _a, varP _b]) |])
]
r' <- case mf of
Nothing -> return r
Just f -> return (InfixD f pat : r)
return r'
where
sig = forallT
(map plainTV tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkBuild :: Name -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkBuild tn cn tvs tag fs0 fs fs1 = do
fun <- newName ("_build" ++ cn)
xs <- replicateM (length fs) (newName "_x")
let
vs = foldl' (\es e -> [| SmartExp ($es `Pair` $e) |]) [| SmartExp Nil |]
$ map (\t -> [| unExp (undef @ $(return t)) |] ) (concat (reverse fs0))
++ map varE xs
++ map (\t -> [| unExp (undef @ $(return t)) |] ) (concat fs1)
tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |]
body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) []
r <- sequence [ sigD fun sig
, funD fun [body]
]
return (fun, r)
where
sig = forallT
(map plainTV tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkMatch :: Name -> String -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkMatch tn pn cn tvs tag fs0 fs fs1 = do
fun <- newName ("_match" ++ cn)
e <- newName "_e"
x <- newName "_x"
(ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] []
unbind <- isExtEnabled RebindableSyntax
let
eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id
lhs = [p| (Exp $(varP e)) |]
body = normalB $ eqE $ caseE (varE e)
[ TH.match (conP 'SmartExp [(conP 'Match [matchP ps, varP x])]) (normalB [| Just $(tupE es) |]) []
, TH.match (conP 'SmartExp [(recP 'Match [])]) (normalB [| Nothing |]) []
, TH.match wildP (normalB [| error $error_msg |]) []
]
r <- sequence [ sigD fun sig
, funD fun [clause [lhs] body []]
]
return (fun, r)
where
sig = forallT
(map plainTV tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |]
matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |]
where
pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |]
extract [] _ ps es = return (ps, es)
extract (u:us) x ps es = do
_u <- newName "_u"
let x' = [| Prj PairIdxLeft (SmartExp $x) |]
if not u
then extract us x' (wildP:ps) es
else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es)
vs = reverse
$ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ]
error_msg =
let pv = unwords
$ take (length fs + 1)
$ concatMap (map reverse)
$ iterate (concatMap (\xs -> [ x:xs | x <- ['a'..'z'] ])) [""]
in stringE $ unlines
[ "Embedded pattern synonym used outside 'match' context."
, ""
, "To use case statements in the embedded language the case statement must"
, "be applied as an n-ary function to the 'match' operator. For single"
, "argument case statements this can be done inline using LambdaCase, for"
, "example:"
, ""
, "> x & match \\case"
, printf "> %s%s -> ..." pn pv
, printf "> _%s -> ..." (replicate (length pn + length pv - 1) ' ')
]
fst3 :: (a,b,c) -> a
fst3 (a,_,_) = a
thd3 :: (a,b,c) -> c
thd3 (_,_,c) = c
rename :: Name -> Name
rename nm =
let
split acc [] = (reverse acc, '\0')
split acc [l] = (reverse acc, l)
split acc (l:ls) = split (l:acc) ls
nm' = nameBase nm
(base, suffix) = split [] nm'
in
case suffix of
'_' -> mkName base
_ -> mkName (nm' ++ "_")
checkExts :: [Extension] -> Q ()
checkExts req = do
enabled <- extsEnabled
let missing = req \\ enabled
unless (null missing) . fail . unlines
$ printf "You must enable the following language extensions to generate pattern synonyms:"
: map (printf " {-# LANGUAGE %s #-}" . show) missing
type EncodedString = String
zencode :: String -> EncodedString
zencode [] = []
zencode (h:rest) = encode_digit h ++ go rest
where
go [] = []
go (c:cs) = encode_ch c ++ go cs
unencoded_char :: Char -> Bool
unencoded_char 'z' = False
unencoded_char 'Z' = False
unencoded_char c = isAlphaNum c
encode_digit :: Char -> EncodedString
encode_digit c | isDigit c = encode_as_unicode_char c
| otherwise = encode_ch c
encode_ch :: Char -> EncodedString
encode_ch c | unencoded_char c = [c]
encode_ch '(' = "ZL"
encode_ch ')' = "ZR"
encode_ch '[' = "ZM"
encode_ch ']' = "ZN"
encode_ch ':' = "ZC"
encode_ch 'Z' = "ZZ"
encode_ch 'z' = "zz"
encode_ch '&' = "za"
encode_ch '|' = "zb"
encode_ch '^' = "zc"
encode_ch '$' = "zd"
encode_ch '=' = "ze"
encode_ch '>' = "zg"
encode_ch '#' = "zh"
encode_ch '.' = "zi"
encode_ch '<' = "zl"
encode_ch '-' = "zm"
encode_ch '!' = "zn"
encode_ch '+' = "zp"
encode_ch '\'' = "zq"
encode_ch '\\' = "zr"
encode_ch '/' = "zs"
encode_ch '*' = "zt"
encode_ch '_' = "zu"
encode_ch '%' = "zv"
encode_ch c = encode_as_unicode_char c
encode_as_unicode_char :: Char -> EncodedString
encode_as_unicode_char c
= 'z'
: if isDigit (head hex_str) then hex_str
else '0':hex_str
where
hex_str = showHex (ord c) "U"