{-# language FlexibleInstances, DeriveFunctor #-}
{-# language ScopedTypeVariables #-}
{-# language RankNTypes #-}
{-# language ViewPatterns #-}
module Data.SRTree.Internal
( SRTree(..)
, Function(..)
, Op(..)
, param
, var
, arity
, getChildren
, countNodes
, countVarNodes
, countConsts
, countParams
, countOccurrences
, deriveBy
, deriveByVar
, deriveByParam
, derivative
, forwardMode
, gradParamsFwd
, gradParamsRev
, evalFun
, evalOp
, inverseFunc
, evalTree
, relabelParams
, constsToParam
, floatConstsToParam
, paramsToConst
, Fix (..)
)
where
import Data.SRTree.Recursion ( Fix (..), cata, mutu, accu, cataM )
import qualified Data.Vector as V
import Data.Vector ((!))
import Control.Monad.State
import qualified Data.DList as DL
import Data.Bifunctor (second)
import Debug.Trace (trace)
data SRTree val =
Var Int
| Param Int
| Const Double
| Uni Function val
| Bin Op val val
deriving (Int -> SRTree val -> ShowS
forall val. Show val => Int -> SRTree val -> ShowS
forall val. Show val => [SRTree val] -> ShowS
forall val. Show val => SRTree val -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SRTree val] -> ShowS
$cshowList :: forall val. Show val => [SRTree val] -> ShowS
show :: SRTree val -> String
$cshow :: forall val. Show val => SRTree val -> String
showsPrec :: Int -> SRTree val -> ShowS
$cshowsPrec :: forall val. Show val => Int -> SRTree val -> ShowS
Show, SRTree val -> SRTree val -> Bool
forall val. Eq val => SRTree val -> SRTree val -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SRTree val -> SRTree val -> Bool
$c/= :: forall val. Eq val => SRTree val -> SRTree val -> Bool
== :: SRTree val -> SRTree val -> Bool
$c== :: forall val. Eq val => SRTree val -> SRTree val -> Bool
Eq, SRTree val -> SRTree val -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {val}. Ord val => Eq (SRTree val)
forall val. Ord val => SRTree val -> SRTree val -> Bool
forall val. Ord val => SRTree val -> SRTree val -> Ordering
forall val. Ord val => SRTree val -> SRTree val -> SRTree val
min :: SRTree val -> SRTree val -> SRTree val
$cmin :: forall val. Ord val => SRTree val -> SRTree val -> SRTree val
max :: SRTree val -> SRTree val -> SRTree val
$cmax :: forall val. Ord val => SRTree val -> SRTree val -> SRTree val
>= :: SRTree val -> SRTree val -> Bool
$c>= :: forall val. Ord val => SRTree val -> SRTree val -> Bool
> :: SRTree val -> SRTree val -> Bool
$c> :: forall val. Ord val => SRTree val -> SRTree val -> Bool
<= :: SRTree val -> SRTree val -> Bool
$c<= :: forall val. Ord val => SRTree val -> SRTree val -> Bool
< :: SRTree val -> SRTree val -> Bool
$c< :: forall val. Ord val => SRTree val -> SRTree val -> Bool
compare :: SRTree val -> SRTree val -> Ordering
$ccompare :: forall val. Ord val => SRTree val -> SRTree val -> Ordering
Ord, forall a b. a -> SRTree b -> SRTree a
forall a b. (a -> b) -> SRTree a -> SRTree b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SRTree b -> SRTree a
$c<$ :: forall a b. a -> SRTree b -> SRTree a
fmap :: forall a b. (a -> b) -> SRTree a -> SRTree b
$cfmap :: forall a b. (a -> b) -> SRTree a -> SRTree b
Functor)
data Op = Add | Sub | Mul | Div | Power
deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Op] -> ShowS
$cshowList :: [Op] -> ShowS
show :: Op -> String
$cshow :: Op -> String
showsPrec :: Int -> Op -> ShowS
$cshowsPrec :: Int -> Op -> ShowS
Show, ReadPrec [Op]
ReadPrec Op
Int -> ReadS Op
ReadS [Op]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Op]
$creadListPrec :: ReadPrec [Op]
readPrec :: ReadPrec Op
$creadPrec :: ReadPrec Op
readList :: ReadS [Op]
$creadList :: ReadS [Op]
readsPrec :: Int -> ReadS Op
$creadsPrec :: Int -> ReadS Op
Read, Op -> Op -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Op -> Op -> Bool
$c/= :: Op -> Op -> Bool
== :: Op -> Op -> Bool
$c== :: Op -> Op -> Bool
Eq, Eq Op
Op -> Op -> Bool
Op -> Op -> Ordering
Op -> Op -> Op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Op -> Op -> Op
$cmin :: Op -> Op -> Op
max :: Op -> Op -> Op
$cmax :: Op -> Op -> Op
>= :: Op -> Op -> Bool
$c>= :: Op -> Op -> Bool
> :: Op -> Op -> Bool
$c> :: Op -> Op -> Bool
<= :: Op -> Op -> Bool
$c<= :: Op -> Op -> Bool
< :: Op -> Op -> Bool
$c< :: Op -> Op -> Bool
compare :: Op -> Op -> Ordering
$ccompare :: Op -> Op -> Ordering
Ord, Int -> Op
Op -> Int
Op -> [Op]
Op -> Op
Op -> Op -> [Op]
Op -> Op -> Op -> [Op]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Op -> Op -> Op -> [Op]
$cenumFromThenTo :: Op -> Op -> Op -> [Op]
enumFromTo :: Op -> Op -> [Op]
$cenumFromTo :: Op -> Op -> [Op]
enumFromThen :: Op -> Op -> [Op]
$cenumFromThen :: Op -> Op -> [Op]
enumFrom :: Op -> [Op]
$cenumFrom :: Op -> [Op]
fromEnum :: Op -> Int
$cfromEnum :: Op -> Int
toEnum :: Int -> Op
$ctoEnum :: Int -> Op
pred :: Op -> Op
$cpred :: Op -> Op
succ :: Op -> Op
$csucc :: Op -> Op
Enum)
data Function =
Id
| Abs
| Sin
| Cos
| Tan
| Sinh
| Cosh
| Tanh
| ASin
| ACos
| ATan
| ASinh
| ACosh
| ATanh
| Sqrt
| Cbrt
| Square
| Log
| Exp
deriving (Int -> Function -> ShowS
[Function] -> ShowS
Function -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Function] -> ShowS
$cshowList :: [Function] -> ShowS
show :: Function -> String
$cshow :: Function -> String
showsPrec :: Int -> Function -> ShowS
$cshowsPrec :: Int -> Function -> ShowS
Show, ReadPrec [Function]
ReadPrec Function
Int -> ReadS Function
ReadS [Function]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Function]
$creadListPrec :: ReadPrec [Function]
readPrec :: ReadPrec Function
$creadPrec :: ReadPrec Function
readList :: ReadS [Function]
$creadList :: ReadS [Function]
readsPrec :: Int -> ReadS Function
$creadsPrec :: Int -> ReadS Function
Read, Function -> Function -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Function -> Function -> Bool
$c/= :: Function -> Function -> Bool
== :: Function -> Function -> Bool
$c== :: Function -> Function -> Bool
Eq, Eq Function
Function -> Function -> Bool
Function -> Function -> Ordering
Function -> Function -> Function
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Function -> Function -> Function
$cmin :: Function -> Function -> Function
max :: Function -> Function -> Function
$cmax :: Function -> Function -> Function
>= :: Function -> Function -> Bool
$c>= :: Function -> Function -> Bool
> :: Function -> Function -> Bool
$c> :: Function -> Function -> Bool
<= :: Function -> Function -> Bool
$c<= :: Function -> Function -> Bool
< :: Function -> Function -> Bool
$c< :: Function -> Function -> Bool
compare :: Function -> Function -> Ordering
$ccompare :: Function -> Function -> Ordering
Ord, Int -> Function
Function -> Int
Function -> [Function]
Function -> Function
Function -> Function -> [Function]
Function -> Function -> Function -> [Function]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Function -> Function -> Function -> [Function]
$cenumFromThenTo :: Function -> Function -> Function -> [Function]
enumFromTo :: Function -> Function -> [Function]
$cenumFromTo :: Function -> Function -> [Function]
enumFromThen :: Function -> Function -> [Function]
$cenumFromThen :: Function -> Function -> [Function]
enumFrom :: Function -> [Function]
$cenumFrom :: Function -> [Function]
fromEnum :: Function -> Int
$cfromEnum :: Function -> Int
toEnum :: Int -> Function
$ctoEnum :: Int -> Function
pred :: Function -> Function
$cpred :: Function -> Function
succ :: Function -> Function
$csucc :: Function -> Function
Enum)
var :: Int -> Fix SRTree
var :: Int -> Fix SRTree
var Int
ix = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Var Int
ix)
param :: Int -> Fix SRTree
param :: Int -> Fix SRTree
param Int
ix = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Param Int
ix)
instance Num (Fix SRTree) where
Fix (Const Double
0) + :: Fix SRTree -> Fix SRTree -> Fix SRTree
+ Fix SRTree
r = Fix SRTree
r
Fix SRTree
l + Fix (Const Double
0) = Fix SRTree
l
Fix (Const Double
c1) + Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1 forall a. Num a => a -> a -> a
+ Double
c2
Fix SRTree
l + Fix SRTree
r = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Add Fix SRTree
l Fix SRTree
r
{-# INLINE (+) #-}
Fix SRTree
l - :: Fix SRTree -> Fix SRTree -> Fix SRTree
- Fix (Const Double
0) = Fix SRTree
l
Fix (Const Double
0) - Fix SRTree
r = forall a. Num a => a -> a
negate Fix SRTree
r
Fix (Const Double
c1) - Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1 forall a. Num a => a -> a -> a
- Double
c2
Fix SRTree
l - Fix SRTree
r = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Fix SRTree
l Fix SRTree
r
{-# INLINE (-) #-}
Fix (Const Double
0) * :: Fix SRTree -> Fix SRTree -> Fix SRTree
* Fix SRTree
_ = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
Fix SRTree
_ * Fix (Const Double
0) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
Fix (Const Double
1) * Fix SRTree
r = Fix SRTree
r
Fix SRTree
l * Fix (Const Double
1) = Fix SRTree
l
Fix (Const Double
c1) * Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1 forall a. Num a => a -> a -> a
* Double
c2
Fix SRTree
l * Fix SRTree
r = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Fix SRTree
l Fix SRTree
r
{-# INLINE (*) #-}
abs :: Fix SRTree -> Fix SRTree
abs = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Abs
{-# INLINE abs #-}
negate :: Fix SRTree -> Fix SRTree
negate (Fix (Const Double
x)) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const (forall a. Num a => a -> a
negate Double
x)
negate Fix SRTree
t = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const (-Double
1)) forall a. Num a => a -> a -> a
* Fix SRTree
t
{-# INLINE negate #-}
signum :: Fix SRTree -> Fix SRTree
signum Fix SRTree
t = case Fix SRTree
t of
Fix (Const Double
x) -> forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
signum Double
x
Fix SRTree
_ -> forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
fromInteger :: Integer -> Fix SRTree
fromInteger Integer
x = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const (forall a. Num a => Integer -> a
fromInteger Integer
x)
{-# INLINE fromInteger #-}
instance Fractional (Fix SRTree) where
Fix SRTree
l / :: Fix SRTree -> Fix SRTree -> Fix SRTree
/ Fix (Const Double
1) = Fix SRTree
l
Fix (Const Double
c1) / Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1forall a. Fractional a => a -> a -> a
/Double
c2
Fix SRTree
l / Fix SRTree
r = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Div Fix SRTree
l Fix SRTree
r
{-# INLINE (/) #-}
fromRational :: Rational -> Fix SRTree
fromRational = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => Rational -> a
fromRational
{-# INLINE fromRational #-}
instance Floating (Fix SRTree) where
pi :: Fix SRTree
pi = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const forall a. Floating a => a
pi
{-# INLINE pi #-}
exp :: Fix SRTree -> Fix SRTree
exp = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Exp
{-# INLINE exp #-}
log :: Fix SRTree -> Fix SRTree
log = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Log
{-# INLINE log #-}
sqrt :: Fix SRTree -> Fix SRTree
sqrt = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sqrt
{-# INLINE sqrt #-}
sin :: Fix SRTree -> Fix SRTree
sin = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sin
{-# INLINE sin #-}
cos :: Fix SRTree -> Fix SRTree
cos = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cos
{-# INLINE cos #-}
tan :: Fix SRTree -> Fix SRTree
tan = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tan
{-# INLINE tan #-}
asin :: Fix SRTree -> Fix SRTree
asin = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASin
{-# INLINE asin #-}
acos :: Fix SRTree -> Fix SRTree
acos = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACos
{-# INLINE acos #-}
atan :: Fix SRTree -> Fix SRTree
atan = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATan
{-# INLINE atan #-}
sinh :: Fix SRTree -> Fix SRTree
sinh = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sinh
{-# INLINE sinh #-}
cosh :: Fix SRTree -> Fix SRTree
cosh = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cosh
{-# INLINE cosh #-}
tanh :: Fix SRTree -> Fix SRTree
tanh = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tanh
{-# INLINE tanh #-}
asinh :: Fix SRTree -> Fix SRTree
asinh = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASinh
{-# INLINE asinh #-}
acosh :: Fix SRTree -> Fix SRTree
acosh = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACosh
{-# INLINE acosh #-}
atanh :: Fix SRTree -> Fix SRTree
atanh = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATanh
{-# INLINE atanh #-}
Fix SRTree
l ** :: Fix SRTree -> Fix SRTree -> Fix SRTree
** Fix (Const Double
1) = Fix SRTree
l
Fix SRTree
l ** Fix (Const Double
0) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
1)
Fix SRTree
l ** Fix SRTree
r = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Power Fix SRTree
l Fix SRTree
r
{-# INLINE (**) #-}
logBase :: Fix SRTree -> Fix SRTree -> Fix SRTree
logBase Fix SRTree
l (Fix (Const Double
1)) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
logBase Fix SRTree
l Fix SRTree
r = forall a. Floating a => a -> a
log Fix SRTree
l forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
log Fix SRTree
r
{-# INLINE logBase #-}
arity :: Fix SRTree -> Int
arity :: Fix SRTree -> Int
arity = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a} {val}. Num a => SRTree val -> a
alg
where
alg :: SRTree val -> a
alg Var {} = a
0
alg Param {} = a
0
alg Const {} = a
0
alg Uni {} = a
1
alg Bin {} = a
2
{-# INLINE arity #-}
getChildren :: Fix SRTree -> [Fix SRTree]
getChildren :: Fix SRTree -> [Fix SRTree]
getChildren (Fix (Var {})) = []
getChildren (Fix (Param {})) = []
getChildren (Fix (Const {})) = []
getChildren (Fix (Uni Function
_ Fix SRTree
t)) = [Fix SRTree
t]
getChildren (Fix (Bin Op
_ Fix SRTree
l Fix SRTree
r)) = [Fix SRTree
l, Fix SRTree
r]
{-# INLINE getChildren #-}
countNodes :: Fix SRTree -> Int
countNodes :: Fix SRTree -> Int
countNodes = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
where
alg :: SRTree a -> a
alg Var {} = a
1
alg Param {} = a
1
alg Const {} = a
1
alg (Uni Function
_ a
t) = a
1 forall a. Num a => a -> a -> a
+ a
t
alg (Bin Op
_ a
l a
r) = a
1 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countNodes #-}
countVarNodes :: Fix SRTree -> Int
countVarNodes :: Fix SRTree -> Int
countVarNodes = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
where
alg :: SRTree a -> a
alg Var {} = a
1
alg Param {} = a
0
alg Const {} = a
0
alg (Uni Function
_ a
t) = a
0 forall a. Num a => a -> a -> a
+ a
t
alg (Bin Op
_ a
l a
r) = a
0 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countVarNodes #-}
countParams :: Fix SRTree -> Int
countParams :: Fix SRTree -> Int
countParams = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
where
alg :: SRTree a -> a
alg Var {} = a
0
alg Param {} = a
1
alg Const {} = a
0
alg (Uni Function
_ a
t) = a
0 forall a. Num a => a -> a -> a
+ a
t
alg (Bin Op
_ a
l a
r) = a
0 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countParams #-}
countConsts :: Fix SRTree -> Int
countConsts :: Fix SRTree -> Int
countConsts = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
where
alg :: SRTree a -> a
alg Var {} = a
0
alg Param {} = a
0
alg Const {} = a
1
alg (Uni Function
_ a
t) = a
0 forall a. Num a => a -> a -> a
+ a
t
alg (Bin Op
_ a
l a
r) = a
0 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countConsts #-}
countOccurrences :: Int -> Fix SRTree -> Int
countOccurrences :: Int -> Fix SRTree -> Int
countOccurrences Int
ix = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
where
alg :: SRTree a -> a
alg (Var Int
iy) = if Int
ix forall a. Eq a => a -> a -> Bool
== Int
iy then a
1 else a
0
alg Param {} = a
0
alg Const {} = a
0
alg (Uni Function
_ a
t) = a
t
alg (Bin Op
_ a
l a
r) = a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countOccurrences #-}
evalTree :: (Num a, Floating a) => V.Vector a -> V.Vector Double -> (Double -> a) -> Fix SRTree -> a
evalTree :: forall a.
(Num a, Floating a) =>
Vector a -> Vector Double -> (Double -> a) -> Fix SRTree -> a
evalTree Vector a
xss Vector Double
params Double -> a
f = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree a -> a
alg
where
alg :: SRTree a -> a
alg (Var Int
ix) = Vector a
xss forall a. Vector a -> Int -> a
! Int
ix
alg (Param Int
ix) = Double -> a
f forall a b. (a -> b) -> a -> b
$ Vector Double
params forall a. Vector a -> Int -> a
! Int
ix
alg (Const Double
c) = Double -> a
f Double
c
alg (Uni Function
g a
t) = forall a. Floating a => Function -> a -> a
evalFun Function
g a
t
alg (Bin Op
op a
l a
r) = forall a. Floating a => Op -> a -> a -> a
evalOp Op
op a
l a
r
{-# INLINE evalTree #-}
evalOp :: Floating a => Op -> a -> a -> a
evalOp :: forall a. Floating a => Op -> a -> a -> a
evalOp Op
Add = forall a. Num a => a -> a -> a
(+)
evalOp Op
Sub = (-)
evalOp Op
Mul = forall a. Num a => a -> a -> a
(*)
evalOp Op
Div = forall a. Fractional a => a -> a -> a
(/)
evalOp Op
Power = forall a. Floating a => a -> a -> a
(**)
{-# INLINE evalOp #-}
evalFun :: Floating a => Function -> a -> a
evalFun :: forall a. Floating a => Function -> a -> a
evalFun Function
Id = forall a. a -> a
id
evalFun Function
Abs = forall a. Num a => a -> a
abs
evalFun Function
Sin = forall a. Floating a => a -> a
sin
evalFun Function
Cos = forall a. Floating a => a -> a
cos
evalFun Function
Tan = forall a. Floating a => a -> a
tan
evalFun Function
Sinh = forall a. Floating a => a -> a
sinh
evalFun Function
Cosh = forall a. Floating a => a -> a
cosh
evalFun Function
Tanh = forall a. Floating a => a -> a
tanh
evalFun Function
ASin = forall a. Floating a => a -> a
asin
evalFun Function
ACos = forall a. Floating a => a -> a
acos
evalFun Function
ATan = forall a. Floating a => a -> a
atan
evalFun Function
ASinh = forall a. Floating a => a -> a
asinh
evalFun Function
ACosh = forall a. Floating a => a -> a
acosh
evalFun Function
ATanh = forall a. Floating a => a -> a
atanh
evalFun Function
Sqrt = forall a. Floating a => a -> a
sqrt
evalFun Function
Cbrt = forall a. Floating a => a -> a
cbrt
evalFun Function
Square = (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
evalFun Function
Log = forall a. Floating a => a -> a
log
evalFun Function
Exp = forall a. Floating a => a -> a
exp
{-# INLINE evalFun #-}
cbrt :: Floating val => val -> val
cbrt :: forall a. Floating a => a -> a
cbrt val
x = forall a. Num a => a -> a
signum val
x forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
abs val
x forall a. Floating a => a -> a -> a
** (val
1forall a. Fractional a => a -> a -> a
/val
3)
{-# INLINE cbrt #-}
inverseFunc :: Function -> Function
inverseFunc :: Function -> Function
inverseFunc Function
Id = Function
Id
inverseFunc Function
Sin = Function
ASin
inverseFunc Function
Cos = Function
ACos
inverseFunc Function
Tan = Function
ATan
inverseFunc Function
Tanh = Function
ATanh
inverseFunc Function
ASin = Function
Sin
inverseFunc Function
ACos = Function
Cos
inverseFunc Function
ATan = Function
Tan
inverseFunc Function
ATanh = Function
Tanh
inverseFunc Function
Sqrt = Function
Square
inverseFunc Function
Square = Function
Sqrt
inverseFunc Function
Log = Function
Exp
inverseFunc Function
Exp = Function
Log
inverseFunc Function
x = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Function
x forall a. [a] -> [a] -> [a]
++ String
" has no support for inverse function"
{-# INLINE inverseFunc #-}
deriveBy :: Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy :: Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy Bool
p Int
dx = forall a b. (a, b) -> a
fst (forall (f :: * -> *) a b.
Functor f =>
(f (a, b) -> a) -> (f (a, b) -> b) -> (Fix f -> a, Fix f -> b)
mutu forall {b}. Floating b => SRTree (b, b) -> b
alg1 forall {a}. SRTree (a, Fix SRTree) -> Fix SRTree
alg2)
where
alg1 :: SRTree (b, b) -> b
alg1 (Var Int
ix) = if Bool -> Bool
not Bool
p Bool -> Bool -> Bool
&& Int
ix forall a. Eq a => a -> a -> Bool
== Int
dx then b
1 else b
0
alg1 (Param Int
ix) = if Bool
p Bool -> Bool -> Bool
&& Int
ix forall a. Eq a => a -> a -> Bool
== Int
dx then b
1 else b
0
alg1 (Const Double
_) = b
0
alg1 (Uni Function
f (b, b)
t) = forall a. Floating a => Function -> a -> a
derivative Function
f (forall a b. (a, b) -> b
snd (b, b)
t) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
t
alg1 (Bin Op
Add (b, b)
l (b, b)
r) = forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> a
fst (b, b)
r
alg1 (Bin Op
Sub (b, b)
l (b, b)
r) = forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
- forall a b. (a, b) -> a
fst (b, b)
r
alg1 (Bin Op
Mul (b, b)
l (b, b)
r) = forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> b
snd (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
r
alg1 (Bin Op
Div (b, b)
l (b, b)
r) = (forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
- forall a b. (a, b) -> b
snd (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
r) forall a. Fractional a => a -> a -> a
/ forall a b. (a, b) -> b
snd (b, b)
r forall a. Floating a => a -> a -> a
** b
2
alg1 (Bin Op
Power (b, b)
l (b, b)
r) = forall a b. (a, b) -> b
snd (b, b)
l forall a. Floating a => a -> a -> a
** (forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
- b
1) forall a. Num a => a -> a -> a
* (forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> b
snd (b, b)
l forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log (forall a b. (a, b) -> b
snd (b, b)
l) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
r)
alg2 :: SRTree (a, Fix SRTree) -> Fix SRTree
alg2 (Var Int
ix) = Int -> Fix SRTree
var Int
ix
alg2 (Param Int
ix) = Int -> Fix SRTree
param Int
ix
alg2 (Const Double
c) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
c)
alg2 (Uni Function
f (a, Fix SRTree)
t) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Function -> val -> SRTree val
Uni Function
f forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd (a, Fix SRTree)
t)
alg2 (Bin Op
f (a, Fix SRTree)
l (a, Fix SRTree)
r) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f (forall a b. (a, b) -> b
snd (a, Fix SRTree)
l) (forall a b. (a, b) -> b
snd (a, Fix SRTree)
r))
newtype Tape a = Tape { forall a. Tape a -> [a]
untape :: [a] } deriving (Int -> Tape a -> ShowS
forall a. Show a => Int -> Tape a -> ShowS
forall a. Show a => [Tape a] -> ShowS
forall a. Show a => Tape a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tape a] -> ShowS
$cshowList :: forall a. Show a => [Tape a] -> ShowS
show :: Tape a -> String
$cshow :: forall a. Show a => Tape a -> String
showsPrec :: Int -> Tape a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Tape a -> ShowS
Show, forall a b. a -> Tape b -> Tape a
forall a b. (a -> b) -> Tape a -> Tape b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Tape b -> Tape a
$c<$ :: forall a b. a -> Tape b -> Tape a
fmap :: forall a b. (a -> b) -> Tape a -> Tape b
$cfmap :: forall a b. (a -> b) -> Tape a -> Tape b
Functor)
instance Num a => Num (Tape a) where
(Tape [a]
x) + :: Tape a -> Tape a -> Tape a
+ (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(+) [a]
x [a]
y
(Tape [a]
x) - :: Tape a -> Tape a -> Tape a
- (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) [a]
x [a]
y
(Tape [a]
x) * :: Tape a -> Tape a -> Tape a
* (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [a]
x [a]
y
abs :: Tape a -> Tape a
abs (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => a -> a
abs [a]
x)
signum :: Tape a -> Tape a
signum (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => a -> a
signum [a]
x)
fromInteger :: Integer -> Tape a
fromInteger Integer
x = forall a. [a] -> Tape a
Tape [forall a. Num a => Integer -> a
fromInteger Integer
x]
negate :: Tape a -> Tape a
negate (Tape [a]
x) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
*(-a
1)) [a]
x
instance Floating a => Floating (Tape a) where
pi :: Tape a
pi = forall a. [a] -> Tape a
Tape [forall a. Floating a => a
pi]
exp :: Tape a -> Tape a
exp (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
exp [a]
x)
log :: Tape a -> Tape a
log (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
log [a]
x)
sqrt :: Tape a -> Tape a
sqrt (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
sqrt [a]
x)
sin :: Tape a -> Tape a
sin (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
sin [a]
x)
cos :: Tape a -> Tape a
cos (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
cos [a]
x)
tan :: Tape a -> Tape a
tan (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
tan [a]
x)
asin :: Tape a -> Tape a
asin (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
asin [a]
x)
acos :: Tape a -> Tape a
acos (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
acos [a]
x)
atan :: Tape a -> Tape a
atan (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
atan [a]
x)
sinh :: Tape a -> Tape a
sinh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
sinh [a]
x)
cosh :: Tape a -> Tape a
cosh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
cosh [a]
x)
tanh :: Tape a -> Tape a
tanh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
tanh [a]
x)
asinh :: Tape a -> Tape a
asinh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
asinh [a]
x)
acosh :: Tape a -> Tape a
acosh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
acosh [a]
x)
atanh :: Tape a -> Tape a
atanh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
atanh [a]
x)
(Tape [a]
x) ** :: Tape a -> Tape a -> Tape a
** (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Floating a => a -> a -> a
(**) [a]
x [a]
y
instance Fractional a => Fractional (Tape a) where
fromRational :: Rational -> Tape a
fromRational Rational
x = forall a. [a] -> Tape a
Tape [forall a. Fractional a => Rational -> a
fromRational Rational
x]
(Tape [a]
x) / :: Tape a -> Tape a -> Tape a
/ (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Fractional a => a -> a -> a
(/) [a]
x [a]
y
recip :: Tape a -> Tape a
recip (Tape [a]
x) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Fractional a => a -> a
recip [a]
x
forwardMode :: (Show a, Num a, Floating a) => V.Vector a -> V.Vector Double -> (Double -> a) -> Fix SRTree -> [a]
forwardMode :: forall a.
(Show a, Num a, Floating a) =>
Vector a -> Vector Double -> (Double -> a) -> Fix SRTree -> [a]
forwardMode Vector a
xss Vector Double
theta Double -> a
f = forall a. Tape a -> [a]
untape forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst (forall (f :: * -> *) a b.
Functor f =>
(f (a, b) -> a) -> (f (a, b) -> b) -> (Fix f -> a, Fix f -> b)
mutu SRTree (Tape a, Tape a) -> Tape a
alg1 forall {a}. SRTree (a, Tape a) -> Tape a
alg2)
where
n :: Int
n = forall a. Vector a -> Int
V.length Vector Double
theta
repMat :: a -> Tape a
repMat a
v = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate Int
n a
v
zeroes :: Tape a
zeroes = forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f Double
0
twos :: Tape a
twos = forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f Double
2
tapeXs :: [Tape a]
tapeXs = [forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Vector a
xss forall a. Vector a -> Int -> a
! Int
ix | Int
ix <- [Int
0 .. forall a. Vector a -> Int
V.length Vector a
xss forall a. Num a => a -> a -> a
- Int
1]]
tapeTheta :: [Tape a]
tapeTheta = [forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f (Vector Double
theta forall a. Vector a -> Int -> a
! Int
ix) | Int
ix <- [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
1]]
paramVec :: [Tape a]
paramVec = [ forall a. [a] -> Tape a
Tape [if Int
ixforall a. Eq a => a -> a -> Bool
==Int
iy then Double -> a
f Double
1 else Double -> a
f Double
0 | Int
iy <- [Int
0 .. Int
nforall a. Num a => a -> a -> a
-Int
1]] | Int
ix <- [Int
0 .. Int
nforall a. Num a => a -> a -> a
-Int
1] ]
alg1 :: SRTree (Tape a, Tape a) -> Tape a
alg1 (Var Int
ix) = Tape a
zeroes
alg1 (Param Int
ix) = [Tape a]
paramVec forall a. [a] -> Int -> a
!! Int
ix
alg1 (Const Double
_) = Tape a
zeroes
alg1 (Uni Function
f (Tape a, Tape a)
t) = forall a. Floating a => Function -> a -> a
derivative Function
f (forall a b. (a, b) -> b
snd (Tape a, Tape a)
t) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
t
alg1 (Bin Op
Add (Tape a, Tape a)
l (Tape a, Tape a)
r) = forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> a
fst (Tape a, Tape a)
r
alg1 (Bin Op
Sub (Tape a, Tape a)
l (Tape a, Tape a)
r) = forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
- forall a b. (a, b) -> a
fst (Tape a, Tape a)
r
alg1 (Bin Op
Mul (Tape a, Tape a)
l (Tape a, Tape a)
r) = (forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (Tape a, Tape a)
r) forall a. Num a => a -> a -> a
+ (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
r)
alg1 (Bin Op
Div (Tape a, Tape a)
l (Tape a, Tape a)
r) = ((forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (Tape a, Tape a)
r) forall a. Num a => a -> a -> a
- (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
r)) forall a. Fractional a => a -> a -> a
/ forall a b. (a, b) -> b
snd (Tape a, Tape a)
r forall a. Floating a => a -> a -> a
** Tape a
twos
alg1 (Bin Op
Power (Tape a, Tape a)
l (Tape a, Tape a)
r) = forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Floating a => a -> a -> a
** (forall a b. (a, b) -> b
snd (Tape a, Tape a)
r forall a. Num a => a -> a -> a
- Tape a
1) forall a. Num a => a -> a -> a
* ((forall a b. (a, b) -> b
snd (Tape a, Tape a)
r forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
l) forall a. Num a => a -> a -> a
+ (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
r))
alg2 :: SRTree (a, Tape a) -> Tape a
alg2 (Var Int
ix) = [Tape a]
tapeXs forall a. [a] -> Int -> a
!! Int
ix
alg2 (Param Int
ix) = [Tape a]
tapeTheta forall a. [a] -> Int -> a
!! Int
ix
alg2 (Const Double
c) = forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f Double
c
alg2 (Uni Function
g (a, Tape a)
t) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Floating a => Function -> a -> a
evalFun Function
g) (forall a b. (a, b) -> b
snd (a, Tape a)
t)
alg2 (Bin Op
op (a, Tape a)
l (a, Tape a)
r) = forall a. Floating a => Op -> a -> a -> a
evalOp Op
op (forall a b. (a, b) -> b
snd (a, Tape a)
l) (forall a b. (a, b) -> b
snd (a, Tape a)
r)
gradParamsFwd :: (Show a, Num a, Floating a) => V.Vector a -> V.Vector Double -> (Double -> a) -> Fix SRTree -> (a, [a])
gradParamsFwd :: forall a.
(Show a, Num a, Floating a) =>
Vector a
-> Vector Double -> (Double -> a) -> Fix SRTree -> (a, [a])
gradParamsFwd Vector a
xss Vector Double
theta Double -> a
f = forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a. DList a -> [a]
DL.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (a, DList a) -> (a, DList a)
alg
where
n :: Int
n = forall a. Vector a -> Int
V.length Vector Double
theta
alg :: SRTree (a, DList a) -> (a, DList a)
alg (Var Int
ix) = (Vector a
xss forall a. Vector a -> Int -> a
! Int
ix, forall a. DList a
DL.empty)
alg (Param Int
ix) = (Double -> a
f forall a b. (a -> b) -> a -> b
$ Vector Double
theta forall a. Vector a -> Int -> a
! Int
ix, forall a. a -> DList a
DL.singleton a
1)
alg (Const Double
c) = (Double -> a
f Double
c, forall a. DList a
DL.empty)
alg (Uni Function
f (a
v, DList a
gs)) = let v' :: a
v' = forall a. Floating a => Function -> a -> a
evalFun Function
f a
v
dv :: a
dv = forall a. Floating a => Function -> a -> a
derivative Function
f a
v
in (a
v', forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Num a => a -> a -> a
*a
dv) DList a
gs)
alg (Bin Op
Add (a
v1, DList a
l) (a
v2, DList a
r)) = (a
v1forall a. Num a => a -> a -> a
+a
v2, forall a. DList a -> DList a -> DList a
DL.append DList a
l DList a
r)
alg (Bin Op
Sub (a
v1, DList a
l) (a
v2, DList a
r)) = (a
v1forall a. Num a => a -> a -> a
-a
v2, forall a. DList a -> DList a -> DList a
DL.append DList a
l (forall a b. (a -> b) -> DList a -> DList b
DL.map forall a. Num a => a -> a
negate DList a
r))
alg (Bin Op
Mul (a
v1, DList a
l) (a
v2, DList a
r)) = (a
v1forall a. Num a => a -> a -> a
*a
v2, forall a. DList a -> DList a -> DList a
DL.append (forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Num a => a -> a -> a
*a
v2) DList a
l) (forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Num a => a -> a -> a
*a
v1) DList a
r))
alg (Bin Op
Div (a
v1, DList a
l) (a
v2, DList a
r)) = let dv :: a
dv = (-a
v1forall a. Fractional a => a -> a -> a
/a
v2forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
in (a
v1forall a. Fractional a => a -> a -> a
/a
v2, forall a. DList a -> DList a -> DList a
DL.append (forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Fractional a => a -> a -> a
/a
v2) DList a
l) (forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Num a => a -> a -> a
*a
dv) DList a
r))
alg (Bin Op
Power (a
v1, DList a
l) (a
v2, DList a
r)) = let dv1 :: a
dv1 = a
v1 forall a. Floating a => a -> a -> a
** (a
v2 forall a. Num a => a -> a -> a
- a
1)
dv2 :: a
dv2 = a
v1 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
v1
in (a
v1 forall a. Floating a => a -> a -> a
** a
v2, forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Num a => a -> a -> a
*a
dv1) (forall a. DList a -> DList a -> DList a
DL.append (forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Num a => a -> a -> a
*a
v2) DList a
l) (forall a b. (a -> b) -> DList a -> DList b
DL.map (forall a. Num a => a -> a -> a
*a
dv2) DList a
r)))
data TupleF a b = S a | T a b | B a b b deriving forall a b. a -> TupleF a b -> TupleF a a
forall a b. (a -> b) -> TupleF a a -> TupleF a b
forall a a b. a -> TupleF a b -> TupleF a a
forall a a b. (a -> b) -> TupleF a a -> TupleF a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TupleF a b -> TupleF a a
$c<$ :: forall a a b. a -> TupleF a b -> TupleF a a
fmap :: forall a b. (a -> b) -> TupleF a a -> TupleF a b
$cfmap :: forall a a b. (a -> b) -> TupleF a a -> TupleF a b
Functor
type Tuple a = Fix (TupleF a)
gradParamsRev :: forall a . (Show a, Num a, Floating a) => V.Vector a -> V.Vector Double -> (Double -> a) -> Fix SRTree -> (a, [a])
gradParamsRev :: forall a.
(Show a, Num a, Floating a) =>
Vector a
-> Vector Double -> (Double -> a) -> Fix SRTree -> (a, [a])
gradParamsRev Vector a
xss Vector Double
theta Double -> a
f Fix SRTree
t = (forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF a)
fwdMode, forall a. DList a -> [a]
DL.toList DList a
g)
where
fwdMode :: Fix (TupleF a)
fwdMode = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix (TupleF a)) -> Fix (TupleF a)
forward Fix SRTree
t
g :: DList a
g = forall (f :: * -> *) p a.
Functor f =>
(forall x. f x -> p -> f (x, p))
-> (f a -> p -> a) -> Fix f -> p -> a
accu forall {a} {a}.
Floating a =>
SRTree a -> (a, Fix (TupleF a)) -> SRTree (a, (a, Fix (TupleF a)))
reverse forall {a} {b}. SRTree (DList a) -> (a, b) -> DList a
combine Fix SRTree
t (a
1, Fix (TupleF a)
fwdMode)
oneTpl :: a -> Fix (TupleF a)
oneTpl a
x = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall a b. a -> TupleF a b
S a
x
tuple :: a -> Fix (TupleF a) -> Fix (TupleF a)
tuple a
x Fix (TupleF a)
y = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> TupleF a b
T a
x Fix (TupleF a)
y
branch :: a -> Fix (TupleF a) -> Fix (TupleF a) -> Fix (TupleF a)
branch a
x Fix (TupleF a)
y Fix (TupleF a)
z = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> b -> TupleF a b
B a
x Fix (TupleF a)
y Fix (TupleF a)
z
getTop :: Fix (TupleF a) -> a
getTop (Fix (S a
x)) = a
x
getTop (Fix (T a
x Fix (TupleF a)
y)) = a
x
getTop (Fix (B a
x Fix (TupleF a)
y Fix (TupleF a)
z)) = a
x
unCons :: Fix (TupleF a) -> Fix (TupleF a)
unCons (Fix (T a
x Fix (TupleF a)
y)) = Fix (TupleF a)
y
getBranches :: Fix (TupleF a) -> (Fix (TupleF a), Fix (TupleF a))
getBranches (Fix (B a
x Fix (TupleF a)
y Fix (TupleF a)
z)) = (Fix (TupleF a)
y,Fix (TupleF a)
z)
forward :: SRTree (Fix (TupleF a)) -> Fix (TupleF a)
forward (Var Int
ix) = forall {a}. a -> Fix (TupleF a)
oneTpl (Vector a
xss forall a. Vector a -> Int -> a
! Int
ix)
forward (Param Int
ix) = forall {a}. a -> Fix (TupleF a)
oneTpl (Double -> a
f forall a b. (a -> b) -> a -> b
$ Vector Double
theta forall a. Vector a -> Int -> a
! Int
ix)
forward (Const Double
c) = forall {a}. a -> Fix (TupleF a)
oneTpl (Double -> a
f Double
c)
forward (Uni Function
f Fix (TupleF a)
t) = let v :: a
v = forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF a)
t
in forall {a}. a -> Fix (TupleF a) -> Fix (TupleF a)
tuple (forall a. Floating a => Function -> a -> a
evalFun Function
f a
v) Fix (TupleF a)
t
forward (Bin Op
op Fix (TupleF a)
l Fix (TupleF a)
r) = let vl :: a
vl = forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF a)
l
vr :: a
vr = forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF a)
r
in forall {a}. a -> Fix (TupleF a) -> Fix (TupleF a) -> Fix (TupleF a)
branch (forall a. Floating a => Op -> a -> a -> a
evalOp Op
op a
vl a
vr) Fix (TupleF a)
l Fix (TupleF a)
r
reverse :: SRTree a -> (a, Fix (TupleF a)) -> SRTree (a, (a, Fix (TupleF a)))
reverse (Var Int
ix) (a
dx, Fix (TupleF a)
_) = forall val. Int -> SRTree val
Var Int
ix
reverse (Param Int
ix) (a
dx, Fix (TupleF a)
_) = forall val. Int -> SRTree val
Param Int
ix
reverse (Const Double
v) (a
dx, Fix (TupleF a)
_) = forall val. Double -> SRTree val
Const Double
v
reverse (Uni Function
f a
t) (a
dx, forall {a}. Fix (TupleF a) -> Fix (TupleF a)
unCons -> Fix (TupleF a)
v) = forall val. Function -> val -> SRTree val
Uni Function
f (a
t, (a
dx forall a. Num a => a -> a -> a
* (forall a. Floating a => Function -> a -> a
derivative Function
f forall a b. (a -> b) -> a -> b
$ forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF a)
v), Fix (TupleF a)
v))
reverse (Bin Op
op a
l a
r) (a
dx, forall {a}. Fix (TupleF a) -> (Fix (TupleF a), Fix (TupleF a))
getBranches -> (Fix (TupleF a)
vl, Fix (TupleF a)
vr)) = let (a
dxl, a
dxr) = forall {b}. Floating b => Op -> b -> b -> b -> (b, b)
diff Op
op a
dx (forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF a)
vl) (forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF a)
vr)
in forall val. Op -> val -> val -> SRTree val
Bin Op
op (a
l, (a
dxl, Fix (TupleF a)
vl)) (a
r, (a
dxr, Fix (TupleF a)
vr))
diff :: Op -> b -> b -> b -> (b, b)
diff Op
Add b
dx b
vl b
vr = (b
dx, b
dx)
diff Op
Sub b
dx b
vl b
vr = (b
dx, forall a. Num a => a -> a
negate b
dx)
diff Op
Mul b
dx b
vl b
vr = (b
dx forall a. Num a => a -> a -> a
* b
vr, b
dx forall a. Num a => a -> a -> a
* b
vl)
diff Op
Div b
dx b
vl b
vr = (b
dx forall a. Fractional a => a -> a -> a
/ b
vr, b
dx forall a. Num a => a -> a -> a
* (-b
vlforall a. Fractional a => a -> a -> a
/b
vrforall a b. (Num a, Integral b) => a -> b -> a
^Integer
2))
diff Op
Power b
dx b
vl b
vr = let dxl :: b
dxl = b
dx forall a. Num a => a -> a -> a
* b
vl forall a. Floating a => a -> a -> a
** (b
vr forall a. Num a => a -> a -> a
- b
1)
dv2 :: b
dv2 = b
vl forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log b
vl
in (b
dxl forall a. Num a => a -> a -> a
* b
vr, b
dxl forall a. Num a => a -> a -> a
* b
dv2)
combine :: SRTree (DList a) -> (a, b) -> DList a
combine (Var Int
ix) (a, b)
s = forall a. DList a
DL.empty
combine (Param Int
ix) (a, b)
s = forall a. a -> DList a
DL.singleton forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst (a, b)
s
combine (Const Double
c) (a, b)
s = forall a. DList a
DL.empty
combine (Uni Function
_ DList a
gs) (a, b)
s = DList a
gs
combine (Bin Op
op DList a
l DList a
r) (a, b)
s = forall a. DList a -> DList a -> DList a
DL.append DList a
l DList a
r
derivative :: Floating a => Function -> a -> a
derivative :: forall a. Floating a => Function -> a -> a
derivative Function
Id = forall a b. a -> b -> a
const a
1
derivative Function
Abs = \a
x -> a
x forall a. Fractional a => a -> a -> a
/ forall a. Num a => a -> a
abs a
x
derivative Function
Sin = forall a. Floating a => a -> a
cos
derivative Function
Cos = forall a. Num a => a -> a
negateforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a. Floating a => a -> a
sin
derivative Function
Tan = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Floating a => a -> a -> a
**a
2.0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
cos
derivative Function
Sinh = forall a. Floating a => a -> a
cosh
derivative Function
Cosh = forall a. Floating a => a -> a
sinh
derivative Function
Tanh = (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Floating a => a -> a -> a
**a
2.0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
tanh
derivative Function
ASin = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ACos = forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ATan = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
+) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ASinh = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
+) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ACosh = \a
x -> a
1 forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
-a
1) forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
+a
1))
derivative Function
ATanh = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
Sqrt = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
2forall a. Num a => a -> a -> a
*) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt
derivative Function
Cbrt = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
3forall a. Num a => a -> a -> a
*) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
cbrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
Square = (a
2forall a. Num a => a -> a -> a
*)
derivative Function
Exp = forall a. Floating a => a -> a
exp
derivative Function
Log = forall a. Fractional a => a -> a
recip
{-# INLINE derivative #-}
deriveByVar :: Int -> Fix SRTree -> Fix SRTree
deriveByVar :: Int -> Fix SRTree -> Fix SRTree
deriveByVar = Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy Bool
False
deriveByParam :: Int -> Fix SRTree -> Fix SRTree
deriveByParam :: Int -> Fix SRTree -> Fix SRTree
deriveByParam = Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy Bool
True
relabelParams :: Fix SRTree -> Fix SRTree
relabelParams :: Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
t = forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM forall {f :: * -> *} {a}.
Applicative f =>
SRTree (f a) -> f (SRTree a)
lTor SRTree (Fix SRTree) -> State Int (Fix SRTree)
alg Fix SRTree
t forall s a. State s a -> s -> a
`evalState` Int
0
where
lTor :: SRTree (f a) -> f (SRTree a)
lTor (Uni Function
f f a
mt) = forall val. Function -> val -> SRTree val
Uni Function
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
mt;
lTor (Bin Op
f f a
ml f a
mr) = forall val. Op -> val -> val -> SRTree val
Bin Op
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
ml forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f a
mr
lTor (Var Int
ix) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall val. Int -> SRTree val
Var Int
ix)
lTor (Param Int
ix) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall val. Int -> SRTree val
Param Int
ix)
lTor (Const Double
c) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall val. Double -> SRTree val
Const Double
c)
alg :: SRTree (Fix SRTree) -> State Int (Fix SRTree)
alg :: SRTree (Fix SRTree) -> State Int (Fix SRTree)
alg (Var Int
ix) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int -> Fix SRTree
var Int
ix
alg (Param Int
ix) = do Int
iy <- forall s (m :: * -> *). MonadState s m => m s
get; forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
+Int
1); forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Fix SRTree
param Int
iy)
alg (Const Double
c) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const Double
c
alg (Uni Function
f Fix SRTree
t) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)
alg (Bin Op
f Fix SRTree
l Fix SRTree
r) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f Fix SRTree
l Fix SRTree
r)
constsToParam :: Fix SRTree -> (Fix SRTree, [Double])
constsToParam :: Fix SRTree -> (Fix SRTree, [Double])
constsToParam = forall {t} {a} {b}. (t -> a) -> (t, b) -> (a, b)
first Fix SRTree -> Fix SRTree
relabelParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg
where
first :: (t -> a) -> (t, b) -> (a, b)
first t -> a
f (t
x, b
y) = (t -> a
f t
x, b
y)
alg :: SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg (Var Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Var Int
ix, [])
alg (Param Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
ix, [Double
1.0])
alg (Const Double
c) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
0, [Double
c])
alg (Uni Function
f (Fix SRTree, [Double])
t) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Function -> val -> SRTree val
Uni Function
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
t), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
t)
alg (Bin Op
f (Fix SRTree, [Double])
l (Fix SRTree, [Double])
r) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
l) (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
r)), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
l forall a. Semigroup a => a -> a -> a
<> forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
r)
floatConstsToParam :: Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam :: Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam = forall {t} {a} {b}. (t -> a) -> (t, b) -> (a, b)
first Fix SRTree -> Fix SRTree
relabelParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg
where
first :: (t -> a) -> (t, b) -> (a, b)
first t -> a
f (t
x, b
y) = (t -> a
f t
x, b
y)
alg :: SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg (Var Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Var Int
ix, [])
alg (Param Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
ix, [Double
1.0])
alg (Const Double
c) = if forall a b. (RealFrac a, Integral b) => a -> b
floor Double
c forall a. Eq a => a -> a -> Bool
== forall a b. (RealFrac a, Integral b) => a -> b
ceiling Double
c then (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const Double
c, []) else (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
0, [Double
c])
alg (Uni Function
f (Fix SRTree, [Double])
t) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Function -> val -> SRTree val
Uni Function
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
t), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
t)
alg (Bin Op
f (Fix SRTree, [Double])
l (Fix SRTree, [Double])
r) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
l) (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
r)), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
l forall a. Semigroup a => a -> a -> a
<> forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
r)
paramsToConst :: [Double] -> Fix SRTree -> Fix SRTree
paramsToConst :: [Double] -> Fix SRTree -> Fix SRTree
paramsToConst [Double]
theta = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree) -> Fix SRTree
alg
where
alg :: SRTree (Fix SRTree) -> Fix SRTree
alg (Var Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Var Int
ix
alg (Param Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const ([Double]
theta forall a. [a] -> Int -> a
!! Int
ix)
alg (Const Double
c) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const Double
c
alg (Uni Function
f Fix SRTree
t) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t
alg (Bin Op
f Fix SRTree
l Fix SRTree
r) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
f Fix SRTree
l Fix SRTree
r