{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}

module Downhill.Metric
  ( MetricTensor (..)
  )
where

import Data.VectorSpace ((^+^))
import Downhill.Grad (Dual (evalGrad), HasGrad (Grad, Tang), MScalar)

-- | @MetricTensor@ converts gradients to vectors.
--
-- It is really inverse of a metric tensor, because it maps cotangent
-- space into tangent space. Gradient descent doesn't need metric tensor,
-- it needs inverse.
class Dual (Tang p) (Grad p) => MetricTensor p g where
  -- | @m@ must be symmetric:
  --
  -- @evalGrad x (evalMetric m y) = evalGrad y (evalMetric m x)@
  evalMetric :: g -> Grad p -> Tang p

  -- | @innerProduct m x y = evalGrad x (evalMetric m y)@
  innerProduct :: g -> Grad p -> Grad p -> MScalar p
  innerProduct g
g Grad p
x Grad p
y = Grad p -> Tang p -> MScalar p
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad @(Tang p) @(Grad p) Grad p
x (g -> Grad p -> Tang p
forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @p g
g Grad p
y)

  -- | @sqrNorm m x = innerProduct m x x@
  sqrNorm :: g -> Grad p -> MScalar p
  sqrNorm g
g Grad p
x = g -> Grad p -> Grad p -> MScalar p
forall p g. MetricTensor p g => g -> Grad p -> Grad p -> MScalar p
innerProduct @p g
g Grad p
x Grad p
x

instance MetricTensor Integer Integer where
  evalMetric :: Integer -> Grad Integer -> Tang Integer
evalMetric Integer
m Grad Integer
x = Integer
m Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
Grad Integer
x

instance (MScalar a ~ MScalar b, MetricTensor a ma, MetricTensor b mb) => MetricTensor (a, b) (ma, mb) where
  evalMetric :: (ma, mb) -> Grad (a, b) -> Tang (a, b)
evalMetric (ma
ma, mb
mb) (a, b) = (ma -> Grad a -> Tang a
forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @a ma
ma Grad a
a, mb -> Grad b -> Tang b
forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @b mb
mb Grad b
b)
  sqrNorm :: (ma, mb) -> Grad (a, b) -> MScalar (a, b)
sqrNorm (ma
ma, mb
mb) (a, b) = ma -> Grad a -> MScalar a
forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @a ma
ma Grad a
a Scalar (Grad b) -> Scalar (Grad b) -> Scalar (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^+^ mb -> Grad b -> MScalar b
forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @b mb
mb Grad b
b

instance
  ( MScalar a ~ MScalar b,
    MScalar a ~ MScalar c,
    MetricTensor a ma,
    MetricTensor b mb,
    MetricTensor c mc
  ) =>
  MetricTensor (a, b, c) (ma, mb, mc)
  where
  evalMetric :: (ma, mb, mc) -> Grad (a, b, c) -> Tang (a, b, c)
evalMetric (ma
ma, mb
mb, mc
mc) (a, b, c) = (ma -> Grad a -> Tang a
forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @a ma
ma Grad a
a, mb -> Grad b -> Tang b
forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @b mb
mb Grad b
b, mc -> Grad c -> Tang c
forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @c mc
mc Grad c
c)
  sqrNorm :: (ma, mb, mc) -> Grad (a, b, c) -> MScalar (a, b, c)
sqrNorm (ma
ma, mb
mb, mc
mc) (a, b, c) = ma -> Grad a -> MScalar a
forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @a ma
ma Grad a
a Scalar (Grad c) -> Scalar (Grad c) -> Scalar (Grad c)
forall v. AdditiveGroup v => v -> v -> v
^+^ mb -> Grad b -> MScalar b
forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @b mb
mb Grad b
b Scalar (Grad c) -> Scalar (Grad c) -> Scalar (Grad c)
forall v. AdditiveGroup v => v -> v -> v
^+^ mc -> Grad c -> MScalar c
forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @c mc
mc Grad c
c

instance MetricTensor Float Float where
  evalMetric :: Float -> Grad Float -> Tang Float
evalMetric Float
m Grad Float
dv = Float
m Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
Grad Float
dv

instance MetricTensor Double Double where
  evalMetric :: Double -> Grad Double -> Tang Double
evalMetric Double
m Grad Double
dv = Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
Grad Double
dv

data L2 = L2

instance (Dual (Tang p) (Grad p), Grad p ~ Tang p) => MetricTensor p L2 where
  evalMetric :: L2 -> Grad p -> Tang p
evalMetric L2
L2 Grad p
v = Tang p
Grad p
v