{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
module TypeChain.ChatModels.PromptTemplate (ToPrompt(..), makeTemplate, user, assistant, system) where
import Data.Char (toLower)
import Data.List (nub)
import Language.Haskell.TH
import TypeChain.ChatModels.Types
data TemplateToken = ConstString String | Var String deriving TemplateToken -> TemplateToken -> Bool
(TemplateToken -> TemplateToken -> Bool)
-> (TemplateToken -> TemplateToken -> Bool) -> Eq TemplateToken
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TemplateToken -> TemplateToken -> Bool
== :: TemplateToken -> TemplateToken -> Bool
$c/= :: TemplateToken -> TemplateToken -> Bool
/= :: TemplateToken -> TemplateToken -> Bool
Eq
type PromptTemplate = (Q Exp, [Name])
class ToPrompt a where
toPrompt :: a -> [Message]
user :: String -> Q PromptTemplate
user :: String -> Q PromptTemplate
user String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| UserMessage |]
assistant :: String -> Q PromptTemplate
assistant :: String -> Q PromptTemplate
assistant String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| AssistantMessage |]
system :: String -> Q PromptTemplate
system :: String -> Q PromptTemplate
system String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| SystemMessage |]
toTemplate :: String -> Q Exp -> Q PromptTemplate
toTemplate :: String -> Q Exp -> Q PromptTemplate
toTemplate String
xs Q Exp
f = do
let tempParam :: Name
tempParam = String -> Name
mkName String
"template"
let tokens :: [TemplateToken]
tokens = String -> [TemplateToken]
parseTemplateTokens String
xs
expr :: Q Exp
expr = Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
tempParam [TemplateToken]
tokens
names :: [Name]
names = (String -> Name) -> [String] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map String -> Name
mkName ([String] -> [Name]) -> [String] -> [Name]
forall a b. (a -> b) -> a -> b
$ [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ [TemplateToken] -> [String]
getVarTokens [TemplateToken]
tokens
PromptTemplate -> Q PromptTemplate
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE Q Exp
f Q Exp
expr, [Name]
names)
parseTemplateTokens :: String -> [TemplateToken]
parseTemplateTokens :: String -> [TemplateToken]
parseTemplateTokens [] = []
parseTemplateTokens (Char
'{':String
xs) = String -> TemplateToken
Var String
first TemplateToken -> [TemplateToken] -> [TemplateToken]
forall a. a -> [a] -> [a]
: String -> [TemplateToken]
parseTemplateTokens String
rest
where (String
first, String -> String
forall a. HasCallStack => [a] -> [a]
tail -> String
rest) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'}') String
xs
parseTemplateTokens (Char
x:String
xs) = String -> TemplateToken
ConstString (Char
x Char -> String -> String
forall a. a -> [a] -> [a]
: String
first) TemplateToken -> [TemplateToken] -> [TemplateToken]
forall a. a -> [a] -> [a]
: String -> [TemplateToken]
parseTemplateTokens String
rest
where (String
first, String
rest) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'{') String
xs
getVarTokens :: [TemplateToken] -> [String]
getVarTokens :: [TemplateToken] -> [String]
getVarTokens [] = []
getVarTokens (Var String
x : [TemplateToken]
xs) = String
x String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [TemplateToken] -> [String]
getVarTokens [TemplateToken]
xs
getVarTokens (TemplateToken
_ : [TemplateToken]
xs) = [TemplateToken] -> [String]
getVarTokens [TemplateToken]
xs
tokensToExpr :: Name -> [TemplateToken] -> Q Exp
tokensToExpr :: Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
_ [] = [| "" |]
tokensToExpr Name
name (ConstString String
x : [TemplateToken]
xs) = [| x ++ $(Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
name [TemplateToken]
xs) |]
tokensToExpr Name
name (Var String
x : [TemplateToken]
xs) = Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE [| (++) |] (Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> Q Exp) -> Name -> Q Exp
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
x) (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name)) Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` Name -> [TemplateToken] -> Q Exp
tokensToExpr Name
name [TemplateToken]
xs
makeTemplate :: String -> [Q PromptTemplate] -> Q [Dec]
makeTemplate :: String -> [Q PromptTemplate] -> Q [Dec]
makeTemplate String
name [Q PromptTemplate]
xs = do
let typeName :: Name
typeName = String -> Name
mkName String
name
funcName :: Name
funcName = String -> Name
mkName (String
"mk" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name)
([Q Exp]
exps, [[Name]] -> [Name]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat -> [Name] -> [Name]
forall a. Eq a => [a] -> [a]
nub -> [Name]
names) <- [PromptTemplate] -> ([Q Exp], [[Name]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([PromptTemplate] -> ([Q Exp], [[Name]]))
-> Q [PromptTemplate] -> Q ([Q Exp], [[Name]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Q PromptTemplate] -> Q [PromptTemplate]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Q PromptTemplate]
xs
[Exp]
exps' <- [Q Exp] -> Q [Exp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Q Exp]
exps
Bang
varbang <- Q SourceUnpackedness -> Q SourceStrictness -> Q Bang
forall (m :: * -> *).
Quote m =>
m SourceUnpackedness -> m SourceStrictness -> m Bang
bang Q SourceUnpackedness
forall (m :: * -> *). Quote m => m SourceUnpackedness
sourceNoUnpack Q SourceStrictness
forall (m :: * -> *). Quote m => m SourceStrictness
sourceStrict
let recordFields :: [VarBangType]
recordFields = (Name -> VarBangType) -> [Name] -> [VarBangType]
forall a b. (a -> b) -> [a] -> [b]
map (, Bang
varbang, Name -> Type
ConT ''String) [Name]
names :: [VarBangType]
promptFunc :: Dec
promptFunc = Name -> [Clause] -> Dec
FunD 'toPrompt [[Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP (Name -> Pat) -> Name -> Pat
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"template"] (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
ListE [Exp]
exps') []]
filledConstructor :: Exp
filledConstructor = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
typeName) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
names)
mkFuncClause :: Clause
mkFuncClause = [Pat] -> Body -> [Dec] -> Clause
Clause ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
names) (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'toPrompt) Exp
filledConstructor) []
[Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [ Cxt
-> Name
-> [TyVarBndr ()]
-> Maybe Type
-> [Con]
-> [DerivClause]
-> Dec
DataD [] Name
typeName [] Maybe Type
forall a. Maybe a
Nothing [Name -> [VarBangType] -> Con
RecC Name
typeName [VarBangType]
recordFields] []
, Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [] (Type -> Type -> Type
AppT (Name -> Type
ConT ''ToPrompt) (Name -> Type
ConT Name
typeName)) [Dec
promptFunc]
, Name -> [Clause] -> Dec
FunD Name
funcName [Clause
mkFuncClause]
]