module Data.Binding.Hobbits.QQ (nuP, clP, clNuP) where
import Language.Haskell.TH.Syntax as TH
import Language.Haskell.TH.Ppr as TH
import Language.Haskell.TH.Quote
import qualified Data.Generics as SYB
import Control.Monad.Writer (runWriterT, tell)
import Data.Monoid (Any(..))
import qualified Data.Binding.Hobbits.Internal.Utilities as IU
import Data.Binding.Hobbits.Internal.Mb
import Data.Binding.Hobbits.Internal.Closed
import Data.Binding.Hobbits.PatternParser (parsePattern)
import Data.Binding.Hobbits.NuMatching
appEMulti :: Exp -> [Exp] -> Exp
appEMulti = foldl AppE
compose :: Exp -> Exp -> Exp
compose f g = VarE '(.) `AppE` f `AppE` g
patQQ :: String -> (String -> Q Pat) -> QuasiQuoter
patQQ n pat = QuasiQuoter (err "Exp") pat (err "Type") (err "Decs")
where err s = error $ "QQ `" ++ n ++ "' is for patterns, not " ++ s ++ "."
data WrapKit =
WrapKit {_varView :: Exp, _asXform :: Pat -> Pat, _topXform :: Bool -> Pat -> Pat}
combineWrapKits :: WrapKit -> WrapKit -> WrapKit
combineWrapKits (WrapKit {_varView = varViewO, _asXform = asXformO, _topXform = topXformO})
(WrapKit {_varView = varViewI, _asXform = asXformI, _topXform = topXformI}) =
WrapKit {_varView = varViewO `compose` varViewI,
_asXform = asXformO . asXformI,
_topXform = \b -> topXformO b . topXformI b}
wrapVars :: Monad m => WrapKit -> Pat -> m Pat
wrapVars (WrapKit {_varView = varView, _asXform = asXform, _topXform = topXform}) pat = do
(pat', Any usedVarView) <- runWriterT m
return $ topXform usedVarView pat'
where
m = IU.everywhereButM (SYB.mkQ False isExp) (SYB.mkM w) pat
where isExp :: Exp -> Bool
isExp _ = True
hit x = tell (Any True) >> return x
w p@VarP{} = hit $ ViewP varView p
w (AsP v p) = hit $ ViewP varView $ AsP v $ asXform p
w (ViewP (VarE n) p) = return $ ViewP (VarE 'unClosed `AppE` VarE n) p
w (ViewP e _) = fail $ "view function must be a single name: `" ++ show (TH.ppr e) ++ "'"
w p = return p
parseHere :: String -> Q Pat
parseHere s = do
fn <- loc_filename `fmap` location
case parsePattern fn s of
Left e -> error $ "Parse error: `" ++ e ++
"'\n\n\t when parsing pattern: `" ++ s ++ "'."
Right p -> return p
same_ctx :: Mb ctx a -> Mb ctx b -> Mb ctx b
same_ctx _ x = x
nuKit :: TH.Name -> TH.Name -> WrapKit
nuKit topVar namesVar = WrapKit {_varView = varView, _asXform = asXform, _topXform = topXform} where
varView = (VarE 'same_ctx `AppE` VarE topVar) `compose`
(appEMulti (ConE 'MkMbPair) [VarE 'nuMatchingProof, VarE namesVar])
asXform p = ViewP (VarE 'ensureFreshPair) (TupP [WildP, p])
topXform b p = if b then AsP topVar $ ViewP (VarE 'ensureFreshPair) (TupP [VarP namesVar, p]) else asXform p
nuP = patQQ "nuP" $ \s -> do
topVar <- newName "topMb"
namesVar <- newName "topNames"
parseHere s >>= wrapVars (nuKit topVar namesVar)
clKit = WrapKit {_varView = ConE 'Closed, _asXform = asXform, _topXform = const asXform}
where asXform p = ConP 'Closed [p]
clP = patQQ "clP" $ (>>= wrapVars clKit) . parseHere
clNuP = patQQ "clNuP" $ \s -> do
topVar <- newName "topMb"
namesVar <- newName "topNames"
parseHere s >>= wrapVars (clKit `combineWrapKits` nuKit topVar namesVar)