{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} -- | -- Module : Data.Array.Accelerate.Pattern.TH -- Copyright : [2018..2020] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Pattern.TH ( mkPattern, mkPatterns, ) where import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Control.Monad import Data.Bits import Data.Char import Data.List ( (\\), foldl' ) import Language.Haskell.TH hiding ( Exp, Match, match, tupP, tupE ) import Language.Haskell.TH.Extra import Numeric import Text.Printf import qualified Language.Haskell.TH as TH import GHC.Stack -- | As 'mkPattern', but for a list of types -- mkPatterns :: [Name] -> DecsQ mkPatterns nms = concat <$> mapM mkPattern nms -- | Generate pattern synonyms for the given simple (Haskell'98) sum or -- product data type. -- -- Constructor and record selectors are renamed to add a trailing -- underscore if it does not exist, or to remove it if it does. For infix -- constructors, the name is prepended with a colon ':'. For example: -- -- > data Point = Point { xcoord_ :: Float, ycoord_ :: Float } -- > deriving (Generic, Elt) -- -- Will create the pattern synonym: -- -- > Point_ :: Exp Float -> Exp Float -> Exp Point -- -- together with the selector functions -- -- > xcoord :: Exp Point -> Exp Float -- > ycoord :: Exp Point -> Exp Float -- mkPattern :: Name -> DecsQ mkPattern nm = do info <- reify nm case info of TyConI dec -> mkDec dec _ -> fail "mkPatterns: expected the name of a newtype or datatype" mkDec :: Dec -> DecsQ mkDec dec = case dec of DataD _ nm tv _ cs _ -> mkDataD nm tv cs NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c _ -> fail "mkPatterns: expected the name of a newtype or datatype" mkNewtypeD :: Name -> [TyVarBndr] -> Con -> DecsQ mkNewtypeD tn tvs c = mkDataD tn tvs [c] mkDataD :: Name -> [TyVarBndr] -> [Con] -> DecsQ mkDataD tn tvs cs = do (pats, decs) <- unzip <$> go cs comp <- pragCompleteD pats Nothing return $ comp : concat decs where -- For single-constructor types we create the pattern synonym for the -- type directly in terms of Pattern go [] = fail "mkPatterns: empty data declarations not supported" go [c] = return <$> mkConP tn tvs c go _ = go' [] (map fieldTys cs) ctags cs -- For sum-types, when creating the pattern for an individual -- constructor we need to know about the types of the fields all other -- constructors as well go' prev (this:next) (tag:tags) (con:cons) = do r <- mkConS tn tvs prev next tag con rs <- go' (this:prev) next tags cons return (r : rs) go' _ [] [] [] = return [] go' _ _ _ _ = fail "mkPatterns: unexpected error" fieldTys (NormalC _ fs) = map snd fs fieldTys (RecC _ fs) = map (\(_,_,t) -> t) fs fieldTys (InfixC a _ b) = [snd a, snd b] fieldTys _ = fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" -- TODO: The GTags class demonstrates a way to generate the tags for -- a given constructor, rather than backwards-engineering the structure -- as we've done here. We should use that instead! -- ctags = let n = length cs m = n `quot` 2 l = take m (iterate (True:) [False]) r = take (n-m) (iterate (True:) [True]) -- bitsToTag = foldl' f 0 where f i False = i `shiftL` 1 f i True = setBit (i `shiftL` 1) 0 in map bitsToTag (l ++ r) mkConP :: Name -> [TyVarBndr] -> Con -> Q (Name, [Dec]) mkConP tn' tvs' con' = do checkExts [ PatternSynonyms ] case con' of NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs) RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs) InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b] _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" where mkNormalC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) mkNormalC tn cn tvs fs = do xs <- replicateM (length fs) (newName "_x") r <- sequence [ patSynSigD pat sig , patSynD pat (prefixPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |] ] return (pat, r) where pat = rename cn sig = forallT (map plainTV tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) (foldr (\t ts -> [t| $t -> $ts |]) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) mkRecC :: Name -> Name -> [Name] -> [Name] -> [Type] -> Q (Name, [Dec]) mkRecC tn cn tvs xs fs = do r <- sequence [ patSynSigD pat sig , patSynD pat (recordPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |] ] return (pat, r) where pat = rename cn sig = forallT (map plainTV tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) (foldr (\t ts -> [t| $t -> $ts |]) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) mkInfixC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec]) mkInfixC tn cn tvs fs = do mf <- reifyFixity cn _a <- newName "_a" _b <- newName "_b" r <- sequence [ patSynSigD pat sig , patSynD pat (infixPatSyn _a _b) implBidir [p| Pattern $(tupP [varP _a, varP _b]) |] ] r' <- case mf of Nothing -> return r Just f -> return (InfixD f pat : r) return (pat, r') where pat = mkName (':' : nameBase cn) sig = forallT (map plainTV tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) (foldr (\t ts -> [t| $t -> $ts |]) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) mkConS :: Name -> [TyVarBndr] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) mkConS tn' tvs' prev' next' tag' con' = do checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns] case con' of NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next' RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next' InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next' _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" where mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkNormalC tn cn tag tvs ps fs ns = do let pat = rename cn (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match return $ (pat, concat [dec_pat, dec_build, dec_match]) mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkRecC tn cn tag tvs xs ps fs ns = do let pat = rename cn (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match return $ (pat, concat [dec_pat, dec_build, dec_match]) mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkInfixC tn cn tag tvs ps fs ns = do let pat = mkName (':' : nameBase cn) (fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match return $ (pat, concat [dec_pat, dec_build, dec_match]) mkNormalC_pattern :: Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] mkNormalC_pattern tn pat tvs fs build match = do xs <- replicateM (length fs) (newName "_x") r <- sequence [ patSynSigD pat sig , patSynD pat (prefixPatSyn xs) (explBidir [clause [] (normalB (varE build)) []]) (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) ] return r where sig = forallT (map plainTV tvs) (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) (foldr (\t ts -> [t| $t -> $ts |]) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) mkRecC_pattern :: Name -> Name -> [Name] -> [Name] -> [Type] -> Name -> Name -> Q [Dec] mkRecC_pattern tn pat tvs xs fs build match = do r <- sequence [ patSynSigD pat sig , patSynD pat (recordPatSyn xs) (explBidir [clause [] (normalB (varE build)) []]) (parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |]) ] return r where sig = forallT (map plainTV tvs) (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) (foldr (\t ts -> [t| $t -> $ts |]) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) mkInfixC_pattern :: Name -> Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec] mkInfixC_pattern tn cn pat tvs fs build match = do mf <- reifyFixity cn _a <- newName "_a" _b <- newName "_b" r <- sequence [ patSynSigD pat sig , patSynD pat (infixPatSyn _a _b) (explBidir [clause [] (normalB (varE build)) []]) (parensP $ viewP (varE match) [p| Just $(tupP [varP _a, varP _b]) |]) ] r' <- case mf of Nothing -> return r Just f -> return (InfixD f pat : r) return r' where sig = forallT (map plainTV tvs) (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) (foldr (\t ts -> [t| $t -> $ts |]) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) mkBuild :: Name -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkBuild tn cn tvs tag fs0 fs fs1 = do fun <- newName ("_build" ++ cn) xs <- replicateM (length fs) (newName "_x") let vs = foldl' (\es e -> [| SmartExp ($es `Pair` $e) |]) [| SmartExp Nil |] $ map (\t -> [| unExp (undef @ $(return t)) |] ) (concat (reverse fs0)) ++ map varE xs ++ map (\t -> [| unExp (undef @ $(return t)) |] ) (concat fs1) tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |] body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] r <- sequence [ sigD fun sig , funD fun [body] ] return (fun, r) where sig = forallT (map plainTV tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) (foldr (\t ts -> [t| $t -> $ts |]) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) mkMatch :: Name -> String -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkMatch tn pn cn tvs tag fs0 fs fs1 = do fun <- newName ("_match" ++ cn) e <- newName "_e" x <- newName "_x" (ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] [] unbind <- isExtEnabled RebindableSyntax let eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id lhs = [p| (Exp $(varP e)) |] body = normalB $ eqE $ caseE (varE e) [ TH.match (conP 'SmartExp [(conP 'Match [matchP ps, varP x])]) (normalB [| Just $(tupE es) |]) [] , TH.match (conP 'SmartExp [(recP 'Match [])]) (normalB [| Nothing |]) [] , TH.match wildP (normalB [| error $error_msg |]) [] ] r <- sequence [ sigD fun sig , funD fun [clause [lhs] body []] ] return (fun, r) where sig = forallT (map plainTV tvs) (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |] where pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] extract [] _ ps es = return (ps, es) extract (u:us) x ps es = do _u <- newName "_u" let x' = [| Prj PairIdxLeft (SmartExp $x) |] if not u then extract us x' (wildP:ps) es else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) vs = reverse $ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ] error_msg = let pv = unwords $ take (length fs + 1) $ concatMap (map reverse) $ iterate (concatMap (\xs -> [ x:xs | x <- ['a'..'z'] ])) [""] in stringE $ unlines [ "Embedded pattern synonym used outside 'match' context." , "" , "To use case statements in the embedded language the case statement must" , "be applied as an n-ary function to the 'match' operator. For single" , "argument case statements this can be done inline using LambdaCase, for" , "example:" , "" , "> x & match \\case" , printf "> %s%s -> ..." pn pv , printf "> _%s -> ..." (replicate (length pn + length pv - 1) ' ') ] fst3 :: (a,b,c) -> a fst3 (a,_,_) = a thd3 :: (a,b,c) -> c thd3 (_,_,c) = c rename :: Name -> Name rename nm = let split acc [] = (reverse acc, '\0') -- shouldn't happen split acc [l] = (reverse acc, l) split acc (l:ls) = split (l:acc) ls -- nm' = nameBase nm (base, suffix) = split [] nm' in case suffix of '_' -> mkName base _ -> mkName (nm' ++ "_") checkExts :: [Extension] -> Q () checkExts req = do enabled <- extsEnabled let missing = req \\ enabled unless (null missing) . fail . unlines $ printf "You must enable the following language extensions to generate pattern synonyms:" : map (printf " {-# LANGUAGE %s #-}" . show) missing -- A simplified version of that stolen from GHC/Utils/Encoding.hs -- type EncodedString = String zencode :: String -> EncodedString zencode [] = [] zencode (h:rest) = encode_digit h ++ go rest where go [] = [] go (c:cs) = encode_ch c ++ go cs unencoded_char :: Char -> Bool unencoded_char 'z' = False unencoded_char 'Z' = False unencoded_char c = isAlphaNum c encode_digit :: Char -> EncodedString encode_digit c | isDigit c = encode_as_unicode_char c | otherwise = encode_ch c encode_ch :: Char -> EncodedString encode_ch c | unencoded_char c = [c] -- Common case first encode_ch '(' = "ZL" encode_ch ')' = "ZR" encode_ch '[' = "ZM" encode_ch ']' = "ZN" encode_ch ':' = "ZC" encode_ch 'Z' = "ZZ" encode_ch 'z' = "zz" encode_ch '&' = "za" encode_ch '|' = "zb" encode_ch '^' = "zc" encode_ch '$' = "zd" encode_ch '=' = "ze" encode_ch '>' = "zg" encode_ch '#' = "zh" encode_ch '.' = "zi" encode_ch '<' = "zl" encode_ch '-' = "zm" encode_ch '!' = "zn" encode_ch '+' = "zp" encode_ch '\'' = "zq" encode_ch '\\' = "zr" encode_ch '/' = "zs" encode_ch '*' = "zt" encode_ch '_' = "zu" encode_ch '%' = "zv" encode_ch c = encode_as_unicode_char c encode_as_unicode_char :: Char -> EncodedString encode_as_unicode_char c = 'z' : if isDigit (head hex_str) then hex_str else '0':hex_str where hex_str = showHex (ord c) "U"