{-# OPTIONS_GHC -Wall #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} ----------------------------------------------------------------------------- -- | -- Module : Menoh -- Copyright : Copyright (c) 2018 Preferred Networks, Inc. -- License : MIT (see the file LICENSE) -- -- Maintainer : Masahiro Sakai -- Stability : experimental -- Portability : non-portable -- -- Haskell binding for /Menoh/ DNN inference library. -- -- = Basic usage -- -- 1. Load computation graph from ONNX file using 'makeModelDataFromONNX'. -- -- 2. Specify input variable type/dimentions (in particular batch size) and -- which output variables you want to retrieve. These information is -- represented as 'VariableProfileTable'. -- Simple way to construct 'VariableProfileTable' is to use 'makeVariableProfileTable'. -- -- 3. Optimize 'ModelData' with respect to your 'VariableProfileTable' by using -- 'optimizeModelData'. -- -- 4. Construct a 'Model' using 'makeModel' or 'makeModelWithConfig'. -- If you want to use custom buffers instead of internally allocated ones, -- You need to use more low level 'ModelBuilder'. -- -- 5. Load input data. This can be done conveniently using 'writeBufferFromVector' -- or 'writeBufferFromStorableVector'. There are also more low-level API such as -- 'unsafeGetBuffer' and 'withBuffer'. -- -- 6. Run inference using 'run'. -- -- 7. Retrieve the result data. This can be done conveniently using 'readBufferToVector' -- or 'readBufferToStorableVector'. -- ----------------------------------------------------------------------------- module Menoh ( -- * Basic data types Dims , DType (..) , HasDType (..) , Error (..) -- * ModelData type , ModelData (..) , makeModelDataFromONNX , optimizeModelData -- * Model type , Model (..) , makeModel , makeModelWithConfig , run , getDType , getDims , unsafeGetBuffer , withBuffer , writeBufferFromVector , writeBufferFromStorableVector , readBufferToVector , readBufferToStorableVector -- * Misc , version , bindingVersion -- * Low-level API -- ** VariableProfileTable , VariableProfileTable (..) , makeVariableProfileTable , vptGetDType , vptGetDims -- ** Builder for 'VariableProfileTable' , VariableProfileTableBuilder (..) , makeVariableProfileTableBuilder , addInputProfileDims2 , addInputProfileDims4 , addOutputProfile , buildVariableProfileTable -- ** Builder for 'Model' , ModelBuilder (..) , makeModelBuilder , attachExternalBuffer , buildModel , buildModelWithConfig ) where import Control.Concurrent import Control.Monad import Control.Monad.Trans.Control (MonadBaseControl, liftBaseOp) import Control.Monad.IO.Class import Control.Exception import qualified Data.Aeson as J import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import Data.Proxy import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import qualified Data.Vector.Generic as VG import Data.IntMap (IntMap) import qualified Data.IntMap as IntMap import Data.Version import Foreign import Foreign.C import qualified Menoh.Base as Base import qualified Paths_menoh #include "MachDeps.h" -- ------------------------------------------------------------------------ -- | Functions in this module can throw this exception type. data Error = ErrorStdError String | ErrorUnknownError String | ErrorInvalidFilename String | ErrorONNXParseError String | ErrorInvalidDType String | ErrorInvalidAttributeType String | ErrorUnsupportedOperatorAttribute String | ErrorDimensionMismatch String | ErrorVariableNotFound String | ErrorIndexOutOfRange String | ErrorJSONParseError String | ErrorInvalidBackendName String | ErrorUnsupportedOperator String | ErrorFailedToConfigureOperator String | ErrorBackendError String | ErrorSameNamedVariableAlreadyExist String deriving (Eq, Ord, Show, Read) instance Exception Error runMenoh :: IO Base.MenohErrorCode -> IO () runMenoh m = runInBoundThread' $ do e <- m if e == Base.menohErrorCodeSuccess then return () else do s <- peekCString =<< Base.menoh_get_last_error_message case IntMap.lookup (fromIntegral e) table of Just ex -> throwIO $ ex s Nothing -> throwIO $ ErrorUnknownError $ s ++ "(error code: " ++ show (fromIntegral e :: Int) ++ ")" where table :: IntMap (String -> Error) table = IntMap.fromList $ map (\(k,v) -> (fromIntegral k, v)) $ [ (Base.menohErrorCodeStdError , ErrorStdError) , (Base.menohErrorCodeUnknownError , ErrorUnknownError) , (Base.menohErrorCodeInvalidFilename , ErrorInvalidFilename) , (Base.menohErrorCodeOnnxParseError , ErrorONNXParseError) , (Base.menohErrorCodeInvalidDtype , ErrorInvalidDType) , (Base.menohErrorCodeInvalidAttributeType , ErrorInvalidAttributeType) , (Base.menohErrorCodeUnsupportedOperatorAttribute , ErrorUnsupportedOperatorAttribute) , (Base.menohErrorCodeDimensionMismatch , ErrorDimensionMismatch) , (Base.menohErrorCodeVariableNotFound , ErrorVariableNotFound) , (Base.menohErrorCodeIndexOutOfRange , ErrorIndexOutOfRange) , (Base.menohErrorCodeJsonParseError , ErrorJSONParseError) , (Base.menohErrorCodeInvalidBackendName , ErrorInvalidBackendName) , (Base.menohErrorCodeUnsupportedOperator , ErrorUnsupportedOperator) , (Base.menohErrorCodeFailedToConfigureOperator , ErrorFailedToConfigureOperator) , (Base.menohErrorCodeBackendError , ErrorBackendError) , (Base.menohErrorCodeSameNamedVariableAlreadyExist , ErrorSameNamedVariableAlreadyExist) ] runInBoundThread' :: IO a -> IO a runInBoundThread' action | rtsSupportsBoundThreads = runInBoundThread action | otherwise = action -- ------------------------------------------------------------------------ -- | Data type of array elements data DType = DTypeFloat -- ^ single precision floating point number | DTypeUnknown !Base.MenohDType -- ^ types that this binding is unware of deriving (Eq, Ord, Show, Read) instance Enum DType where toEnum x | x == fromIntegral Base.menohDtypeFloat = DTypeFloat | otherwise = DTypeUnknown (fromIntegral x) fromEnum DTypeFloat = fromIntegral Base.menohDtypeFloat fromEnum (DTypeUnknown i) = fromIntegral i -- | Haskell types that have associated 'DType' type code. class Storable a => HasDType a where dtypeOf :: Proxy a -> DType instance HasDType CFloat where dtypeOf _ = DTypeFloat #if SIZEOF_HSFLOAT == SIZEOF_FLOAT instance HasDType Float where dtypeOf _ = DTypeFloat #endif -- ------------------------------------------------------------------------ -- | Dimensions of array type Dims = [Int] -- ------------------------------------------------------------------------ -- | @ModelData@ contains model parameters and computation graph structure. newtype ModelData = ModelData (ForeignPtr Base.MenohModelData) -- | Load onnx file and make 'ModelData'. makeModelDataFromONNX :: MonadIO m => FilePath -> m ModelData makeModelDataFromONNX fpath = liftIO $ withCString fpath $ \fpath' -> alloca $ \ret -> do runMenoh $ Base.menoh_make_model_data_from_onnx fpath' ret liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret -- | Optimize function for 'ModelData'. -- -- This function modify given 'ModelData'. optimizeModelData :: MonadIO m => ModelData -> VariableProfileTable -> m () optimizeModelData (ModelData m) (VariableProfileTable vpt) = liftIO $ withForeignPtr m $ \m' -> withForeignPtr vpt $ \vpt' -> runMenoh $ Base.menoh_model_data_optimize m' vpt' -- ------------------------------------------------------------------------ -- | Builder for creation of 'VariableProfileTable'. newtype VariableProfileTableBuilder = VariableProfileTableBuilder (ForeignPtr Base.MenohVariableProfileTableBuilder) -- | Factory function for 'VariableProfileTableBuilder'. makeVariableProfileTableBuilder :: MonadIO m => m VariableProfileTableBuilder makeVariableProfileTableBuilder = liftIO $ alloca $ \p -> do runMenoh $ Base.menoh_make_variable_profile_table_builder p liftM VariableProfileTableBuilder $ newForeignPtr Base.menoh_delete_variable_profile_table_builder_funptr =<< peek p addInputProfileDims :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> Dims -> m () addInputProfileDims vpt name dtype dims = case dims of [num, size] -> addInputProfileDims2 vpt name dtype (num, size) [num, channel, height, width] -> addInputProfileDims4 vpt name dtype (num, channel, height, width) _ -> liftIO $ throwIO $ ErrorDimensionMismatch $ "Menoh.addInputProfileDims: cannot handle dims of length " ++ show (length dims) -- | Add 2D input profile. -- -- Input profile contains name, dtype and dims @(num, size)@. -- This 2D input is conventional batched 1D inputs. addInputProfileDims2 :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> (Int, Int) -- ^ (num, size) -> m () addInputProfileDims2 (VariableProfileTableBuilder vpt) name dtype (num, size) = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_2 vpt' name' (fromIntegral (fromEnum dtype)) (fromIntegral num) (fromIntegral size) -- | Add 4D input profile -- -- Input profile contains name, dtype and dims @(num, channel, height, width)@. -- This 4D input is conventional batched image inputs. Image input is -- 3D (channel, height, width). addInputProfileDims4 :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> (Int, Int, Int, Int) -- ^ (num, channel, height, width) -> m () addInputProfileDims4 (VariableProfileTableBuilder vpt) name dtype (num, channel, height, width) = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_4 vpt' name' (fromIntegral (fromEnum dtype)) (fromIntegral num) (fromIntegral channel) (fromIntegral height) (fromIntegral width) -- | Add output profile -- -- Output profile contains name and dtype. Its 'Dims' are calculated automatically, -- so that you don't need to specify explicitly. addOutputProfile :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> m () addOutputProfile (VariableProfileTableBuilder vpt) name dtype = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_output_profile vpt' name' (fromIntegral (fromEnum dtype)) -- | Factory function for 'VariableProfileTable' buildVariableProfileTable :: MonadIO m => VariableProfileTableBuilder -> ModelData -> m VariableProfileTable buildVariableProfileTable (VariableProfileTableBuilder b) (ModelData m) = liftIO $ withForeignPtr b $ \b' -> withForeignPtr m $ \m' -> alloca $ \ret -> do runMenoh $ Base.menoh_build_variable_profile_table b' m' ret liftM VariableProfileTable $ newForeignPtr Base.menoh_delete_variable_profile_table_funptr =<< peek ret -- ------------------------------------------------------------------------ -- | @VariableProfileTable@ contains information of dtype and dims of variables. -- -- Users can access to dtype and dims via 'vptGetDType' and 'vptGetDims'. newtype VariableProfileTable = VariableProfileTable (ForeignPtr Base.MenohVariableProfileTable) -- | Convenient function for constructing 'VariableProfileTable'. -- -- If you need finer control, you can use 'VariableProfileTableBuidler'. makeVariableProfileTable :: MonadIO m => [(String, DType, Dims)] -- ^ input names with dtypes and dims -> [(String, DType)] -- ^ required output name list with dtypes -> ModelData -- ^ model data -> m VariableProfileTable makeVariableProfileTable input_name_and_dims_pair_list required_output_name_list model_data = liftIO $ do b <- makeVariableProfileTableBuilder forM_ input_name_and_dims_pair_list $ \(name,dtype,dims) -> do addInputProfileDims b name dtype dims forM_ required_output_name_list $ \(name,dtype) -> do addOutputProfile b name dtype buildVariableProfileTable b model_data -- | Accessor function for 'VariableProfileTable' -- -- Select variable name and get its 'DType'. vptGetDType :: MonadIO m => VariableProfileTable -> String -> m DType vptGetDType (VariableProfileTable vpt) name = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret (toEnum . fromIntegral) <$> peek ret -- | Accessor function for 'VariableProfileTable' -- -- Select variable name and get its 'Dims'. vptGetDims :: MonadIO m => VariableProfileTable -> String -> m Dims vptGetDims (VariableProfileTable vpt) name = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret size <- peek ret forM [0..size-1] $ \i -> do runMenoh $ Base.menoh_variable_profile_table_get_dims_at vpt' name' (fromIntegral i) ret fromIntegral <$> peek ret -- ------------------------------------------------------------------------ -- | Helper for creating of 'Model'. newtype ModelBuilder = ModelBuilder (ForeignPtr Base.MenohModelBuilder) -- | Factory function for 'ModelBuilder' makeModelBuilder :: MonadIO m => VariableProfileTable -> m ModelBuilder makeModelBuilder (VariableProfileTable vpt) = liftIO $ withForeignPtr vpt $ \vpt' -> alloca $ \ret -> do runMenoh $ Base.menoh_make_model_builder vpt' ret liftM ModelBuilder $ newForeignPtr Base.menoh_delete_model_builder_funptr =<< peek ret -- | Attach a buffer which allocated by users. -- -- Users can attach a external buffer which they allocated to target variable. -- -- Variables attached no external buffer are attached internal buffers allocated -- automatically. -- -- Users can get that internal buffer handle by calling 'unsafeGetBuffer' etc. later. attachExternalBuffer :: MonadIO m => ModelBuilder -> String -> Ptr a -> m () attachExternalBuffer (ModelBuilder m) name buf = liftIO $ withForeignPtr m $ \m' -> withCString name $ \name' -> runMenoh $ Base.menoh_model_builder_attach_external_buffer m' name' buf -- | Factory function for 'Model'. buildModel :: MonadIO m => ModelBuilder -> ModelData -> String -- ^ backend name -> m Model buildModel builder m backend = liftIO $ withCString "" $ buildModelWithConfigString builder m backend -- | Similar to 'buildModel', but backend specific configuration can be supplied as JSON. buildModelWithConfig :: (MonadIO m, J.ToJSON a) => ModelBuilder -> ModelData -> String -- ^ backend name -> a -- ^ backend config -> m Model buildModelWithConfig builder m backend backend_config = liftIO $ BS.useAsCString (BL.toStrict (J.encode backend_config)) $ buildModelWithConfigString builder m backend buildModelWithConfigString :: MonadIO m => ModelBuilder -> ModelData -> String -- ^ backend name -> CString -- ^ backend config -> m Model buildModelWithConfigString (ModelBuilder builder) (ModelData m) backend backend_config = liftIO $ withForeignPtr builder $ \builder' -> withForeignPtr m $ \m' -> withCString backend $ \backend' -> alloca $ \ret -> do runMenoh $ Base.menoh_build_model builder' m' backend' backend_config ret liftM Model $ newForeignPtr Base.menoh_delete_model_funptr =<< peek ret -- ------------------------------------------------------------------------ -- | ONNX model with input/output buffers newtype Model = Model (ForeignPtr Base.MenohModel) -- | Run model inference. -- -- This function can't be called asynchronously. run :: MonadIO m => Model -> m () run (Model model) = liftIO $ withForeignPtr model $ \model' -> do runMenoh $ Base.menoh_model_run model' -- | Get 'DType' of target variable. getDType :: MonadIO m => Model -> String -> m DType getDType (Model m) name = liftIO $ do withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_model_get_variable_dtype m' name' ret liftM (toEnum . fromIntegral) $ peek ret -- | Get 'Dims' of target variable. getDims :: MonadIO m => Model -> String -> m Dims getDims (Model m) name = liftIO $ do withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_model_get_variable_dims_size m' name' ret size <- peek ret forM [0..size-1] $ \i -> do runMenoh $ Base.menoh_model_get_variable_dims_at m' name' (fromIntegral i) ret fromIntegral <$> peek ret -- | Get a buffer handle attached to target variable. -- -- Users can get a buffer handle attached to target variable. -- If that buffer is allocated by users and attached to the variable by calling -- 'attachExternalBuffer', returned buffer handle is same to it. -- -- This function is unsafe because it does not prevent the model to be GC'ed and -- the returned pointer become dangling pointer. -- -- See also 'withBuffer'. unsafeGetBuffer :: MonadIO m => Model -> String -> m (Ptr a) unsafeGetBuffer (Model m) name = liftIO $ do withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret peek ret -- | This function takes a function which is applied to the buffer associated to specified variable. -- The resulting action is then executed. The buffer is kept alive at least during the whole action, -- even if it is not used directly inside. -- Note that it is not safe to return the pointer from the action and use it after the action completes. -- -- See also 'unsafeGetBuffer'. withBuffer :: forall m r a. (MonadIO m, MonadBaseControl IO m) => Model -> String -> (Ptr a -> m r) -> m r withBuffer (Model m) name f = liftBaseOp (withForeignPtr m) $ \m' -> (liftBaseOp (withCString name) :: (CString -> m r) -> m r) $ \name' -> liftBaseOp alloca $ \ret -> do p <- liftIO $ do runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret peek ret f p checkDType :: String -> DType -> DType -> IO () checkDType name dtype1 dtype2 | dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch" | otherwise = return () checkDTypeAndSize :: String -> (DType,Int) -> (DType,Int) -> IO () checkDTypeAndSize name (dtype1,n1) (dtype2,n2) | dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch" | n1 /= n2 = throwIO $ ErrorDimensionMismatch $ name ++ ": dimension mismatch" | otherwise = return () -- | Copy whole elements of 'VG.Vector' into a model's buffer writeBufferFromVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> v a -> m () writeBufferFromVector model name vec = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name let n = product dims checkDTypeAndSize "Menoh.writeBufferFromVector" (dtype, n) (dtypeOf (Proxy :: Proxy a), VG.length vec) forM_ [0..n-1] $ \i -> do pokeElemOff p i (vec VG.! i) -- | Copy whole elements of @'VS.Vector' a@ into a model's buffer writeBufferFromStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> VS.Vector a -> m () writeBufferFromStorableVector model name vec = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name let n = product dims checkDTypeAndSize "Menoh.writeBufferFromStorableVector" (dtype, n) (dtypeOf (Proxy :: Proxy a), VG.length vec) VS.unsafeWith vec $ \src -> do copyArray p src n -- | Read whole elements of 'Array' and return as a 'VG.Vector'. readBufferToVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> m (v a) readBufferToVector model name = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name checkDType "Menoh.readBufferToVector" dtype (dtypeOf (Proxy :: Proxy a)) let n = product dims VG.generateM n $ peekElemOff p -- | Read whole eleemnts of 'Array' and return as a @'VS.Vector' 'Float'@. readBufferToStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> m (VS.Vector a) readBufferToStorableVector model name = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name checkDType "Menoh.readBufferToStorableVector" dtype (dtypeOf (Proxy :: Proxy a)) let n = product dims vec <- VSM.new n VSM.unsafeWith vec $ \dst -> copyArray dst p n VS.unsafeFreeze vec -- | Convenient methods for constructing a 'Model'. makeModel :: MonadIO m => VariableProfileTable -- ^ variable profile table -> ModelData -- ^ model data -> String -- ^ backend name -> m Model makeModel vpt model_data backend_name = liftIO $ do b <- makeModelBuilder vpt buildModel b model_data backend_name -- | Similar to 'makeModel' but backend-specific configuration can be supplied. makeModelWithConfig :: (MonadIO m, J.ToJSON a) => VariableProfileTable -- ^ variable profile table -> ModelData -- ^ model data -> String -- ^ backend name -> a -- ^ backend config -> m Model makeModelWithConfig vpt model_data backend_name backend_config = liftIO $ do b <- makeModelBuilder vpt buildModelWithConfig b model_data backend_name backend_config -- ------------------------------------------------------------------------ -- | Menoh version which was supplied on compilation time via CPP macro. version :: Version version = makeVersion [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version] -- | Version of this Haskell binding. (Not the version of /Menoh/ itself) bindingVersion :: Version bindingVersion = Paths_menoh.version