{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE CPP #-} module Shapes.Linear.Template where import Test.QuickCheck.Arbitrary import Control.Monad import Language.Haskell.TH -- TODO: Use a wrapper type to hold multiple sizes of vector? data ValueInfo = ValueInfo { _valueN :: Name , _valueWrap :: Name , _valueBoxed :: Name , _valueAdd :: Name , _valueSub :: Name , _valueMul :: Name , _valueDiv :: Name , _valueNeg :: Name , _valueEq :: Name , _valueNeq :: Name , _valueLeq :: Name , _valueGeq :: Name , _valueGt :: Name , _valueLt :: Name } makeInlineD :: Name -> DecQ makeInlineD n = pragInlD n Inline FunLike AllPhases makeVectorN :: Int -> Name makeVectorN dim = mkName $ "V" ++ show dim makeVectorType :: ValueInfo -> Int -> DecsQ makeVectorType vi@ValueInfo{..} dim = do #if MIN_VERSION_template_haskell(2,11,0) notStrict_ <- bang noSourceUnpackedness noSourceStrictness #else notStrict_ <- notStrict #endif let vectorN = makeVectorN dim constrArg = (notStrict_, ConT _valueN) definers = [ defineLift , defineLift2 , defineDot , defineFromList , defineToList , deriveShow , deriveArbitrary ] impls <- concat <$> mapM (\f -> f vectorN vi dim) definers #if MIN_VERSION_template_haskell(2,11,0) let decs = DataD [] vectorN [] Nothing [NormalC vectorN (replicate dim constrArg)] [] : impls #else let decs = DataD [] vectorN [] [NormalC vectorN (replicate dim constrArg)] [] : impls #endif return decs deriveShow :: Name -> ValueInfo -> Int -> DecsQ deriveShow vectorN ValueInfo{..} dim = do (pat, vars) <- conPE vectorN "a" dim let f [] = [| "" |] f (v:vs) = [| " " ++ show $(appE (conE _valueWrap) v) ++ $(f vs) |] constructorShown = nameBase vectorN showClause = clause [pat] (normalB [| constructorShown ++ $(f vars) |]) [] return <$> instanceD (cxt []) (appT (conT ''Show) (conT vectorN)) [funD 'show [showClause]] dimE :: Int -> ExpQ dimE = litE . integerL . fromIntegral deriveArbitrary :: Name -> ValueInfo -> Int -> DecsQ deriveArbitrary vectorN ValueInfo{..} dim = do let arbClause = clause [] (normalB $ infixApp (fromListE vectorN) (varE '(<$>)) arbList) [] arbList = [| replicateM $(dimE dim) arbitrary |] return <$> instanceD (cxt []) (appT (conT ''Arbitrary) (conT vectorN)) [funD 'arbitrary [arbClause]] defineLift :: Name -> ValueInfo -> Int -> DecsQ defineLift vectorN ValueInfo{..} dim = do (funcP, funcV) <- newPE "f" (vecP, elemVars) <- conPE vectorN "a" dim let liftClause = clause [funcP, vecP] liftBody [] f = appE funcV liftBody = normalB $ appsE (conE vectorN : fmap f elemVars) liftName = mkName $ "lift" ++ nameBase vectorN valueT = conT _valueN vectorT = conT vectorN liftType = arrowsT [arrowsT [valueT, valueT], vectorT, vectorT] inlSigDef liftName liftType [liftClause] defineLift2 :: Name -> ValueInfo -> Int -> DecsQ defineLift2 vectorN ValueInfo{..} dim = do (funcP, funcV) <- newPE "f" (vecP, elemVars) <- conPE vectorN "a" dim (vecP', elemVars') <- conPE vectorN "b" dim let pairVars = zip elemVars elemVars' liftClause = clause [funcP, vecP, vecP'] liftBody [] f (x, y) = appsE [funcV, x, y] liftBody = normalB $ appsE (conE vectorN : fmap f pairVars) liftName = mkName $ "lift2" ++ nameBase vectorN valueT = conT _valueN vectorT = conT vectorN liftType = arrowsT [arrowsT [valueT, valueT, valueT], vectorT, vectorT, vectorT] inlSigDef liftName liftType [liftClause] dotE :: ValueInfo -> [ExpQ] -> [ExpQ] -> ExpQ dotE ValueInfo{..} row col = foldl1 (infixApp' $ varE _valueAdd) products where products = uncurry (infixApp' $ varE _valueMul) <$> zip row col defineDot :: Name -> ValueInfo -> Int -> DecsQ defineDot vectorN vi@ValueInfo{..} dim = do (vecP, elemVars) <- conPE vectorN "a" dim (vecP', elemVars') <- conPE vectorN "b" dim let dotClause = clause [vecP, vecP'] (normalB $ dotE vi elemVars elemVars') [] dotName = mkName $ "dot" ++ nameBase vectorN valueT = conT _valueN vectorT = conT vectorN dotType = arrowsT [vectorT, vectorT, valueT] inlSigDef dotName dotType [dotClause] defineJoinSplit :: ValueInfo -> (Int, Int) -> DecsQ defineJoinSplit ValueInfo{..} (left, right) = do let vecN = makeVectorN left vecN' = makeVectorN right vecN'' = makeVectorN (left + right) (vecP, elemVs) <- conPE vecN "a" left (vecP', elemVs') <- conPE vecN' "b" right (vecP'', elemVs'') <- conPE vecN'' "c" (left + right) let joinE = appsE (conE vecN'' : elemVs ++ elemVs') joinC = simpleClause [vecP, vecP'] joinE joinN = mkName $ "join" ++ show left ++ "v" ++ show right joinT = arrowsT [vecT, vecT', vecT''] (leftVs, rightVs) = splitAt left elemVs'' splitE = tupE [ appsE $ conE vecN : leftVs , appsE $ conE vecN' : rightVs ] splitC = simpleClause [vecP''] splitE splitN = mkName $ "split" ++ show left ++ "v" ++ show right splitT = arrowsT [vecT'', tupT [vecT, vecT']] vecT = conT vecN vecT' = conT vecN' vecT'' = conT vecN'' joinI <- inlSigDef joinN joinT [joinC] splitI <- inlSigDef splitN splitT [splitC] return $ joinI ++ splitI fromListN :: Name -> Name fromListN = mkName . ("fromList" ++) . nameBase fromListE :: Name -> ExpQ fromListE = varE . fromListN defineFromList :: Name -> ValueInfo -> Int -> DecsQ defineFromList vectorN ValueInfo{..} dim = do (pats, vars) <- genPEWith "x" dim (conP _valueWrap . return . varP) varE let listPat = listP pats vecE = appsE (conE vectorN : vars) fromListClause0 = clause [listPat] (normalB vecE) [] fromListClause1 = clause [wildP] (normalB [| error "wrong number of elements" |]) [] vectorT = conT vectorN argT = appT listT (conT _valueBoxed) fromListType = arrowsT [argT, vectorT] inlSigDef (fromListN vectorN) fromListType [fromListClause0, fromListClause1] defineToList :: Name -> ValueInfo -> Int -> DecsQ defineToList vectorN ValueInfo{..} dim = do (vecP, elemVars) <- conPE vectorN "a" dim let boxedElemVars = fmap (appE $ conE _valueWrap) elemVars toListClause = clause [vecP] (normalB $ listE boxedElemVars) [] toListName = mkName $ "toList" ++ nameBase vectorN vectorT = conT vectorN resultT = appT listT (conT _valueBoxed) toListType = arrowsT [vectorT, resultT] inlSigDef toListName toListType [toListClause] infixApp' :: ExpQ -> ExpQ -> ExpQ -> ExpQ infixApp' = flip infixApp inlSigDef :: Name -> TypeQ -> [ClauseQ] -> DecsQ inlSigDef funN funT funCs = do sigdef <- funSigDef funN funT funCs inl <- makeInlineD funN return $ sigdef ++ [inl] funSigDef :: Name -> TypeQ -> [ClauseQ] -> DecsQ funSigDef funN funT funCs = do funSig <- sigD funN funT funDef <- funD funN funCs return [funSig, funDef] tupT :: [TypeQ] -> TypeQ tupT ts = foldl appT (tupleT $ length ts) ts arrowsT :: [TypeQ] -> TypeQ arrowsT [] = error "can't have no type" arrowsT [t] = t arrowsT (t:ts) = appT (appT arrowT t) $ arrowsT ts newPE :: String -> Q (PatQ, ExpQ) newPE x = do x' <- newName x return (varP x', varE x') conPE :: Name -> String -> Int -> Q (PatQ, [ExpQ]) conPE conN x dim = do (pats, vars) <- genPE x dim return (conP conN pats, vars) genPEWith :: String -> Int -> (Name -> PatQ) -> (Name -> ExpQ) -> Q ([PatQ], [ExpQ]) genPEWith x n mkP mkE = do ids <- replicateM n (newName x) return (fmap mkP ids, fmap mkE ids) genPE :: String -> Int -> Q ([PatQ], [ExpQ]) genPE x n = genPEWith x n varP varE simpleClause :: [PatQ] -> ExpQ -> ClauseQ simpleClause ps e = clause ps (normalB e) []