{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module IntelliMonad.Tools.Utils where

import Control.Monad (forM)
import Control.Monad.IO.Class
import Data.Aeson (encode)
import qualified Data.Aeson as A
import qualified Data.ByteString as BS
import Data.Maybe (catMaybes)
import Data.Proxy
import Data.Text (Text)
import qualified Data.Text.Encoding as T
import Data.Time
import IntelliMonad.Types
import qualified OpenAI.Types as API

addTools :: [ToolProxy] -> API.CreateChatCompletionRequest -> API.CreateChatCompletionRequest
addTools :: [ToolProxy]
-> CreateChatCompletionRequest -> CreateChatCompletionRequest
addTools [] CreateChatCompletionRequest
v = CreateChatCompletionRequest
v
addTools (ToolProxy
tool : [ToolProxy]
tools') CreateChatCompletionRequest
v =
  case ToolProxy
tool of
    (ToolProxy (Proxy t
_ :: Proxy a)) -> [ToolProxy]
-> CreateChatCompletionRequest -> CreateChatCompletionRequest
addTools [ToolProxy]
tools' (forall a.
Tool a =>
CreateChatCompletionRequest -> CreateChatCompletionRequest
toolAdd @a CreateChatCompletionRequest
v)

toolExec' ::
  forall t p m.
  (PersistentBackend p, MonadIO m, MonadFail m, Tool t, A.FromJSON t, A.ToJSON (Output t)) =>
  Text ->
  Text ->
  Text ->
  Text ->
  Prompt m (Maybe Content)
toolExec' :: forall t p (m :: * -> *).
(PersistentBackend p, MonadIO m, MonadFail m, Tool t, FromJSON t,
 ToJSON (Output t)) =>
Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
toolExec' Text
sessionName Text
id' Text
name' Text
args' = do
  if Text
name' Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== forall a. Tool a => Text
toolFunctionName @t
    then case (ByteString -> Either String t
forall a. FromJSON a => ByteString -> Either String a
A.eitherDecode (ByteString -> ByteString
BS.fromStrict (Text -> ByteString
T.encodeUtf8 Text
args')) :: Either String t) of
      Left String
_ -> Maybe Content -> Prompt m (Maybe Content)
forall a. a -> StateT PromptEnv m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Content
forall a. Maybe a
Nothing
      Right t
input -> do
        Output t
output <- forall a p (m :: * -> *).
(Tool a, MonadIO m, MonadFail m, PersistentBackend p) =>
a -> Prompt m (Output a)
toolExec @t @p @m t
input
        UTCTime
time <- IO UTCTime -> StateT PromptEnv m UTCTime
forall a. IO a -> StateT PromptEnv m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
        Maybe Content -> Prompt m (Maybe Content)
forall a. a -> StateT PromptEnv m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Content -> Prompt m (Maybe Content))
-> Maybe Content -> Prompt m (Maybe Content)
forall a b. (a -> b) -> a -> b
$ Content -> Maybe Content
forall a. a -> Maybe a
Just (Content -> Maybe Content) -> Content -> Maybe Content
forall a b. (a -> b) -> a -> b
$ (User -> Message -> Text -> UTCTime -> Content
Content User
Tool (Text -> Text -> Text -> Message
ToolReturn Text
id' Text
name' (ByteString -> Text
T.decodeUtf8Lenient (ByteString -> ByteString
BS.toStrict (Output t -> ByteString
forall a. ToJSON a => a -> ByteString
encode Output t
output)))) Text
sessionName UTCTime
time)
    else Maybe Content -> Prompt m (Maybe Content)
forall a. a -> StateT PromptEnv m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Content
forall a. Maybe a
Nothing

(<||>) ::
  forall m.
  (MonadIO m, MonadFail m) =>
  (Text -> Text -> Text -> Text -> Prompt m (Maybe Content)) ->
  (Text -> Text -> Text -> Text -> Prompt m (Maybe Content)) ->
  Text ->
  Text ->
  Text ->
  Text ->
  Prompt m (Maybe Content)
<||> :: forall (m :: * -> *).
(MonadIO m, MonadFail m) =>
(Text -> Text -> Text -> Text -> Prompt m (Maybe Content))
-> (Text -> Text -> Text -> Text -> Prompt m (Maybe Content))
-> Text
-> Text
-> Text
-> Text
-> Prompt m (Maybe Content)
(<||>) Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
tool0 Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
tool1 Text
sessionName Text
id' Text
name' Text
args' = do
  Maybe Content
a <- Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
tool0 Text
sessionName Text
id' Text
name' Text
args'
  case Maybe Content
a of
    Just Content
v -> Maybe Content -> Prompt m (Maybe Content)
forall a. a -> StateT PromptEnv m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Content -> Maybe Content
forall a. a -> Maybe a
Just Content
v)
    Maybe Content
Nothing -> Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
tool1 Text
sessionName Text
id' Text
name' Text
args'

mergeToolCall :: forall p m. (PersistentBackend p, MonadIO m, MonadFail m) => [ToolProxy] -> Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
mergeToolCall :: forall p (m :: * -> *).
(PersistentBackend p, MonadIO m, MonadFail m) =>
[ToolProxy]
-> Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
mergeToolCall [] Text
_ Text
_ Text
_ Text
_ = Maybe Content -> StateT PromptEnv m (Maybe Content)
forall a. a -> StateT PromptEnv m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Content
forall a. Maybe a
Nothing
mergeToolCall (ToolProxy
tool : [ToolProxy]
tools') Text
sessionName Text
id' Text
name' Text
args' = do
  case ToolProxy
tool of
    (ToolProxy (Proxy t
_ :: Proxy a)) -> (forall t p (m :: * -> *).
(PersistentBackend p, MonadIO m, MonadFail m, Tool t, FromJSON t,
 ToJSON (Output t)) =>
Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
toolExec' @a @p (Text
 -> Text -> Text -> Text -> StateT PromptEnv m (Maybe Content))
-> (Text
    -> Text -> Text -> Text -> StateT PromptEnv m (Maybe Content))
-> Text
-> Text
-> Text
-> Text
-> StateT PromptEnv m (Maybe Content)
forall (m :: * -> *).
(MonadIO m, MonadFail m) =>
(Text -> Text -> Text -> Text -> Prompt m (Maybe Content))
-> (Text -> Text -> Text -> Text -> Prompt m (Maybe Content))
-> Text
-> Text
-> Text
-> Text
-> Prompt m (Maybe Content)
<||> forall p (m :: * -> *).
(PersistentBackend p, MonadIO m, MonadFail m) =>
[ToolProxy]
-> Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
mergeToolCall @p [ToolProxy]
tools') Text
sessionName Text
id' Text
name' Text
args'

hasToolCall :: Contents -> Bool
hasToolCall :: Contents -> Bool
hasToolCall Contents
cs =
  let loop :: Contents -> Bool
loop [] = Bool
False
      loop ((Content User
_ (ToolCall Text
_ Text
_ Text
_) Text
_ UTCTime
_) : Contents
_) = Bool
True
      loop (Content
_ : Contents
cs') = Contents -> Bool
loop Contents
cs'
   in Contents -> Bool
loop Contents
cs

filterToolCall :: Contents -> Contents
filterToolCall :: Contents -> Contents
filterToolCall Contents
cs =
  let loop :: Contents -> Contents
loop [] = []
      loop (m :: Content
m@(Content User
_ (ToolCall Text
_ Text
_ Text
_) Text
_ UTCTime
_) : Contents
cs') = Content
m Content -> Contents -> Contents
forall a. a -> [a] -> [a]
: Contents -> Contents
loop Contents
cs'
      loop (Content
_ : Contents
cs') = Contents -> Contents
loop Contents
cs'
   in Contents -> Contents
loop Contents
cs

tryToolExec :: forall p m. (PersistentBackend p, MonadIO m, MonadFail m) => [ToolProxy] -> Text -> Contents -> Prompt m Contents
tryToolExec :: forall p (m :: * -> *).
(PersistentBackend p, MonadIO m, MonadFail m) =>
[ToolProxy] -> Text -> Contents -> Prompt m Contents
tryToolExec [ToolProxy]
tools Text
sessionName Contents
contents = do
  [Maybe Content]
cs <- Contents
-> (Content -> StateT PromptEnv m (Maybe Content))
-> StateT PromptEnv m [Maybe Content]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Contents -> Contents
filterToolCall Contents
contents) ((Content -> StateT PromptEnv m (Maybe Content))
 -> StateT PromptEnv m [Maybe Content])
-> (Content -> StateT PromptEnv m (Maybe Content))
-> StateT PromptEnv m [Maybe Content]
forall a b. (a -> b) -> a -> b
$ \(Content User
_ (ToolCall Text
id' Text
name' Text
args') Text
_ UTCTime
_) -> do
    forall p (m :: * -> *).
(PersistentBackend p, MonadIO m, MonadFail m) =>
[ToolProxy]
-> Text -> Text -> Text -> Text -> Prompt m (Maybe Content)
mergeToolCall @p [ToolProxy]
tools Text
sessionName Text
id' Text
name' Text
args'
  Contents -> Prompt m Contents
forall a. a -> StateT PromptEnv m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Contents -> Prompt m Contents) -> Contents -> Prompt m Contents
forall a b. (a -> b) -> a -> b
$ [Maybe Content] -> Contents
forall a. [Maybe a] -> [a]
catMaybes [Maybe Content]
cs

findToolCall :: ToolProxy -> Contents -> Maybe Content
findToolCall :: ToolProxy -> Contents -> Maybe Content
findToolCall ToolProxy
_ [] = Maybe Content
forall a. Maybe a
Nothing
findToolCall t :: ToolProxy
t@(ToolProxy (Proxy t
Proxy :: Proxy a)) (Content
c : Contents
cs) =
  case Content
c of
    Content User
_ (Message Text
_) Text
_ UTCTime
_ -> ToolProxy -> Contents -> Maybe Content
findToolCall ToolProxy
t Contents
cs
    Content User
_ (Image Text
_ Text
_) Text
_ UTCTime
_ -> ToolProxy -> Contents -> Maybe Content
findToolCall ToolProxy
t Contents
cs
    Content User
_ (ToolCall Text
_ Text
name' Text
_) Text
_ UTCTime
_ ->
      if Text
name' Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== forall a. Tool a => Text
toolFunctionName @a
        then Content -> Maybe Content
forall a. a -> Maybe a
Just Content
c
        else ToolProxy -> Contents -> Maybe Content
findToolCall ToolProxy
t Contents
cs
    Content User
_ (ToolReturn Text
_ Text
_ Text
_) Text
_ UTCTime
_ -> ToolProxy -> Contents -> Maybe Content
findToolCall ToolProxy
t Contents
cs