{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module TypeChain.ChatModels.OpenAI (OpenAIChat(..), OpenAIChatModel(..), initOpenAIChat) where

import Control.Lens hiding ((.=))
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.State

import Data.Aeson
import Data.ByteString.Lazy (ByteString)
import Data.Functor (($>))
import Data.Maybe (fromMaybe)

import GHC.Generics (Generic)

import TypeChain.ChatModels.Types

import Network.HTTP.Simple
import Network.HTTP.Conduit

import qualified Data.ByteString.Lazy as BS

data OpenAIChatModel = GPT35Turbo | GPT4 | GPT4Turbo

instance Show OpenAIChatModel where 
    show :: OpenAIChatModel -> String
show OpenAIChatModel
GPT35Turbo = String
"gpt-3.5-turbo"
    show OpenAIChatModel
GPT4       = String
"gpt-4"
    show OpenAIChatModel
GPT4Turbo  = String
"gpt-4-turbo-preview"

instance ToJSON OpenAIChatModel where 
    toJSON :: OpenAIChatModel -> Value
toJSON = String -> Value
forall a. ToJSON a => a -> Value
toJSON (String -> Value)
-> (OpenAIChatModel -> String) -> OpenAIChatModel -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenAIChatModel -> String
forall a. Show a => a -> String
show

data OpenAIChat = OpenAIChat { OpenAIChat -> OpenAIChatModel
chatModel   :: OpenAIChatModel
                             , OpenAIChat -> Maybe [Message]
messages    :: Maybe [Message] -- ^ @Nothing@ = Do not remember messages
                             , OpenAIChat -> Float
temperature :: Float
                             , OpenAIChat -> ByteString
apiKey      :: ApiKey
                             } deriving (forall x. OpenAIChat -> Rep OpenAIChat x)
-> (forall x. Rep OpenAIChat x -> OpenAIChat) -> Generic OpenAIChat
forall x. Rep OpenAIChat x -> OpenAIChat
forall x. OpenAIChat -> Rep OpenAIChat x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. OpenAIChat -> Rep OpenAIChat x
from :: forall x. OpenAIChat -> Rep OpenAIChat x
$cto :: forall x. Rep OpenAIChat x -> OpenAIChat
to :: forall x. Rep OpenAIChat x -> OpenAIChat
Generic

instance ToJSON OpenAIChat where 
    toJSON :: OpenAIChat -> Value
toJSON OpenAIChat
model = [Pair] -> Value
object [ Key
"model"       Key -> OpenAIChatModel -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= OpenAIChat -> OpenAIChatModel
chatModel OpenAIChat
model
                          , Key
"temperature" Key -> Float -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= OpenAIChat -> Float
temperature OpenAIChat
model
                          , Key
"messages"    Key -> Maybe [Message] -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= OpenAIChat -> Maybe [Message]
messages OpenAIChat
model
                          ]

-- | A list of responses from OpenAI's GPT-3.5-Turbo model
data Choices = Choices { Choices -> Message
message       :: Message 
                       , Choices -> String
finish_reason :: String 
                       , Choices -> Int
index         :: Int
                       } deriving (forall x. Choices -> Rep Choices x)
-> (forall x. Rep Choices x -> Choices) -> Generic Choices
forall x. Rep Choices x -> Choices
forall x. Choices -> Rep Choices x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Choices -> Rep Choices x
from :: forall x. Choices -> Rep Choices x
$cto :: forall x. Rep Choices x -> Choices
to :: forall x. Rep Choices x -> Choices
Generic

instance FromJSON Choices 

-- | Minimal JSON response datatype from OpenAI's GPT-3.5-Turbo model
data OpenAIResponse = OpenAIResponse { OpenAIResponse -> String
model   :: String
                                     , OpenAIResponse -> [Choices]
choices :: [Choices]
                                     } deriving (forall x. OpenAIResponse -> Rep OpenAIResponse x)
-> (forall x. Rep OpenAIResponse x -> OpenAIResponse)
-> Generic OpenAIResponse
forall x. Rep OpenAIResponse x -> OpenAIResponse
forall x. OpenAIResponse -> Rep OpenAIResponse x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. OpenAIResponse -> Rep OpenAIResponse x
from :: forall x. OpenAIResponse -> Rep OpenAIResponse x
$cto :: forall x. Rep OpenAIResponse x -> OpenAIResponse
to :: forall x. Rep OpenAIResponse x -> OpenAIResponse
Generic


instance FromJSON OpenAIResponse where 
    parseJSON :: Value -> Parser OpenAIResponse
parseJSON = String
-> (Object -> Parser OpenAIResponse)
-> Value
-> Parser OpenAIResponse
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"OpenAIResponse" ((Object -> Parser OpenAIResponse)
 -> Value -> Parser OpenAIResponse)
-> (Object -> Parser OpenAIResponse)
-> Value
-> Parser OpenAIResponse
forall a b. (a -> b) -> a -> b
$ \Object
o -> do 
        String
model   <- Object
o Object -> Key -> Parser String
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"model"
        [Choices]
choices <- Object
o Object -> Key -> Parser [Choices]
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"choices"
        OpenAIResponse -> Parser OpenAIResponse
forall a. a -> Parser a
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenAIResponse -> Parser OpenAIResponse)
-> OpenAIResponse -> Parser OpenAIResponse
forall a b. (a -> b) -> a -> b
$ String -> [Choices] -> OpenAIResponse
OpenAIResponse String
model [Choices]
choices

-- | Create an OpenAI chat model with default values 
--
-- Model: GPT-3.5-Turbo
--
-- Memorization: Enabled 
--
-- Temperature: 0.7 
--
-- ApiKey: <empty value>
initOpenAIChat :: OpenAIChat
initOpenAIChat :: OpenAIChat
initOpenAIChat = OpenAIChat { chatModel :: OpenAIChatModel
chatModel   = OpenAIChatModel
GPT35Turbo 
                            , messages :: Maybe [Message]
messages    = [Message] -> Maybe [Message]
forall a. a -> Maybe a
Just []
                            , temperature :: Float
temperature = Float
0.7
                            , apiKey :: ByteString
apiKey      = ByteString
"MISSING-API-KEY"
                            }

mkOpenAIChatHeaders :: ApiKey -> RequestHeaders
mkOpenAIChatHeaders :: ByteString -> RequestHeaders
mkOpenAIChatHeaders ByteString
k = [(HeaderName
"Content-Type", ByteString
"application/json"), (HeaderName
"Authorization", ByteString
"Bearer " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
k)]

mkGPT35TurboRequest :: MonadThrow m => OpenAIChat -> m Request
mkGPT35TurboRequest :: forall (m :: * -> *). MonadThrow m => OpenAIChat -> m Request
mkGPT35TurboRequest OpenAIChat
gpt = do 
    Request
initReq <- String -> m Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest String
"https://api.openai.com/v1/chat/completions"
    Request -> m Request
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> m Request) -> Request -> m Request
forall a b. (a -> b) -> a -> b
$ Request
initReq { requestHeaders = mkOpenAIChatHeaders (apiKey gpt) 
                     , requestBody    = RequestBodyLBS (encode gpt)
                     , method         = "POST"
                     }

instance ChatModel OpenAIChat where 
    predicts :: forall (m :: * -> *) msg s.
(MonadIO m, MonadThrow m, MsgList msg) =>
Lens' s OpenAIChat -> msg -> TypeChainT s m [Message]
predicts Lens' s OpenAIChat
model msg
m = do 
        let msgs :: [Message]
msgs = msg -> [Message]
forall a. MsgList a => a -> [Message]
toMsgList msg
m
        (OpenAIChat -> f OpenAIChat) -> s -> f s
Lens' s OpenAIChat
model Lens' s OpenAIChat -> [Message] -> TypeChainT s m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> [Message] -> TypeChainT s m ()
forall (m :: * -> *) s.
Monad m =>
Lens' s OpenAIChat -> [Message] -> TypeChainT s m ()
`memorizes` [Message]
msgs

        OpenAIChat
gpt <- (s -> OpenAIChat) -> StateT s m OpenAIChat
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Getting OpenAIChat s OpenAIChat -> s -> OpenAIChat
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting OpenAIChat s OpenAIChat
Lens' s OpenAIChat
model)
        Request
req <- OpenAIChat -> StateT s m Request
forall (m :: * -> *). MonadThrow m => OpenAIChat -> m Request
mkGPT35TurboRequest OpenAIChat
gpt
        Response ByteString
res <- Request -> StateT s m (Response ByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLBS Request
req

        case forall a. FromJSON a => ByteString -> Maybe a
decode @OpenAIResponse (Response ByteString -> ByteString
forall body. Response body -> body
responseBody Response ByteString
res) of 
            Maybe OpenAIResponse
Nothing -> IO [Message] -> TypeChainT s m [Message]
forall a. IO a -> StateT s m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [Message] -> TypeChainT s m [Message])
-> IO [Message] -> TypeChainT s m [Message]
forall a b. (a -> b) -> a -> b
$ String -> IO ()
putStrLn String
"WARNING: Failed to decode OpenAIResponse" IO () -> [Message] -> IO [Message]
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> []
            Just OpenAIResponse
r  -> let newMsgs :: [Message]
newMsgs = (Choices -> Message) -> [Choices] -> [Message]
forall a b. (a -> b) -> [a] -> [b]
map Choices -> Message
message (OpenAIResponse -> [Choices]
choices OpenAIResponse
r) in (OpenAIChat -> f OpenAIChat) -> s -> f s
Lens' s OpenAIChat
model Lens' s OpenAIChat -> [Message] -> TypeChainT s m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> [Message] -> TypeChainT s m ()
forall (m :: * -> *) s.
Monad m =>
Lens' s OpenAIChat -> [Message] -> TypeChainT s m ()
`memorizes` [Message]
newMsgs TypeChainT s m () -> [Message] -> TypeChainT s m [Message]
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> [Message]
newMsgs

instance RememberingChatModel OpenAIChat where 
    setMemoryEnabledFor :: forall (m :: * -> *) s.
Monad m =>
Lens' s OpenAIChat -> Bool -> TypeChainT s m ()
setMemoryEnabledFor Lens' s OpenAIChat
model Bool
status = do 
        OpenAIChat
m <- (s -> OpenAIChat) -> StateT s m OpenAIChat
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Getting OpenAIChat s OpenAIChat -> s -> OpenAIChat
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting OpenAIChat s OpenAIChat
Lens' s OpenAIChat
model)
        case (OpenAIChat -> Maybe [Message]
messages OpenAIChat
m, Bool
status) of 
            (Maybe [Message]
Nothing, Bool
True) -> (OpenAIChat -> Identity OpenAIChat) -> s -> Identity s
Lens' s OpenAIChat
model ((OpenAIChat -> Identity OpenAIChat) -> s -> Identity s)
-> (OpenAIChat -> OpenAIChat) -> TypeChainT s m ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= \OpenAIChat
m -> OpenAIChat
m { messages = Just [] }
            (Just [Message]
_, Bool
False) -> (OpenAIChat -> Identity OpenAIChat) -> s -> Identity s
Lens' s OpenAIChat
model ((OpenAIChat -> Identity OpenAIChat) -> s -> Identity s)
-> (OpenAIChat -> OpenAIChat) -> TypeChainT s m ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= \OpenAIChat
m -> OpenAIChat
m { messages = Nothing }
            (Maybe [Message], Bool)
_               -> () -> TypeChainT s m ()
forall a. a -> StateT s m a
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Do nothing since the status is already matching

    forgetFor :: forall (m :: * -> *) s.
Monad m =>
Lens' s OpenAIChat -> TypeChainT s m ()
forgetFor Lens' s OpenAIChat
model = do 
        OpenAIChat
m <- (s -> OpenAIChat) -> StateT s m OpenAIChat
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Getting OpenAIChat s OpenAIChat -> s -> OpenAIChat
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting OpenAIChat s OpenAIChat
Lens' s OpenAIChat
model)
        case OpenAIChat -> Maybe [Message]
messages OpenAIChat
m of 
            Maybe [Message]
Nothing -> () -> TypeChainT s m ()
forall a. a -> StateT s m a
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Do nothing since we're already forgetting
            Just [Message]
_  -> (OpenAIChat -> Identity OpenAIChat) -> s -> Identity s
Lens' s OpenAIChat
model ((OpenAIChat -> Identity OpenAIChat) -> s -> Identity s)
-> (OpenAIChat -> OpenAIChat) -> TypeChainT s m ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= \OpenAIChat
m -> OpenAIChat
m { messages = Just [] }

    memorizes :: forall (m :: * -> *) s.
Monad m =>
Lens' s OpenAIChat -> [Message] -> TypeChainT s m ()
memorizes Lens' s OpenAIChat
model [Message]
msgs = (OpenAIChat -> Identity OpenAIChat) -> s -> Identity s
Lens' s OpenAIChat
model ((OpenAIChat -> Identity OpenAIChat) -> s -> Identity s)
-> (OpenAIChat -> OpenAIChat) -> StateT s m ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= \OpenAIChat
m -> OpenAIChat
m { messages = (++) <$> messages m <*> Just (toMsgList msgs) }
    rememberFor :: forall (m :: * -> *) s.
Monad m =>
Lens' s OpenAIChat -> TypeChainT s m [Message]
rememberFor Lens' s OpenAIChat
model = (s -> [Message]) -> StateT s m [Message]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ([Message] -> Maybe [Message] -> [Message]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Message] -> [Message])
-> (s -> Maybe [Message]) -> s -> [Message]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenAIChat -> Maybe [Message]
messages (OpenAIChat -> Maybe [Message])
-> (s -> OpenAIChat) -> s -> Maybe [Message]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Getting OpenAIChat s OpenAIChat -> s -> OpenAIChat
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting OpenAIChat s OpenAIChat
Lens' s OpenAIChat
model)