{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}

module Downhill.BVar
  ( BVar (..),
    var,
    constant,
    backprop,
    -- * Pattern synonyms
    pattern T2,
    pattern T3
  )
where

import Data.AdditiveGroup (AdditiveGroup)
import Data.AffineSpace (AffineSpace ((.+^), (.-.)))
import qualified Data.AffineSpace as AffineSpace
import Data.VectorSpace
  ( AdditiveGroup (..),
    InnerSpace ((<.>)),
    VectorSpace ((*^)),
  )
import qualified Data.VectorSpace as VectorSpace
import Downhill.Grad
  ( Dual (evalGrad),
    HasGrad (Grad, Tang),
    HasGradAffine, MScalar
  )
import Downhill.Linear.BackGrad
  ( BackGrad (..),
    realNode,
  )
import qualified Downhill.Linear.Backprop as BP
import Downhill.Linear.Expr (BasicVector, Expr (ExprVar))
import Downhill.Linear.Lift (lift2_dense)
import Prelude hiding (id, (.))
import qualified Downhill.Linear.Prelude as Linear

-- | Variable is a value paired with derivative.
data BVar r a = BVar
  { BVar r a -> a
bvarValue :: a,
    BVar r a -> BackGrad r (Grad a)
bvarGrad :: BackGrad r (Grad a)
  }

instance (AdditiveGroup b, HasGrad b) => AdditiveGroup (BVar r b) where
  zeroV :: BVar r b
zeroV = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
forall v. AdditiveGroup v => v
zeroV BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  negateV :: BVar r b -> BVar r b
negateV (BVar b
y0 BackGrad r (Grad b)
dy) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall v. AdditiveGroup v => v -> v
negateV b
y0) (BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v
negateV BackGrad r (Grad b)
dy)
  BVar b
y0 BackGrad r (Grad b)
dy ^-^ :: BVar r b -> BVar r b -> BVar r b
^-^ BVar b
z0 BackGrad r (Grad b)
dz = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
y0 b -> b -> b
forall v. AdditiveGroup v => v -> v -> v
^-^ b
z0) (BackGrad r (Grad b)
dy BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^-^ BackGrad r (Grad b)
dz)
  BVar b
y0 BackGrad r (Grad b)
dy ^+^ :: BVar r b -> BVar r b -> BVar r b
^+^ BVar b
z0 BackGrad r (Grad b)
dz = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
y0 b -> b -> b
forall v. AdditiveGroup v => v -> v -> v
^+^ b
z0) (BackGrad r (Grad b)
dy BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^+^ BackGrad r (Grad b)
dz)

instance (Num b, HasGrad b, MScalar b ~ b) => Num (BVar r b) where
  (BVar b
f0 BackGrad r (Grad b)
df) + :: BVar r b -> BVar r b -> BVar r b
+ (BVar b
g0 BackGrad r (Grad b)
dg) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
f0 b -> b -> b
forall a. Num a => a -> a -> a
+ b
g0) (BackGrad r (Grad b)
df BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^+^ BackGrad r (Grad b)
dg)
  (BVar b
f0 BackGrad r (Grad b)
df) - :: BVar r b -> BVar r b -> BVar r b
- (BVar b
g0 BackGrad r (Grad b)
dg) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
f0 b -> b -> b
forall a. Num a => a -> a -> a
- b
g0) (BackGrad r (Grad b)
df BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^-^ BackGrad r (Grad b)
dg)
  (BVar b
f0 BackGrad r (Grad b)
df) * :: BVar r b -> BVar r b -> BVar r b
* (BVar b
g0 BackGrad r (Grad b)
dg) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
f0 b -> b -> b
forall a. Num a => a -> a -> a
* b
g0) (b
Scalar (BackGrad r (Grad b))
f0 Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dg BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^+^ b
Scalar (BackGrad r (Grad b))
g0 Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
df)
  negate :: BVar r b -> BVar r b
negate (BVar b
f0 BackGrad r (Grad b)
df) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Num a => a -> a
negate b
f0) (BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v
negateV BackGrad r (Grad b)
df)
  abs :: BVar r b -> BVar r b
abs (BVar b
f0 BackGrad r (Grad b)
df) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Num a => a -> a
abs b
f0) (b -> b
forall a. Num a => a -> a
signum b
f0 Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
df) -- TODO: ineffiency: multiplication by 1
  signum :: BVar r b -> BVar r b
signum (BVar b
f0 BackGrad r (Grad b)
_) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Num a => a -> a
signum b
f0) BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  fromInteger :: Integer -> BVar r b
fromInteger Integer
x = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (Integer -> b
forall a. Num a => Integer -> a
fromInteger Integer
x) BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV

sqr :: Num a => a -> a
sqr :: a -> a
sqr a
x = a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
x

rsqrt :: Floating a => a -> a
rsqrt :: a -> a
rsqrt a
x = a -> a
forall a. Fractional a => a -> a
recip (a -> a
forall a. Floating a => a -> a
sqrt a
x)

instance (Fractional b, HasGrad b, MScalar b ~ b) => Fractional (BVar r b) where
  fromRational :: Rational -> BVar r b
fromRational Rational
x = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (Rational -> b
forall a. Fractional a => Rational -> a
fromRational Rational
x) BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  recip :: BVar r b -> BVar r b
recip (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Fractional a => a -> a
recip b
x) (b
Scalar (BackGrad r (Grad b))
df Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
    where
      df :: b
df = b -> b
forall a. Num a => a -> a
negate (b -> b
forall a. Fractional a => a -> a
recip (b -> b
forall a. Num a => a -> a
sqr b
x))
  BVar b
x BackGrad r (Grad b)
dx / :: BVar r b -> BVar r b -> BVar r b
/ BVar b
y BackGrad r (Grad b)
dy = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
x b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
y) ((b -> b
forall a. Fractional a => a -> a
recip b
y Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx) BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^-^ ((b
x b -> b -> b
forall a. Fractional a => a -> a -> a
/ b -> b
forall a. Num a => a -> a
sqr b
y) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dy))

instance (Floating b, HasGrad b, MScalar b ~ b) => Floating (BVar r b) where
  pi :: BVar r b
pi = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
forall a. Floating a => a
pi BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  exp :: BVar r b -> BVar r b
exp (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
exp b
x) (b -> b
forall a. Floating a => a -> a
exp b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  log :: BVar r b -> BVar r b
log (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
log b
x) (b -> b
forall a. Fractional a => a -> a
recip b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  sin :: BVar r b -> BVar r b
sin (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
sin b
x) (b -> b
forall a. Floating a => a -> a
cos b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  cos :: BVar r b -> BVar r b
cos (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
cos b
x) (b -> b
forall a. Num a => a -> a
negate (b -> b
forall a. Floating a => a -> a
sin b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  asin :: BVar r b -> BVar r b
asin (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
asin b
x) (b -> b
forall a. Floating a => a -> a
rsqrt (b
1 b -> b -> b
forall a. Num a => a -> a -> a
- b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  acos :: BVar r b -> BVar r b
acos (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
acos b
x) (b -> b
forall a. Num a => a -> a
negate (b -> b
forall a. Floating a => a -> a
rsqrt (b
1 b -> b -> b
forall a. Num a => a -> a -> a
- b -> b
forall a. Num a => a -> a
sqr b
x)) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  atan :: BVar r b -> BVar r b
atan (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
atan b
x) (b -> b
forall a. Fractional a => a -> a
recip (b
1 b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  sinh :: BVar r b -> BVar r b
sinh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
sinh b
x) (b -> b
forall a. Floating a => a -> a
cosh b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  cosh :: BVar r b -> BVar r b
cosh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
cosh b
x) (b -> b
forall a. Floating a => a -> a
sinh b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  asinh :: BVar r b -> BVar r b
asinh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
asinh b
x) (b -> b
forall a. Floating a => a -> a
rsqrt (b
1 b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  acosh :: BVar r b -> BVar r b
acosh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
acosh b
x) (b -> b
forall a. Floating a => a -> a
rsqrt (b -> b
forall a. Num a => a -> a
sqr b
x b -> b -> b
forall a. Num a => a -> a -> a
- b
1) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  atanh :: BVar r b -> BVar r b
atanh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
atanh b
x) (b -> b
forall a. Fractional a => a -> a
recip (b
1 b -> b -> b
forall a. Num a => a -> a -> a
- b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)

instance
  ( VectorSpace v,
    HasGrad v,
    Tang v ~ v,
    BasicVector (MScalar v),
    Grad (MScalar v) ~ MScalar v
  ) =>
  VectorSpace (BVar r v)
  where
  type Scalar (BVar r v) = BVar r (MScalar v)
  BVar a da *^ :: Scalar (BVar r v) -> BVar r v -> BVar r v
*^ BVar v
v BackGrad r (Grad v)
dv = v -> BackGrad r (Grad v) -> BVar r v
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (Scalar v
Scalar (Grad v)
a Scalar v -> v -> v
forall v. VectorSpace v => Scalar v -> v -> v
*^ v
v) ((Grad v -> Scalar (Grad v))
-> (Grad v -> Grad v)
-> BackGrad r (Scalar (Grad v))
-> BackGrad r (Grad v)
-> BackGrad r (Grad v)
forall v a b r.
(BasicVector v, BasicVector a, BasicVector b) =>
(v -> a)
-> (v -> b) -> BackGrad r a -> BackGrad r b -> BackGrad r v
lift2_dense Grad v -> MScalar v
Grad v -> Scalar (Grad v)
bpA Grad v -> Grad v
bpV BackGrad r (Scalar (Grad v))
BackGrad r (Grad (Scalar (Grad v)))
da BackGrad r (Grad v)
dv)
    where
      bpA :: Grad v -> MScalar v
      bpA :: Grad v -> MScalar v
bpA Grad v
dz = Grad v -> v -> Scalar v
forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad Grad v
dz v
v
      bpV :: Grad v -> Grad v
      bpV :: Grad v -> Grad v
bpV Grad v
dz = Scalar (Grad v)
a Scalar (Grad v) -> Grad v -> Grad v
forall v. VectorSpace v => Scalar v -> v -> v
*^ Grad v
dz

instance (HasGrad p, HasGradAffine p) => AffineSpace (BVar r p) where
  type Diff (BVar r p) = BVar r (Tang p)
  BVar p
y0 BackGrad r (Grad p)
dy .+^ :: BVar r p -> Diff (BVar r p) -> BVar r p
.+^ BVar z0 dz = p -> BackGrad r (Grad p) -> BVar r p
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (p
y0 p -> Diff p -> p
forall p. AffineSpace p => p -> Diff p -> p
.+^ Diff p
z0) (BackGrad r (Grad p)
dy BackGrad r (Grad p) -> BackGrad r (Grad p) -> BackGrad r (Grad p)
forall v. AdditiveGroup v => v -> v -> v
^+^ BackGrad r (Grad p)
BackGrad r (Grad (Diff p))
dz)
  BVar p
y0 BackGrad r (Grad p)
dy .-. :: BVar r p -> BVar r p -> Diff (BVar r p)
.-. BVar p
z0 BackGrad r (Grad p)
dz = Diff p -> BackGrad r (Grad (Diff p)) -> BVar r (Diff p)
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (p
y0 p -> p -> Diff p
forall p. AffineSpace p => p -> p -> Diff p
.-. p
z0) (BackGrad r (Grad p)
dy BackGrad r (Grad p) -> BackGrad r (Grad p) -> BackGrad r (Grad p)
forall v. AdditiveGroup v => v -> v -> v
^-^ BackGrad r (Grad p)
dz)

instance
  ( VectorSpace v,
    HasGrad v,
    Grad v ~ v,
    Tang v ~ v,
    BasicVector (MScalar v),
    Grad (MScalar v) ~ MScalar v,
    InnerSpace v,
    HasGrad (MScalar v)
  ) =>
  InnerSpace (BVar r v)
  where
  BVar v
u BackGrad r (Grad v)
du <.> :: BVar r v -> BVar r v -> Scalar (BVar r v)
<.> BVar v
v BackGrad r (Grad v)
dv = MScalar v -> BackGrad r (Grad (MScalar v)) -> BVar r (MScalar v)
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (v
u v -> v -> Scalar v
forall v. InnerSpace v => v -> v -> Scalar v
<.> v
v) ((MScalar v -> v)
-> (MScalar v -> v)
-> BackGrad r v
-> BackGrad r v
-> BackGrad r (MScalar v)
forall v a b r.
(BasicVector v, BasicVector a, BasicVector b) =>
(v -> a)
-> (v -> b) -> BackGrad r a -> BackGrad r b -> BackGrad r v
lift2_dense MScalar v -> v
MScalar v -> Grad v
bpU MScalar v -> v
MScalar v -> Grad v
bpV BackGrad r v
BackGrad r (Grad v)
du BackGrad r v
BackGrad r (Grad v)
dv)
    where
      bpU :: MScalar v -> Grad v
      bpU :: MScalar v -> Grad v
bpU MScalar v
dz = Scalar v
MScalar v
dz Scalar v -> v -> v
forall v. VectorSpace v => Scalar v -> v -> v
*^ v
v
      bpV :: MScalar v -> Grad v
      bpV :: MScalar v -> Grad v
bpV MScalar v
dz = Scalar v
MScalar v
dz Scalar v -> v -> v
forall v. VectorSpace v => Scalar v -> v -> v
*^ v
u

-- | A variable with derivative of zero.
constant :: forall r a. (BasicVector (Grad a), AdditiveGroup (Grad a)) => a -> BVar r a
constant :: a -> BVar r a
constant a
x = a -> BackGrad r (Grad a) -> BVar r a
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x BackGrad r (Grad a)
forall v. AdditiveGroup v => v
zeroV

-- | A variable with identity derivative.
var :: a -> BVar (Grad a) a
var :: a -> BVar (Grad a) a
var a
x = a -> BackGrad (Grad a) (Grad a) -> BVar (Grad a) a
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x (Expr (Grad a) (Grad a) -> BackGrad (Grad a) (Grad a)
forall a v. Expr a v -> BackGrad a v
realNode Expr (Grad a) (Grad a)
forall a. Expr a a
ExprVar)

--backprop :: forall a p. (HasGrad p, BasicVector a) => BVar a p -> GradBuilder p -> a
--backprop (BVar _y0 x) = BP.backprop x

-- | Reverse mode differentiation.
--
-- 
backprop :: forall r a. (HasGrad a, BasicVector r) => BVar r a -> Grad a -> r
backprop :: BVar r a -> Grad a -> r
backprop (BVar a
_y0 BackGrad r (Grad a)
x) = BackGrad r (Grad a) -> Grad a -> r
forall a v.
(BasicVector a, BasicVector v) =>
BackGrad a v -> v -> a
BP.backprop BackGrad r (Grad a)
x


splitPair :: (BasicVector (Grad a), BasicVector (Grad b)) => BVar r (a, b) -> (BVar r a, BVar r b)
splitPair :: BVar r (a, b) -> (BVar r a, BVar r b)
splitPair (BVar (a
a, b
b) (Linear.T2 da db)) = (a -> BackGrad r (Grad a) -> BVar r a
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
a BackGrad r (Grad a)
da, b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
b BackGrad r (Grad b)
db)

pattern T2 :: forall r a b. (BasicVector (Grad a), BasicVector (Grad b)) => BVar r a -> BVar r b -> BVar r (a, b)
pattern $bT2 :: BVar r a -> BVar r b -> BVar r (a, b)
$mT2 :: forall r r a b.
(BasicVector (Grad a), BasicVector (Grad b)) =>
BVar r (a, b) -> (BVar r a -> BVar r b -> r) -> (Void# -> r) -> r
T2 a b <- (splitPair -> (a, b))
  where T2 (BVar a
a BackGrad r (Grad a)
da) (BVar b
b BackGrad r (Grad b)
db) = (a, b) -> BackGrad r (Grad (a, b)) -> BVar r (a, b)
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (a
a, b
b) (BackGrad r (Grad a)
-> BackGrad r (Grad b) -> BackGrad r (Grad a, Grad b)
forall r a b.
(BasicVector a, BasicVector b) =>
BackGrad r a -> BackGrad r b -> BackGrad r (a, b)
Linear.T2 BackGrad r (Grad a)
da BackGrad r (Grad b)
db)

splitTriple :: (BasicVector (Grad a), BasicVector (Grad b), BasicVector (Grad c)) => BVar r (a, b, c) -> (BVar r a, BVar r b, BVar r c)
splitTriple :: BVar r (a, b, c) -> (BVar r a, BVar r b, BVar r c)
splitTriple (BVar (a
a, b
b, c
c) (Linear.T3 da db dc)) = (a -> BackGrad r (Grad a) -> BVar r a
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
a BackGrad r (Grad a)
da, b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
b BackGrad r (Grad b)
db, c -> BackGrad r (Grad c) -> BVar r c
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar c
c BackGrad r (Grad c)
dc)

pattern T3 :: forall r a b c. (BasicVector (Grad a), BasicVector (Grad b), BasicVector (Grad c))
 => BVar r a -> BVar r b -> BVar r c -> BVar r (a, b, c)
pattern $bT3 :: BVar r a -> BVar r b -> BVar r c -> BVar r (a, b, c)
$mT3 :: forall r r a b c.
(BasicVector (Grad a), BasicVector (Grad b),
 BasicVector (Grad c)) =>
BVar r (a, b, c)
-> (BVar r a -> BVar r b -> BVar r c -> r) -> (Void# -> r) -> r
T3 a b c <- (splitTriple -> (a, b, c))
  where T3 (BVar a
a BackGrad r (Grad a)
da) (BVar b
b BackGrad r (Grad b)
db) (BVar c
c BackGrad r (Grad c)
dc) = (a, b, c) -> BackGrad r (Grad (a, b, c)) -> BVar r (a, b, c)
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (a
a, b
b, c
c) (BackGrad r (Grad a)
-> BackGrad r (Grad b)
-> BackGrad r (Grad c)
-> BackGrad r (Grad a, Grad b, Grad c)
forall r a b c.
(BasicVector a, BasicVector b, BasicVector c) =>
BackGrad r a
-> BackGrad r b -> BackGrad r c -> BackGrad r (a, b, c)
Linear.T3 BackGrad r (Grad a)
da BackGrad r (Grad b)
db BackGrad r (Grad c)
dc)