-- | Creation of declarations from a 'FoldFamily'
{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Origami.Internal.TH(mkFoldDecs,
    ctorNamesAreUnique,
    duplicateCtorNames,
    typeNamesAreUnique) where

import Control.Lens hiding (Fold)
import Data.Bitraversable(Bitraversable(..))
import Data.Char(toLower)
import Data.Foldable(Foldable(foldMap))
import Data.List(sort, sortBy)
import qualified Data.Map as M
import Data.Ord(comparing)
import Data.Origami.Internal.Fold(Fold(..), errFold, foldFoldFamily)
import Data.Origami.Internal.FoldFamily
import Data.Origami.Internal.THUtils
import Data.Origami.Internal.Trifunctor(Trifunctor(..), Tritraversable(..))
import Data.Sequence.Lens(seqOf)
import qualified Data.Set as S
import Data.Set.Lens(setOf)
import Data.Traversable(sequence)
import Language.Haskell.TH.Syntax
import Prelude hiding (sequence)

typeNames :: Traversal' FoldFamily Name
typeNames = dataTys . traverse . name

dataCases' :: Traversal' FoldFamily DataCase
dataCases' = dataTys . traverse . dataCases . traverse

-- | A 'Traversal' over the constructor 'Name's of the 'FoldFamily'
ctorNames :: Traversal' FoldFamily Name
ctorNames = dataCases' . name

-- | Returns @True@ iff the 'Name's of the datatypes in the
-- 'FoldFamily' are all unique.
typeNamesAreUnique :: FoldFamily -> Bool
typeNamesAreUnique = areUnique . seqOf typeNames

-- | Returns the set of constructor 'Name's in the 'FoldFamily' that
-- are repeated.
duplicateCtorNames :: FoldFamily -> S.Set String
duplicateCtorNames = duplicates . fmap upperName . seqOf ctorNames

-- | Returns @True@ iff the constructor 'Name's of the 'FoldFamily'
-- are all unique.
ctorNamesAreUnique :: FoldFamily -> Bool
ctorNamesAreUnique = areUnique . seqOf ctorNames

duplicates :: (Ord a, Foldable f) => f a -> S.Set a
duplicates = M.keysSet
                 . M.filter (> (1 :: Int))
                 . M.fromListWith (+)
                 . foldMap (\x -> [(x, 1)])

areUnique :: (Ord a, Foldable f) => f a -> Bool
areUnique = S.null . duplicates

lowerTHName :: Name -> Name
lowerTHName = mkName . dodgeNameClash . lowerName
    where
    lowerName :: Name -> String
    lowerName nm = toLower c : cs
        where
        (c : cs) = nameBase nm

dodgeNameClash :: String -> String
dodgeNameClash str = if str `elem` keywords
                         then str'
                         else str
    where
    str' = str ++ "'"

    -- | Alphanumeric keywords in Haskell
    keywords :: [String]
    keywords = ["as", "case", "class", "data", "default", "deriving",
                "do", "else", "family", "forall", "foreign hiding",
                "if", "import", "in", "infix", "infixl", "infixr",
                "instance", "let", "mdo", "module newtype", "of",
                "proc", "qualified", "rec", "then", "type", "where"]

upperTHName :: Name -> Name
upperTHName = mkName . upperName

thMkName :: Name -> Name
thMkName = mkName . thMkString

thMkString :: Name -> String
thMkString nm = "mk" ++ upperName nm

thFoldName :: Name -> Name
thFoldName = mkName . thFoldString

thFoldString :: Name -> String
thFoldString nm = "fold" ++ upperName nm

foldName :: Name
foldName = mkName "Fold"

typeNameList :: FoldFamily -> [Name]
typeNameList = S.toList . setOf typeNames

{-
ctorNameList :: FoldFamily -> [Name]
ctorNameList = S.toList . setOf ctorNames
-}

-- | Creates declarations for the
-- * @Fold@,
-- * @idFold@,
-- * @errFold@,
-- * @monadicFold@,
-- * and one @foldXxx@ function for each datatype @Xxx@ in the
-- 'FoldFamily'.
mkFoldDecs :: FoldFamily -> [Dec]
mkFoldDecs ff = mkFoldDec ff : mkIdFoldDecs ff ++ mkErrFoldDecs ff
                                               ++ mkMonadicFoldDecs ff
                                               ++ mkFoldFuncDecs ff

-- | Creates a declaration for the @Fold@.
mkFoldDec :: FoldFamily -> Dec
mkFoldDec ff = foldFoldFamily fold' ff
    where
    fold' :: Fold (Name, [Type]) Type [(Name, VarStrictType)] Dec Name
    fold' = Fold {
                mkFoldFamily = mkFoldFamily',
                mkDataTy = mkDataTy',
                mkDataCase = (,),
                mkTy = id,
                mkAtomic = ConT . upperTHName,
                mkNonatomic = VarT . lowerTHName,
                mkFunct = mkFunct',
                mkBifunct = mkBifunct',
                mkTrifunct = mkTrifunct'
            }

    mkFoldFamily' :: [[(Name, VarStrictType)]] -> Dec
    mkFoldFamily' dts = DataD [] foldName tvbs [con] []
        where
        tvbs :: [TyVarBndr]
        tvbs = map (PlainTV . lowerTHName) $ typeNameList ff

        con :: Con
        con = RecC foldName vsts

        vsts :: [VarStrictType]
        vsts = map snd $ sortBy (comparing fst) $ concat dts

    mkDataTy' :: Name -> [(Name, [Type])] -> [(Name, VarStrictType)]
    mkDataTy' ty dcs
        = [(ctor, (ctorNm, NotStrict, fldTy))
               | (ctor, fldTys) <- dcs,
                 let ctorNm = thMkName ctor,
                 let fldTy = funcTs (fldTys ++ [resTy])]
        where
        resTy = VarT $ lowerTHName ty

    mkFunct' :: Name -> Type -> Type
    mkFunct' nm
        | nm == ''[]            = AppT ListT
        | otherwise             = AppT (ConT nm)

    mkBifunct' :: Name -> Type -> Type -> Type
    mkBifunct' nm lhs rhs
        | nm == ''(,)     = appTs [TupleT 2, lhs, rhs]
        | otherwise       = appTs [ConT nm, lhs, rhs]

    mkTrifunct' :: Name -> Type -> Type -> Type -> Type
    mkTrifunct' nm l' m' r'
        | nm == ''(,,)   = appTs [TupleT 3, l', m', r']
        | otherwise      = appTs [ConT nm, l', m', r']

-- | Creates a declaration for the @idFold@.
mkIdFoldDecs :: FoldFamily -> [Dec]
mkIdFoldDecs ff = foldFoldFamily fold' ff
    where
    fold' :: Fold Name dataField [Name] [Dec] ty
    fold' = (errFold "mkIdFoldDecs.fold'"){
                mkFoldFamily = mkFoldFamily',
                mkDataTy = mkDataTy',
                mkDataCase = const
            }

    mkFoldFamily' :: [[Name]] -> [Dec]
    mkFoldFamily' dcs = [SigD nm ty, ValD pat bd []]
        where
        nm :: Name
        nm = mkName "idFold"

        ty :: Type
        ty = appTs $ map ConT $ foldName : map upperTHName (typeNameList ff)

        pat :: Pat
        pat = VarP nm

        bd :: Body
        bd = NormalB
                 $ RecConE foldName
                       [(thMkName ws, ConE $ upperTHName ws) | ws <- ctors ]

        ctors :: [Name]
        ctors = sort $ concat dcs

    mkDataTy' :: Name -> [Name] -> [Name]
    mkDataTy' _ = id

foldTy' :: FoldFamily -> Type
foldTy' ff = appTs $ ConT foldName : map (VarT. lowerTHName) tyNms
    where
    tyNms = typeNameList ff

-- | Creates a declaration for the @errFold@.
mkErrFoldDecs :: FoldFamily -> [Dec]
mkErrFoldDecs ff = foldFoldFamily fold' ff
    where
    fold' :: Fold (Name, FieldExp) dataField [(Name, FieldExp)] [Dec] ty
    fold' = (errFold "errFoldDecs"){
                mkFoldFamily = mkFoldFamily',
                mkDataTy = flip const,
                mkDataCase = mkDataCase'
            }

    mkFoldFamily' :: [[(Name, FieldExp)]] -> [Dec]
    mkFoldFamily' dts = [SigD nm ty, FunD nm [cl]]
        where
        cl :: Clause
        cl = Clause [VarP foldTagNm] bd [errDef]

        bd :: Body
        bd = mkSortedRecBody foldName dts

    mkDataCase' :: Name -> [dataField] -> (Name, FieldExp)
    mkDataCase' ctor _ = (ctor, (thMkName ctor, errExp))
        where
        errExp :: Exp
        errExp = AppE (VarE errNm) (LitE $ StringL $ thMkString ctor)

    foldTagNm :: Name
    foldTagNm = mkName "foldTag'"

    nm :: Name
    nm = mkName "errFold"

    ty :: Type
    ty = funcT (ConT ''String) (ForallT tvbs [] $ foldTy' ff)
        where
        tvbs :: [TyVarBndr]
        tvbs = map (PlainTV. lowerTHName) $ typeNameList ff

    errNm :: Name
    errNm = mkName "err"

    errDef :: Dec
    errDef = FunD errNm [cl']
        where
        fieldTagNm :: Name
        fieldTagNm = mkName "fieldTag"

        cl' :: Clause
        cl' = Clause [VarP fieldTagNm] bd' []

        bd' :: Body
        bd' = NormalB $ AppE (VarE $ mkName "error")
                             (ParensE $ AppE (VarE $ mkName "concat") $
                                              ListE [VarE foldTagNm,
                                                     LitE $ StringL ".",
                                                     VarE fieldTagNm])

-- | Creates a declaration for the @monadicFold@.
mkMonadicFoldDecs :: FoldFamily -> [Dec]
mkMonadicFoldDecs ff = foldFoldFamily fold' ff
    where
    fold' :: Fold (Name, FieldExp) Exp [(Name, FieldExp)] [Dec] ty
    fold' = Fold {
                mkFoldFamily = mkFoldFamily',
                mkDataTy = const id,
                mkDataCase = mkDataCase',
                mkAtomic = mkAtomic',
                mkNonatomic = mkNonatomic',
                mkFunct = mkFunct',
                mkBifunct = mkBifunct',
                mkTrifunct = mkTrifunct',
                mkTy = error "mkMonadicFoldDecs.mkTy"
            }

    nm :: Name
    nm = mkName "monadicFold"

    m :: Name -> Type
    m nm' = AppT (VarT mNm) (VarT nm')

    mNm :: Name
    mNm = mkName "m"

    monadicFoldTy :: Type
    monadicFoldTy = appTs (ConT foldName
                              : map (m . lowerTHName) (typeNameList ff))

    baseFoldName :: Name
    baseFoldName = mkName "baseFold"

    mkFoldFamily' :: [[(Name, FieldExp)]] -> [Dec]
    mkFoldFamily' dcs = [SigD nm ty',
                         FunD nm [cl]]
        where
        ty' :: Type
        ty' = ForallT tvbs [] ty

        tvbs :: [TyVarBndr]
        tvbs = map (PlainTV . lowerTHName) $ typeNameList ff

        ty :: Type
#if MIN_VERSION_template_haskell(2,10,0)
        ty = funcT baseFoldTy
                   (ForallT [PlainTV mNm]
                            [AppT (ConT ''Monad) (VarT mNm)]
                            monadicFoldTy)
#else
        ty = funcT baseFoldTy
                   (ForallT [PlainTV mNm]
                            [ClassP ''Monad [VarT mNm]]
                            monadicFoldTy)
#endif

        baseFoldTy :: Type
        baseFoldTy = foldTy' ff

        cl :: Clause
        cl = Clause [VarP baseFoldName] bd []

        bd :: Body
        bd = mkSortedRecBody foldName dcs

    mkDataCase' :: Name -> [Exp] -> (Name, FieldExp)
    mkDataCase' ctor dfs = (ctor, (thMkName ctor, LamE pats doE))
        where
        pats :: [Pat]
        pats = zipWith const varPs dfs

        doE :: Exp
        doE = DoE (bindStmts ++ [NoBindS resE])

        bindStmts :: [Stmt]
        bindStmts = [BindS p (AppE df e) | (p, df, e) <- zip3 varPs' dfs varEs]

        resE :: Exp
        resE = AppE (VarE 'return)
                    (appEs $ VarE (thMkName ctor) : VarE baseFoldName : exps')
            where
            exps' :: [Exp]
            exps' = zipWith const varEs' dfs

    mkAtomic' :: ty -> Exp
    mkAtomic' _ = VarE 'return

    mkNonatomic' :: ty -> Exp
    mkNonatomic' _  = VarE 'id

    mkFunct' :: Name -> Exp -> Exp
    mkFunct' _ f = ParensE $ comp (VarE 'sequence)
                                  (appEs [VarE 'fmap, f])

    mkBifunct' :: Name -> Exp -> Exp -> Exp
    mkBifunct' _ l r = ParensE $ comp (VarE 'bisequence)
                                      (appEs [VarE 'bimap, l, r])

    mkTrifunct' :: Name -> Exp -> Exp -> Exp -> Exp
    mkTrifunct' _ l' m' r' = ParensE $ comp (VarE 'trisequence)
                                            (appEs [VarE 'trimap, l', m', r'])

    -- | Composes two 'Exp's.
    comp :: Exp -> Exp -> Exp
    comp lhs rhs = InfixE (Just lhs) composeE (Just rhs)
        where
        composeE :: Exp
        composeE = VarE $ mkName "."

-- | Creates a fold function @foldXxx@ for each datatype @Xxx@ in the
-- 'FoldFamily'.
mkFoldFuncDecs :: FoldFamily -> [Dec]
mkFoldFuncDecs ff = foldFoldFamily fold' ff
    where
    fold' :: Fold Clause Exp [Dec] [Dec] Name
    fold' = Fold {
                mkFoldFamily = concat,
                mkDataTy = mkDataTy',
                mkDataCase = mkDataCase',
                mkAtomic = const (VarE 'id),
                mkNonatomic = mkNonatomic',
                mkFunct = mkFunct',
                mkBifunct = mkBifunct',
                mkTrifunct = mkTrifunct',
                mkTy = id
            }

    fName :: Name
    fName = mkName "f"

    fExp :: Exp
    fExp = VarE fName

    fPat :: Pat
    fPat = VarP fName

    mkDataTy' :: Name -> [Clause] -> [Dec]
    mkDataTy' nm dcs = [SigD foldNm ty, FunD foldNm dcs]
        where
        foldNm :: Name
        foldNm = thFoldName nm

        ty :: Type
        ty = ForallT tvbs [] ty'
            where
            tvbs :: [TyVarBndr]
            tvbs = map (PlainTV. lowerTHName) tyNms

            tyNms :: [Name]
            tyNms = typeNameList ff

            ty' :: Type
            ty' = funcTs [foldTy' ff,
                          ConT $ upperTHName nm,
                          VarT $ lowerTHName nm]

    mkDataCase' :: Name -> [Exp] -> Clause
    mkDataCase' ctor dfs = Clause [fPat, argPat] bd []
        where
        argPat :: Pat
        argPat = ConP (upperTHName ctor) $ zipWith const varPs dfs

        bd :: Body
        bd = NormalB $ appEs $ (VarE $ thMkName ctor)
                               : fExp
                               : zipWith AppE dfs varEs

    mkNonatomic' :: Name -> Exp
    mkNonatomic' ws = AppE (VarE $ thFoldName ws) fExp

    mkFunct' :: Name -> Exp -> Exp
    mkFunct' _ = AppE (VarE 'fmap)

    mkBifunct' :: Name -> Exp -> Exp -> Exp
    mkBifunct' _ e e' = appEs [VarE 'bimap, e, e']

    mkTrifunct' :: Name -> Exp -> Exp -> Exp -> Exp
    mkTrifunct' _ e e' e'' = appEs [VarE 'trimap, e, e', e'']

mkSortedRecBody :: Name -> [[(Name, FieldExp)]] -> Body
mkSortedRecBody nm taggedFieldExps = NormalB $ RecConE nm fieldExps
    where
    fieldExps :: [FieldExp]
    fieldExps = map snd $ sortBy (comparing fst) $ concat taggedFieldExps