{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE OverloadedStrings #-}

module Data.Ollama.Embeddings
  ( -- * Embedding API
    embedding
  , embeddingOps
  ) where

import Data.Aeson
import Data.Ollama.Common.Utils as CU
import Data.Text (Text)
import Data.Text qualified as T
import GHC.Generics
import Network.HTTP.Client

-- TODO: Add Options parameter
data EmbeddingOps = EmbeddingOps
  { EmbeddingOps -> Text
model :: Text
  , EmbeddingOps -> Text
input :: Text
  , EmbeddingOps -> Maybe Bool
truncate :: Maybe Bool
  , EmbeddingOps -> Maybe Text
keepAlive :: Maybe Text
  }
  deriving (Int -> EmbeddingOps -> ShowS
[EmbeddingOps] -> ShowS
EmbeddingOps -> String
(Int -> EmbeddingOps -> ShowS)
-> (EmbeddingOps -> String)
-> ([EmbeddingOps] -> ShowS)
-> Show EmbeddingOps
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EmbeddingOps -> ShowS
showsPrec :: Int -> EmbeddingOps -> ShowS
$cshow :: EmbeddingOps -> String
show :: EmbeddingOps -> String
$cshowList :: [EmbeddingOps] -> ShowS
showList :: [EmbeddingOps] -> ShowS
Show, EmbeddingOps -> EmbeddingOps -> Bool
(EmbeddingOps -> EmbeddingOps -> Bool)
-> (EmbeddingOps -> EmbeddingOps -> Bool) -> Eq EmbeddingOps
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: EmbeddingOps -> EmbeddingOps -> Bool
== :: EmbeddingOps -> EmbeddingOps -> Bool
$c/= :: EmbeddingOps -> EmbeddingOps -> Bool
/= :: EmbeddingOps -> EmbeddingOps -> Bool
Eq)

data EmbeddingResp = EmbeddingResp
  { EmbeddingResp -> Text
model :: Text
  , EmbeddingResp -> [[Float]]
embedding' :: [[Float]]
  }
  deriving (Int -> EmbeddingResp -> ShowS
[EmbeddingResp] -> ShowS
EmbeddingResp -> String
(Int -> EmbeddingResp -> ShowS)
-> (EmbeddingResp -> String)
-> ([EmbeddingResp] -> ShowS)
-> Show EmbeddingResp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EmbeddingResp -> ShowS
showsPrec :: Int -> EmbeddingResp -> ShowS
$cshow :: EmbeddingResp -> String
show :: EmbeddingResp -> String
$cshowList :: [EmbeddingResp] -> ShowS
showList :: [EmbeddingResp] -> ShowS
Show, EmbeddingResp -> EmbeddingResp -> Bool
(EmbeddingResp -> EmbeddingResp -> Bool)
-> (EmbeddingResp -> EmbeddingResp -> Bool) -> Eq EmbeddingResp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: EmbeddingResp -> EmbeddingResp -> Bool
== :: EmbeddingResp -> EmbeddingResp -> Bool
$c/= :: EmbeddingResp -> EmbeddingResp -> Bool
/= :: EmbeddingResp -> EmbeddingResp -> Bool
Eq, (forall x. EmbeddingResp -> Rep EmbeddingResp x)
-> (forall x. Rep EmbeddingResp x -> EmbeddingResp)
-> Generic EmbeddingResp
forall x. Rep EmbeddingResp x -> EmbeddingResp
forall x. EmbeddingResp -> Rep EmbeddingResp x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. EmbeddingResp -> Rep EmbeddingResp x
from :: forall x. EmbeddingResp -> Rep EmbeddingResp x
$cto :: forall x. Rep EmbeddingResp x -> EmbeddingResp
to :: forall x. Rep EmbeddingResp x -> EmbeddingResp
Generic, Maybe EmbeddingResp
Value -> Parser [EmbeddingResp]
Value -> Parser EmbeddingResp
(Value -> Parser EmbeddingResp)
-> (Value -> Parser [EmbeddingResp])
-> Maybe EmbeddingResp
-> FromJSON EmbeddingResp
forall a.
(Value -> Parser a)
-> (Value -> Parser [a]) -> Maybe a -> FromJSON a
$cparseJSON :: Value -> Parser EmbeddingResp
parseJSON :: Value -> Parser EmbeddingResp
$cparseJSONList :: Value -> Parser [EmbeddingResp]
parseJSONList :: Value -> Parser [EmbeddingResp]
$comittedField :: Maybe EmbeddingResp
omittedField :: Maybe EmbeddingResp
FromJSON)

instance ToJSON EmbeddingOps where
  toJSON :: EmbeddingOps -> Value
toJSON (EmbeddingOps Text
model Text
input Maybe Bool
truncate' Maybe Text
keepAlive) =
    [Pair] -> Value
object
      [ Key
"model" Key -> Text -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Text
model
      , Key
"input" Key -> Text -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Text
input
      , Key
"truncate" Key -> Maybe Bool -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Maybe Bool
truncate'
      , Key
"keep_alive" Key -> Maybe Text -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Maybe Text
keepAlive
      ]

-- TODO: Add Options parameter

-- | Embedding API
embeddingOps ::
  -- | Model
  Text ->
  -- | Input
  Text ->
  -- | Truncate
  Maybe Bool ->
  -- | Keep Alive
  Maybe Text ->
  IO (Maybe EmbeddingResp)
embeddingOps :: Text
-> Text -> Maybe Bool -> Maybe Text -> IO (Maybe EmbeddingResp)
embeddingOps Text
modelName Text
input Maybe Bool
mTruncate Maybe Text
mKeepAlive = do
  let url :: Text
url = OllamaClient -> Text
CU.host OllamaClient
defaultOllama
  Manager
manager <- ManagerSettings -> IO Manager
newManager ManagerSettings
defaultManagerSettings
  Request
initialRequest <- String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest (String -> IO Request) -> String -> IO Request
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack (Text
url Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"/api/embed")
  let reqBody :: EmbeddingOps
reqBody =
        EmbeddingOps
          { $sel:model:EmbeddingOps :: Text
model = Text
modelName
          , $sel:input:EmbeddingOps :: Text
input = Text
input
          , $sel:truncate:EmbeddingOps :: Maybe Bool
truncate = Maybe Bool
mTruncate
          , $sel:keepAlive:EmbeddingOps :: Maybe Text
keepAlive = Maybe Text
mKeepAlive
          }
      request :: Request
request =
        Request
initialRequest
          { method = "POST"
          , requestBody = RequestBodyLBS $ encode reqBody
          }
  Response ByteString
resp <- Request -> Manager -> IO (Response ByteString)
httpLbs Request
request Manager
manager
  let mRes :: Maybe EmbeddingResp
mRes = ByteString -> Maybe EmbeddingResp
forall a. FromJSON a => ByteString -> Maybe a
decode (Response ByteString -> ByteString
forall body. Response body -> body
responseBody Response ByteString
resp) :: Maybe EmbeddingResp
  case Maybe EmbeddingResp
mRes of
    Maybe EmbeddingResp
Nothing -> Maybe EmbeddingResp -> IO (Maybe EmbeddingResp)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EmbeddingResp
forall a. Maybe a
Nothing
    Just EmbeddingResp
r -> Maybe EmbeddingResp -> IO (Maybe EmbeddingResp)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe EmbeddingResp -> IO (Maybe EmbeddingResp))
-> Maybe EmbeddingResp -> IO (Maybe EmbeddingResp)
forall a b. (a -> b) -> a -> b
$ EmbeddingResp -> Maybe EmbeddingResp
forall a. a -> Maybe a
Just EmbeddingResp
r

-- Higher level binding that only takes important params

-- | Embedding API
embedding ::
  -- | Model
  Text ->
  -- | Input
  Text ->
  IO (Maybe EmbeddingResp)
embedding :: Text -> Text -> IO (Maybe EmbeddingResp)
embedding Text
modelName Text
input =
  Text
-> Text -> Maybe Bool -> Maybe Text -> IO (Maybe EmbeddingResp)
embeddingOps Text
modelName Text
input Maybe Bool
forall a. Maybe a
Nothing Maybe Text
forall a. Maybe a
Nothing