{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleInstances #-}

module TypeChain.ChatModels.Types ( TypeChain
                                  , TypeChainT
                                  , ApiKey 
                                  , Role(..)
                                  , Message(..), role, content
                                  , pattern UserMessage 
                                  , pattern AssistantMessage 
                                  , pattern SystemMessage
                                  , MsgList(..)
                                  , ChatModel(..) 
                                  , RememberingChatModel(..)
                                  , module Control.Monad.State
                                  ) where 
import Control.Lens hiding ((.=))
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.State

import Data.Aeson
import Data.ByteString (ByteString)
import Data.Kind (Constraint)

import GHC.Generics (Generic)

type TypeChain model = StateT model IO
type TypeChainT = StateT

type ApiKey = ByteString

-- | Way of distinguising who said what in a conversation
data Role = User | Assistant | System

-- | A message with a role and content (lenses @role@ and @content@)
data Message = Message { Message -> Role
_role    :: Role 
                       , Message -> String
_content :: String 
                       } deriving (forall x. Message -> Rep Message x)
-> (forall x. Rep Message x -> Message) -> Generic Message
forall x. Rep Message x -> Message
forall x. Message -> Rep Message x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Message -> Rep Message x
from :: forall x. Message -> Rep Message x
$cto :: forall x. Rep Message x -> Message
to :: forall x. Rep Message x -> Message
Generic

instance Show Role where 
    show :: Role -> String
show = \case 
        Role
User      -> String
"user"
        Role
Assistant -> String
"assistant" 
        Role
System    -> String
"system"

instance Show Message where 
    show :: Message -> String
show (Message Role
r String
c) = Role -> String
forall a. Show a => a -> String
show Role
r String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
c

makeLenses ''Message

-- | Pattern synonym for creating a @Message@ with @User@ role
pattern UserMessage :: String -> Message 
pattern $mUserMessage :: forall {r}. Message -> (String -> r) -> ((# #) -> r) -> r
$bUserMessage :: String -> Message
UserMessage s = Message User s

-- | Pattern synonym for creating a @Message@ with @Assistant@ role
pattern AssistantMessage :: String -> Message 
pattern $mAssistantMessage :: forall {r}. Message -> (String -> r) -> ((# #) -> r) -> r
$bAssistantMessage :: String -> Message
AssistantMessage s = Message Assistant s 

-- | Pattern synonym for creating a @Message@ with @System@ role
pattern SystemMessage :: String -> Message 
pattern $mSystemMessage :: forall {r}. Message -> (String -> r) -> ((# #) -> r) -> r
$bSystemMessage :: String -> Message
SystemMessage s = Message System s

instance ToJSON Role where 
    toJSON :: Role -> Value
toJSON = \case 
        Role
User      -> Text -> Value
String Text
"user"
        Role
Assistant -> Text -> Value
String Text
"assistant"
        Role
System    -> Text -> Value
String Text
"system"

instance FromJSON Role where 
    parseJSON :: Value -> Parser Role
parseJSON = String -> (Text -> Parser Role) -> Value -> Parser Role
forall a. String -> (Text -> Parser a) -> Value -> Parser a
withText String
"Role" ((Text -> Parser Role) -> Value -> Parser Role)
-> (Text -> Parser Role) -> Value -> Parser Role
forall a b. (a -> b) -> a -> b
$ \case 
        Text
"user"      -> Role -> Parser Role
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Role
User
        Text
"assistant" -> Role -> Parser Role
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Role
Assistant
        Text
"system"    -> Role -> Parser Role
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Role
System
        Text
_           -> String -> Parser Role
forall a. String -> Parser a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid role"

instance ToJSON Message where 
    toJSON :: Message -> Value
toJSON (Message Role
r String
c) = [Pair] -> Value
object [ Key
"role"    Key -> Role -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Role
r 
                                  , Key
"content" Key -> String -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= String
c
                                  ]

instance FromJSON Message where 
    parseJSON :: Value -> Parser Message
parseJSON = String -> (Object -> Parser Message) -> Value -> Parser Message
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"Message" ((Object -> Parser Message) -> Value -> Parser Message)
-> (Object -> Parser Message) -> Value -> Parser Message
forall a b. (a -> b) -> a -> b
$ \Object
o -> do 
        Role
role    <- Object
o Object -> Key -> Parser Role
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"role"
        String
content <- Object
o Object -> Key -> Parser String
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"content"
        Message -> Parser Message
forall a. a -> Parser a
forall (m :: * -> *) a. Monad m => a -> m a
return (Message -> Parser Message) -> Message -> Parser Message
forall a b. (a -> b) -> a -> b
$ Role -> String -> Message
Message Role
role String
content

-- | Helper typeclass to allow for multiple types to be passed to the  
-- @ChatModel@ functions. 
--
-- NOTE: If this is used with the @OverloadedStrings@ extension, you will need 
-- type annotations when using the @String@ instance.
class MsgList a where 

    -- | Convert to @a@ list of messages
    toMsgList :: a -> [Message]

instance MsgList String where 
    toMsgList :: String -> [Message]
toMsgList = Message -> [Message]
forall a. MsgList a => a -> [Message]
toMsgList (Message -> [Message])
-> (String -> Message) -> String -> [Message]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Message
UserMessage

instance MsgList Message where 
    toMsgList :: Message -> [Message]
toMsgList = Message -> [Message]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance MsgList [Message] where 
    toMsgList :: [Message] -> [Message]
toMsgList = [Message] -> [Message]
forall a. a -> a
id

-- | A class for Chat Models
-- In order to achieve compatibility with as many different kinds of LLMS as 
-- possible, the predict function is constrained to MonadIO so that it has the 
-- capability to either make an API call, run a local model, or any other action 
-- that may require IO.
--
-- Computations with a @ChatModel@ are expected to be run in a @StateT@ monad 
-- (see @TypeChain@ and @TypeChainT@ for specific types) so that the model can 
-- be updated with new messages and the output messages can be logged. 
--
-- Functions that operate in a context where multiple models are available 
-- (e.g. @predicts@ and @addMsgsTo@) use lenses to allow extraction and 
-- modification of the model without knowing the specific state type. 
--
-- Exmaple: If working with two models, you can use @(model1, model2)@ as the 
-- state type and pass the @_1@ and @_2@ lenses to @predicts@ and @addMsgsTo@
-- to specify which model to use in the function.
class ChatModel a where 

    -- | Predict for current and only model
    -- This function should prompt the model (either via API or locally), and
    -- return the response.
    --
    -- NOTE: If a model has the capability to remember previous messages, it 
    -- should implement @RememberingChatModel@ and automatically manage this 
    -- functionality in the @predict@ function.
    predict :: (MonadIO m, MonadThrow m, MsgList msg) => msg -> TypeChainT a m [Message]
    predict = Lens' a a -> msg -> TypeChainT a m [Message]
forall a (m :: * -> *) msg s.
(ChatModel a, MonadIO m, MonadThrow m, MsgList msg) =>
Lens' s a -> msg -> TypeChainT s m [Message]
forall (m :: * -> *) msg s.
(MonadIO m, MonadThrow m, MsgList msg) =>
Lens' s a -> msg -> TypeChainT s m [Message]
predicts (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id

    -- | Predict for a specific model via lens
    -- This function should prompt the model (either via API or locally), log 
    -- the input messages, log the output messages, and return the output messages. 
    --
    -- NOTE: If a model has the capability to remember previous messages, it 
    -- should implement @RememberingChatModel@ and automatically manage this 
    -- functionality in the @predicts@ function.
    predicts :: (MonadIO m, MonadThrow m, MsgList msg) => Lens' s a -> msg -> TypeChainT s m [Message]

-- Typeclass for chatmodels that can remember previous messages 
class ChatModel a => RememberingChatModel a where 
    
    -- | Enable/Disable memory for current and only model
    setMemoryEnabled :: Monad m => Bool -> TypeChainT a m ()
    setMemoryEnabled = Lens' a a -> Bool -> TypeChainT a m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> Bool -> TypeChainT s m ()
forall (m :: * -> *) s.
Monad m =>
Lens' s a -> Bool -> TypeChainT s m ()
setMemoryEnabledFor (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id

    -- | Enable/Disable memory for specific model
    setMemoryEnabledFor :: Monad m => Lens' s a -> Bool -> TypeChainT s m ()

    -- | Remove all remembered messages for the current and only model.
    -- This does not affect a model's ability to remember future messages.
    forget :: Monad m => TypeChainT a m ()
    forget = Lens' a a -> TypeChainT a m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> TypeChainT s m ()
forall (m :: * -> *) s. Monad m => Lens' s a -> TypeChainT s m ()
forgetFor (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id

    -- | Remove all remebered messages for a specific model. 
    -- This does not affect a model's ability to remember future messages.
    forgetFor :: Monad m => Lens' s a -> TypeChainT s m ()

    -- | Remember a list of messages for the current and only model.
    -- This does not affect a model's ability to remember future messages and 
    -- should respect the current memory setting.
    memorize :: Monad m => [Message] -> TypeChainT a m ()
    memorize = ((a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id Lens' a a -> [Message] -> TypeChainT a m ()
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> [Message] -> TypeChainT s m ()
forall (m :: * -> *) s.
Monad m =>
Lens' s a -> [Message] -> TypeChainT s m ()
`memorizes`)

    -- | Remember a list of messages for a specific model. 
    -- This does not affect a model's ability to remember future messages and 
    -- should respect the current memory setting.
    memorizes :: Monad m => Lens' s a -> [Message] -> TypeChainT s m ()

    -- | Retrieve all remembered messages for the current and only model.
    -- This does not forget any messages nor affect a model's ability to 
    -- remember future messages.
    remember :: Monad m => TypeChainT a m [Message]
    remember = Lens' a a -> TypeChainT a m [Message]
forall a (m :: * -> *) s.
(RememberingChatModel a, Monad m) =>
Lens' s a -> TypeChainT s m [Message]
forall (m :: * -> *) s.
Monad m =>
Lens' s a -> TypeChainT s m [Message]
rememberFor (a -> f a) -> a -> f a
forall a. a -> a
Lens' a a
id

    -- | Retrieve all remembered messages for a specific model. 
    -- This does not forget any messages nor affect a model's ability to 
    -- remember future messages.
    rememberFor :: Monad m => Lens' s a -> TypeChainT s m [Message]