{-# OPTIONS_GHC -Wall #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} ----------------------------------------------------------------------------- -- | -- Module : Menoh -- Copyright : Copyright (c) 2018 Preferred Networks, Inc. -- License : MIT (see the file LICENSE) -- -- Maintainer : Masahiro Sakai -- Stability : experimental -- Portability : non-portable -- -- Haskell binding for /Menoh/ DNN inference library. -- -- = Basic usage -- -- 1. Load computation graph from ONNX file using 'makeModelDataFromONNXFile'. -- 2. Specify input variable type/dimentions (in particular batch size) and -- which output variables you want to retrieve. This can be done by -- constructing 'VariableProfileTable' using 'makeVariableProfileTable'. -- 3. Optimize 'ModelData' with respect to your 'VariableProfileTable' by using -- 'optimizeModelData'. -- 4. Construct a 'Model' using 'makeModel' or 'makeModelWithConfig'. -- If you want to use custom buffers instead of internally allocated ones, -- You need to use more low level 'ModelBuilder'. -- 5. Load input data. This can be done conveniently using 'writeBuffer'. -- There are also more low-level API such as 'unsafeGetBuffer' and 'withBuffer'. -- 6. Run inference using 'run'. -- 7. Retrieve the result data. This can be done conveniently using 'readBuffer'. -- -- = Note on thread safety -- -- TL;DR: If you want to use Menoh from multiple haskell threads, you need to -- use /threaded/ RTS by supplying @-threaded@ option to GHC. -- -- Menoh uses thread local storage (TLS) for storing error information, and -- the only way to use TLS safely is to use in /bound/ threads -- (see "Control.Concurrent#boundthreads"). -- -- * In /threaded RTS/ (i.e. 'rtsSupportsBoundThreads' is True), this module -- runs computation in bound threads by using 'runInBoundThread'. (If the -- calling thread is not bound, 'runInBoundThread' create a bound thread -- temporarily and run the computation inside it). -- -- * In /non-threaded RTS/, this module /does not/ use 'runInBoundThread' and -- is therefore unsafe to use from multiple haskell threads. Using non-threaded -- RTS is allowed for the sake of convenience (e.g. running in GHCi) despite -- its unsafety. -- ----------------------------------------------------------------------------- #include "MachDeps.h" #include #define MIN_VERSION_libmenoh(major,minor,patch) (\ (major) < MENOH_MAJOR_VERSION || \ (major) == MENOH_MAJOR_VERSION && (minor) < MENOH_MINOR_VERSION || \ (major) == MENOH_MAJOR_VERSION && (minor) == MENOH_MINOR_VERSION && (patch) <= MENOH_PATCH_VERSION) module Menoh ( -- * Basic data types Dims , DType (..) , Error (..) -- * ModelData type , ModelData (..) , makeModelDataFromONNXFile , makeModelDataFromONNX , makeModelDataFromONNXByteString , optimizeModelData -- ** Manual model data construction API , makeModelData , addParameterFromPtr , addNewNode , addInputNameToCurrentNode , addOutputNameToCurrentNode , AttributeType (..) , addAttribute -- * VariableProfileTable , VariableProfileTable (..) , makeVariableProfileTable , vptGetDType , vptGetDims -- * Model type , Model (..) , makeModel , makeModelWithConfig , run , getDType , getDims -- ** Accessors for buffers , ToBuffer (..) , FromBuffer (..) , writeBuffer , readBuffer -- ** Low-level accessors for buffers , unsafeGetBuffer , withBuffer -- ** Deprecated accessors for buffers , HasDType (..) , writeBufferFromVector , writeBufferFromStorableVector , readBufferToVector , readBufferToStorableVector -- * Misc , version , bindingVersion -- * Low-level API -- ** Builder for 'VariableProfileTable' , VariableProfileTableBuilder (..) , makeVariableProfileTableBuilder , addInputProfileDims2 , addInputProfileDims4 , addOutputName , addOutputProfile , AddOutput (..) , buildVariableProfileTable -- ** Builder for 'Model' , ModelBuilder (..) , makeModelBuilder , attachExternalBuffer , buildModel , buildModelWithConfig ) where import Control.Applicative import Control.Concurrent import Control.Monad import Control.Monad.Trans.Control (MonadBaseControl, liftBaseOp) import Control.Monad.IO.Class import Control.Exception import qualified Data.Aeson as J import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import Data.Proxy import Data.Typeable import qualified Data.Vector as V import qualified Data.Vector.Generic as VG import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import qualified Data.Vector.Unboxed as VU import Data.IntMap (IntMap) import qualified Data.IntMap as IntMap import Data.Version import Foreign import Foreign.C import qualified Menoh.Base as Base import qualified Paths_menoh -- ------------------------------------------------------------------------ -- | Functions in this module can throw this exception type. data Error = ErrorStdError String | ErrorUnknownError String | ErrorInvalidFilename String | ErrorONNXParseError String | ErrorInvalidDType String | ErrorInvalidAttributeType String | ErrorUnsupportedOperatorAttribute String | ErrorDimensionMismatch String | ErrorVariableNotFound String | ErrorIndexOutOfRange String | ErrorJSONParseError String | ErrorInvalidBackendName String | ErrorUnsupportedOperator String | ErrorFailedToConfigureOperator String | ErrorBackendError String | ErrorSameNamedVariableAlreadyExist String | UnsupportedInputDims String | SameNamedParameterAlreadyExist String | SameNamedAttributeAlreadyExist String | InvalidBackendConfigError String | InputNotFoundError String | OutputNotFoundError String deriving (Eq, Ord, Show, Read, Typeable) instance Exception Error runMenoh :: IO Base.MenohErrorCode -> IO () runMenoh m = runInBoundThread' $ do e <- m if e == Base.menohErrorCodeSuccess then return () else do s <- peekCString =<< Base.menoh_get_last_error_message case IntMap.lookup (fromIntegral e) table of Just ex -> throwIO $ ex s Nothing -> throwIO $ ErrorUnknownError $ s ++ "(error code: " ++ show (fromIntegral e :: Int) ++ ")" where table :: IntMap (String -> Error) table = IntMap.fromList $ map (\(k,v) -> (fromIntegral k, v)) $ [ (Base.menohErrorCodeStdError , ErrorStdError) , (Base.menohErrorCodeUnknownError , ErrorUnknownError) , (Base.menohErrorCodeInvalidFilename , ErrorInvalidFilename) , (Base.menohErrorCodeOnnxParseError , ErrorONNXParseError) , (Base.menohErrorCodeInvalidDtype , ErrorInvalidDType) , (Base.menohErrorCodeInvalidAttributeType , ErrorInvalidAttributeType) , (Base.menohErrorCodeUnsupportedOperatorAttribute , ErrorUnsupportedOperatorAttribute) , (Base.menohErrorCodeDimensionMismatch , ErrorDimensionMismatch) , (Base.menohErrorCodeVariableNotFound , ErrorVariableNotFound) , (Base.menohErrorCodeIndexOutOfRange , ErrorIndexOutOfRange) , (Base.menohErrorCodeJsonParseError , ErrorJSONParseError) , (Base.menohErrorCodeInvalidBackendName , ErrorInvalidBackendName) , (Base.menohErrorCodeUnsupportedOperator , ErrorUnsupportedOperator) , (Base.menohErrorCodeFailedToConfigureOperator , ErrorFailedToConfigureOperator) , (Base.menohErrorCodeBackendError , ErrorBackendError) , (Base.menohErrorCodeSameNamedVariableAlreadyExist , ErrorSameNamedVariableAlreadyExist) , (Base.menohErrorCodeUnsupportedInputDims , UnsupportedInputDims) , (Base.menohErrorCodeSameNamedParameterAlreadyExist, SameNamedParameterAlreadyExist) , (Base.menohErrorCodeSameNamedAttributeAlreadyExist, SameNamedAttributeAlreadyExist) , (Base.menohErrorCodeInvalidBackendConfigError , InvalidBackendConfigError) , (Base.menohErrorCodeInputNotFoundError , InputNotFoundError) , (Base.menohErrorCodeOutputNotFoundError , OutputNotFoundError) ] runInBoundThread' :: IO a -> IO a runInBoundThread' action | rtsSupportsBoundThreads = runInBoundThread action | otherwise = action -- ------------------------------------------------------------------------ -- | Data type of array elements data DType = DTypeFloat -- ^ single precision floating point number | DTypeUnknown !Base.MenohDType -- ^ types that this binding is unware of deriving (Eq, Ord, Show, Read) instance Enum DType where toEnum x | x == fromIntegral Base.menohDtypeFloat = DTypeFloat | otherwise = DTypeUnknown (fromIntegral x) fromEnum DTypeFloat = fromIntegral Base.menohDtypeFloat fromEnum (DTypeUnknown i) = fromIntegral i dtypeSize :: DType -> Int dtypeSize DTypeFloat = sizeOf (undefined :: CFloat) dtypeSize (DTypeUnknown _) = error "Menoh.dtypeSize: unknown DType" {-# DEPRECATED HasDType "use FromBuffer/ToBuffer instead" #-} -- | Haskell types that have associated 'DType' type code. class Storable a => HasDType a where dtypeOf :: Proxy a -> DType instance HasDType CFloat where dtypeOf _ = DTypeFloat #if SIZEOF_HSFLOAT == SIZEOF_FLOAT instance HasDType Float where dtypeOf _ = DTypeFloat #endif -- ------------------------------------------------------------------------ -- | Dimensions of array type Dims = [Int] -- ------------------------------------------------------------------------ -- | @ModelData@ contains model parameters and computation graph structure. newtype ModelData = ModelData (ForeignPtr Base.MenohModelData) {-# DEPRECATED makeModelDataFromONNX "use makeModelDataFromONNXFile instead" #-} -- | Load onnx file and make 'ModelData'. makeModelDataFromONNX :: MonadIO m => FilePath -> m ModelData makeModelDataFromONNX = makeModelDataFromONNXFile -- | Load onnx file and make 'ModelData'. makeModelDataFromONNXFile :: MonadIO m => FilePath -> m ModelData makeModelDataFromONNXFile fpath = liftIO $ withCString fpath $ \fpath' -> alloca $ \ret -> do runMenoh $ Base.menoh_make_model_data_from_onnx fpath' ret liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret -- | make 'ModelData' from on-memory 'BS.ByteString'. makeModelDataFromONNXByteString :: MonadIO m => BS.ByteString -> m ModelData makeModelDataFromONNXByteString b = liftIO $ BS.useAsCStringLen b $ \(p,len) -> alloca $ \ret -> do runMenoh $ Base.menoh_make_model_data_from_onnx_data_on_memory p (fromIntegral len) ret liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret -- | Optimize function for 'ModelData'. -- -- This function modify given 'ModelData'. optimizeModelData :: MonadIO m => ModelData -> VariableProfileTable -> m () optimizeModelData (ModelData m) (VariableProfileTable vpt) = liftIO $ withForeignPtr m $ \m' -> withForeignPtr vpt $ \vpt' -> runMenoh $ Base.menoh_model_data_optimize m' vpt' -- | Make empty 'ModelData' makeModelData :: MonadIO m => m ModelData makeModelData = liftIO $ alloca $ \ret -> do runMenoh $ Base.menoh_make_model_data ret liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret -- | Add a new parameter in 'ModelData' -- -- This API is tentative and will be changed in the future. -- -- Duplication of parameter_name is not allowed and it throws error. addParameterFromPtr :: MonadIO m => ModelData -> String -> DType -> Dims -> Ptr a -> m () addParameterFromPtr (ModelData m) name dtype dims p = liftIO $ withForeignPtr m $ \m' -> withCString name $ \name' -> withArrayLen (map fromIntegral dims) $ \n dims' -> runMenoh $ Base.menoh_model_data_add_parameter m' name' (fromIntegral (fromEnum dtype)) (fromIntegral n) dims' p -- | Add a new node to 'ModelData' addNewNode :: MonadIO m => ModelData -> String -> m () addNewNode (ModelData m) name = liftIO $ withForeignPtr m $ \m' -> withCString name $ \name' -> runMenoh $ Base.menoh_model_data_add_new_node m' name' -- | Add a new input name to latest added node in 'ModelData' addInputNameToCurrentNode :: MonadIO m => ModelData -> String -> m () addInputNameToCurrentNode (ModelData m) name = liftIO $ withForeignPtr m $ \m' -> withCString name $ \name' -> runMenoh $ Base.menoh_model_data_add_input_name_to_current_node m' name' -- | Add a new output name to latest added node in 'ModelData' addOutputNameToCurrentNode :: MonadIO m => ModelData -> String -> m () addOutputNameToCurrentNode (ModelData m) name = liftIO $ withForeignPtr m $ \m' -> withCString name $ \name' -> runMenoh $ Base.menoh_model_data_add_output_name_to_current_node m' name' -- | A class of types that can be added to nodes using 'addAttribute'. class AttributeType value where basicAddAttribute :: Ptr Base.MenohModelData -> CString -> value -> IO () instance AttributeType Int where basicAddAttribute m' name' value = runMenoh $ Base.menoh_model_data_add_attribute_int_to_current_node m' name' (fromIntegral value) instance AttributeType Float where basicAddAttribute m' name' value = runMenoh $ Base.menoh_model_data_add_attribute_float_to_current_node m' name' (realToFrac value) instance AttributeType [Int] where basicAddAttribute m' name' values = withArrayLen (map fromIntegral values) $ \n values' -> runMenoh $ Base.menoh_model_data_add_attribute_ints_to_current_node m' name' (fromIntegral n) values' instance AttributeType [Float] where basicAddAttribute m' name' values = withArrayLen (map realToFrac values) $ \n values' -> runMenoh $ Base.menoh_model_data_add_attribute_floats_to_current_node m' name' (fromIntegral n) values' -- | Add a new attribute to latest added node in model_data addAttribute :: (AttributeType value, MonadIO m) => ModelData -> String -> value -> m () addAttribute (ModelData m) name value = liftIO $ withForeignPtr m $ \m' -> withCString name $ \name' -> basicAddAttribute m' name' value -- ------------------------------------------------------------------------ -- | Builder for creation of 'VariableProfileTable'. newtype VariableProfileTableBuilder = VariableProfileTableBuilder (ForeignPtr Base.MenohVariableProfileTableBuilder) -- | Factory function for 'VariableProfileTableBuilder'. makeVariableProfileTableBuilder :: MonadIO m => m VariableProfileTableBuilder makeVariableProfileTableBuilder = liftIO $ alloca $ \p -> do runMenoh $ Base.menoh_make_variable_profile_table_builder p liftM VariableProfileTableBuilder $ newForeignPtr Base.menoh_delete_variable_profile_table_builder_funptr =<< peek p addInputProfileDims :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> Dims -> m () addInputProfileDims (VariableProfileTableBuilder vpt) name dtype dims = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> withArrayLen (map fromIntegral dims) $ \n dims' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile vpt' name' (fromIntegral (fromEnum dtype)) (fromIntegral n) dims' -- | Add 2D input profile. -- -- Input profile contains name, dtype and dims @(num, size)@. -- This 2D input is conventional batched 1D inputs. {-# DEPRECATED addInputProfileDims2 "use addInputProfileDims instead" #-} addInputProfileDims2 :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> (Int, Int) -- ^ (num, size) -> m () addInputProfileDims2 (VariableProfileTableBuilder vpt) name dtype (num, size) = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_2 vpt' name' (fromIntegral (fromEnum dtype)) (fromIntegral num) (fromIntegral size) -- | Add 4D input profile -- -- Input profile contains name, dtype and dims @(num, channel, height, width)@. -- This 4D input is conventional batched image inputs. Image input is -- 3D (channel, height, width). {-# DEPRECATED addInputProfileDims4 "use addInputProfileDims instead" #-} addInputProfileDims4 :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> (Int, Int, Int, Int) -- ^ (num, channel, height, width) -> m () addInputProfileDims4 (VariableProfileTableBuilder vpt) name dtype (num, channel, height, width) = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_4 vpt' name' (fromIntegral (fromEnum dtype)) (fromIntegral num) (fromIntegral channel) (fromIntegral height) (fromIntegral width) -- | Add output name -- -- Output profile contains name and dtype. Its 'Dims' and 'DType' are calculated -- automatically, so that you don't need to specify explicitly. addOutputName :: MonadIO m => VariableProfileTableBuilder -> String -> m () addOutputName (VariableProfileTableBuilder vpt) name = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_output_name vpt' name' {-# DEPRECATED addOutputProfile "use addOutputName instead" #-} -- | Add output profile -- -- Output profile contains name and dtype. Its 'Dims' are calculated automatically, -- so that you don't need to specify explicitly. addOutputProfile :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> m () addOutputProfile (VariableProfileTableBuilder vpt) name dtype = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> runMenoh $ Base.menoh_variable_profile_table_builder_add_output_profile vpt' name' (fromIntegral (fromEnum dtype)) -- | Type class for abstracting 'addOutputProfile' and 'addOutputName'. class AddOutput a where addOutput :: VariableProfileTableBuilder -> a -> IO () instance AddOutput String where addOutput = addOutputName instance AddOutput (String, DType) where addOutput b (name,_dtype) = addOutputName b name -- | Factory function for 'VariableProfileTable' buildVariableProfileTable :: MonadIO m => VariableProfileTableBuilder -> ModelData -> m VariableProfileTable buildVariableProfileTable (VariableProfileTableBuilder b) (ModelData m) = liftIO $ withForeignPtr b $ \b' -> withForeignPtr m $ \m' -> alloca $ \ret -> do runMenoh $ Base.menoh_build_variable_profile_table b' m' ret liftM VariableProfileTable $ newForeignPtr Base.menoh_delete_variable_profile_table_funptr =<< peek ret -- ------------------------------------------------------------------------ -- | @VariableProfileTable@ contains information of dtype and dims of variables. -- -- Users can access to dtype and dims via 'vptGetDType' and 'vptGetDims'. newtype VariableProfileTable = VariableProfileTable (ForeignPtr Base.MenohVariableProfileTable) -- | Convenient function for constructing 'VariableProfileTable'. -- -- If you need finer control, you can use 'VariableProfileTableBuidler'. makeVariableProfileTable :: (AddOutput a, MonadIO m) => [(String, DType, Dims)] -- ^ input names with dtypes and dims -> [a] -- ^ required output informations (@`String`@ or @('String', 'DType')@) -> ModelData -- ^ model data -> m VariableProfileTable makeVariableProfileTable input_name_and_dims_pair_list required_output_name_list model_data = liftIO $ runInBoundThread' $ do b <- makeVariableProfileTableBuilder forM_ input_name_and_dims_pair_list $ \(name,dtype,dims) -> do addInputProfileDims b name dtype dims mapM_ (addOutput b) required_output_name_list buildVariableProfileTable b model_data -- | Accessor function for 'VariableProfileTable' -- -- Select variable name and get its 'DType'. vptGetDType :: MonadIO m => VariableProfileTable -> String -> m DType vptGetDType (VariableProfileTable vpt) name = liftIO $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret (toEnum . fromIntegral) <$> peek ret -- | Accessor function for 'VariableProfileTable' -- -- Select variable name and get its 'Dims'. vptGetDims :: MonadIO m => VariableProfileTable -> String -> m Dims vptGetDims (VariableProfileTable vpt) name = liftIO $ runInBoundThread' $ withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret size <- peek ret forM [0..size-1] $ \i -> do runMenoh $ Base.menoh_variable_profile_table_get_dims_at vpt' name' (fromIntegral i) ret fromIntegral <$> peek ret -- ------------------------------------------------------------------------ -- | Helper for creating of 'Model'. newtype ModelBuilder = ModelBuilder (ForeignPtr Base.MenohModelBuilder) -- | Factory function for 'ModelBuilder' makeModelBuilder :: MonadIO m => VariableProfileTable -> m ModelBuilder makeModelBuilder (VariableProfileTable vpt) = liftIO $ withForeignPtr vpt $ \vpt' -> alloca $ \ret -> do runMenoh $ Base.menoh_make_model_builder vpt' ret liftM ModelBuilder $ newForeignPtr Base.menoh_delete_model_builder_funptr =<< peek ret -- | Attach a buffer which allocated by users. -- -- Users can attach a external buffer which they allocated to target variable. -- -- Variables attached no external buffer are attached internal buffers allocated -- automatically. -- -- Users can get that internal buffer handle by calling 'unsafeGetBuffer' etc. later. attachExternalBuffer :: MonadIO m => ModelBuilder -> String -> Ptr a -> m () attachExternalBuffer (ModelBuilder m) name buf = liftIO $ withForeignPtr m $ \m' -> withCString name $ \name' -> runMenoh $ Base.menoh_model_builder_attach_external_buffer m' name' buf -- | Factory function for 'Model'. buildModel :: MonadIO m => ModelBuilder -> ModelData -> String -- ^ backend name -> m Model buildModel builder m backend = liftIO $ withCString "" $ buildModelWithConfigString builder m backend -- | Similar to 'buildModel', but backend specific configuration can be supplied as JSON. buildModelWithConfig :: (MonadIO m, J.ToJSON a) => ModelBuilder -> ModelData -> String -- ^ backend name -> a -- ^ backend config -> m Model buildModelWithConfig builder m backend backend_config = liftIO $ BS.useAsCString (BL.toStrict (J.encode backend_config)) $ buildModelWithConfigString builder m backend buildModelWithConfigString :: MonadIO m => ModelBuilder -> ModelData -> String -- ^ backend name -> CString -- ^ backend config -> m Model buildModelWithConfigString (ModelBuilder builder) (ModelData m) backend backend_config = liftIO $ withForeignPtr builder $ \builder' -> withForeignPtr m $ \m' -> withCString backend $ \backend' -> alloca $ \ret -> do runMenoh $ Base.menoh_build_model builder' m' backend' backend_config ret liftM Model $ newForeignPtr Base.menoh_delete_model_funptr =<< peek ret -- ------------------------------------------------------------------------ -- | ONNX model with input/output buffers newtype Model = Model (ForeignPtr Base.MenohModel) -- | Run model inference. -- -- This function can't be called asynchronously. run :: MonadIO m => Model -> m () run (Model model) = liftIO $ withForeignPtr model $ \model' -> do runMenoh $ Base.menoh_model_run model' -- | Get 'DType' of target variable. getDType :: MonadIO m => Model -> String -> m DType getDType (Model m) name = liftIO $ do withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_model_get_variable_dtype m' name' ret liftM (toEnum . fromIntegral) $ peek ret -- | Get 'Dims' of target variable. getDims :: MonadIO m => Model -> String -> m Dims getDims (Model m) name = liftIO $ runInBoundThread' $ do withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_model_get_variable_dims_size m' name' ret size <- peek ret forM [0..size-1] $ \i -> do runMenoh $ Base.menoh_model_get_variable_dims_at m' name' (fromIntegral i) ret fromIntegral <$> peek ret -- ------------------------------------------------------------------------ -- Accessing buffers -- | Get a buffer handle attached to target variable. -- -- Users can get a buffer handle attached to target variable. -- If that buffer is allocated by users and attached to the variable by calling -- 'attachExternalBuffer', returned buffer handle is same to it. -- -- This function is unsafe because it does not prevent the model to be GC'ed and -- the returned pointer become dangling pointer. -- -- See also 'withBuffer'. unsafeGetBuffer :: MonadIO m => Model -> String -> m (Ptr a) unsafeGetBuffer (Model m) name = liftIO $ do withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret peek ret -- | This function takes a function which is applied to the buffer associated to specified variable. -- The resulting action is then executed. The buffer is kept alive at least during the whole action, -- even if it is not used directly inside. -- Note that it is not safe to return the pointer from the action and use it after the action completes. -- -- See also 'unsafeGetBuffer'. withBuffer :: forall m r a. (MonadIO m, MonadBaseControl IO m) => Model -> String -> (Ptr a -> m r) -> m r withBuffer (Model m) name f = liftBaseOp (withForeignPtr m) $ \m' -> (liftBaseOp (withCString name) :: (CString -> m r) -> m r) $ \name' -> liftBaseOp alloca $ \ret -> do p <- liftIO $ do runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret peek ret f p -- | Type that can be written to menoh's buffer. class ToBuffer a where -- Basic method for implementing @ToBuffer@ class. -- Normal user should use 'writeBuffer' instead. basicWriteBuffer :: DType -> Dims -> Ptr () -> a -> IO () -- | Type that can be read from menoh's buffer. class FromBuffer a where -- Basic method for implementing @FromBuffer@ class. -- Normal user should use 'readBuffer' instead. basicReadBuffer :: DType -> Dims -> Ptr () -> IO a -- | Read values from the given model's buffer readBuffer :: (FromBuffer a, MonadIO m) => Model -> String -> m a readBuffer model name = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name basicReadBuffer dtype dims p -- | Write values to the given model's buffer writeBuffer :: (ToBuffer a, MonadIO m) => Model -> String -> a -> m () writeBuffer model name a = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name basicWriteBuffer dtype dims p a -- | Default implementation of 'basicWriteBuffer' for 'VG.Vector' class -- for the cases whete the 'Storable' is compatible for representation in buffers. basicWriteBufferGenericVectorStorable :: forall v a. (VG.Vector v a, Storable a) => DType -> DType -> Dims -> Ptr () -> v a -> IO () basicWriteBufferGenericVectorStorable dtype0 dtype dims p vec = do let n = product dims p' = castPtr p checkDTypeAndSize "Menoh.basicWriteBufferGenericVectorStorable" (dtype, n) (dtype0, VG.length vec) forM_ [0..n-1] $ \i -> do pokeElemOff p' i (vec VG.! i) -- | Default implementation of 'basicReadToBuffer' for 'VG.Vector' class -- for the cases whete the 'Storable' is compatible for representation in buffers. basicReadBufferGenericVectorStorable :: forall v a. (VG.Vector v a, Storable a) => DType -> DType -> Dims -> Ptr () -> IO (v a) basicReadBufferGenericVectorStorable dtype0 dtype dims p = do checkDType "Menoh.basicReadBufferGenericVectorStorable" dtype dtype0 let n = product dims p' = castPtr p VG.generateM n $ peekElemOff p' -- | Default implementation of 'basicWriteBuffer' for 'VS.Vector' class -- for the cases whete the 'Storable' is compatible for representation in buffers. basicWriteBufferStorableVector :: forall a. (Storable a) => DType -> DType -> Dims -> Ptr () -> VS.Vector a -> IO () basicWriteBufferStorableVector dtype0 dtype dims p vec = do let n = product dims checkDTypeAndSize "Menoh.basicWriteBufferStorableVector" (dtype, n) (dtype0, VG.length vec) VS.unsafeWith vec $ \src -> do copyArray (castPtr p) src n -- | Default implementation of 'basicReadToBuffer' for 'VS.Vector' class -- for the cases whete the 'Storable' is compatible for representation in buffers. basicReadBufferStorableVector :: forall a. (Storable a) => DType -> DType -> Dims -> Ptr () -> IO (VS.Vector a) basicReadBufferStorableVector dtype0 dtype dims p = do checkDType "Menoh.basicReadBufferStorableVector" dtype dtype0 let n = product dims vec <- VSM.new n VSM.unsafeWith vec $ \dst -> copyArray dst (castPtr p) n VS.unsafeFreeze vec instance ToBuffer (V.Vector Float) where basicWriteBuffer = basicWriteBufferGenericVectorStorable DTypeFloat instance FromBuffer (V.Vector Float) where basicReadBuffer = basicReadBufferGenericVectorStorable DTypeFloat instance ToBuffer (VU.Vector Float) where basicWriteBuffer = basicWriteBufferGenericVectorStorable DTypeFloat instance FromBuffer (VU.Vector Float) where basicReadBuffer = basicReadBufferGenericVectorStorable DTypeFloat instance ToBuffer (VS.Vector Float) where basicWriteBuffer = basicWriteBufferStorableVector DTypeFloat instance FromBuffer (VS.Vector Float) where basicReadBuffer = basicReadBufferStorableVector DTypeFloat instance ToBuffer a => ToBuffer [a] where basicWriteBuffer _dtype [] _p _xs = throwIO $ ErrorDimensionMismatch $ "ToBuffer{[a]}.basicWriteBuffer: empty dims" basicWriteBuffer dtype (dim : dims) p xs = do unless (dim == length xs) $ do throwIO $ ErrorDimensionMismatch $ "ToBuffer{[a]}.basicWriteBuffer: dimension mismatch" let s = product dims * dtypeSize dtype forM_ (zip [0,s..] xs) $ \(offset,x) -> do basicWriteBuffer dtype dims (p `plusPtr` offset) x instance FromBuffer a => FromBuffer [a] where basicReadBuffer _dtype [] _p = throwIO $ ErrorDimensionMismatch $ "FromBuffer{[a]}.basicReadBuffer: empty dims" basicReadBuffer dtype (dim : dims) p = do let s = product dims * dtypeSize dtype forM [0..dim-1] $ \i -> do basicReadBuffer dtype dims (p `plusPtr` (i*s)) checkDType :: String -> DType -> DType -> IO () checkDType name dtype1 dtype2 | dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch" | otherwise = return () checkDTypeAndSize :: String -> (DType,Int) -> (DType,Int) -> IO () checkDTypeAndSize name (dtype1,n1) (dtype2,n2) | dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch" | n1 /= n2 = throwIO $ ErrorDimensionMismatch $ name ++ ": dimension mismatch" | otherwise = return () {-# DEPRECATED writeBufferFromVector, writeBufferFromStorableVector "Use ToBuffer class and writeBuffer instead" #-} -- | Copy whole elements of 'VG.Vector' into a model's buffer writeBufferFromVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> v a -> m () writeBufferFromVector model name vec = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name basicWriteBufferGenericVectorStorable (dtypeOf (Proxy :: Proxy a)) dtype dims p vec -- | Copy whole elements of @'VS.Vector' a@ into a model's buffer writeBufferFromStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> VS.Vector a -> m () writeBufferFromStorableVector model name vec = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name basicWriteBufferStorableVector (dtypeOf (Proxy :: Proxy a)) dtype dims p vec {-# DEPRECATED readBufferToVector, readBufferToStorableVector "Use FromBuffer class and readBuffer instead" #-} -- | Read whole elements of 'Array' and return as a 'VG.Vector'. readBufferToVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> m (v a) readBufferToVector model name = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name basicReadBufferGenericVectorStorable (dtypeOf (Proxy :: Proxy a)) dtype dims p -- | Read whole eleemnts of 'Array' and return as a 'VS.Vector'. readBufferToStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> m (VS.Vector a) readBufferToStorableVector model name = liftIO $ withBuffer model name $ \p -> do dtype <- getDType model name dims <- getDims model name basicReadBufferStorableVector (dtypeOf (Proxy :: Proxy a)) dtype dims p -- ------------------------------------------------------------------------ -- | Convenient methods for constructing a 'Model'. makeModel :: MonadIO m => VariableProfileTable -- ^ variable profile table -> ModelData -- ^ model data -> String -- ^ backend name -> m Model makeModel vpt model_data backend_name = liftIO $ do b <- makeModelBuilder vpt buildModel b model_data backend_name -- | Similar to 'makeModel' but backend-specific configuration can be supplied. makeModelWithConfig :: (MonadIO m, J.ToJSON a) => VariableProfileTable -- ^ variable profile table -> ModelData -- ^ model data -> String -- ^ backend name -> a -- ^ backend config -> m Model makeModelWithConfig vpt model_data backend_name backend_config = liftIO $ do b <- makeModelBuilder vpt buildModelWithConfig b model_data backend_name backend_config -- ------------------------------------------------------------------------ -- | Menoh version which was supplied on compilation time via CPP macro. version :: Version #if MIN_VERSION_base(4,8,0) version = makeVersion [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version] #else version = Version [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version] [] #endif -- | Version of this Haskell binding. (Not the version of /Menoh/ itself) bindingVersion :: Version bindingVersion = Paths_menoh.version