{-# LANGUAGE CPP #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ForeignFunctionInterface #-} module Foreign.CUDA.Cublas.TH where import Control.Applicative import Control.Arrow import Control.Monad ((>=>), join, void) import GHC.Exts (groupWith) import Language.Haskell.TH as TH import Language.C import Language.C.System.GCC import Data.List (isInfixOf, isPrefixOf, isSuffixOf) import Data.Char (toLower, toUpper) import Data.Maybe (mapMaybe) import Debug.Trace import qualified Foreign.C.Types as C import qualified Foreign as F import Foreign.Storable.Complex () import Data.Complex (Complex(..)) import System.FilePath.Posix (()) import Foreign.CUDA as FC import qualified Foreign.CUDA.Runtime.Stream as FC import qualified Foreign.CUDA.Cublas.Types as BL import qualified Foreign.CUDA.Cusparse.Types as SP import qualified Foreign.CUDA.Cublas.Error as BL import qualified Foreign.CUDA.Cusparse.Error as SP try :: [(Bool, a)] -> a try ((p,y):conds) = if p then y else try conds try [] = error "TH.try: No match!" data TypeInfo = TI { ctype :: Q Type , hsinput :: Either Convert Create , hsoutput :: Maybe (Either Convert Destroy) } data TypeDat = TD { ct :: Q Type , hst :: Q Type , c2hs :: (Q Exp, ExpType) , hs2c :: (Q Exp, ExpType) } data ExpType = Pure | Monadic data TypeC = VoidC | IntC | FloatC | DoubleC | EnumC String | ComplexC TypeC | ArbStructC String | PtrC TypeC | PhonyC TH.Name | ArrC TypeC deriving (Eq, Show) prim :: Q Type -> Q Type -> Q Exp -> Q Exp -> TypeDat prim ct hst c2hs hs2c = TD ct hst (c2hs, Pure) (hs2c, Pure) simple :: Q Type -> TypeDat simple t = prim t t [| id |] [| id |] bothc :: (a -> b) -> Complex a -> Complex b bothc f (a :+ b) = f a :+ f b typeDat :: TypeC -> TypeDat typeDat (PhonyC n) = simple (varT n) typeDat VoidC = simple [t| () |] typeDat (PtrC t) = prim [t| F.Ptr $(ctype) |] [t| FC.DevicePtr $(ctype) |] [| FC.DevicePtr |] [| FC.useDevicePtr |] where ctype = ct (typeDat t) typeDat (ArrC t) = typeDat (PtrC t) typeDat IntC = prim [t| C.CInt |] [t| Int |] [| fromIntegral |] [| fromIntegral |] typeDat FloatC = simple [t| C.CFloat |] typeDat DoubleC = simple [t| C.CDouble |] typeDat (EnumC str) = prim [t| C.CInt |] x [| toEnum . fromIntegral |] [| fromIntegral . fromEnum |] where x = case str of "cublasStatus_t" -> [t| BL.Status |] "cublasOperation_t" -> [t| BL.Operation |] "cublasSideMode_t" -> [t| BL.SideMode |] "cublasFillMode_t" -> [t| BL.FillMode |] "cublasPointerMode_t" -> [t| BL.PointerMode |] "cublasAtomicsMode_t" -> [t| BL.AtomicsMode |] "cublasDiagType_t" -> [t| BL.DiagType |] "cusparseStatus_t" -> [t| SP.Status |] "cusparseOperation_t" -> [t| SP.Operation |] "cusparseDirection_t" -> [t| SP.Direction |] "cusparseHybPartition_t" -> [t| SP.HybPartition |] "cusparseFillMode_t" -> [t| SP.FillMode |] "cusparsePointerMode_t" -> [t| SP.PointerMode |] "cusparseDiagType_t" -> [t| SP.DiagType |] "cusparseIndexBase_t" -> [t| SP.IndexBase |] "cusparseAction_t" -> [t| SP.Action |] "cusparseMatrixType_t" -> [t| SP.MatrixType |] "cusparseSolvePolicy_t" -> [t| SP.SolvePolicy |] otherwise -> error ("typeDat.EnumC : Missing type: " ++ str) typeDat (ArbStructC str) = case str of "cublasHandle_t" -> prim [t| F.Ptr () |] [t| BL.Handle |] [| BL.Handle |] [| BL.useHandle |] "cusparseHandle_t" -> prim [t| F.Ptr () |] [t| SP.Handle |] [| SP.Handle |] [| SP.useHandle |] "cusparseHybMat_t" -> prim [t| F.Ptr () |] [t| SP.HybMat |] [| SP.HybMat |] [| SP.useHybMat |] "cusparseMatDescr_t" -> prim [t| F.Ptr () |] [t| SP.MatDescr |] [| SP.MatDescr |] [| SP.useMatDescr |] "cusparseSolveAnalysisInfo_t" -> prim [t| F.Ptr () |] [t| SP.SolveAnalysisInfo |] [| SP.SolveAnalysisInfo |] [| SP.useSolveAnalysisInfo |] "csrsv2Info_t" -> prim [t| F.Ptr () |] [t| SP.Csrsv2Info |] [| SP.Csrsv2Info |] [| SP.useCsrsv2Info |] "csric02Info_t" -> prim [t| F.Ptr () |] [t| SP.Csric02Info |] [| SP.Csric02Info |] [| SP.useCsric02Info |] "csrilu02Info_t" -> prim [t| F.Ptr () |] [t| SP.Csrilu02Info |] [| SP.Csrilu02Info |] [| SP.useCsrilu02Info |] "bsrsv2Info_t" -> prim [t| F.Ptr () |] [t| SP.Bsrsv2Info |] [| SP.Bsrsv2Info |] [| SP.useBsrsv2Info |] "bsric02Info_t" -> prim [t| F.Ptr () |] [t| SP.Bsric02Info |] [| SP.Bsric02Info |] [| SP.useBsric02Info |] "bsrilu02Info_t" -> prim [t| F.Ptr () |] [t| SP.Bsrilu02Info |] [| SP.Bsrilu02Info |] [| SP.useBsrilu02Info |] "cudaStream_t" -> prim [t| F.Ptr () |] [t| FC.Stream |] [| FC.Stream |] [| FC.useStream |] typeDat (ComplexC t) = prim [t| Complex $(ctype) |] [t| Complex $(hstype) |] [| bothc $(fromC) |] [| bothc $(toC) |] where TD ctype hstype (fromC, Pure) (toC, Pure) = typeDat t convertT x y = Left (Convert x y) createT = Right . Create destroyT = Right . Destroy data Convert = Convert (Q Type) (Q Exp) newtype Create = Create (Q Exp) newtype Destroy = Destroy (Q Exp) pointerify :: Q Type -> Q Type pointerify x = [t| F.Ptr $(x) |] useT :: TypeC -> TypeInfo useT = useT' . typeDat where useT' (TD ct hst c2hs (hs2c,purity)) = TI ct (convertT hst exp) Nothing where exp = case purity of Pure -> [| return . $(hs2c) |]; Monadic -> hs2c inT :: TypeC -> TypeInfo inT (PtrC t) = inT' (typeDat t) where inT' (TD ct hst c2hs (hs2c,purity)) = TI (pointerify ct) (convertT hst exp) (Just (destroyT [| F.free |])) where exp = case purity of Pure -> [| F.new . $(hs2c) |] ; Monadic -> undefined inT (ArrC t) = inT' (typeDat t) where inT' (TD ct hst c2hs (hs2c,purity)) = TI (pointerify ct) (convertT [t| [ $(hst) ] |] exp) (Just (destroyT [| FC.free . FC.DevicePtr |])) where exp = case purity of Pure -> [| fmap FC.useDevicePtr . FC.newListArray . map $(hs2c) |] ; Monadic -> undefined outT :: TypeC -> TypeInfo outT (PtrC t) = outT' (typeDat t) where outT' (TD ct hst (c2hs,purity) hs2c) = TI ct (createT [| F.malloc |]) (Just (convertT hst [| \p -> do { x <- F.peek p ; F.free p; $(exp) x } |])) where exp = case purity of Pure -> [| return . $(c2hs) |] ; Monadic -> c2hs inOutT :: TypeC -> TypeInfo inOutT (PtrC t) = inOutT' (typeDat t) where inOutT' (TD ct hst (c2hs,purity1) (hs2c,purity2)) = TI (pointerify ct) (convertT hst exp1) (Just (convertT hst [| \p -> do { x <- F.peek p ; F.free p; $(exp2) x } |])) where exp1 = case purity1 of Pure -> [| F.new . $(hs2c) |] ; Monadic -> hs2c exp2 = case purity2 of Pure -> [| return . $(c2hs) |] ; Monadic -> c2hs convert :: CTypeSpecifier a -> TypeC convert (CVoidType _) = VoidC --CCharType a --CShortType a convert (CIntType _) = IntC --CLongType a convert (CFloatType _) = FloatC convert (CDoubleType _) = DoubleC --CSignedType a --CUnsigType a --CBoolType a --CComplexType a convert (CTypeDef ident _) = try [ (s `elem` ["cublasHandle_t", "cusparseHybMat_t", "cusparseHandle_t", "cusparseMatDescr_t", "cusparseSolveAnalysisInfo_t", "cudaStream_t", "csrsv2Info_t", "csric02Info_t", "csrilu02Info_t", "bsrsv2Info_t", "bsric02Info_t", "bsrilu02Info_t" ] , ArbStructC s) , (s=="cuComplex", ComplexC FloatC) , (s=="cuDoubleComplex", ComplexC DoubleC) , (True, EnumC s) ] where s = identToString ident convert _ = VoidC convert' :: [CDeclarationSpecifier a] -> TypeC convert' (CTypeSpec x:_) = convert x convert' (_:xs) = convert' xs convert' [] = error "convert': invalid CDeclarationSpecifier list" typeOf :: (TypeC -> Q Type) -> CDeclaration a -> Q Type typeOf proj (CDecl basetype [(Just (CDeclr (Just ident) ptrs _ _ _), _, _)] _) = foldr f (proj $ convert' basetype) ptrs where f (CPtrDeclr _ _) b = [t| F.Ptr $(b) |] f _ _ = error "haven't implemented other things" pointerification :: CDeclaration a -> (TypeC -> TypeC) pointerification (CDecl _ [(Just (CDeclr _ ptrs _ _ _), _, _)] _) = foldr (.) id $ map f ptrs where f (CPtrDeclr _ _) = PtrC f (CArrDeclr _ _ _) = ArrC f _ = id --possible there are other things that should be here? baseType :: CDeclaration a -> TypeC baseType (CDecl basetype _ _) = convert' basetype cType :: CDeclaration a -> TypeC cType d = (pointerification d) (baseType d) typeInfo :: String -> CVar -> TypeInfo typeInfo fn (n, typec) = ($ typec) $ try [ {- CUBLAS -} (case typec of ArrC _ -> True; otherwise -> False , inT) , (n `elem` ["alpha", "beta", "a", "b", "c", "d1", "d2", "x1", "y1", "s"] , inT) , ( "create" `isPrefixOf` fn || n == "result" , outT) {- CuSPARSE -} , ("DevHostPtr" `isSuffixOf` n , outT) {- End -} , (True , useT) ] declName :: CDeclaration a -> Maybe String declName (CDecl _ [(Just (CDeclr (Just ident) _ _ _ _), _, _)] _) = Just (identToString ident) declName _ = Nothing outMarshall :: TypeC -> (Q Exp, Q Type -> Q Type) outMarshall (EnumC "cublasStatus_t") = ([| BL.resultIfOk |], id) outMarshall (EnumC "cusparseStatus_t") = ([| SP.resultIfOk |], id) outMarshall VoidC = ([| return . snd |], id) outMarshall x = ([| return . fst |], const (hst $ typeDat x)) createf' :: (String, CFunction) -> Q [Dec] createf' (foreignname, cf@(fname, rettype, args)) = do ins <- mapM (safeName "_in") args toCs <- mapM (safeName "_out") args (outstatements, (outtypes, outs)) <- second (unzip . filterMaybes) . unzip <$> collect (zip3 args argsTI toCs) let instatements = map inMarsh (zip3 argsTI ins toCs) ret <- newName "res" let runstatement = bindS (varP ret) (foldl f z toCs) let returnstatement = [| $(checkStatusExp) ( $(outputConv) $(varE ret), $(tupE (map varE outs)) ) |] expr <- doE $ concat [instatements, runstatement:outstatements, [noBindS returnstatement]] let usedins = map snd . filter (isused . fst) $ zip argsTI ins let fdec = FunD fcall [Clause (map VarP usedins) (NormalB expr) []] tdec <- sigD fcall $ funTypeMod checkStatusType argsTI return [tdec, fdec] where safeName :: String -> CVar -> Q TH.Name safeName end (s:str, _) = newName (toLower s : str ++ end) argsTI = functionTypeInfo cf (outputConv, _) = c2hs (typeDat rettype) (checkStatusExp, checkStatusType) = outMarshall rettype isused (TI _ (Left _) _) = True isused _ = False fcall = mkName fname z = varE (mkName foreignname) f x e = appE x (varE e) inMarsh (ti,e,e') = case hsinput ti of Left (Convert t a) -> bindS (varP e') (appE a (varE e)) Right (Create a) -> bindS (varP e') a collect ( (arg, TI _ _ (Just cleanup), e) : xs) = do e' <- safeName "_out" arg let outinfo = case cleanup of Left (Convert t a) -> (bindS (varP e') (appE a (varE e)), Just (t, e')) Right (Destroy a) -> (noBindS (appE a (varE e)), Nothing) ys <- collect xs return (outinfo:ys) collect (_:xs) = collect xs collect [] = return [] collecti (TI _ (Left (Convert t _)) _) = [t] collecti _ = [] cublasFile, cusparseFile :: FilePath cublasFile = CUDA_INCLUDE_DIR "cublas_v2.h" cusparseFile = CUDA_INCLUDE_DIR "cusparse_v2.h" filterMaybes :: [Maybe a] -> [a] filterMaybes [] = [] filterMaybes (Just x:xs) = x : filterMaybes xs filterMaybes (Nothing:xs) = filterMaybes xs funname :: CDeclaration a -> String funname (CDecl _ [(Just (CDeclr (Just ident ) _ _ _ _), _, _)] _) = identToString ident funname _ = "Weird!" desired :: String -> CFunction -> Bool desired prefix (name, _, _) = any (`isPrefixOf` name) $ map (prefix ++) ("Get" : map (:[]) "SDCZX") infol :: Show a => CDerivedDeclarator a -> Maybe [[String]] infol (CFunDeclr (Right (ys,_)) _ _) = Just $ map f ys where f (CDecl specs _ _) = map show specs infol _ = Nothing funArgs :: CDeclarator a -> Maybe [CDeclaration a] funArgs (CDeclr _ [(CFunDeclr (Right (ys,_)) _ _)] _ _ _) = Just ys funArgs _ = Nothing funDecl :: CDeclaration a -> Maybe (CDeclarator a) funDecl (CDecl _ [(Just declarator, _, _)] _) = Just declarator funDecl _ = Nothing maybeFunction :: CDeclaration a -> Maybe (CFunction) maybeFunction d@(CDecl returnType _ _) = do args <- funArgs =<< funDecl d retName <- declName d argNames <- mapM declName args let argTypes = map cType args return (retName, convert' returnType, zip argNames argTypes ) maybeExternalDec :: CExternalDeclaration a -> Maybe (CDeclaration a) maybeExternalDec (CDeclExt d) = Just d maybeExternalDec _ = Nothing type CVar = (String, TypeC) type CFunction = ( String , TypeC , [CVar] ) getFunctions :: FilePath -> IO [CFunction] getFunctions fp = do Right (CTranslUnit xs _) <- parseCFile (newGCC "/usr/bin/gcc") Nothing [] fp return $ mapMaybe (maybeExternalDec >=> maybeFunction) xs createf :: FilePath -> CFunction -> Q Dec createf fp (name, ret, args) = forImpD cCall safe{-unsafe-} (fp ++ ' ':name) (mkName name) cFunType where cFunType = foldr f z (map (ct . typeDat . snd) args) z = [t| IO $(ct . typeDat $ ret) |] f x y = [t| $(x) -> $(y) |] sharedDecs :: String -> [CFunction] -> [(String, [(String, CFunction)])] sharedDecs prefix xs = xs'' where g x@(s,ret,args) = do newname <- dropc <$> goodName prefix s return (s, (newname, ret, args)) xs' = mapMaybe g xs fst3 (s,_,_) = s dropc name = if last name == 'c' then init name else name --for dot, ger, ... xs'' = map ( (fst3 . snd . head) &&& id) . filter sdFilter . groupWith (tail . fst3 . snd) $ xs' sdFilter xs = length xs == 4 && not ( any (`isInfixOf` (fst (head xs))) ["rot_v2", "rotg_v2", "hybsv_analysis", "numericBoost"] ) mkClass :: String -> [CFunction] -> Q Dec mkClass (p:prefix) xs = classD (return []) className [PlainTV typeName] [] decs where className = mkName (toUpper p:prefix) typeName = mkName "a" decs = map (f . phonifyF) xs mkPhony :: TypeC -> TypeC mkPhony (PtrC t) = PtrC (mkPhony t) mkPhony (ArrC t) = ArrC (mkPhony t) mkPhony x = let t' = PhonyC typeName in case x of DoubleC -> t'; FloatC -> t'; ComplexC _ -> t'; y -> y phonifyF :: CFunction -> CFunction phonifyF (name, ret, args) = (name, mkPhony ret, map (second mkPhony) args) f cfunc@(name, _, _) = sigD (mkName (tail name)) (funType $ functionTypeInfo cfunc) mkClassInstances :: String -> [(String, [(String,CFunction)])] -> [Q Dec] mkClassInstances (p:prefix) xs = map (\c -> makeInstance c $ map (f c) xs) "sdcz" where makeInstance c decs = instanceD (return []) classSig (decs) where classSig = appT (return . ConT $ mkName (toUpper p:prefix)) (ct . typeDat $ typeMap c) f c (_, funcs) = (!! 1) <$> createf' (foreignn, (name, ret, args)) where [(foreignn,((_:name), ret, args))] = filter (\(_,((s:_),_,_))-> s==c) funcs typeMap :: Char -> TypeC typeMap 'c' = ComplexC FloatC typeMap 'z' = ComplexC DoubleC typeMap 'd' = DoubleC typeMap 's' = FloatC typeMap _ = error "typeMap: Invalid character" makeClassDecs :: String -> FilePath -> IO (Q [Dec]) makeClassDecs str fp = do sds <- sharedDecs str <$> getFunctions fp return $ sequence (mkClass str (map (snd . head . snd) sds) : mkClassInstances str sds) makeFFIDecs :: String -> FilePath -> IO (Q [Dec]) makeFFIDecs str fp = sequence . map (createf fp) . filter (desired str) <$> getFunctions fp makeAllFuncs :: String -> FilePath -> IO (Q [Dec]) makeAllFuncs str fp = fmap concat . sequence . mapMaybe (fmap createf' . alter). filter (desired str) <$> getFunctions fp where alter (fname, rettype, args) = do newname <- goodName str fname return (fname, (newname, rettype, args)) goodName :: String -> String -> Maybe String goodName prefix = f where v2suff = "_v2" l = length prefix f str = if pre == prefix then Just (toLower x : xs) else Nothing where (pre, name) = splitAt l str (name', v2) = splitAt (length name - length v2suff) name (x : xs) = if v2 == v2suff then name' else name doIO :: IO (Q [a]) -> Q [a] doIO = join . runIO inTypes :: [TypeInfo] -> [Q Type] inTypes = mapMaybe f where f (TI _ (Left (Convert t _)) _) = Just t f _ = Nothing outTypes :: [TypeInfo] -> [Q Type] outTypes = mapMaybe f where f (TI _ _ (Just (Left (Convert t _)))) = Just t f _ = Nothing functionTypeInfo :: CFunction -> [TypeInfo] functionTypeInfo (fname, ret, args) = map (typeInfo fname) args funTypeMod :: (Q Type -> Q Type) -> [TypeInfo] -> Q Type funTypeMod f args = foldr arrow z ins where arrow x y = [t| $(x) -> $(y) |] z = [t| IO $( f $ foldl appT (tupleT (length outs)) outs) |] [ins, outs] = map ($ args) [inTypes, outTypes] funType :: [TypeInfo] -> Q Type funType = funTypeMod id