{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}

module TypeChain.ChatModels.PromptTemplate (ToPrompt(..), makeTemplate, user, assistant, system) where

import Data.Char (toLower)
import Data.List (nub)

import Language.Haskell.TH

import TypeChain.ChatModels.Types

data TemplateToken = ConstString String | Var String deriving TemplateToken -> TemplateToken -> Bool
(TemplateToken -> TemplateToken -> Bool)
-> (TemplateToken -> TemplateToken -> Bool) -> Eq TemplateToken
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TemplateToken -> TemplateToken -> Bool
== :: TemplateToken -> TemplateToken -> Bool
$c/= :: TemplateToken -> TemplateToken -> Bool
/= :: TemplateToken -> TemplateToken -> Bool
Eq

type PromptTemplate = (Q Exp, [Name])

-- | Typeclass used to convert generated record types into a list of messages.
--
-- Instances of this typeclass are generated by the `makeTemplate` function and 
-- should only be used if you need to construct a prompt manually.
class ToPrompt a where 

    -- | Return the list of messages that should be used as the prompt.
    toPrompt :: a -> [Message]

-- | Convert a String into a compile prompt template for the `makeTemplate` function.
--
-- This particular function is for user messages.
user :: String -> Q PromptTemplate
user :: String -> Q PromptTemplate
user String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| UserMessage |]

-- | Convert a String into a compile prompt template for the `makeTemplate` function.
--
-- This particular function is for assistant messages.
assistant :: String -> Q PromptTemplate
assistant :: String -> Q PromptTemplate
assistant String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| AssistantMessage |]

-- | Convert a String into a compile prompt template for the `makeTemplate` function. 
--
-- This particular function is for system messages.
system :: String -> Q PromptTemplate
system :: String -> Q PromptTemplate
system String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| SystemMessage |]

toTemplate :: String -> Q Exp -> Q PromptTemplate
toTemplate :: String -> Q Exp -> Q PromptTemplate
toTemplate String
xs Q Exp
f = do 
    let tempParam :: Name
tempParam = String -> Name
mkName String
"template"

    let tokens :: [TemplateToken]
tokens = String -> [TemplateToken]
parseTemplateTokens String
xs
        expr :: Q Exp
expr   = Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
tempParam [TemplateToken]
tokens
        names :: [Name]
names  = (String -> Name) -> [String] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map String -> Name
mkName ([String] -> [Name]) -> [String] -> [Name]
forall a b. (a -> b) -> a -> b
$ [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ [TemplateToken] -> [String]
getVarTokens [TemplateToken]
tokens

    PromptTemplate -> Q PromptTemplate
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE Q Exp
f Q Exp
expr, [Name]
names)

parseTemplateTokens :: String -> [TemplateToken]
parseTemplateTokens :: String -> [TemplateToken]
parseTemplateTokens [] = [] 
parseTemplateTokens (Char
'{':String
xs) = String -> TemplateToken
Var String
first TemplateToken -> [TemplateToken] -> [TemplateToken]
forall a. a -> [a] -> [a]
: String -> [TemplateToken]
parseTemplateTokens String
rest
    where (String
first, String -> String
forall a. HasCallStack => [a] -> [a]
tail -> String
rest) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'}') String
xs
parseTemplateTokens (Char
x:String
xs) = String -> TemplateToken
ConstString (Char
x Char -> String -> String
forall a. a -> [a] -> [a]
: String
first) TemplateToken -> [TemplateToken] -> [TemplateToken]
forall a. a -> [a] -> [a]
: String -> [TemplateToken]
parseTemplateTokens String
rest
    where (String
first, String
rest) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'{') String
xs

getVarTokens :: [TemplateToken] -> [String]
getVarTokens :: [TemplateToken] -> [String]
getVarTokens [] = [] 
getVarTokens (Var String
x : [TemplateToken]
xs) = String
x String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [TemplateToken] -> [String]
getVarTokens [TemplateToken]
xs
getVarTokens (TemplateToken
_     : [TemplateToken]
xs) = [TemplateToken] -> [String]
getVarTokens [TemplateToken]
xs

tokensToExpr :: Name -> [TemplateToken] -> Q Exp
tokensToExpr :: Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
_ [] = [| "" |]
tokensToExpr Name
name (ConstString String
x : [TemplateToken]
xs) = [| x ++ $(Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
name [TemplateToken]
xs) |]
tokensToExpr Name
name (Var String
x : [TemplateToken]
xs) = Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE [| (++) |] (Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> Q Exp) -> Name -> Q Exp
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
x) (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name)) Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
name [TemplateToken]
xs


-- | Given a typename and a list of messages, generate a data type and a function to construct it.
--
-- Example: `makeTemplate "Translate" [system "translate {a} to {b}.", user "{text}"]`
--
-- This generates a record named @Translate@ with fields @a@, @b@, and @text@. 
-- It also generates a function @mkTranslate :: String -> String -> String -> [Message]@.
-- To allow for quick and easy construction of the prompt if needed. Otherwise, you can use the 
-- generated data type in conjunction with the `toPrompt` function to be more explicit.
--
-- See the example on the repo's README.md for an example of what the generated code looks like.
makeTemplate :: String -> [Q PromptTemplate] -> Q [Dec] 
makeTemplate :: String -> [Q PromptTemplate] -> Q [Dec]
makeTemplate String
name [Q PromptTemplate]
xs = do 
    let typeName :: Name
typeName = String -> Name
mkName String
name
        funcName :: Name
funcName = String -> Name
mkName (String
"mk" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name)

    ([Q Exp]
exps, [[Name]] -> [Name]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat -> [Name] -> [Name]
forall a. Eq a => [a] -> [a]
nub -> [Name]
names) <- [PromptTemplate] -> ([Q Exp], [[Name]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([PromptTemplate] -> ([Q Exp], [[Name]]))
-> Q [PromptTemplate] -> Q ([Q Exp], [[Name]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Q PromptTemplate] -> Q [PromptTemplate]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Q PromptTemplate]
xs

    [Exp]
exps'   <- [Q Exp] -> Q [Exp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Q Exp]
exps
    Bang
varbang <- Q SourceUnpackedness -> Q SourceStrictness -> Q Bang
forall (m :: * -> *).
Quote m =>
m SourceUnpackedness -> m SourceStrictness -> m Bang
bang Q SourceUnpackedness
forall (m :: * -> *). Quote m => m SourceUnpackedness
sourceNoUnpack Q SourceStrictness
forall (m :: * -> *). Quote m => m SourceStrictness
sourceStrict

    let recordFields :: [VarBangType]
recordFields      = (Name -> VarBangType) -> [Name] -> [VarBangType]
forall a b. (a -> b) -> [a] -> [b]
map (, Bang
varbang, Name -> Type
ConT ''String) [Name]
names :: [VarBangType]
        promptFunc :: Dec
promptFunc        = Name -> [Clause] -> Dec
FunD 'toPrompt [[Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP (Name -> Pat) -> Name -> Pat
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"template"] (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
ListE [Exp]
exps') []]
        filledConstructor :: Exp
filledConstructor = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
typeName) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
names)
        mkFuncClause :: Clause
mkFuncClause      = [Pat] -> Body -> [Dec] -> Clause
Clause ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
names) (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'toPrompt) Exp
filledConstructor) []

    [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [ Cxt
-> Name
-> [TyVarBndr ()]
-> Maybe Type
-> [Con]
-> [DerivClause]
-> Dec
DataD     [] Name
typeName [] Maybe Type
forall a. Maybe a
Nothing [Name -> [VarBangType] -> Con
RecC Name
typeName [VarBangType]
recordFields] []
           , Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [] (Type -> Type -> Type
AppT (Name -> Type
ConT ''ToPrompt) (Name -> Type
ConT Name
typeName)) [Dec
promptFunc]
           , Name -> [Clause] -> Dec
FunD      Name
funcName [Clause
mkFuncClause]
           ]