{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Downhill.Grad
( Dual (..),
HasGrad (..), MScalar,
GradBuilder,
HasGradAffine,
)
where
import Data.AffineSpace (AffineSpace (Diff))
import Data.Kind (Type)
import Data.VectorSpace (AdditiveGroup ((^+^), zeroV), VectorSpace(Scalar))
import Downhill.Linear.Expr (BasicVector (VecBuilder))
import GHC.Generics (Generic (Rep, from), K1 (K1), M1 (M1), U1 (U1), V1, (:*:) ((:*:)))
class
(
Scalar v ~ Scalar dv,
AdditiveGroup (Scalar v),
VectorSpace v,
VectorSpace dv
) =>
Dual v dv
where
evalGrad :: dv -> v -> Scalar v
default evalGrad :: (GDual (Scalar v) (Rep v) (Rep dv), Generic dv, Generic v) => dv -> v -> Scalar v
evalGrad dv
dv v
v = Rep dv Any -> Rep v Any -> Scalar dv
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad (dv -> Rep dv Any
forall a x. Generic a => a -> Rep a x
from dv
dv) (v -> Rep v Any
forall a x. Generic a => a -> Rep a x
from v
v)
type MScalar p = Scalar (Tang p)
class
( Dual (Tang p) (Grad p),
BasicVector (Grad p),
Scalar (Tang p) ~ Scalar (Grad p)
) =>
HasGrad p
where
type Tang p :: Type
type Grad p :: Type
type GradBuilder v = VecBuilder (Grad v)
type HasGradAffine p =
( AffineSpace p,
HasGrad p,
HasGrad (Tang p),
Tang p ~ Diff p,
Tang (Tang p) ~ Tang p,
Grad (Tang p) ~ Grad p
)
instance Dual Integer Integer where
evalGrad :: Integer -> Integer -> Scalar Integer
evalGrad = Integer -> Integer -> Scalar Integer
forall a. Num a => a -> a -> a
(*)
instance HasGrad Integer where
type Tang Integer = Integer
type Grad Integer = Integer
instance (Scalar a ~ Scalar b, Dual a da, Dual b db) => Dual (a, b) (da, db) where
evalGrad :: (da, db) -> (a, b) -> Scalar (a, b)
evalGrad (da
a, db
b) (a
x, b
y) = da -> a -> Scalar a
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad da
a a
x Scalar db -> Scalar db -> Scalar db
forall v. AdditiveGroup v => v -> v -> v
^+^ db -> b -> Scalar b
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad db
b b
y
instance (Scalar a ~ Scalar b, Scalar a ~ Scalar c, Dual a da, Dual b db, Dual c dc) => Dual (a, b, c) (da, db, dc) where
evalGrad :: (da, db, dc) -> (a, b, c) -> Scalar (a, b, c)
evalGrad (da
a, db
b, dc
c) (a
x, b
y, c
z) = da -> a -> Scalar a
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad da
a a
x Scalar dc -> Scalar dc -> Scalar dc
forall v. AdditiveGroup v => v -> v -> v
^+^ db -> b -> Scalar b
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad db
b b
y Scalar dc -> Scalar dc -> Scalar dc
forall v. AdditiveGroup v => v -> v -> v
^+^ dc -> c -> Scalar c
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad dc
c c
z
instance
( HasGrad a,
HasGrad b,
MScalar b ~ MScalar a
) =>
HasGrad (a, b)
where
type Grad (a, b) = (Grad a, Grad b)
type Tang (a, b) = (Tang a, Tang b)
instance
( HasGrad a,
HasGrad b,
HasGrad c,
MScalar b ~ MScalar a,
MScalar c ~ MScalar a
) =>
HasGrad (a, b, c)
where
type Grad (a, b, c) = (Grad a, Grad b, Grad c)
type Tang (a, b, c) = (Tang a, Tang b, Tang c)
instance Dual Float Float where
evalGrad :: Float -> Float -> Scalar Float
evalGrad = Float -> Float -> Scalar Float
forall a. Num a => a -> a -> a
(*)
instance HasGrad Float where
type Grad Float = Float
type Tang Float = Float
instance Dual Double Double where
evalGrad :: Double -> Double -> Scalar Double
evalGrad = Double -> Double -> Scalar Double
forall a. Num a => a -> a -> a
(*)
instance HasGrad Double where
type Grad Double = Double
type Tang Double = Double
class GDual s v dv where
gevalGrad :: dv p -> v p -> s
instance (s ~ Scalar v, Dual v dv) => GDual s (K1 x v) (K1 x dv) where
gevalGrad :: K1 x dv p -> K1 x v p -> s
gevalGrad (K1 dv
dv) (K1 v
v) = dv -> v -> Scalar v
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad dv
dv v
v
instance (GDual s v dv) => GDual s (M1 x y v) (M1 x y' dv) where
gevalGrad :: M1 x y' dv p -> M1 x y v p -> s
gevalGrad (M1 dv p
dv) (M1 v p
v) = dv p -> v p -> s
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad dv p
dv v p
v
instance (AdditiveGroup s, GDual s u du, GDual s v dv) => GDual s (u :*: v) (du :*: dv) where
gevalGrad :: (:*:) du dv p -> (:*:) u v p -> s
gevalGrad (du p
du :*: dv p
dv) (u p
u :*: v p
v) = du p -> u p -> s
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad du p
du u p
u s -> s -> s
forall v. AdditiveGroup v => v -> v -> v
^+^ dv p -> v p -> s
forall s (v :: * -> *) (dv :: * -> *) p.
GDual s v dv =>
dv p -> v p -> s
gevalGrad dv p
dv v p
v
instance GDual s V1 V1 where
gevalGrad :: V1 p -> V1 p -> s
gevalGrad = \case {}
instance AdditiveGroup s => GDual s U1 U1 where
gevalGrad :: U1 p -> U1 p -> s
gevalGrad U1 p
U1 = U1 p -> s
forall v. AdditiveGroup v => v
zeroV