{-# LANGUAGE LambdaCase   #-}
{-# LANGUAGE ViewPatterns #-}

module Language.Haskell.TH.TypeInterpreter.Expression
    ( TypeAtom (..)
    , TypeExp (..)
    , substitute
    , substituteAll
    , reduce
    , match
    , familyExp )
where

import Control.Monad

import qualified Data.Map as Map

import Language.Haskell.TH (Name)

-- | Type atom
data TypeAtom
    = Integer Integer
    | String String
    | Name Name
    | PromotedName Name
    deriving Eq

instance Show TypeAtom where
    show (Integer i)      = show i
    show (String s)       = show s
    show (Name n)         = show n
    show (PromotedName n) = '\'' : show n

-- | Type expression
data TypeExp
    = Atom TypeAtom
    | Apply TypeExp TypeExp
    | Variable Name
    | Synonym Name TypeExp
    | Family (TypeExp -> TypeExp)

instance Show TypeExp where
    showsPrec n (Atom a)      = showsPrec n a
    showsPrec _ (Variable n)  = (show n ++)
    showsPrec n (Apply f x)   = showParen (n >= 10) $ \ tail ->
        showsPrec 10 f (' ' : showsPrec 10 x tail)
    showsPrec n (Synonym s x) = showParen (n >= 10) $ \ tail ->
        'λ' : showsPrec 0 s ('.' : ' ' : showsPrec 0 x tail)
    showsPrec n (Family _)    = showParen (n >= 10) ("λ?. ?" ++)

-- | @substitute name typ exp@ replaces all occurences of @name@ in @exp@ with @typ@.
substitute :: Name -> TypeExp -> TypeExp -> TypeExp
substitute name typ =
    subst
    where
        subst = \case
            Variable varName
                | varName == name -> typ

            Apply fun param ->
                Apply (subst fun) (subst param)

            Synonym subName body
                | subName == name -> subst body
                | otherwise       -> Synonym subName (subst body)

            t -> t

-- | Just like 'substitute' but for more variables.
substituteAll :: Map.Map Name TypeExp -> TypeExp -> TypeExp
substituteAll =
    flip (Map.foldlWithKey' (\ exp name typ -> substitute name typ exp))

-- | Try to reduce the given type expression as much as possible.
reduce :: TypeExp -> TypeExp
reduce = \case
    Apply (reduce -> f) (reduce -> x)
        | Synonym n b <- f -> reduce (substitute n x b)
        | Family g <- f    -> reduce (g x)
        | otherwise        -> Apply f x

    Synonym n b -> Synonym n (reduce b)

    Family f -> Family (reduce . f)

    t -> t

-- | @match pattern input@ pattern matches @input@ against the given @pattern@.
match :: TypeExp -> TypeExp -> Maybe (Map.Map Name TypeExp)
match pattern input =
    match' (reduce pattern, reduce input)
    where
        match' = \case
            (Variable n, v        ) -> Just (Map.singleton n v)
            (Apply f x , Apply g y) -> Map.union <$> match' (f, g) <*> match' (x, y)
            (Atom l    , Atom r   ) -> Map.empty <$ guard (l == r)
            _                       -> Nothing

-- | @familyExp n impl@ creates a type family expression with @n@ parameters and the
-- implementation @impl@.
familyExp :: Int -> ([TypeExp] -> TypeExp) -> TypeExp
familyExp n f
    | n <= 0    = f []
    | otherwise = Family (\ t -> familyExp (n - 1) (\ ts -> f (t : ts)))