{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Downhill.BVar.Num
(
AsNum (..),
NumBVar,
numbvarValue,
var,
constant,
backpropNum
)
where
import Data.AffineSpace (AffineSpace (..))
import Data.Semigroup (Sum (Sum, getSum))
import Data.VectorSpace (AdditiveGroup (..), VectorSpace (..), zeroV)
import Downhill.BVar (BVar (bvarValue), backprop)
import qualified Downhill.BVar as BVar
import Downhill.Grad
( Dual (evalGrad),
HasGrad (Grad, Tang)
)
import Downhill.Linear.Expr (BasicVector (..))
import Downhill.Metric (MetricTensor (evalMetric))
newtype AsNum a = AsNum {AsNum a -> a
unAsNum :: a}
deriving (Int -> AsNum a -> ShowS
[AsNum a] -> ShowS
AsNum a -> String
(Int -> AsNum a -> ShowS)
-> (AsNum a -> String) -> ([AsNum a] -> ShowS) -> Show (AsNum a)
forall a. Show a => Int -> AsNum a -> ShowS
forall a. Show a => [AsNum a] -> ShowS
forall a. Show a => AsNum a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AsNum a] -> ShowS
$cshowList :: forall a. Show a => [AsNum a] -> ShowS
show :: AsNum a -> String
$cshow :: forall a. Show a => AsNum a -> String
showsPrec :: Int -> AsNum a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> AsNum a -> ShowS
Show)
deriving (Integer -> AsNum a
AsNum a -> AsNum a
AsNum a -> AsNum a -> AsNum a
(AsNum a -> AsNum a -> AsNum a)
-> (AsNum a -> AsNum a -> AsNum a)
-> (AsNum a -> AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (Integer -> AsNum a)
-> Num (AsNum a)
forall a. Num a => Integer -> AsNum a
forall a. Num a => AsNum a -> AsNum a
forall a. Num a => AsNum a -> AsNum a -> AsNum a
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> AsNum a
$cfromInteger :: forall a. Num a => Integer -> AsNum a
signum :: AsNum a -> AsNum a
$csignum :: forall a. Num a => AsNum a -> AsNum a
abs :: AsNum a -> AsNum a
$cabs :: forall a. Num a => AsNum a -> AsNum a
negate :: AsNum a -> AsNum a
$cnegate :: forall a. Num a => AsNum a -> AsNum a
* :: AsNum a -> AsNum a -> AsNum a
$c* :: forall a. Num a => AsNum a -> AsNum a -> AsNum a
- :: AsNum a -> AsNum a -> AsNum a
$c- :: forall a. Num a => AsNum a -> AsNum a -> AsNum a
+ :: AsNum a -> AsNum a -> AsNum a
$c+ :: forall a. Num a => AsNum a -> AsNum a -> AsNum a
Num) via a
deriving (Num (AsNum a)
Num (AsNum a)
-> (AsNum a -> AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (Rational -> AsNum a)
-> Fractional (AsNum a)
Rational -> AsNum a
AsNum a -> AsNum a
AsNum a -> AsNum a -> AsNum a
forall a. Fractional a => Num (AsNum a)
forall a. Fractional a => Rational -> AsNum a
forall a. Fractional a => AsNum a -> AsNum a
forall a. Fractional a => AsNum a -> AsNum a -> AsNum a
forall a.
Num a
-> (a -> a -> a) -> (a -> a) -> (Rational -> a) -> Fractional a
fromRational :: Rational -> AsNum a
$cfromRational :: forall a. Fractional a => Rational -> AsNum a
recip :: AsNum a -> AsNum a
$crecip :: forall a. Fractional a => AsNum a -> AsNum a
/ :: AsNum a -> AsNum a -> AsNum a
$c/ :: forall a. Fractional a => AsNum a -> AsNum a -> AsNum a
$cp1Fractional :: forall a. Fractional a => Num (AsNum a)
Fractional) via a
deriving (Fractional (AsNum a)
AsNum a
Fractional (AsNum a)
-> AsNum a
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a -> AsNum a)
-> (AsNum a -> AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> (AsNum a -> AsNum a)
-> Floating (AsNum a)
AsNum a -> AsNum a
AsNum a -> AsNum a -> AsNum a
forall a. Floating a => Fractional (AsNum a)
forall a. Floating a => AsNum a
forall a. Floating a => AsNum a -> AsNum a
forall a. Floating a => AsNum a -> AsNum a -> AsNum a
forall a.
Fractional a
-> a
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> Floating a
log1mexp :: AsNum a -> AsNum a
$clog1mexp :: forall a. Floating a => AsNum a -> AsNum a
log1pexp :: AsNum a -> AsNum a
$clog1pexp :: forall a. Floating a => AsNum a -> AsNum a
expm1 :: AsNum a -> AsNum a
$cexpm1 :: forall a. Floating a => AsNum a -> AsNum a
log1p :: AsNum a -> AsNum a
$clog1p :: forall a. Floating a => AsNum a -> AsNum a
atanh :: AsNum a -> AsNum a
$catanh :: forall a. Floating a => AsNum a -> AsNum a
acosh :: AsNum a -> AsNum a
$cacosh :: forall a. Floating a => AsNum a -> AsNum a
asinh :: AsNum a -> AsNum a
$casinh :: forall a. Floating a => AsNum a -> AsNum a
tanh :: AsNum a -> AsNum a
$ctanh :: forall a. Floating a => AsNum a -> AsNum a
cosh :: AsNum a -> AsNum a
$ccosh :: forall a. Floating a => AsNum a -> AsNum a
sinh :: AsNum a -> AsNum a
$csinh :: forall a. Floating a => AsNum a -> AsNum a
atan :: AsNum a -> AsNum a
$catan :: forall a. Floating a => AsNum a -> AsNum a
acos :: AsNum a -> AsNum a
$cacos :: forall a. Floating a => AsNum a -> AsNum a
asin :: AsNum a -> AsNum a
$casin :: forall a. Floating a => AsNum a -> AsNum a
tan :: AsNum a -> AsNum a
$ctan :: forall a. Floating a => AsNum a -> AsNum a
cos :: AsNum a -> AsNum a
$ccos :: forall a. Floating a => AsNum a -> AsNum a
sin :: AsNum a -> AsNum a
$csin :: forall a. Floating a => AsNum a -> AsNum a
logBase :: AsNum a -> AsNum a -> AsNum a
$clogBase :: forall a. Floating a => AsNum a -> AsNum a -> AsNum a
** :: AsNum a -> AsNum a -> AsNum a
$c** :: forall a. Floating a => AsNum a -> AsNum a -> AsNum a
sqrt :: AsNum a -> AsNum a
$csqrt :: forall a. Floating a => AsNum a -> AsNum a
log :: AsNum a -> AsNum a
$clog :: forall a. Floating a => AsNum a -> AsNum a
exp :: AsNum a -> AsNum a
$cexp :: forall a. Floating a => AsNum a -> AsNum a
pi :: AsNum a
$cpi :: forall a. Floating a => AsNum a
$cp1Floating :: forall a. Floating a => Fractional (AsNum a)
Floating) via a
instance Num a => Dual (AsNum a) (AsNum a) where
evalGrad :: AsNum a -> AsNum a -> Scalar (AsNum a)
evalGrad = AsNum a -> AsNum a -> Scalar (AsNum a)
forall a. Num a => a -> a -> a
(*)
instance Num a => HasGrad (AsNum a) where
type Grad (AsNum a) = AsNum a
type Tang (AsNum a) = AsNum a
instance Num a => MetricTensor (AsNum a) (AsNum a) where
evalMetric :: AsNum a -> Grad (AsNum a) -> Tang (AsNum a)
evalMetric (AsNum a
m) (AsNum x) = a -> AsNum a
forall a. a -> AsNum a
AsNum (a
m a -> a -> a
forall a. Num a => a -> a -> a
* a
x)
instance Num a => AdditiveGroup (AsNum a) where
zeroV :: AsNum a
zeroV = AsNum a
0
^+^ :: AsNum a -> AsNum a -> AsNum a
(^+^) = AsNum a -> AsNum a -> AsNum a
forall a. Num a => a -> a -> a
(+)
^-^ :: AsNum a -> AsNum a -> AsNum a
(^-^) = (-)
negateV :: AsNum a -> AsNum a
negateV = AsNum a -> AsNum a
forall a. Num a => a -> a
negate
instance Num a => VectorSpace (AsNum a) where
type Scalar (AsNum a) = AsNum a
*^ :: Scalar (AsNum a) -> AsNum a -> AsNum a
(*^) = Scalar (AsNum a) -> AsNum a -> AsNum a
forall a. Num a => a -> a -> a
(*)
instance Num a => BasicVector (AsNum a) where
type VecBuilder (AsNum a) = Sum a
sumBuilder :: VecBuilder (AsNum a) -> AsNum a
sumBuilder = a -> AsNum a
forall a. a -> AsNum a
AsNum (a -> AsNum a) -> (Sum a -> a) -> Sum a -> AsNum a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sum a -> a
forall a. Sum a -> a
getSum
identityBuilder :: AsNum a -> VecBuilder (AsNum a)
identityBuilder = a -> Sum a
forall a. a -> Sum a
Sum (a -> Sum a) -> (AsNum a -> a) -> AsNum a -> Sum a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AsNum a -> a
forall a. AsNum a -> a
unAsNum
instance Num a => AffineSpace (AsNum a) where
type Diff (AsNum a) = AsNum a
AsNum a
x .-. :: AsNum a -> AsNum a -> Diff (AsNum a)
.-. AsNum a
y = a -> AsNum a
forall a. a -> AsNum a
AsNum (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y)
AsNum a
x .+^ :: AsNum a -> Diff (AsNum a) -> AsNum a
.+^ AsNum y = a -> AsNum a
forall a. a -> AsNum a
AsNum (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y)
type NumBVar a = BVar (AsNum a) (AsNum a)
constant :: forall a. Num a => a -> NumBVar a
constant :: a -> NumBVar a
constant = (BasicVector (Grad (AsNum a)), AdditiveGroup (Grad (AsNum a))) =>
AsNum a -> NumBVar a
forall r a.
(BasicVector (Grad a), AdditiveGroup (Grad a)) =>
a -> BVar r a
BVar.constant @(AsNum a) @(AsNum a) (AsNum a -> NumBVar a) -> (a -> AsNum a) -> a -> NumBVar a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> AsNum a
forall a. a -> AsNum a
AsNum
var :: Num a => a -> NumBVar a
var :: a -> NumBVar a
var = AsNum a -> NumBVar a
forall a. a -> BVar (Grad a) a
BVar.var (AsNum a -> NumBVar a) -> (a -> AsNum a) -> a -> NumBVar a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> AsNum a
forall a. a -> AsNum a
AsNum
backpropNum :: forall a. Num a => NumBVar a -> a
backpropNum :: NumBVar a -> a
backpropNum NumBVar a
x = AsNum a -> a
forall a. AsNum a -> a
unAsNum (AsNum a -> a) -> AsNum a -> a
forall a b. (a -> b) -> a -> b
$ NumBVar a -> Grad (AsNum a) -> AsNum a
forall r a. (HasGrad a, BasicVector r) => BVar r a -> Grad a -> r
backprop @(AsNum a) @(AsNum a) NumBVar a
x (a -> AsNum a
forall a. a -> AsNum a
AsNum a
1)
numbvarValue :: NumBVar a -> a
numbvarValue :: NumBVar a -> a
numbvarValue = AsNum a -> a
forall a. AsNum a -> a
unAsNum (AsNum a -> a) -> (NumBVar a -> AsNum a) -> NumBVar a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NumBVar a -> AsNum a
forall r a. BVar r a -> a
bvarValue