{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{-|
Module      :  Data.Aeson.Schema.Utils.Sum
Maintainer  :  Brandon Chinn <brandonchinn178@gmail.com>
Stability   :  experimental
Portability :  portable

The 'SumType' data type that represents a sum type consisting of types
specified in a type-level list.
-}
module Data.Aeson.Schema.Utils.Sum (
  SumType (..),
  fromSumType,
) where

import Control.Applicative ((<|>))
import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Kind (Constraint, Type)
import Data.Proxy (Proxy (..))
import GHC.TypeLits (ErrorMessage (..), Nat, TypeError, type (-))

-- | Represents a sum type.
--
--  Loads the first type that successfully parses the JSON value.
--
--  Example:
--
--  @
--  data Owl = Owl
--  data Cat = Cat
--  data Toad = Toad
--  type Animal = SumType '[Owl, Cat, Toad]
--
--  Here Owl                         :: Animal
--  There (Here Cat)                 :: Animal
--  There (There (Here Toad))        :: Animal
--
--  {\- Fails at compile-time
--  Here True                        :: Animal
--  Here Cat                         :: Animal
--  There (Here Owl)                 :: Animal
--  There (There (There (Here Owl))) :: Animal
--  -\}
--  @
data SumType (types :: [Type]) where
  Here :: forall x xs. x -> SumType (x ': xs)
  There :: forall x xs. SumType xs -> SumType (x ': xs)

deriving instance (Show x, Show (SumType xs)) => Show (SumType (x ': xs))
instance Show (SumType '[]) where
  show :: SumType '[] -> String
show = \case {}

deriving instance (Eq x, Eq (SumType xs)) => Eq (SumType (x ': xs))
instance Eq (SumType '[]) where
  SumType '[]
_ == :: SumType '[] -> SumType '[] -> Bool
== SumType '[]
_ = Bool
True

deriving instance (Ord x, Ord (SumType xs)) => Ord (SumType (x ': xs))
instance Ord (SumType '[]) where
  compare :: SumType '[] -> SumType '[] -> Ordering
compare SumType '[]
_ SumType '[]
_ = Ordering
EQ

instance (FromJSON x, FromJSON (SumType xs)) => FromJSON (SumType (x ': xs)) where
  parseJSON :: Value -> Parser (SumType (x : xs))
parseJSON Value
v = (forall x (xs :: [*]). x -> SumType (x : xs)
Here forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. FromJSON a => Value -> Parser a
parseJSON Value
v) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (forall x (xs :: [*]). SumType xs -> SumType (x : xs)
There forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. FromJSON a => Value -> Parser a
parseJSON Value
v)

instance FromJSON (SumType '[]) where
  parseJSON :: Value -> Parser (SumType '[])
parseJSON Value
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Could not parse sum type"

instance (ToJSON x, ToJSON (SumType xs)) => ToJSON (SumType (x ': xs)) where
  toJSON :: SumType (x : xs) -> Value
toJSON = \case
    Here x
x -> forall a. ToJSON a => a -> Value
toJSON x
x
    There SumType xs
xs -> forall a. ToJSON a => a -> Value
toJSON SumType xs
xs

instance ToJSON (SumType '[]) where
  toJSON :: SumType '[] -> Value
toJSON = \case {}

{- Extracting sum type branches -}

class FromSumType (n :: Nat) (types :: [Type]) (x :: Type) where
  fromSumType' :: ('Just x ~ GetIndex n types) => proxy1 n -> SumType types -> Maybe x

instance {-# OVERLAPPING #-} FromSumType 0 (x ': xs) x where
  fromSumType' :: forall (proxy1 :: Nat -> *).
('Just x ~ GetIndex 0 (x : xs)) =>
proxy1 0 -> SumType (x : xs) -> Maybe x
fromSumType' proxy1 0
_ = \case
    Here x
x -> forall a. a -> Maybe a
Just x
x
    There SumType xs
_ -> forall a. Maybe a
Nothing

instance
  {-# OVERLAPPABLE #-}
  ( FromSumType (n - 1) xs x
  , 'Just x ~ GetIndex (n - 1) xs
  ) =>
  FromSumType n (_x ': xs) x
  where
  fromSumType' :: forall (proxy1 :: Nat -> *).
('Just x ~ GetIndex n (_x : xs)) =>
proxy1 n -> SumType (_x : xs) -> Maybe x
fromSumType' proxy1 n
_ = \case
    Here x
_ -> forall a. Maybe a
Nothing
    There SumType xs
xs -> forall (n :: Nat) (types :: [*]) x (proxy1 :: Nat -> *).
(FromSumType n types x, 'Just x ~ GetIndex n types) =>
proxy1 n -> SumType types -> Maybe x
fromSumType' (forall {k} (t :: k). Proxy t
Proxy @(n - 1)) SumType xs
xs

-- | Extract a value from a 'SumType'
--
--  Example:
--
--  @
--  type Animal = SumType '[Owl, Cat, Toad]
--  let someAnimal = ... :: Animal
--
--  fromSumType (Proxy :: Proxy 0) someAnimal :: Maybe Owl
--  fromSumType (Proxy :: Proxy 1) someAnimal :: Maybe Cat
--  fromSumType (Proxy :: Proxy 2) someAnimal :: Maybe Toad
--
--  -- Compile-time error
--  -- fromSumType (Proxy :: Proxy 3) someAnimal
--  @
fromSumType ::
  ( IsInRange n types
  , 'Just result ~ GetIndex n types
  , FromSumType n types result
  ) =>
  proxy n
  -> SumType types
  -> Maybe result
fromSumType :: forall (n :: Nat) (types :: [*]) result (proxy :: Nat -> *).
(IsInRange n types, 'Just result ~ GetIndex n types,
 FromSumType n types result) =>
proxy n -> SumType types -> Maybe result
fromSumType = forall (n :: Nat) (types :: [*]) x (proxy1 :: Nat -> *).
(FromSumType n types x, 'Just x ~ GetIndex n types) =>
proxy1 n -> SumType types -> Maybe x
fromSumType'

{- Helpers -}

type family IsInRange (n :: Nat) (xs :: [Type]) :: Constraint where
  IsInRange n xs =
    IsInRange'
      ( TypeError
          ( 'Text "Index "
              ':<>: 'ShowType n
              ':<>: 'Text " does not exist in list: "
              ':<>: 'ShowType xs
          )
      )
      n
      xs

type family IsInRange' typeErr (n :: Nat) (xs :: [Type]) :: Constraint where
  IsInRange' typeErr _ '[] = typeErr
  IsInRange' _ 0 (_ ': _) = ()
  IsInRange' typeErr n (_ ': xs) = IsInRange' typeErr (n - 1) xs

type family GetIndex (n :: Nat) (types :: [Type]) :: Maybe Type where
  GetIndex 0 (x ': xs) = 'Just x
  GetIndex _ '[] = 'Nothing
  GetIndex n (_ ': xs) = GetIndex (n - 1) xs