{-|
Module      :  Data.Aeson.Schema.Utils.Sum
Maintainer  :  Brandon Chinn <brandon@leapyear.io>
Stability   :  experimental
Portability :  portable

The 'SumType' data type that represents a sum type consisting of types
specified in a type-level list.
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Data.Aeson.Schema.Utils.Sum
  ( SumType(..)
  , fromSumType
  ) where

import Control.Applicative ((<|>))
import Data.Aeson (FromJSON(..), ToJSON(..))
import Data.Kind (Constraint, Type)
import Data.Proxy (Proxy(..))
import GHC.TypeLits (type (-), ErrorMessage(..), Nat, TypeError)

-- | Represents a sum type.
--
-- Loads the first type that successfully parses the JSON value.
--
-- Example:
--
-- @
-- data Owl = Owl
-- data Cat = Cat
-- data Toad = Toad
-- type Animal = SumType '[Owl, Cat, Toad]
--
-- Here Owl                         :: Animal
-- There (Here Cat)                 :: Animal
-- There (There (Here Toad))        :: Animal
--
-- {- Fails at compile-time
-- Here True                        :: Animal
-- Here Cat                         :: Animal
-- There (Here Owl)                 :: Animal
-- There (There (There (Here Owl))) :: Animal
-- -}
-- @
data SumType (types :: [Type]) where
  Here  :: forall x xs. x          -> SumType (x ': xs)
  There :: forall x xs. SumType xs -> SumType (x ': xs)

deriving instance (Show x, Show (SumType xs)) => Show (SumType (x ': xs))
instance Show (SumType '[]) where
  show :: SumType '[] -> String
show = \case {}

deriving instance (Eq x, Eq (SumType xs)) => Eq (SumType (x ': xs))
instance Eq (SumType '[]) where
  SumType '[]
_ == :: SumType '[] -> SumType '[] -> Bool
== SumType '[]
_ = Bool
True

deriving instance (Ord x, Ord (SumType xs)) => Ord (SumType (x ': xs))
instance Ord (SumType '[]) where
  compare :: SumType '[] -> SumType '[] -> Ordering
compare SumType '[]
_ SumType '[]
_ = Ordering
EQ

instance (FromJSON x, FromJSON (SumType xs)) => FromJSON (SumType (x ': xs)) where
  parseJSON :: Value -> Parser (SumType (x : xs))
parseJSON Value
v = (x -> SumType (x : xs)
forall x (x :: [*]). x -> SumType (x : x)
Here (x -> SumType (x : xs)) -> Parser x -> Parser (SumType (x : xs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser x
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v) Parser (SumType (x : xs))
-> Parser (SumType (x : xs)) -> Parser (SumType (x : xs))
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SumType xs -> SumType (x : xs)
forall x (xs :: [*]). SumType xs -> SumType (x : xs)
There (SumType xs -> SumType (x : xs))
-> Parser (SumType xs) -> Parser (SumType (x : xs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser (SumType xs)
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v)

instance FromJSON (SumType '[]) where
  parseJSON :: Value -> Parser (SumType '[])
parseJSON Value
_ = String -> Parser (SumType '[])
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Could not parse sum type"

instance (ToJSON x, ToJSON (SumType xs)) => ToJSON (SumType (x ': xs)) where
  toJSON :: SumType (x : xs) -> Value
toJSON = \case
    Here x
x -> x -> Value
forall a. ToJSON a => a -> Value
toJSON x
x
    There SumType xs
xs -> SumType xs -> Value
forall a. ToJSON a => a -> Value
toJSON SumType xs
xs

instance ToJSON (SumType '[]) where
  toJSON :: SumType '[] -> Value
toJSON = \case {}

{- Extracting sum type branches -}

class FromSumType (n :: Nat) (types :: [Type]) (x :: Type) where
  fromSumType' :: 'Just x ~ GetIndex n types => proxy1 n -> SumType types -> Maybe x

instance {-# OVERLAPPING #-} FromSumType 0 (x ': xs) x where
  fromSumType' :: proxy1 0 -> SumType (x : xs) -> Maybe x
fromSumType' proxy1 0
_ = \case
    Here x
x -> x -> Maybe x
forall a. a -> Maybe a
Just x
x
    There SumType xs
_ -> Maybe x
forall a. Maybe a
Nothing

instance {-# OVERLAPPABLE #-}
  ( FromSumType (n - 1) xs x
  , 'Just x ~ GetIndex (n - 1) xs
  ) => FromSumType n (_x ': xs) x where
  fromSumType' :: proxy1 n -> SumType (_x : xs) -> Maybe x
fromSumType' proxy1 n
_ = \case
    Here x
_ -> Maybe x
forall a. Maybe a
Nothing
    There SumType xs
xs -> Proxy (n - 1) -> SumType xs -> Maybe x
forall (n :: Nat) (types :: [*]) x (proxy1 :: Nat -> *).
(FromSumType n types x, 'Just x ~ GetIndex n types) =>
proxy1 n -> SumType types -> Maybe x
fromSumType' (Proxy (n - 1)
forall k (t :: k). Proxy t
Proxy @(n - 1)) SumType xs
xs

-- | Extract a value from a 'SumType'
--
-- Example:
--
-- @
-- type Animal = SumType '[Owl, Cat, Toad]
-- let someAnimal = ... :: Animal
--
-- fromSumType (Proxy :: Proxy 0) someAnimal :: Maybe Owl
-- fromSumType (Proxy :: Proxy 1) someAnimal :: Maybe Cat
-- fromSumType (Proxy :: Proxy 2) someAnimal :: Maybe Toad
--
-- -- Compile-time error
-- -- fromSumType (Proxy :: Proxy 3) someAnimal
-- @
fromSumType
  :: ( IsInRange n types
     , 'Just result ~ GetIndex n types
     , FromSumType n types result
     )
  => proxy n -> SumType types -> Maybe result
fromSumType :: proxy n -> SumType types -> Maybe result
fromSumType = proxy n -> SumType types -> Maybe result
forall (n :: Nat) (types :: [*]) x (proxy1 :: Nat -> *).
(FromSumType n types x, 'Just x ~ GetIndex n types) =>
proxy1 n -> SumType types -> Maybe x
fromSumType'

{- Helpers -}

type family IsInRange (n :: Nat) (xs :: [Type]) :: Constraint where
  IsInRange n xs = IsInRange'
    ( TypeError
      (     'Text "Index "
      ':<>: 'ShowType n
      ':<>: 'Text " does not exist in list: "
      ':<>: 'ShowType xs
      )
    )
    n
    xs

type family IsInRange' typeErr (n :: Nat) (xs :: [Type]) :: Constraint where
  IsInRange' typeErr _ '[] = typeErr
  IsInRange' _ 0 (_ ': _) = ()
  IsInRange' typeErr n (_ ': xs) = IsInRange' typeErr (n - 1) xs

type family GetIndex (n :: Nat) (types :: [Type]) :: Maybe Type where
  GetIndex 0 (x ': xs) = 'Just x
  GetIndex _ '[] = 'Nothing
  GetIndex n (_ ': xs) = GetIndex (n - 1) xs