{-# LANGUAGE TemplateHaskellQuotes, DerivingVia #-}

-- Helpers for TemplateHaskell instance generators

module AST.TH.Internal.Utils
    ( -- Internals for use in TH for sub-classes
      TypeInfo(..), TypeContents(..), CtrTypePattern(..), NodeWitnesses(..)
    , makeTypeInfo, makeNodeOf
    , parts, toTuple, matchType
    , applicativeStyle, unapply, getVar, makeConstructorVars
    , consPat, simplifyContext
    ) where

import           AST.Class.Nodes
import           AST.Knot (Knot(..), GetKnot, type (#))
import qualified Control.Lens as Lens
import           Control.Lens.Operators
import           Control.Monad.Trans.Class (MonadTrans(..))
import           Control.Monad.Trans.State (StateT(..), evalStateT, execStateT, gets, modify)
import           Data.Foldable (traverse_)
import           Data.List (nub)
import           Data.Map (Map)
import qualified Data.Map as Map
import           Data.Maybe (fromMaybe)
import           Data.Set (Set)
import qualified Data.Set as Set
import           Generic.Data (Generically(..))
import           GHC.Generics (Generic)
import           Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D

import           Prelude.Compat

data TypeInfo = TypeInfo
    { tiName :: Name
    , tiInstance :: Type
    , tiVar :: Name
    , tiContents :: TypeContents
    , tiCons :: [D.ConstructorInfo]
    } deriving Show

data TypeContents = TypeContents
    { tcChildren :: Set Type
    , tcEmbeds :: Set Type
    , tcOthers :: Set Type
    } deriving (Show, Generic)
    deriving (Semigroup, Monoid) via Generically TypeContents

makeTypeInfo :: Name -> Q TypeInfo
makeTypeInfo name =
    do
        info <- D.reifyDatatype name
        (dst, var) <- parts info
        contents <- evalStateT (childrenTypes var (AppT dst (VarT var))) mempty
        pure TypeInfo
            { tiName = name
            , tiInstance = dst
            , tiVar = var
            , tiContents = contents
            , tiCons = D.datatypeCons info
            }

parts :: D.DatatypeInfo -> Q (Type, Name)
parts info =
    case D.datatypeVars info of
    [] -> fail "expected type constructor which requires arguments"
    xs ->
        case last xs of
        KindedTV var (ConT knot) | knot == ''Knot -> pure (res, var)
        PlainTV var -> pure (res, var)
        _ -> fail "expected last argument to be a knot variable"
        where
            res =
                foldl AppT (ConT (D.datatypeName info)) (init xs <&> VarT . D.tvName)

childrenTypes ::
    Name -> Type -> StateT (Set Type) Q TypeContents
childrenTypes var typ =
    do
        did <- gets (^. Lens.contains typ)
        if did
            then pure mempty
            else modify (Lens.contains typ .~ True) *> add (matchType var typ)
    where
        add (NodeFofX ast) = pure mempty { tcChildren = Set.singleton ast }
        add (XofF ast) =
            case unapply ast of
            (ConT name, as) -> childrenTypesFromTypeName name as
            (x@VarT{}, as) -> pure mempty { tcEmbeds = Set.singleton (foldl AppT x as) }
            _ -> pure mempty
        add (Tof _ pat) = add pat
        add Other{} = pure mempty

unapply :: Type -> (Type, [Type])
unapply =
    go []
    where
        go as (SigT x _) = go as x
        go as (AppT f a) = go (a:as) f
        go as x = (x, as)

matchType :: Name -> Type -> CtrTypePattern
matchType var (ConT runKnot `AppT` VarT k `AppT` (PromotedT knot `AppT` ast))
    | runKnot == ''GetKnot && knot == 'Knot && k == var =
        NodeFofX ast
matchType var (ConT tie `AppT` VarT k `AppT` ast)
    | tie == ''(#) && k == var =
        NodeFofX ast
matchType var (ast `AppT` VarT knot)
    | knot == var && ast /= ConT ''GetKnot =
        XofF ast
matchType var x@(AppT t typ) =
    -- TODO: check if applied over a functor-kinded type.
    case matchType var typ of
    Other{} -> Other x
    pat -> Tof t pat
matchType _ t = Other t

data CtrTypePattern
    = NodeFofX Type
    | XofF Type
    | Tof Type CtrTypePattern
    | Other Type
    deriving Show

getVar :: Type -> Maybe Name
getVar (VarT x) = Just x
getVar (SigT x _) = getVar x
getVar _ = Nothing

childrenTypesFromTypeName ::
    Name -> [Type] -> StateT (Set Type) Q TypeContents
childrenTypesFromTypeName name args =
    reifyInstances ''KNodesConstraint [typ, VarT constraintVar] & lift
    >>=
    \case
    [] ->
        D.reifyDatatype name
        <&> Just
        & recover (pure Nothing)
        & lift
        >>=
        \case
        Just info ->
            do
                (_, var) <- parts info & lift
                D.datatypeCons info >>= D.constructorFields
                    <&> D.applySubstitution substs
                    & traverse (childrenTypes var)
                    <&> mconcat
            where
                substs =
                    zip (D.datatypeVars info) args
                    <&> Lens._1 %~ D.tvName
                    & Map.fromList
        Nothing ->
            -- Not a datatype, so an embedded type family
            pure mempty { tcEmbeds = Set.singleton typ }
    [TySynInstD ccI (TySynEqn [typI, VarT cI] x)]
        | ccI == ''KNodesConstraint ->
            case unapply typI of
            (ConT n1, argsI) | n1 == name ->
                case traverse getVar argsI of
                Nothing ->
                    error ("TODO: Support Children constraint of flexible instances " <> show typ)
                Just argNames ->
                    childrenTypesFromChildrenConstraint cI (D.applySubstitution substs x)
                    where
                        substs = zip argNames args & Map.fromList
            _ -> error ("ReifyInstances brought wrong typ: " <> show (name, typI))
    xs -> error ("Malformed ChildrenConstraint instance: " <> show xs)
    where
        typ = foldl AppT (ConT name) args

constraintVar :: Name
constraintVar = mkName "constraint"

childrenTypesFromChildrenConstraint ::
    Name -> Type -> StateT (Set Type) Q TypeContents
childrenTypesFromChildrenConstraint c0 c@(AppT (VarT c1) x)
    | c0 == c1 = pure mempty { tcChildren = Set.singleton x }
    | otherwise = error ("TODO: Unsupported ChildrenContraint " <> show c)
childrenTypesFromChildrenConstraint c0 constraints =
    case unapply constraints of
    (ConT cc1, [x, VarT c1])
        | cc1 == ''KNodesConstraint && c0 == c1 ->
            pure mempty { tcEmbeds = Set.singleton x }
    (TupleT{}, xs) ->
        traverse (childrenTypesFromChildrenConstraint c0) xs <&> mconcat
    _ -> pure mempty { tcOthers = Set.singleton (D.applySubstitution subst constraints) }
    where
        subst = mempty & Lens.at c0 ?~ VarT constraintVar

toTuple :: Foldable t => t Type -> Type
toTuple xs = foldl AppT (TupleT (length xs)) xs

applicativeStyle :: Exp -> [Exp] -> Exp
applicativeStyle f =
    foldl ap (AppE (VarE 'pure) f)
    where
        ap x y = InfixE (Just x) (VarE '(<*>)) (Just y)

makeConstructorVars :: String -> D.ConstructorInfo -> [(Type, Name)]
makeConstructorVars prefix cons =
    [0::Int ..] <&> show <&> (('_':prefix) <>) <&> mkName
    & zip (D.constructorFields cons)

consPat :: D.ConstructorInfo -> [(Type, Name)] -> Pat
consPat cons vars =
    ConP (D.constructorName cons) (vars <&> snd <&> VarP)

simplifyContext :: [Pred] -> CxtQ
simplifyContext preds =
    goPreds preds
    & (`execStateT` (mempty :: Set (Name, [Type]), mempty :: Set Pred))
    <&> snd
    <&> Set.toList
    where
        goPreds ps = ps <&> unapply & traverse_ go
        go (c, [VarT v]) =
            -- Work-around reifyInstances returning instances for type variables
            -- by not checking.
            yep c [VarT v]
        go (ConT c, xs) =
            Lens.use (Lens._1 . Lens.contains key)
            >>=
            \case
            True -> pure () -- already checked
            False ->
                do
                    Lens._1 . Lens.contains key .= True
                    reifyInstances c xs & lift
                        >>=
                        \case
                        [InstanceD _ context other _] ->
                            D.unifyTypes [foldl AppT (ConT c) xs, other] & lift
                            <&> (`D.applySubstitution` context)
                            >>= goPreds
                        _ -> yep (ConT c) xs
            where
                key = (c, xs)
        go (c, xs) = yep c xs
        yep c xs = Lens._2 . Lens.contains (foldl AppT c xs) .= True

data NodeWitnesses = NodeWitnesses
    { nodeWit :: Type -> Exp
    , embedWit :: Type -> Exp
    , nodeWitCtrs :: [Name]
    , embedWitCtrs :: [Name]
    }

makeNodeOf :: TypeInfo -> ([Con], NodeWitnesses)
makeNodeOf info =
    ( (nodes <&> Lens._1 %~ nodeGadtType) <> (embeds <&> Lens._1 %~ embedGadtType)
        <&> \(t, n) -> GadtC [n] [] t
    , NodeWitnesses
        { nodeWit = nodes <&> Lens._2 %~ ConE & Map.fromList & getWit
        , embedWit = embeds <&> Lens._2 %~ ConE & Map.fromList & getWit
        , nodeWitCtrs = nodes <&> snd
        , embedWitCtrs = embeds <&> snd
        }
    )
    where
        niceTypeName = tiName info & show & makeNiceName
        makeNiceName = reverse . takeWhile (/= '.') . reverse
        nodeBase = "W_" <> niceTypeName <> "_"
        embedBase = "E_" <> niceTypeName <> "_"
        pats =
            tiCons info
            >>= D.constructorFields
            <&> matchType (tiVar info)
        makeNiceType (ConT x) = makeNiceName (show x)
        makeNiceType (AppT x y) = makeNiceType x <> "_" <> makeNiceType y
        makeNiceType (VarT x) = takeWhile (/= '_') (show x)
        makeNiceType (SigT x _) = makeNiceType x
        makeNiceType x = error ("TODO: Witness name generator is partial! Need to support " <> show x)
        nodes =
            pats >>= nodesForPat & nub
            <&> \t -> (t, mkName (nodeBase <> makeNiceType t))
        nodesForPat (NodeFofX t) = [t]
        nodesForPat (Tof _ pat) = nodesForPat pat
        nodesForPat _ = []
        nodeGadtType t = ConT ''KWitness `AppT` tiInstance info `AppT` t
        embeds =
            pats >>= embedsForPat & nub
            <&> \t -> (t, mkName (embedBase <> makeNiceType t))
        embedsForPat (XofF t) = [t]
        embedsForPat (Tof _ pat) = embedsForPat pat
        embedsForPat _ = []
        embedGadtType t =
            ArrowT
            `AppT` (ConT ''KWitness `AppT` t `AppT` VarT nodeVar)
            `AppT` (ConT ''KWitness `AppT` tiInstance info `AppT` VarT nodeVar)
        nodeVar = mkName "node"
        getWit :: Map Type Exp -> Type -> Exp
        getWit m k =
            m ^? Lens.ix k
            & fromMaybe (LitE (StringL ("Cant find witness for " <> show k <> " in " <> show m)))