{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TemplateHaskell #-}

module TypeChain.ChatModels.PromptTemplate (makeTemplate, user, assistant, system) where

import Data.List (nub)

import Language.Haskell.TH

import TypeChain.ChatModels.Types

data TemplateToken = ConstString String | Var String deriving TemplateToken -> TemplateToken -> Bool
(TemplateToken -> TemplateToken -> Bool)
-> (TemplateToken -> TemplateToken -> Bool) -> Eq TemplateToken
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TemplateToken -> TemplateToken -> Bool
== :: TemplateToken -> TemplateToken -> Bool
$c/= :: TemplateToken -> TemplateToken -> Bool
/= :: TemplateToken -> TemplateToken -> Bool
Eq

type PromptTemplate = (Q Exp, [Name])

user :: String -> Q PromptTemplate
user :: String -> Q PromptTemplate
user String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| UserMessage |]

assistant :: String -> Q PromptTemplate
assistant :: String -> Q PromptTemplate
assistant String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| AssistantMessage |]

system :: String -> Q PromptTemplate
system :: String -> Q PromptTemplate
system String
xs = String -> Q Exp -> Q PromptTemplate
toTemplate String
xs [| SystemMessage |]

toTemplate :: String -> Q Exp -> Q PromptTemplate
toTemplate :: String -> Q Exp -> Q PromptTemplate
toTemplate String
xs Q Exp
f = do 
    let tokens :: [TemplateToken]
tokens = String -> [TemplateToken]
parseTemplateTokens String
xs
        names :: [Name]
names  = (String -> Name) -> [String] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map String -> Name
mkName ([String] -> [Name]) -> [String] -> [Name]
forall a b. (a -> b) -> a -> b
$ [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ [TemplateToken] -> [String]
getVarTokens [TemplateToken]
tokens
        params :: [Q Pat]
params = (Name -> Q Pat) -> [Name] -> [Q Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
names
        func :: Q Exp
func   = [Q Pat] -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE [Q Pat]
params (Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE Q Exp
f (Q Exp -> Q Exp) -> Q Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ [TemplateToken] -> Q Exp
tokensToExpr [TemplateToken]
tokens)
        expr :: Q Exp
expr   = (Q Exp -> Q Exp -> Q Exp) -> Q Exp -> [Q Exp] -> Q Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE Q Exp
func ((Name -> Q Exp) -> [Name] -> [Q Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE [Name]
names)

    PromptTemplate -> Q PromptTemplate
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Q Exp
expr, [Name]
names)

makeTemplate :: [Q PromptTemplate] -> Q Exp
makeTemplate :: [Q PromptTemplate] -> Q Exp
makeTemplate [Q PromptTemplate]
xs = do 
    ([Q Exp]
exps, [[Name]]
ps) <- [PromptTemplate] -> ([Q Exp], [[Name]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([PromptTemplate] -> ([Q Exp], [[Name]]))
-> Q [PromptTemplate] -> Q ([Q Exp], [[Name]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Q PromptTemplate] -> Q [PromptTemplate]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Q PromptTemplate]
xs
    let params :: [Name]
params = [Name] -> [Name]
forall a. Eq a => [a] -> [a]
nub ([Name] -> [Name]) -> [Name] -> [Name]
forall a b. (a -> b) -> a -> b
$ [[Name]] -> [Name]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Name]]
ps
    
    [Q Pat] -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE ((Name -> Q Pat) -> [Name] -> [Q Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
params) (Q Exp -> Q Exp) -> Q Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Q Exp] -> Q Exp
forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE [Q Exp]
exps

parseTemplateTokens :: String -> [TemplateToken]
parseTemplateTokens :: String -> [TemplateToken]
parseTemplateTokens [] = [] 
parseTemplateTokens (Char
'{':String
xs) = String -> TemplateToken
Var String
first TemplateToken -> [TemplateToken] -> [TemplateToken]
forall a. a -> [a] -> [a]
: String -> [TemplateToken]
parseTemplateTokens String
rest
    where (String
first, String -> String
forall a. HasCallStack => [a] -> [a]
tail -> String
rest) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'}') String
xs
parseTemplateTokens (Char
x:String
xs) = String -> TemplateToken
ConstString (Char
x Char -> String -> String
forall a. a -> [a] -> [a]
: String
first) TemplateToken -> [TemplateToken] -> [TemplateToken]
forall a. a -> [a] -> [a]
: String -> [TemplateToken]
parseTemplateTokens String
rest
    where (String
first, String
rest) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'{') String
xs

getVarTokens :: [TemplateToken] -> [String]
getVarTokens :: [TemplateToken] -> [String]
getVarTokens [] = [] 
getVarTokens (Var String
x : [TemplateToken]
xs) = String
x String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [TemplateToken] -> [String]
getVarTokens [TemplateToken]
xs
getVarTokens (TemplateToken
_     : [TemplateToken]
xs) = [TemplateToken] -> [String]
getVarTokens [TemplateToken]
xs

tokensToExpr :: [TemplateToken] -> Q Exp
tokensToExpr :: [TemplateToken] -> Q Exp
tokensToExpr [] = [| "" |]
tokensToExpr (ConstString String
x : [TemplateToken]
xs) = [| x ++ $([TemplateToken] -> Q Exp
tokensToExpr [TemplateToken]
xs) |]
tokensToExpr (Var String
x : [TemplateToken]
xs) = Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE [| (++) |] (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (String -> Name
mkName String
x)) Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` [TemplateToken] -> Q Exp
tokensToExpr [TemplateToken]
xs