{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Menoh
(
Dims
, DType (..)
, Error (..)
, ModelData (..)
, makeModelDataFromONNX
, optimizeModelData
, VariableProfileTable (..)
, makeVariableProfileTable
, vptGetDType
, vptGetDims
, Model (..)
, makeModel
, makeModelWithConfig
, run
, getDType
, getDims
, ToBuffer (..)
, FromBuffer (..)
, writeBuffer
, readBuffer
, unsafeGetBuffer
, withBuffer
, HasDType (..)
, writeBufferFromVector
, writeBufferFromStorableVector
, readBufferToVector
, readBufferToStorableVector
, version
, bindingVersion
, VariableProfileTableBuilder (..)
, makeVariableProfileTableBuilder
, addInputProfileDims2
, addInputProfileDims4
, addOutputProfile
, buildVariableProfileTable
, ModelBuilder (..)
, makeModelBuilder
, attachExternalBuffer
, buildModel
, buildModelWithConfig
) where
import Control.Applicative
import Control.Concurrent
import Control.Monad
import Control.Monad.Trans.Control (MonadBaseControl, liftBaseOp)
import Control.Monad.IO.Class
import Control.Exception
import qualified Data.Aeson as J
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Proxy
import Data.Typeable
import qualified Data.Vector as V
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import qualified Data.Vector.Unboxed as VU
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.Version
import Foreign
import Foreign.C
import qualified Menoh.Base as Base
import qualified Paths_menoh
#include "MachDeps.h"
data Error
= ErrorStdError String
| ErrorUnknownError String
| ErrorInvalidFilename String
| ErrorONNXParseError String
| ErrorInvalidDType String
| ErrorInvalidAttributeType String
| ErrorUnsupportedOperatorAttribute String
| ErrorDimensionMismatch String
| ErrorVariableNotFound String
| ErrorIndexOutOfRange String
| ErrorJSONParseError String
| ErrorInvalidBackendName String
| ErrorUnsupportedOperator String
| ErrorFailedToConfigureOperator String
| ErrorBackendError String
| ErrorSameNamedVariableAlreadyExist String
deriving (Eq, Ord, Show, Read, Typeable)
instance Exception Error
runMenoh :: IO Base.MenohErrorCode -> IO ()
runMenoh m = runInBoundThread' $ do
e <- m
if e == Base.menohErrorCodeSuccess then
return ()
else do
s <- peekCString =<< Base.menoh_get_last_error_message
case IntMap.lookup (fromIntegral e) table of
Just ex -> throwIO $ ex s
Nothing -> throwIO $ ErrorUnknownError $ s ++ "(error code: " ++ show (fromIntegral e :: Int) ++ ")"
where
table :: IntMap (String -> Error)
table = IntMap.fromList $ map (\(k,v) -> (fromIntegral k, v)) $
[ (Base.menohErrorCodeStdError , ErrorStdError)
, (Base.menohErrorCodeUnknownError , ErrorUnknownError)
, (Base.menohErrorCodeInvalidFilename , ErrorInvalidFilename)
, (Base.menohErrorCodeOnnxParseError , ErrorONNXParseError)
, (Base.menohErrorCodeInvalidDtype , ErrorInvalidDType)
, (Base.menohErrorCodeInvalidAttributeType , ErrorInvalidAttributeType)
, (Base.menohErrorCodeUnsupportedOperatorAttribute , ErrorUnsupportedOperatorAttribute)
, (Base.menohErrorCodeDimensionMismatch , ErrorDimensionMismatch)
, (Base.menohErrorCodeVariableNotFound , ErrorVariableNotFound)
, (Base.menohErrorCodeIndexOutOfRange , ErrorIndexOutOfRange)
, (Base.menohErrorCodeJsonParseError , ErrorJSONParseError)
, (Base.menohErrorCodeInvalidBackendName , ErrorInvalidBackendName)
, (Base.menohErrorCodeUnsupportedOperator , ErrorUnsupportedOperator)
, (Base.menohErrorCodeFailedToConfigureOperator , ErrorFailedToConfigureOperator)
, (Base.menohErrorCodeBackendError , ErrorBackendError)
, (Base.menohErrorCodeSameNamedVariableAlreadyExist , ErrorSameNamedVariableAlreadyExist)
]
runInBoundThread' :: IO a -> IO a
runInBoundThread' action
| rtsSupportsBoundThreads = runInBoundThread action
| otherwise = action
data DType
= DTypeFloat
| DTypeUnknown !Base.MenohDType
deriving (Eq, Ord, Show, Read)
instance Enum DType where
toEnum x
| x == fromIntegral Base.menohDtypeFloat = DTypeFloat
| otherwise = DTypeUnknown (fromIntegral x)
fromEnum DTypeFloat = fromIntegral Base.menohDtypeFloat
fromEnum (DTypeUnknown i) = fromIntegral i
dtypeSize :: DType -> Int
dtypeSize DTypeFloat = sizeOf (undefined :: CFloat)
dtypeSize (DTypeUnknown _) = error "Menoh.dtypeSize: unknown DType"
{-# DEPRECATED HasDType "use FromBuffer/ToBuffer instead" #-}
class Storable a => HasDType a where
dtypeOf :: Proxy a -> DType
instance HasDType CFloat where
dtypeOf _ = DTypeFloat
#if SIZEOF_HSFLOAT == SIZEOF_FLOAT
instance HasDType Float where
dtypeOf _ = DTypeFloat
#endif
type Dims = [Int]
newtype ModelData = ModelData (ForeignPtr Base.MenohModelData)
makeModelDataFromONNX :: MonadIO m => FilePath -> m ModelData
makeModelDataFromONNX fpath = liftIO $ withCString fpath $ \fpath' -> alloca $ \ret -> do
runMenoh $ Base.menoh_make_model_data_from_onnx fpath' ret
liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret
optimizeModelData :: MonadIO m => ModelData -> VariableProfileTable -> m ()
optimizeModelData (ModelData m) (VariableProfileTable vpt) = liftIO $
withForeignPtr m $ \m' -> withForeignPtr vpt $ \vpt' ->
runMenoh $ Base.menoh_model_data_optimize m' vpt'
newtype VariableProfileTableBuilder
= VariableProfileTableBuilder (ForeignPtr Base.MenohVariableProfileTableBuilder)
makeVariableProfileTableBuilder :: MonadIO m => m VariableProfileTableBuilder
makeVariableProfileTableBuilder = liftIO $ alloca $ \p -> do
runMenoh $ Base.menoh_make_variable_profile_table_builder p
liftM VariableProfileTableBuilder $ newForeignPtr Base.menoh_delete_variable_profile_table_builder_funptr =<< peek p
addInputProfileDims :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> Dims -> m ()
addInputProfileDims vpt name dtype dims =
case dims of
[num, size] -> addInputProfileDims2 vpt name dtype (num, size)
[num, channel, height, width] -> addInputProfileDims4 vpt name dtype (num, channel, height, width)
_ -> liftIO $ throwIO $ ErrorDimensionMismatch $ "Menoh.addInputProfileDims: cannot handle dims of length " ++ show (length dims)
addInputProfileDims2
:: MonadIO m
=> VariableProfileTableBuilder
-> String
-> DType
-> (Int, Int)
-> m ()
addInputProfileDims2 (VariableProfileTableBuilder vpt) name dtype (num, size) = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' ->
runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_2
vpt' name' (fromIntegral (fromEnum dtype))
(fromIntegral num) (fromIntegral size)
addInputProfileDims4
:: MonadIO m
=> VariableProfileTableBuilder
-> String
-> DType
-> (Int, Int, Int, Int)
-> m ()
addInputProfileDims4 (VariableProfileTableBuilder vpt) name dtype (num, channel, height, width) = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' ->
runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_4
vpt' name' (fromIntegral (fromEnum dtype))
(fromIntegral num) (fromIntegral channel) (fromIntegral height) (fromIntegral width)
addOutputProfile :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> m ()
addOutputProfile (VariableProfileTableBuilder vpt) name dtype = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' ->
runMenoh $ Base.menoh_variable_profile_table_builder_add_output_profile
vpt' name' (fromIntegral (fromEnum dtype))
buildVariableProfileTable
:: MonadIO m
=> VariableProfileTableBuilder
-> ModelData
-> m VariableProfileTable
buildVariableProfileTable (VariableProfileTableBuilder b) (ModelData m) = liftIO $
withForeignPtr b $ \b' -> withForeignPtr m $ \m' -> alloca $ \ret -> do
runMenoh $ Base.menoh_build_variable_profile_table b' m' ret
liftM VariableProfileTable $ newForeignPtr Base.menoh_delete_variable_profile_table_funptr =<< peek ret
newtype VariableProfileTable
= VariableProfileTable (ForeignPtr Base.MenohVariableProfileTable)
makeVariableProfileTable
:: MonadIO m
=> [(String, DType, Dims)]
-> [(String, DType)]
-> ModelData
-> m VariableProfileTable
makeVariableProfileTable input_name_and_dims_pair_list required_output_name_list model_data = liftIO $ runInBoundThread' $ do
b <- makeVariableProfileTableBuilder
forM_ input_name_and_dims_pair_list $ \(name,dtype,dims) -> do
addInputProfileDims b name dtype dims
forM_ required_output_name_list $ \(name,dtype) -> do
addOutputProfile b name dtype
buildVariableProfileTable b model_data
vptGetDType :: MonadIO m => VariableProfileTable -> String -> m DType
vptGetDType (VariableProfileTable vpt) name = liftIO $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret
(toEnum . fromIntegral) <$> peek ret
vptGetDims :: MonadIO m => VariableProfileTable -> String -> m Dims
vptGetDims (VariableProfileTable vpt) name = liftIO $ runInBoundThread' $
withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret
size <- peek ret
forM [0..size-1] $ \i -> do
runMenoh $ Base.menoh_variable_profile_table_get_dims_at vpt' name' (fromIntegral i) ret
fromIntegral <$> peek ret
newtype ModelBuilder = ModelBuilder (ForeignPtr Base.MenohModelBuilder)
makeModelBuilder :: MonadIO m => VariableProfileTable -> m ModelBuilder
makeModelBuilder (VariableProfileTable vpt) = liftIO $
withForeignPtr vpt $ \vpt' -> alloca $ \ret -> do
runMenoh $ Base.menoh_make_model_builder vpt' ret
liftM ModelBuilder $ newForeignPtr Base.menoh_delete_model_builder_funptr =<< peek ret
attachExternalBuffer :: MonadIO m => ModelBuilder -> String -> Ptr a -> m ()
attachExternalBuffer (ModelBuilder m) name buf = liftIO $
withForeignPtr m $ \m' -> withCString name $ \name' ->
runMenoh $ Base.menoh_model_builder_attach_external_buffer m' name' buf
buildModel
:: MonadIO m
=> ModelBuilder
-> ModelData
-> String
-> m Model
buildModel builder m backend = liftIO $
withCString "" $
buildModelWithConfigString builder m backend
buildModelWithConfig
:: (MonadIO m, J.ToJSON a)
=> ModelBuilder
-> ModelData
-> String
-> a
-> m Model
buildModelWithConfig builder m backend backend_config = liftIO $
BS.useAsCString (BL.toStrict (J.encode backend_config)) $
buildModelWithConfigString builder m backend
buildModelWithConfigString
:: MonadIO m
=> ModelBuilder
-> ModelData
-> String
-> CString
-> m Model
buildModelWithConfigString (ModelBuilder builder) (ModelData m) backend backend_config = liftIO $
withForeignPtr builder $ \builder' -> withForeignPtr m $ \m' -> withCString backend $ \backend' -> alloca $ \ret -> do
runMenoh $ Base.menoh_build_model builder' m' backend' backend_config ret
liftM Model $ newForeignPtr Base.menoh_delete_model_funptr =<< peek ret
newtype Model = Model (ForeignPtr Base.MenohModel)
run :: MonadIO m => Model -> m ()
run (Model model) = liftIO $ withForeignPtr model $ \model' -> do
runMenoh $ Base.menoh_model_run model'
getDType :: MonadIO m => Model -> String -> m DType
getDType (Model m) name = liftIO $ do
withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_model_get_variable_dtype m' name' ret
liftM (toEnum . fromIntegral) $ peek ret
getDims :: MonadIO m => Model -> String -> m Dims
getDims (Model m) name = liftIO $ runInBoundThread' $ do
withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_model_get_variable_dims_size m' name' ret
size <- peek ret
forM [0..size-1] $ \i -> do
runMenoh $ Base.menoh_model_get_variable_dims_at m' name' (fromIntegral i) ret
fromIntegral <$> peek ret
unsafeGetBuffer :: MonadIO m => Model -> String -> m (Ptr a)
unsafeGetBuffer (Model m) name = liftIO $ do
withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do
runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret
peek ret
withBuffer :: forall m r a. (MonadIO m, MonadBaseControl IO m) => Model -> String -> (Ptr a -> m r) -> m r
withBuffer (Model m) name f =
liftBaseOp (withForeignPtr m) $ \m' ->
(liftBaseOp (withCString name) :: (CString -> m r) -> m r) $ \name' ->
liftBaseOp alloca $ \ret -> do
p <- liftIO $ do
runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret
peek ret
f p
class ToBuffer a where
basicWriteBuffer :: DType -> Dims -> Ptr () -> a -> IO ()
class FromBuffer a where
basicReadBuffer :: DType -> Dims -> Ptr () -> IO a
readBuffer :: (FromBuffer a, MonadIO m) => Model -> String -> m a
readBuffer model name = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
basicReadBuffer dtype dims p
writeBuffer :: (ToBuffer a, MonadIO m) => Model -> String -> a -> m ()
writeBuffer model name a = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
basicWriteBuffer dtype dims p a
basicWriteBufferGenericVectorStorable
:: forall v a. (VG.Vector v a, Storable a)
=> DType -> DType -> Dims -> Ptr () -> v a -> IO ()
basicWriteBufferGenericVectorStorable dtype0 dtype dims p vec = do
let n = product dims
p' = castPtr p
checkDTypeAndSize "Menoh.basicWriteBufferGenericVectorStorable" (dtype, n) (dtype0, VG.length vec)
forM_ [0..n-1] $ \i -> do
pokeElemOff p' i (vec VG.! i)
basicReadBufferGenericVectorStorable
:: forall v a. (VG.Vector v a, Storable a)
=> DType -> DType -> Dims -> Ptr () -> IO (v a)
basicReadBufferGenericVectorStorable dtype0 dtype dims p = do
checkDType "Menoh.basicReadBufferGenericVectorStorable" dtype dtype0
let n = product dims
p' = castPtr p
VG.generateM n $ peekElemOff p'
basicWriteBufferStorableVector
:: forall a. (Storable a)
=> DType -> DType -> Dims -> Ptr () -> VS.Vector a -> IO ()
basicWriteBufferStorableVector dtype0 dtype dims p vec = do
let n = product dims
checkDTypeAndSize "Menoh.basicWriteBufferStorableVector" (dtype, n) (dtype0, VG.length vec)
VS.unsafeWith vec $ \src -> do
copyArray (castPtr p) src n
basicReadBufferStorableVector
:: forall a. (Storable a)
=> DType -> DType -> Dims -> Ptr () -> IO (VS.Vector a)
basicReadBufferStorableVector dtype0 dtype dims p = do
checkDType "Menoh.basicReadBufferStorableVector" dtype dtype0
let n = product dims
vec <- VSM.new n
VSM.unsafeWith vec $ \dst -> copyArray dst (castPtr p) n
VS.unsafeFreeze vec
instance ToBuffer (V.Vector Float) where
basicWriteBuffer = basicWriteBufferGenericVectorStorable DTypeFloat
instance FromBuffer (V.Vector Float) where
basicReadBuffer = basicReadBufferGenericVectorStorable DTypeFloat
instance ToBuffer (VU.Vector Float) where
basicWriteBuffer = basicWriteBufferGenericVectorStorable DTypeFloat
instance FromBuffer (VU.Vector Float) where
basicReadBuffer = basicReadBufferGenericVectorStorable DTypeFloat
instance ToBuffer (VS.Vector Float) where
basicWriteBuffer = basicWriteBufferStorableVector DTypeFloat
instance FromBuffer (VS.Vector Float) where
basicReadBuffer = basicReadBufferStorableVector DTypeFloat
instance ToBuffer a => ToBuffer [a] where
basicWriteBuffer _dtype [] _p _xs =
throwIO $ ErrorDimensionMismatch $ "ToBuffer{[a]}.basicWriteBuffer: empty dims"
basicWriteBuffer dtype (dim : dims) p xs = do
unless (dim == length xs) $ do
throwIO $ ErrorDimensionMismatch $ "ToBuffer{[a]}.basicWriteBuffer: dimension mismatch"
let s = product dims * dtypeSize dtype
forM_ (zip [0,s..] xs) $ \(offset,x) -> do
basicWriteBuffer dtype dims (p `plusPtr` offset) x
instance FromBuffer a => FromBuffer [a] where
basicReadBuffer _dtype [] _p =
throwIO $ ErrorDimensionMismatch $ "FromBuffer{[a]}.basicReadBuffer: empty dims"
basicReadBuffer dtype (dim : dims) p = do
let s = product dims * dtypeSize dtype
forM [0..dim-1] $ \i -> do
basicReadBuffer dtype dims (p `plusPtr` (i*s))
checkDType :: String -> DType -> DType -> IO ()
checkDType name dtype1 dtype2
| dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch"
| otherwise = return ()
checkDTypeAndSize :: String -> (DType,Int) -> (DType,Int) -> IO ()
checkDTypeAndSize name (dtype1,n1) (dtype2,n2)
| dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch"
| n1 /= n2 = throwIO $ ErrorDimensionMismatch $ name ++ ": dimension mismatch"
| otherwise = return ()
{-# DEPRECATED writeBufferFromVector, writeBufferFromStorableVector "Use ToBuffer class and writeBuffer instead" #-}
writeBufferFromVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> v a -> m ()
writeBufferFromVector model name vec = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
basicWriteBufferGenericVectorStorable (dtypeOf (Proxy :: Proxy a)) dtype dims p vec
writeBufferFromStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> VS.Vector a -> m ()
writeBufferFromStorableVector model name vec = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
basicWriteBufferStorableVector (dtypeOf (Proxy :: Proxy a)) dtype dims p vec
{-# DEPRECATED readBufferToVector, readBufferToStorableVector "Use FromBuffer class and readBuffer instead" #-}
readBufferToVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> m (v a)
readBufferToVector model name = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
basicReadBufferGenericVectorStorable (dtypeOf (Proxy :: Proxy a)) dtype dims p
readBufferToStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> m (VS.Vector a)
readBufferToStorableVector model name = liftIO $ withBuffer model name $ \p -> do
dtype <- getDType model name
dims <- getDims model name
basicReadBufferStorableVector (dtypeOf (Proxy :: Proxy a)) dtype dims p
makeModel
:: MonadIO m
=> VariableProfileTable
-> ModelData
-> String
-> m Model
makeModel vpt model_data backend_name = liftIO $ do
b <- makeModelBuilder vpt
buildModel b model_data backend_name
makeModelWithConfig
:: (MonadIO m, J.ToJSON a)
=> VariableProfileTable
-> ModelData
-> String
-> a
-> m Model
makeModelWithConfig vpt model_data backend_name backend_config = liftIO $ do
b <- makeModelBuilder vpt
buildModelWithConfig b model_data backend_name backend_config
version :: Version
#if MIN_VERSION_base(4,8,0)
version = makeVersion [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version]
#else
version = Version [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version] []
#endif
bindingVersion :: Version
bindingVersion = Paths_menoh.version