module MXNet.Core.Base.Internal.TH where
import Data.Char
import Data.List
import Data.Monoid
import Language.Haskell.TH
import MXNet.Core.NNVM.Internal
import MXNet.Core.Base.Internal
registerNDArrayOps :: Bool
-> Q [Dec]
registerNDArrayOps mutable = runIO $ do
(_, names) <- mxListAllOpNames
concat <$> mapM (register mutable) names
where
register mutable _name = do
(_, handle) <- nnGetOpHandle _name
(_, _, desc, _, argv, argtype, _, _, _) <- mxSymbolGetAtomicSymbolInfo handle
makeNDArrayFunc mutable _name desc argv argtype
registerSymbolOps :: Q [Dec]
registerSymbolOps = runIO $ do
(_, names) <- mxListAllOpNames
concat <$> mapM register names
where
register _name = do
(_, handle) <- nnGetOpHandle _name
(_, _, desc, _, argv, argtype, _, _, _) <- mxSymbolGetAtomicSymbolInfo handle
makeSymbolFunc _name desc argv argtype
makeNDArrayFunc :: Bool
-> String
-> String
-> [String]
-> [String]
-> IO [Dec]
makeNDArrayFunc mutable _name desc argv argtype = do
let deprecated = desc `startWith` "DEPRECATED" ||
_name == "Softmax"
let alias = _name `elem` ["Concat", "Pad", "Flatten", "Reshape"]
let name = let str = if head _name == '_'
then _name
else if _name == "where"
then "where_"
else toLower <$> _name
in if mutable then str <> "'" else str
let explicitArg = getExplicitArg argv argtype
ndarrayArg = filter (\(_, t) -> t `startWith` "NDArray" || t `startWith` "Symbol") explicitArg
ordinaryArg = filter (\(_, t) -> not (t `startWith` "NDArray" || t `startWith` "Symbol")) explicitArg
implicitArg = getImplicitArg argv argtype
hasImplicit = (not . null) implicitArg
let forallArgT = makeForallArgT implicitArg
explicitArgT = (makeHsType . snd) <$> explicitArg
implicitArgT = if hasImplicit
then [AppT (ConT (mkName "HMap")) (VarT (mkName "kvs"))]
else error "Impossible: no implicit available."
let ndarrayArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ndarrayArg
ordinaryArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ordinaryArg
implicitArgP = if hasImplicit
then [VarP . mkName $ "varargs"]
else []
returnArgP = if mutable
then [VarP (mkName "outputs")]
else []
let ndargs = foldr (\(v, t) args-> case makeHsType t of
ConT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (ConE . mkName $ ":") args
AppT ListT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (VarE . mkName $ "++") args
_ -> error "Impossible: not a valid haskell type representation.")
(ListE [])
ndarrayArg
dictargs = UInfixE (VarE (mkName "varArgK")) (VarE (mkName "zip")) (VarE (mkName "varArgV"))
let func = NormalB . DoE $
[ LetS [ ValD (VarP (mkName "allArgs"))
(NormalB $
foldr (\(name, t) acc -> AppE (AppE (AppE (VarE (mkName "add'"))
(SigE (ConE (mkName "Proxy"))
(AppT (ConT (mkName "Proxy")) (LitT (StrTyLit name)))))
(SigE (VarE (mkName ("arg'" <> name))) (makeHsType t)))
acc)
(VarE (mkName $ if hasImplicit then "varargs" else "nil"))
ordinaryArg
)
[]
]
, LetS [ ValD (VarP (mkName "args"))
(NormalB $
AppE (VarE (mkName "dump"))
(VarE (mkName "allArgs"))
)
[]
, ValD (TupP [VarP (mkName "varArgK"), VarP (mkName "varArgV")])
(NormalB $
AppE (VarE (mkName "unzip"))
(VarE (mkName "args"))
)
[]
, ValD (VarP (mkName "outArg"))
(NormalB $
if mutable
then AppE (ConE (mkName "Just")) (VarE (mkName "outputs"))
else ConE (mkName "Nothing")
)
[]
]
, BindS (TupP [VarP (mkName "_"), VarP (mkName "op")]) $
AppE (VarE (mkName "nnGetOpHandle")) (LitE (StringL _name))
, BindS (TupP [VarP (mkName "_"), VarP (mkName "res")]) $
AppE (AppE (AppE (AppE (VarE (mkName "mxImperativeInvoke"))
(VarE (mkName "op")))
ndargs)
dictargs)
(VarE (mkName "outArg"))
, NoBindS $
AppE (VarE (mkName "return")) $
AppE (VarE (mkName "toResult")) (VarE (mkName "res"))
]
let argT = explicitArgT <> (if hasImplicit then implicitArgT else [])
<> (if mutable then [AppT ListT (ConT (mkName "NDArrayHandle"))] else [])
sig = SigD (mkName name) $
ForallT [ PlainTV (mkName "r")]
[ AppT (ConT (mkName "NDArrayOpResult")) (VarT (mkName "r"))]
(forallArgT (foldr (\a b -> ArrowT `AppT` a `AppT` b)
(AppT (ConT (mkName "IO")) (VarT (mkName "r")))
argT))
pragma = PragmaD $
SpecialiseP (mkName name)
(forallArgT (foldr (\a b -> ArrowT `AppT` a `AppT` b)
(AppT (ConT (mkName "IO")) (ConT (mkName "NDArrayHandle")))
argT))
(Just Inline)
AllPhases
fun = FunD (mkName name) [Clause (ndarrayArgP <> ordinaryArgP <> implicitArgP <> returnArgP) func []]
return $ if null argv || deprecated
|| alias
|| _name `elem` ["_NDArray", "_Native", "_arange"]
|| _name `elem` ["cast", "crop"]
|| null explicitArg
|| _name == "take"
then []
else [sig, fun, pragma]
where
makeHsType :: String -> Type
makeHsType s = case s of
"boolean" -> ConT . mkName $ "Bool"
"float" -> ConT . mkName $ "Float"
"double" -> ConT . mkName $ "Double"
"real_t" -> ConT . mkName $ "Float"
'i':'n':'t':_ -> ConT . mkName $ "Int"
'l':'o':'n':'g':_ -> ConT . mkName $ "Int"
"string" -> ConT . mkName $ "String"
"NDArray" -> ConT . mkName $ "NDArrayHandle"
"NDArray-or-Symbol" -> ConT . mkName $ "NDArrayHandle"
"NDArray-or-Symbol[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle"
"Symbol" -> ConT . mkName $ "NDArrayHandle"
"NDArray[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle"
"Symbol[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle"
"Symbol or Symbol[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle"
'{':_ -> ConT . mkName $ "String"
"Shape(tuple)" -> ConT . mkName $ "String"
"tuple of <float>" -> AppT ListT . ConT . mkName $ "Float"
"tuple of <double>" -> AppT ListT . ConT . mkName $ "Double"
s -> ConT . mkName $ "unknown type name: " <> s
makeKVListT :: [(String, String, String)]
-> Type
makeKVListT args = foldr combineKV PromotedNilT ((\(v, t, _) -> makeKV v t) <$> args)
where
makeKV v t = AppT (AppT (PromotedT (mkName ":="))
(LitT (StrTyLit v)))
(makeHsType t)
combineKV a acc = AppT (AppT (PromotedT (mkName ":")) a) acc
makeForallArgT :: [(String, String, String)]
-> (Type -> Type)
makeForallArgT [] = id
makeForallArgT implicitArg =
ForallT [ KindedTV (mkName "kvs")
(AppT ListT
(AppT (ConT (mkName "KV")) StarT))
]
[ AppT (ConT (mkName "ShowKV"))
(VarT (mkName "kvs"))
, AppT (AppT (ConT (mkName "MatchKVList"))
(VarT (mkName "kvs")))
(makeKVListT implicitArg)
]
makeSymbolFunc :: String
-> String
-> [String]
-> [String]
-> IO [Dec]
makeSymbolFunc _name desc argv argtype = do
let deprecated = desc `startWith` "DEPRECATED" ||
_name == "Softmax"
let alias = _name `elem` ["Concat", "Pad", "Flatten", "Reshape"]
let name = let str = if head _name == '_'
then _name
else if _name == "where"
then "where_"
else toLower <$> _name
in str
let explicitArg = getExplicitArg argv argtype
ndarrayArg = filter (\(v, t) -> t `startWith` "NDArray" || t `startWith` "Symbol") explicitArg
ordinaryArg = filter (\(v, t) -> not (t `startWith` "NDArray" || t `startWith` "Symbol")) explicitArg
implicitArg = getImplicitArg argv argtype
hasImplicit = (not . null) implicitArg
let forallArgT = makeForallArgT implicitArg
explicitArgT = (makeHsType . snd) <$> explicitArg
implicitArgT = if hasImplicit
then [AppT (ConT (mkName "HMap")) (VarT (mkName "kvs"))]
else error "Impossible: no implicit available."
let nameArgP = [VarP . mkName $ "name"]
ndarrayArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ndarrayArg
ordinaryArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ordinaryArg
implicitArgP = if hasImplicit
then [VarP . mkName $ "varargs"]
else []
let ndargs = foldr (\(v, t) args -> case makeHsType t of
ConT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (ConE . mkName $ ":") args
AppT ListT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (VarE . mkName $ "++") args
_ -> error "Impossible: not a valid haskell type representation.")
(ListE [])
ndarrayArg
let func = NormalB . DoE $
[ LetS [ ValD (VarP (mkName "allArgs"))
(NormalB $
foldr (\(name, t) acc -> AppE (AppE (AppE (VarE (mkName "add'"))
(SigE (ConE (mkName "Proxy"))
(AppT (ConT (mkName "Proxy")) (LitT (StrTyLit name)))))
(SigE (VarE (mkName ("arg'" <> name))) (makeHsType t)))
acc)
(VarE (mkName $ if hasImplicit then "varargs" else "nil"))
ordinaryArg
)
[]
]
, LetS [ ValD (VarP (mkName "args"))
(NormalB $
AppE (VarE (mkName "dump"))
(VarE (mkName "allArgs"))
)
[]
, ValD (TupP [VarP (mkName "varArgK"), VarP (mkName "varArgV")])
(NormalB $
AppE (VarE (mkName "unzip"))
(VarE (mkName "args"))
)
[]
]
, BindS (TupP [VarP (mkName "_"), VarP (mkName "op")]) $
AppE (VarE (mkName "nnGetOpHandle")) (LitE (StringL _name))
, LetS [ ValD (VarP (mkName "nargs"))
(NormalB (AppE (VarE (mkName "fromIntegral"))
(AppE (VarE (mkName "length"))
(VarE (mkName "varArgK")))))
[]
]
, BindS (TupP [VarP (mkName "_"), VarP (mkName "sym")]) $
AppE (AppE (AppE (AppE (VarE (mkName "mxSymbolCreateAtomicSymbol"))
(VarE (mkName "op")))
(VarE (mkName "nargs")))
(VarE (mkName "varArgK")))
(VarE (mkName "varArgV"))
, BindS (VarP (mkName "_")) $
AppE (AppE (AppE (AppE (VarE (mkName "nnSymbolCompose"))
(VarE (mkName "sym")))
(VarE (mkName "name")))
(ListE []))
ndargs
, NoBindS $
AppE (VarE (mkName "return"))
(VarE (mkName "sym"))
]
let argT = (ConT . mkName $ "String") : explicitArgT <> (if hasImplicit then implicitArgT else [])
sig = SigD (mkName name) $
forallArgT (foldr (\a b -> ArrowT `AppT` a `AppT` b)
(AppT (ConT (mkName "IO")) (ConT (mkName "SymbolHandle")))
argT)
fun = FunD (mkName name) [Clause (nameArgP <> ndarrayArgP <> ordinaryArgP <> implicitArgP) func []]
return $ if null argv || deprecated
|| alias
|| _name `elem` ["_NDArray", "_Native", "_arange"]
|| _name `elem` ["cast", "crop"]
|| null explicitArg
|| _name == "take"
|| _name == "where"
then []
else [sig, fun]
where
makeHsType :: String -> Type
makeHsType s = case s of
"boolean" -> ConT . mkName $ "Bool"
"float" -> ConT . mkName $ "Float"
"double" -> ConT . mkName $ "Double"
"real_t" -> ConT . mkName $ "Float"
'i':'n':'t':_ -> ConT . mkName $ "Int"
'l':'o':'n':'g':_ -> ConT . mkName $ "Int"
"string" -> ConT . mkName $ "String"
"NDArray" -> ConT . mkName $ "SymbolHandle"
"Symbol" -> ConT . mkName $ "SymbolHandle"
"NDArray-or-Symbol" -> ConT . mkName $ "SymbolHandle"
"NDArray-or-Symbol[]" -> AppT ListT . ConT . mkName $ "SymbolHandle"
"NDArray[]" -> AppT ListT . ConT . mkName $ "SymbolHandle"
"Symbol[]" -> AppT ListT . ConT . mkName $ "SymbolHandle"
"Symbol or Symbol[]" -> AppT ListT . ConT . mkName $ "SymbolHandle"
'{':_ -> ConT . mkName $ "String"
"Shape(tuple)" -> ConT . mkName $ "String"
"tuple of <float>" -> AppT ListT . ConT . mkName $ "Float"
"tuple of <double>" -> AppT ListT . ConT . mkName $ "Double"
s -> ConT . mkName $ "unknown type name: " <> s
makeKVListT :: [(String, String, String)]
-> Type
makeKVListT args = foldr combineKV PromotedNilT ((\(v, t, _) -> makeKV v t) <$> args)
where
makeKV v t = AppT (AppT (PromotedT (mkName ":="))
(LitT (StrTyLit v)))
(makeHsType t)
combineKV a acc = AppT (AppT (PromotedT (mkName ":")) a) acc
makeForallArgT :: [(String, String, String)]
-> (Type -> Type)
makeForallArgT [] = id
makeForallArgT implicitArg =
ForallT [ KindedTV (mkName "kvs")
(AppT ListT
(AppT (ConT (mkName "KV")) StarT))
]
[ AppT (ConT (mkName "ShowKV"))
(VarT (mkName "kvs"))
, AppT (AppT (ConT (mkName "MatchKVList"))
(VarT (mkName "kvs")))
(makeKVListT implicitArg)
]
startWith :: String -> String -> Bool
startWith s t = take (length t) s == t
updateMap :: [(String, String)] -> [(String, String)] -> [(String, String)]
updateMap xs [] = xs
updateMap xs ((k, v) : ts) = case findIndex ((== k) . fst) xs of
Just _ -> xs `updateMap` ts
Nothing -> ((k, v) : xs) `updateMap` ts
splitArgType :: String -> [String]
splitArgType (' ' : xs) = splitArgType xs
splitArgType ts = case break (== ',') ts of
([], _) -> []
(t, []) -> [t]
(t, _:xs) -> t : splitArgType xs
getExplicitArg :: [String]
-> [String]
-> [(String, String)]
getExplicitArg argv argtype = [t | Just t <- resolve <$> zip argv argtype]
where
resolve (v, t) = let ts = splitArgType t
in if "optional" `elem` ts
then Nothing
else if null ts
then Just (v, "tuple of <float>")
else Just (v, head ts)
getImplicitArg :: [String]
-> [String]
-> [(String, String, String)]
getImplicitArg argv argtype = [t | Just t <- resolve <$> zip argv argtype]
where
resolve (v, t) = let ts = splitArgType t
in if "optional" `elem` ts
then (\a -> (v, head ts, a)) <$> getDefault ts
else Nothing
getDefault = stripPrefix "default=" . head . filter (isPrefixOf "default=")