{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleInstances #-}
module TypeChain.ChatModels.Types ( TypeChain
, TypeChainT
, ApiKey
, Role(..)
, Message(..), role, content
, pattern UserMessage
, pattern AssistantMessage
, pattern SystemMessage
, MsgList(..)
, ChatModel(..)
, RememberingChatModel(..)
, module Control.Monad.State
) where
import Control.Lens hiding ((.=))
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.State
import Data.Aeson
import Data.ByteString (ByteString)
import Data.Kind (Constraint)
import GHC.Generics (Generic)
type TypeChain model = StateT model IO
type TypeChainT = StateT
type ApiKey = ByteString
data Role = User | Assistant | System
data Message = Message { Message -> Role
_role :: Role
, Message -> String
_content :: String
} deriving (forall x. Message -> Rep Message x)
-> (forall x. Rep Message x -> Message) -> Generic Message
forall x. Rep Message x -> Message
forall x. Message -> Rep Message x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Message -> Rep Message x
from :: forall x. Message -> Rep Message x
$cto :: forall x. Rep Message x -> Message
to :: forall x. Rep Message x -> Message
Generic
instance Show Role where
show :: Role -> String
show = \case
Role
User -> String
"user"
Role
Assistant -> String
"assistant"
Role
System -> String
"system"
instance Show Message where
show :: Message -> String
show (Message Role
r String
c) = Role -> String
forall a. Show a => a -> String
show Role
r String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
c
makeLenses ''Message
pattern UserMessage :: String -> Message
pattern $mUserMessage :: forall {r}. Message -> (String -> r) -> ((# #) -> r) -> r
$bUserMessage :: String -> Message
UserMessage s = Message User s
pattern AssistantMessage :: String -> Message
pattern $mAssistantMessage :: forall {r}. Message -> (String -> r) -> ((# #) -> r) -> r
$bAssistantMessage :: String -> Message
AssistantMessage s = Message Assistant s
pattern SystemMessage :: String -> Message
pattern $mSystemMessage :: forall {r}. Message -> (String -> r) -> ((# #) -> r) -> r
$bSystemMessage :: String -> Message
SystemMessage s = Message System s
instance ToJSON Role where
toJSON :: Role -> Value
toJSON = \case
Role
User -> Text -> Value
String Text
"user"
Role
Assistant -> Text -> Value
String Text
"assistant"
Role
System -> Text -> Value
String Text
"system"
instance FromJSON Role where
parseJSON :: Value -> Parser Role
parseJSON = String -> (Text -> Parser Role) -> Value -> Parser Role
forall a. String -> (Text -> Parser a) -> Value -> Parser a
withText String
"Role" ((Text -> Parser Role) -> Value -> Parser Role)
-> (Text -> Parser Role) -> Value -> Parser Role
forall a b. (a -> b) -> a -> b
$ \case
Text
"user" -> Role -> Parser Role
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Role
User
Text
"assistant" -> Role -> Parser Role
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Role
Assistant
Text
"system" -> Role -> Parser Role
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Role
System
Text
_ -> String -> Parser Role
forall a. String -> Parser a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid role"
instance ToJSON Message where
toJSON :: Message -> Value
toJSON (Message Role
r String
c) = [Pair] -> Value
object [ Key
"role" Key -> Role -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Role
r
, Key
"content" Key -> String -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= String
c
]
instance FromJSON Message where
parseJSON :: Value -> Parser Message
parseJSON = String -> (Object -> Parser Message) -> Value -> Parser Message
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"Message" ((Object -> Parser Message) -> Value -> Parser Message)
-> (Object -> Parser Message) -> Value -> Parser Message
forall a b. (a -> b) -> a -> b
$ \Object
o -> do
Role
role <- Object
o Object -> Key -> Parser Role
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"role"
String
content <- Object
o Object -> Key -> Parser String
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"content"
Message -> Parser Message
forall a. a -> Parser a
forall (m :: * -> *) a. Monad m => a -> m a
return (Message -> Parser Message) -> Message -> Parser Message
forall a b. (a -> b) -> a -> b
$ Role -> String -> Message
Message Role
role String
content
class MsgList a where
toMsgList :: a -> [Message]
instance MsgList String where
toMsgList :: String -> [Message]
toMsgList = Message -> [Message]
forall a. MsgList a => a -> [Message]
toMsgList (Message -> [Message])
-> (String -> Message) -> String -> [Message]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Message
UserMessage
instance MsgList Message where
toMsgList :: Message -> [Message]
toMsgList = Message -> [Message]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance MsgList [Message] where
toMsgList :: [Message] -> [Message]
toMsgList = [Message] -> [Message]
forall a. a -> a
id
class ChatModel a where
predict :: (MonadIO m, MonadThrow m, MsgList msg) => msg -> TypeChainT a m [Message]
predict = Lens' a a -> msg -> TypeChainT a m [Message]
forall a (m :: * -> *) msg s.
(ChatModel a, MonadIO m, MonadThrow m, MsgList msg) =>
Lens' s a -> msg -> TypeChainT s m [Message]
forall (m :: * -> *) msg s.
(MonadIO m, MonadThrow m, MsgList msg) =>
Lens' s a -> msg -> TypeChainT s m [Message]
predicts (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id
predicts :: (MonadIO m, MonadThrow m, MsgList msg) => Lens' s a -> msg -> TypeChainT s m [Message]
class ChatModel a => RememberingChatModel a where
setMemoryEnabled :: Monad m => Bool -> TypeChainT a m ()
setMemoryEnabled = Lens' a a -> Bool -> TypeChainT a m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> Bool -> TypeChainT s m ()
forall (m :: * -> *) s.
Monad m =>
Lens' s a -> Bool -> TypeChainT s m ()
setMemoryEnabledFor (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id
setMemoryEnabledFor :: Monad m => Lens' s a -> Bool -> TypeChainT s m ()
forget :: Monad m => TypeChainT a m ()
forget = Lens' a a -> TypeChainT a m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> TypeChainT s m ()
forall (m :: * -> *) s. Monad m => Lens' s a -> TypeChainT s m ()
forgetFor (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id
forgetFor :: Monad m => Lens' s a -> TypeChainT s m ()
memorize :: Monad m => [Message] -> TypeChainT a m ()
memorize = ((a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id Lens' a a -> [Message] -> TypeChainT a m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> [Message] -> TypeChainT s m ()
forall (m :: * -> *) s.
Monad m =>
Lens' s a -> [Message] -> TypeChainT s m ()
`memorizes`)
memorizes :: Monad m => Lens' s a -> [Message] -> TypeChainT s m ()
remember :: Monad m => TypeChainT a m [Message]
remember = Lens' a a -> TypeChainT a m [Message]
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> TypeChainT s m [Message]
forall (m :: * -> *) s.
Monad m =>
Lens' s a -> TypeChainT s m [Message]
rememberFor (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id
rememberFor :: Monad m => Lens' s a -> TypeChainT s m [Message]