{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Language.Symantic.Compiling.Term where

import Data.Maybe (isJust)
import Data.Semigroup (Semigroup(..))
import qualified Data.Kind as K
import qualified Data.Set as Set
import qualified Data.Text as Text

import Language.Symantic.Grammar
import Language.Symantic.Interpreting
import Language.Symantic.Transforming.Trans
import Language.Symantic.Typing

-- * Type 'Term'
data Term src ss ts vs (t::K.Type) where
 Term :: Type src vs       q
      -> Type src vs       t
      -> TeSym ss ts       (q #> t)
      -> Term src ss ts vs (q #> t)
instance Source src => Eq (Term src ss ts vs t) where
	Term qx tx _ == Term qy ty _ = qx == qy && tx == ty
instance Source src => Show (Term src ss ts vs t) where
	showsPrec p (Term q t _te) = showsPrec p (q #> t)

-- Source
type instance SourceOf (Term src ss ts vs t) = src
instance Source src => Sourced (Term src ss ts vs t) where
	sourceOf  (Term _q t _te)    = sourceOf t
	setSource (Term q t te) src = Term q (setSource t src) te

-- Const
instance ConstsOf (Term src ss ts vs t) where
	constsOf (Term q t _te) = constsOf q `Set.union` constsOf t

-- Var
type instance VarsOf (Term src ss ts vs t) = vs
instance LenVars (Term src ss ts vs t) where
	lenVars (Term _q t _te) = lenVars t
instance AllocVars (Term src ss ts) where
	allocVarsL len (Term q t te) = Term (allocVarsL len q) (allocVarsL len t) te
	allocVarsR len (Term q t te) = Term (allocVarsR len q) (allocVarsR len t) te

-- Fam
instance ExpandFam (Term src ss ts vs t) where
	expandFam (Term q t te) = Term (expandFam q) (expandFam t) te

-- Type
instance SourceInj (TermT src ss ts vs) src => TypeOf (Term src ss ts vs) where
	typeOf t = typeOfTerm t `withSource` TermT t

typeOfTerm :: Source src => Term src ss ts vs t -> Type src vs t
typeOfTerm (Term q t _) = q #> t

-- ** Type 'TermT'
-- | 'Term' with existentialized 'Type'.
data TermT src ss ts vs = forall t. TermT (Term src ss ts vs t)
instance Source src => Show (TermT src ss ts vs) where
	showsPrec p (TermT t) = showsPrec p t

-- ** Type 'TermVT'
-- | 'Term' with existentialized 'Var's and 'Type'.
data TermVT src ss ts = forall vs t. TermVT (Term src ss ts vs t)
instance Source src => Eq (TermVT src ss ts) where
	TermVT x == TermVT y =
		case appendVars x y of
		 (Term qx' tx' _, Term qy' ty' _) ->
			isJust $ (qx' #> tx') `eqTypeKi` (qy' #> ty')
instance Source src => Show (TermVT src ss ts) where
	showsPrec p (TermVT t) = showsPrec p t
type instance SourceOf (TermVT src ss ts) = src
instance Source src => Sourced (TermVT src ss ts) where
	sourceOf  (TermVT t)     = sourceOf t
	setSource (TermVT t) src = TermVT $ setSource t src

liftTermVT :: TermVT src ss '[] -> TermVT src ss ts
liftTermVT (TermVT (Term q t (TeSym te))) =
	TermVT $ Term q t $
	TeSym $ \_c -> te CtxTeZ

-- ** Type 'TermAVT'
-- | Like 'TermVT', but 'CtxTe'-free.
data TermAVT src ss = forall vs t. TermAVT (forall ts. Term src ss ts vs t)
type instance SourceOf (TermAVT src ss) = src
instance Source src => Sourced (TermAVT src ss) where
	sourceOf  (TermAVT t)     = sourceOf t
	setSource (TermAVT t) src = TermAVT (setSource t src)
instance Source src => Eq (TermAVT src ss) where
	TermAVT x == TermAVT y =
		case appendVars x y of
		 (Term qx' tx' _, Term qy' ty' _) ->
			isJust $ (qx' #> tx') `eqTypeKi` (qy' #> ty')
instance Source src => Show (TermAVT src ss) where
	showsPrec p (TermAVT t) = showsPrec p t

-- * Type 'TeSym'
-- | Symantic of a 'Term'.
newtype TeSym ss ts (t::K.Type)
 = TeSym
 ( forall term.
   Syms ss term =>
   Sym_Lambda term =>
   QualOf t =>
   CtxTe term ts -> term (UnQualOf t)
 )

-- | Like 'TeSym', but 'CtxTe'-free
-- and using 'symInj' to be able to use 'Sym'@ s@ inside.
teSym ::
 forall s ss ts t.
 SymInj ss s =>
 (forall term. Sym s term => Sym_Lambda term => QualOf t => term (UnQualOf t)) ->
 TeSym ss ts t
teSym t = symInj @s (TeSym $ const t)

-- ** Type family 'QualOf'
-- | Qualification
type family QualOf (t::K.Type) :: Constraint where
	QualOf (q #> t) = q -- (q # QualOf t)
	QualOf t = (()::Constraint)

-- ** Type family 'UnQualOf'
-- | Unqualification
type family UnQualOf (t::K.Type) :: K.Type where
	UnQualOf (q #> t) = t -- UnQualOf t
	UnQualOf t = t

-- | Return 'K.Constraint' and 'K.Type' part of given 'Type'.
unQualTy ::
 Source src =>
 Type src vs (t::K.Type) ->
 ( TypeK src vs K.Constraint
 , TypeK src vs K.Type )
unQualTy (TyApp _ (TyApp _ c q) t)
 | Just HRefl <- proj_ConstKiTy @(K (#>)) @(#>) c
 = (TypeK q, TypeK t)
unQualTy t = (TypeK $ noConstraintLen (lenVars t), TypeK t)

-- | Remove 'K.Constraint's from given 'Type'.
unQualsTy :: Source src => Type src vs (t::kt) -> TypeK src vs kt
unQualsTy (TyApp _ (TyApp _ c _q) t)
 | Just HRefl <- proj_ConstKiTy @(K (#>)) @(#>) c
 = unQualsTy t
unQualsTy (TyApp src f a)
 | TypeK f' <- unQualsTy f
 , TypeK a' <- unQualsTy a
 = TypeK $ TyApp src f' a'
unQualsTy t = TypeK t

-- * Type 'CtxTe'
-- | GADT for an /interpreting context/:
-- accumulating at each /lambda abstraction/
-- the @term@ of the introduced variable.
data CtxTe (term::K.Type -> K.Type) (hs::[K.Type]) where
	CtxTeZ :: CtxTe term '[]
	CtxTeS :: term t
	       -> CtxTe term ts
	       -> CtxTe term (t ': ts)
infixr 5 `CtxTeS`

-- ** Type 'TermDef'
-- | Convenient type alias for defining 'Term'.
type TermDef s vs t = forall src ss ts. Source src => SymInj ss s => Term src ss ts vs t

-- ** Type family 'Sym'
type family Sym (s::k) :: {-term-}(K.Type -> K.Type) -> Constraint

-- ** Type family 'Syms'
type family Syms (ss::[K.Type]) (term::K.Type -> K.Type) :: Constraint where
	Syms '[] term = ()
	Syms (Proxy s ': ss) term = (Sym s term, Syms ss term)

-- ** Type 'SymInj'
-- | Convenient type synonym wrapping 'SymPInj'
-- applied on the correct 'Index'.
type SymInj ss s = SymInjP (Index ss (Proxy s)) ss s

-- | Inject a given /symantic/ @s@ into a list of them,
-- by returning a function which given a 'TeSym' on @s@
-- returns the same 'TeSym' on @ss@.
symInj ::
 forall s ss ts t.
 SymInj ss s =>
 TeSym '[Proxy s] ts t ->
 TeSym ss ts t
symInj = symInjP @(Index ss (Proxy s))

-- *** Class 'SymPInj'
class SymInjP p ss s where
	symInjP :: TeSym '[Proxy s] ts t -> TeSym ss ts t
instance SymInjP Zero (Proxy s ': ss) (s::k) where
	symInjP (TeSym te) = TeSym te
instance SymInjP p ss s => SymInjP (Succ p) (Proxy not_s ': ss) s where
	symInjP (te::TeSym '[Proxy s] ts t) =
		case symInjP @p te :: TeSym ss ts t of
		 TeSym te' -> TeSym te'

-- * Class 'Sym_Lambda'
class Sym_Lambda term where
	-- | /Function application/.
	apply :: term ((a -> b) -> a -> b)
	default apply :: Sym_Lambda (UnT term) => Trans term => term ((a -> b) -> a -> b)
	apply = trans apply
	
	-- | /Lambda application/.
	app :: term (a -> b) -> (term a -> term b); infixr 0 `app`
	default app :: Sym_Lambda (UnT term) => Trans term => term (arg -> res) -> term arg -> term res
	app = trans2 app
	
	-- | /Lambda abstraction/.
	lam :: (term a -> term b) -> term (a -> b)
	default lam :: Sym_Lambda (UnT term) => Trans term => (term arg -> term res) -> term (arg -> res)
	lam f = trans $ lam (unTrans . f . trans)
	
	-- | Convenient 'lam' and 'app' wrapper.
	let_ :: term var -> (term var -> term res) -> term res
	let_ x f = lam f `app` x
	
	-- | /Lambda abstraction/ beta-reducable without duplication
	-- (i.e. whose variable is used once at most),
	-- mainly useful in compiled 'Term's
	-- whose symantics are not a single 'term'
	-- but a function between 'term's,
	-- which happens because those are more usable when used as an embedded DSL.
	lam1 :: (term a -> term b) -> term (a -> b)
	default lam1 :: Sym_Lambda (UnT term) => Trans term => (term a -> term b) -> term (a -> b)
	lam1 = lam
	
	-- | /Qualification/.
	--
	-- Workaround used in 'readTermWithCtx'.
	qual :: proxy q -> term t -> term (q #> t)
	default qual :: Sym_Lambda (UnT term) => Trans term => proxy q -> term t -> term (q #> t)
	qual q = trans1 (qual q)

lam2 :: Sym_Lambda term => (term a -> term b -> term c) -> term (a -> b -> c)
lam3 :: Sym_Lambda term => (term a -> term b -> term c -> term d) -> term (a -> b -> c -> d)
lam4 :: Sym_Lambda term => (term a -> term b -> term c -> term d -> term e) -> term (a -> b -> c -> d -> e)
lam2 f = lam1 $ lam1 . f
lam3 f = lam1 $ lam2 . f
lam4 f = lam1 $ lam3 . f

-- Interpreting
instance Sym_Lambda Eval where
	apply  = Eval ($)
	app    = (<*>)
	lam f  = Eval (unEval . f . Eval)
	lam1   = lam
	qual _q (Eval t) = Eval $ Qual t
	let_ x f = f x -- NOTE: like flip ($)
instance Sym_Lambda View where
	apply = View $ \_po _v -> "($)"
	app (View a1) (View a2) = View $ \po v ->
		parenInfix po op $
		a1 (op, SideL) v <> " " <> a2 (op, SideR) v
		where op = infixN 10
	lam f = View $ \po v ->
		let x = "x" <> Text.pack (show v) in
		parenInfix po op $
		"\\" <> x <> " -> " <>
		unView (f (View $ \_po _v -> x)) (op, SideL) (succ v)
		where op = infixN 1
	lam1 = lam
	qual _q (View t) = View t -- TODO: maybe print q
	let_ x f =
		View $ \po v ->
			let x' = "x" <> Text.pack (show v) in
			parenInfix po op $
			"let" <> " " <> x' <> " = "
			 <> unView x (infixN 0, SideL) (succ v) <> " in "
			 <> unView (f (View $ \_po _v -> x')) (op, SideL) (succ v)
		where op = infixN 1
instance (Sym_Lambda r1, Sym_Lambda r2) => Sym_Lambda (Dup r1 r2) where
	apply = dup0 @Sym_Lambda apply
	app   = dup2 @Sym_Lambda app
	lam f = dup_1 lam_f `Dup` dup_2 lam_f
		where lam_f = lam f
	lam1 = lam
	qual q = dup1 @Sym_Lambda (qual q)