-- SPDX-FileCopyrightText: 2020 Tocqueville Group
--
-- SPDX-License-Identifier: LicenseRef-MIT-TQ

{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# LANGUAGE QuantifiedConstraints #-}

-- | General type utilities.
module Util.Type
  ( type (==)
  , If
  , type (++)
  , IsElem
  , type (/)
  , type (//)
  , Guard
  , FailWhen
  , FailUnless
  , failUnlessEvi
  , failWhenEvi
  , AllUnique
  , RequireAllUnique
  , ReifyList (..)
  , PatternMatch
  , PatternMatchL
  , KnownList (..)
  , KList (..)
  , RSplit
  , rsplit
  , Some1 (..)
  , recordToSomeList

  , reifyTypeEquality

  , ConcatListOfTypesAssociativity
  , listOfTypesConcatAssociativityAxiom
  ) where

import Data.Constraint ((:-)(..), Dict(..))
import Data.Vinyl.Core (Rec (..))
import qualified Data.Vinyl.Functor as Vinyl
import Data.Vinyl.Recursive (recordToList, rmap)
import Data.Vinyl.TypeLevel (type (++))
import qualified Data.Kind as Kind
import Data.Type.Bool (type (&&), If, Not)
import Data.Type.Equality (type (==))
import GHC.TypeLits (ErrorMessage(..), Symbol, TypeError)
import Unsafe.Coerce (unsafeCoerce)

type family IsElem (a :: k) (l :: [k]) :: Bool where
  IsElem _ '[] = 'False
  IsElem a (a ': _) = 'True
  IsElem a (_ ': as) = IsElem a as

-- | Remove all occurences of the given element.
type family (l :: [k]) / (a :: k) where
  '[] / _ = '[]
  (a ': xs) / a = xs / a
  (b ': xs) / a = b ': (xs / a)

-- | Difference between two lists.
type family (l1 :: [k]) // (l2 :: [k]) :: [k] where
  l // '[] = l
  l // (x ': xs) = (l / x) // xs

type family Guard (cond :: Bool) (a :: k) :: Maybe k where
  Guard 'False _ = 'Nothing
  Guard 'True a = 'Just a

-- | Fail with given error if the condition does not hold.
type family FailUnless (cond :: Bool) (msg :: ErrorMessage) :: Constraint where
  FailUnless 'True _ = ()
  FailUnless 'False msg = TypeError msg

-- | Fail with given error if the condition holds.
type FailWhen cond msg = FailUnless (Not cond) msg

-- | A natural conclusion from the fact that error have not occured.
failUnlessEvi :: forall cond msg. FailUnless cond msg :- (cond ~ 'True)
failUnlessEvi :: FailUnless cond msg :- (cond ~ 'True)
failUnlessEvi = (FailUnless cond msg => Dict (cond ~ 'True))
-> FailUnless cond msg :- (cond ~ 'True)
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub ((FailUnless cond msg => Dict (cond ~ 'True))
 -> FailUnless cond msg :- (cond ~ 'True))
-> (FailUnless cond msg => Dict (cond ~ 'True))
-> FailUnless cond msg :- (cond ~ 'True)
forall a b. (a -> b) -> a -> b
$ Dict ('True ~ 'True) -> Dict (cond ~ 'True)
forall a b. a -> b
unsafeCoerce (Dict ('True ~ 'True) -> Dict (cond ~ 'True))
-> Dict ('True ~ 'True) -> Dict (cond ~ 'True)
forall a b. (a -> b) -> a -> b
$ ('True ~ 'True) => Dict ('True ~ 'True)
forall (a :: Constraint). a => Dict a
Dict @('True ~ 'True)

failWhenEvi :: forall cond msg. FailWhen cond msg :- (cond ~ 'False)
failWhenEvi :: FailWhen cond msg :- (cond ~ 'False)
failWhenEvi = (FailWhen cond msg => Dict (cond ~ 'False))
-> FailWhen cond msg :- (cond ~ 'False)
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub ((FailWhen cond msg => Dict (cond ~ 'False))
 -> FailWhen cond msg :- (cond ~ 'False))
-> (FailWhen cond msg => Dict (cond ~ 'False))
-> FailWhen cond msg :- (cond ~ 'False)
forall a b. (a -> b) -> a -> b
$ Dict ('False ~ 'False) -> Dict (cond ~ 'False)
forall a b. a -> b
unsafeCoerce (Dict ('False ~ 'False) -> Dict (cond ~ 'False))
-> Dict ('False ~ 'False) -> Dict (cond ~ 'False)
forall a b. (a -> b) -> a -> b
$ ('False ~ 'False) => Dict ('False ~ 'False)
forall (a :: Constraint). a => Dict a
Dict @('False ~ 'False)

type family AllUnique (l :: [k]) :: Bool where
  AllUnique '[] = 'True
  AllUnique (x : xs) = Not (IsElem x xs) && AllUnique xs

type RequireAllUnique desc l = RequireAllUnique' desc l l

type family RequireAllUnique' (desc :: Symbol) (l :: [k]) (origL ::[k]) :: Constraint where
  RequireAllUnique' _ '[] _ = ()
  RequireAllUnique' desc (x : xs) origL =
    If (IsElem x xs)
       (TypeError ('Text "Duplicated " ':<>: 'Text desc ':<>: 'Text ":" ':$$:
                   'ShowType x ':$$:
                   'Text "Full list: " ':<>:
                   'ShowType origL
                  )
       )
       (RequireAllUnique' desc xs origL)

-- | Make sure given type is evaluated.
-- This type family fits only for types of 'Kind.Type' kind.
type family PatternMatch (a :: Kind.Type) :: Constraint where
  PatternMatch Int = ((), ())
  PatternMatch _ = ()

type family PatternMatchL (l :: [k]) :: Constraint where
  PatternMatchL '[] = ((), ())
  PatternMatchL _ = ()

-- | Bring type-level list at term-level using given function
-- to demote its individual elements.
class ReifyList (c :: k -> Constraint) (l :: [k]) where
  reifyList :: (forall a. c a => Proxy a -> r) -> [r]

instance ReifyList c '[] where
  reifyList :: (forall (a :: k). c a => Proxy a -> r) -> [r]
reifyList _ = []

instance (c x, ReifyList c xs) => ReifyList c (x ': xs) where
  reifyList :: (forall (a :: a). c a => Proxy a -> r) -> [r]
reifyList reifyElem :: forall (a :: a). c a => Proxy a -> r
reifyElem = Proxy x -> r
forall (a :: a). c a => Proxy a -> r
reifyElem (Proxy x
forall k (t :: k). Proxy t
Proxy @x) r -> [r] -> [r]
forall a. a -> [a] -> [a]
: (forall (a :: a). c a => Proxy a -> r) -> [r]
forall k (c :: k -> Constraint) (l :: [k]) r.
ReifyList c l =>
(forall (a :: k). c a => Proxy a -> r) -> [r]
reifyList @_ @c @xs forall (a :: a). c a => Proxy a -> r
reifyElem

-- | Reify type equality from boolean equality.
reifyTypeEquality :: forall a b x. (a == b) ~ 'True => (a ~ b => x) -> x
reifyTypeEquality :: ((a ~ b) => x) -> x
reifyTypeEquality x :: (a ~ b) => x
x =
  case Dict (a ~ a) -> Dict (a ~ b)
forall a b. a -> b
unsafeCoerce @(Dict (a ~ a)) @(Dict (a ~ b)) Dict (a ~ a)
forall (a :: Constraint). a => Dict a
Dict of
    Dict -> x
(a ~ b) => x
x

-- | Similar to @SingI []@, but does not require individual elements to be also
-- instance of @SingI@.
class KnownList l where
  klist :: KList l
instance KnownList '[] where
  klist :: KList '[]
klist = KList '[]
forall k. KList '[]
KNil
instance KnownList xs => KnownList (x ': xs) where
  klist :: KList (x : xs)
klist = Proxy x -> Proxy xs -> KList (x : xs)
forall k (xs :: [k]) (x :: k).
KnownList xs =>
Proxy x -> Proxy xs -> KList (x : xs)
KCons Proxy x
forall k (t :: k). Proxy t
Proxy Proxy xs
forall k (t :: k). Proxy t
Proxy

-- | 'SList' analogy for 'KnownList'.
data KList (l :: [k]) where
  KNil :: KList '[]
  KCons :: KnownList xs => Proxy x -> Proxy xs -> KList (x ': xs)

type RSplit l r = KnownList l

-- | Split a record into two pieces.
rsplit
  :: forall k (l :: [k]) (r :: [k]) f.
      (RSplit l r)
  => Rec f (l ++ r) -> (Rec f l, Rec f r)
rsplit :: Rec f (l ++ r) -> (Rec f l, Rec f r)
rsplit = case KnownList l => KList l
forall k (l :: [k]). KnownList l => KList l
klist @l of
  KNil -> (Rec f l
forall u (a :: u -> *). Rec a '[]
RNil, )
  KCons{} -> \(x :: f r
x :& r :: Rec f rs
r) ->
    let (x1 :: Rec f xs
x1, r1 :: Rec f r
r1) = Rec f (xs ++ r) -> (Rec f xs, Rec f r)
forall k (l :: [k]) (r :: [k]) (f :: k -> *).
RSplit l r =>
Rec f (l ++ r) -> (Rec f l, Rec f r)
rsplit Rec f rs
Rec f (xs ++ r)
r
    in (f r
x f r -> Rec f xs -> Rec f (r : xs)
forall u (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& Rec f xs
x1, Rec f r
r1)

-- | A value of type parametrized with /some/ type parameter.
data Some1 (f :: k -> Kind.Type) =
  forall a. Some1 (f a)

deriving stock instance (forall a. Show (f a)) => Show (Some1 f)

recordToSomeList :: Rec f l -> [Some1 f]
recordToSomeList :: Rec f l -> [Some1 f]
recordToSomeList = Rec (Const (Some1 f)) l -> [Some1 f]
forall u a (rs :: [u]). Rec (Const a) rs -> [a]
recordToList (Rec (Const (Some1 f)) l -> [Some1 f])
-> (Rec f l -> Rec (Const (Some1 f)) l) -> Rec f l -> [Some1 f]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (x :: k). f x -> Const (Some1 f) x)
-> Rec f l -> Rec (Const (Some1 f)) l
forall u (f :: u -> *) (g :: u -> *) (rs :: [u]).
(forall (x :: u). f x -> g x) -> Rec f rs -> Rec g rs
rmap (Some1 f -> Const (Some1 f) x
forall k a (b :: k). a -> Const a b
Vinyl.Const (Some1 f -> Const (Some1 f) x)
-> (f x -> Some1 f) -> f x -> Const (Some1 f) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f x -> Some1 f
forall k (f :: k -> *) (a :: k). f a -> Some1 f
Some1)

type ConcatListOfTypesAssociativity a b c = ((a ++ b) ++ c) ~ (a ++ (b ++ c))

-- | GHC can't deduce this itself because
-- in general a type family might be not associative,
-- what brings extra difficulties and redundant constraints,
-- especially if you have complex types.
-- But (++) type family is associative, so let's define this small hack.
listOfTypesConcatAssociativityAxiom :: forall a b c . Dict (ConcatListOfTypesAssociativity a b c)
listOfTypesConcatAssociativityAxiom :: Dict (ConcatListOfTypesAssociativity a b c)
listOfTypesConcatAssociativityAxiom =
  Dict ('[] ~ '[]) -> Dict (ConcatListOfTypesAssociativity a b c)
forall a b. a -> b
unsafeCoerce (Dict ('[] ~ '[]) -> Dict (ConcatListOfTypesAssociativity a b c))
-> Dict ('[] ~ '[]) -> Dict (ConcatListOfTypesAssociativity a b c)
forall a b. (a -> b) -> a -> b
$ ConcatListOfTypesAssociativity '[] '[] '[] =>
Dict (ConcatListOfTypesAssociativity '[] '[] '[])
forall (a :: Constraint). a => Dict a
Dict @(ConcatListOfTypesAssociativity '[] '[] '[])