{-# Language DataKinds                #-}
{-# Language InstanceSigs             #-}
{-# Language ScopedTypeVariables      #-}
{-# Language StandaloneKindSignatures #-}
{-# Language TypeApplications         #-}
{-# Language TypeOperators            #-}
{-# Language UndecidableInstances     #-}

module Deriving.On (On(..)) where

import Data.Function (on)
import Data.Hashable (Hashable(..))
import Data.Kind     (Type)
import Data.Ord      (comparing)
import GHC.Records   (HasField(..))
import GHC.TypeLits  (Symbol)

-- | With 'DerivingVia': to derive non-structural instances. Specifies
-- what field to base instances on.
--
-- The type @'On' User "userID"@ is compared and evaluated based only
-- on the @"userID"@ record field. This uses 'HasField' from
-- @GHC.Records@ to project the relevant component.
--
-- @
-- {-# Language DataKinds     #-}
-- {-# Language DerivingVia   #-}
-- {-# Language TypeOperators #-}
--
-- import Deriving.On
-- import Data.Hashable
--
-- data User = User
--   { name   :: String
--   , age    :: Int
--   , userID :: Integer
--   }
--   deriving (Eq, Ord, Hashable)
--   via User `On` "userID"
-- @
--
-- @
-- >> alice = User "Alice" 50 0xDEADBEAF
-- >> bob   = User "Bob"   20 0xDEADBEAF
-- >>
-- >> alice == bob
-- True
-- >> alice <= bob
-- True
-- >> hash alice == hash bob
-- True
-- @
type    On :: Type -> Symbol -> Type
newtype a `On` field = On a

instance (HasField field a b, Eq b) => Eq (a `On` field) where
  (==) :: a `On` field -> a `On` field -> Bool
  On a
a1 == :: On a field -> On a field -> Bool
== On a
a2 = (b -> b -> Bool
forall a. Eq a => a -> a -> Bool
(==) (b -> b -> Bool) -> (a -> b) -> a -> a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall k (x :: k) r a. HasField x r a => r -> a
forall r a. HasField field r a => r -> a
getField @field) a
a1 a
a2

instance (HasField field a b, Ord b) => Ord (a `On` field) where
  compare :: a `On` field -> a `On` field -> Ordering
  On a
a1 compare :: On a field -> On a field -> Ordering
`compare` On a
a2 = (a -> b) -> a -> a -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (forall k (x :: k) r a. HasField x r a => r -> a
forall r a. HasField field r a => r -> a
getField @field) a
a1 a
a2

instance (HasField field a b, Hashable b) => Hashable (a `On` field) where
  hashWithSalt :: Int -> a `On` field -> Int
  hashWithSalt :: Int -> On a field -> Int
hashWithSalt Int
salt (On a
a) = Int -> b -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (a -> b
forall k (x :: k) r a. HasField x r a => r -> a
getField @field a
a)