{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Downhill.Grad
  ( Dual (..),
    HasGrad (..), MScalar,
    GradBuilder,
    HasGradAffine,
  )
where

import Data.AffineSpace (AffineSpace (Diff))
import Data.Kind (Type)
import Data.VectorSpace (AdditiveGroup ((^+^), zeroV), VectorSpace(Scalar))
import Downhill.Linear.Expr (BasicVector (VecBuilder))
import GHC.Generics (Generic (Rep, from), K1 (K1), M1 (M1), U1 (U1), V1, (:*:) ((:*:)))

-- | Dual of a vector @v@ is a linear map @v -> Scalar v@.
class
  ( 
    Scalar v ~ Scalar dv,
    AdditiveGroup (Scalar v),
    VectorSpace v,
    VectorSpace dv
  ) =>
  Dual v dv
  where
  -- if evalGrad goes to HasGrad class, parameter p is ambiguous
  evalGrad :: dv -> v -> Scalar v
  default evalGrad :: (GDual (Scalar v) (Rep v) (Rep dv), Generic dv, Generic v) => dv -> v -> Scalar v
  evalGrad dv
dv v
v = Rep dv Any -> Rep v Any -> Scalar dv
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad (dv -> Rep dv Any
forall a x. Generic a => a -> Rep a x
from dv
dv) (v -> Rep v Any
forall a x. Generic a => a -> Rep a x
from v
v)

type MScalar p = Scalar (Tang p)

-- | Differentiable functions don't need to be constrained to vector spaces, they
-- can be defined on other smooth manifolds, too.
class
  ( Dual (Tang p) (Grad p),
    BasicVector (Grad p),
    Scalar (Tang p) ~ Scalar (Grad p)
  ) =>
  HasGrad p
  where
  -- | Tangent space.
  type Tang p :: Type

  -- | Cotangent space.
  type Grad p :: Type

type GradBuilder v = VecBuilder (Grad v)

type HasGradAffine p =
  ( AffineSpace p,
    HasGrad p,
    HasGrad (Tang p),
    Tang p ~ Diff p,
    Tang (Tang p) ~ Tang p,
    Grad (Tang p) ~ Grad p
  )

instance Dual Integer Integer where
  evalGrad :: Integer -> Integer -> Scalar Integer
evalGrad = Integer -> Integer -> Scalar Integer
forall a. Num a => a -> a -> a
(*)

instance HasGrad Integer where
  type Tang Integer = Integer
  type Grad Integer = Integer

instance (Scalar a ~ Scalar b, Dual a da, Dual b db) => Dual (a, b) (da, db) where
  evalGrad :: (da, db) -> (a, b) -> Scalar (a, b)
evalGrad (da
a, db
b) (a
x, b
y) = da -> a -> Scalar a
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad da
a a
x Scalar db -> Scalar db -> Scalar db
forall v. AdditiveGroup v => v -> v -> v
^+^ db -> b -> Scalar b
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad db
b b
y

instance (Scalar a ~ Scalar b, Scalar a ~ Scalar c, Dual a da, Dual b db, Dual c dc) => Dual (a, b, c) (da, db, dc) where
  evalGrad :: (da, db, dc) -> (a, b, c) -> Scalar (a, b, c)
evalGrad (da
a, db
b, dc
c) (a
x, b
y, c
z) = da -> a -> Scalar a
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad da
a a
x Scalar dc -> Scalar dc -> Scalar dc
forall v. AdditiveGroup v => v -> v -> v
^+^ db -> b -> Scalar b
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad db
b b
y Scalar dc -> Scalar dc -> Scalar dc
forall v. AdditiveGroup v => v -> v -> v
^+^ dc -> c -> Scalar c
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad dc
c c
z

instance
  ( HasGrad a,
    HasGrad b,
    MScalar b ~ MScalar a
  ) =>
  HasGrad (a, b)
  where
  type Grad (a, b) = (Grad a, Grad b)
  type Tang (a, b) = (Tang a, Tang b)

instance
  ( HasGrad a,
    HasGrad b,
    HasGrad c,
    MScalar b ~ MScalar a,
    MScalar c ~ MScalar a
  ) =>
  HasGrad (a, b, c)
  where
  type Grad (a, b, c) = (Grad a, Grad b, Grad c)
  type Tang (a, b, c) = (Tang a, Tang b, Tang c)

instance Dual Float Float where
  evalGrad :: Float -> Float -> Scalar Float
evalGrad = Float -> Float -> Scalar Float
forall a. Num a => a -> a -> a
(*)

instance HasGrad Float where
  type Grad Float = Float
  type Tang Float = Float

instance Dual Double Double where
  evalGrad :: Double -> Double -> Scalar Double
evalGrad = Double -> Double -> Scalar Double
forall a. Num a => a -> a -> a
(*)

instance HasGrad Double where
  type Grad Double = Double
  type Tang Double = Double

class GDual s v dv where
  gevalGrad :: dv p -> v p -> s

instance (s ~ Scalar v, Dual v dv) => GDual s (K1 x v) (K1 x dv) where
  gevalGrad :: K1 x dv p -> K1 x v p -> s
gevalGrad (K1 dv
dv) (K1 v
v) = dv -> v -> Scalar v
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad dv
dv v
v

instance (GDual s v dv) => GDual s (M1 x y v) (M1 x y' dv) where
  gevalGrad :: M1 x y' dv p -> M1 x y v p -> s
gevalGrad (M1 dv p
dv) (M1 v p
v) = dv p -> v p -> s
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad dv p
dv v p
v

instance (AdditiveGroup s, GDual s u du, GDual s v dv) => GDual s (u :*: v) (du :*: dv) where
  gevalGrad :: (:*:) du dv p -> (:*:) u v p -> s
gevalGrad (du p
du :*: dv p
dv) (u p
u :*: v p
v) = du p -> u p -> s
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad du p
du u p
u s -> s -> s
forall v. AdditiveGroup v => v -> v -> v
^+^ dv p -> v p -> s
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad dv p
dv v p
v

instance GDual s V1 V1 where
  gevalGrad :: V1 p -> V1 p -> s
gevalGrad = \case {}

instance AdditiveGroup s => GDual s U1 U1 where
  gevalGrad :: U1 p -> U1 p -> s
gevalGrad U1 p
U1 = U1 p -> s
forall v. AdditiveGroup v => v
zeroV