{-# language FlexibleInstances, DeriveFunctor #-}
{-# language ScopedTypeVariables #-}
{-# language RankNTypes #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SRTree.Internal 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2021
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  FlexibleInstances, DeriveFunctor, ScopedTypeVariables
--
-- Expression tree for Symbolic Regression
--
-----------------------------------------------------------------------------

module Data.SRTree.Internal
         ( SRTree(..)
         , Function(..)
         , Op(..)
         , param
         , var
         , arity
         , getChildren
         , countNodes
         , countVarNodes
         , countConsts
         , countParams
         , countOccurrences
         , deriveBy
         , deriveByVar
         , deriveByParam
         , derivative
         , forwardMode
         , gradParams
         , evalFun
         , evalOp
         , inverseFunc
         , evalTree
         , relabelParams
         , constsToParam
         , floatConstsToParam
         )
         where

import Data.SRTree.Recursion ( Fix(Fix), cata, mutu, cataM )

import qualified Data.Vector as V
import Data.Vector ((!))
import Control.Monad.State

import Debug.Trace (trace)

-- | Tree structure to be used with Symbolic Regression algorithms.
-- This structure is a fixed point of a n-ary tree. 
data SRTree val =
   Var Int     -- ^ index of the variables
 | Param Int   -- ^ index of the parameter
 | Const Double -- ^ constant value, can be converted to a parameter
 | Uni Function val -- ^ univariate function
 | Bin Op val val -- ^ binary operator
 deriving (Int -> SRTree val -> ShowS
forall val. Show val => Int -> SRTree val -> ShowS
forall val. Show val => [SRTree val] -> ShowS
forall val. Show val => SRTree val -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SRTree val] -> ShowS
$cshowList :: forall val. Show val => [SRTree val] -> ShowS
show :: SRTree val -> String
$cshow :: forall val. Show val => SRTree val -> String
showsPrec :: Int -> SRTree val -> ShowS
$cshowsPrec :: forall val. Show val => Int -> SRTree val -> ShowS
Show, SRTree val -> SRTree val -> Bool
forall val. Eq val => SRTree val -> SRTree val -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SRTree val -> SRTree val -> Bool
$c/= :: forall val. Eq val => SRTree val -> SRTree val -> Bool
== :: SRTree val -> SRTree val -> Bool
$c== :: forall val. Eq val => SRTree val -> SRTree val -> Bool
Eq, SRTree val -> SRTree val -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {val}. Ord val => Eq (SRTree val)
forall val. Ord val => SRTree val -> SRTree val -> Bool
forall val. Ord val => SRTree val -> SRTree val -> Ordering
forall val. Ord val => SRTree val -> SRTree val -> SRTree val
min :: SRTree val -> SRTree val -> SRTree val
$cmin :: forall val. Ord val => SRTree val -> SRTree val -> SRTree val
max :: SRTree val -> SRTree val -> SRTree val
$cmax :: forall val. Ord val => SRTree val -> SRTree val -> SRTree val
>= :: SRTree val -> SRTree val -> Bool
$c>= :: forall val. Ord val => SRTree val -> SRTree val -> Bool
> :: SRTree val -> SRTree val -> Bool
$c> :: forall val. Ord val => SRTree val -> SRTree val -> Bool
<= :: SRTree val -> SRTree val -> Bool
$c<= :: forall val. Ord val => SRTree val -> SRTree val -> Bool
< :: SRTree val -> SRTree val -> Bool
$c< :: forall val. Ord val => SRTree val -> SRTree val -> Bool
compare :: SRTree val -> SRTree val -> Ordering
$ccompare :: forall val. Ord val => SRTree val -> SRTree val -> Ordering
Ord, forall a b. a -> SRTree b -> SRTree a
forall a b. (a -> b) -> SRTree a -> SRTree b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SRTree b -> SRTree a
$c<$ :: forall a b. a -> SRTree b -> SRTree a
fmap :: forall a b. (a -> b) -> SRTree a -> SRTree b
$cfmap :: forall a b. (a -> b) -> SRTree a -> SRTree b
Functor)

-- | Supported operators
data Op = Add | Sub | Mul | Div | Power
    deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Op] -> ShowS
$cshowList :: [Op] -> ShowS
show :: Op -> String
$cshow :: Op -> String
showsPrec :: Int -> Op -> ShowS
$cshowsPrec :: Int -> Op -> ShowS
Show, ReadPrec [Op]
ReadPrec Op
Int -> ReadS Op
ReadS [Op]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Op]
$creadListPrec :: ReadPrec [Op]
readPrec :: ReadPrec Op
$creadPrec :: ReadPrec Op
readList :: ReadS [Op]
$creadList :: ReadS [Op]
readsPrec :: Int -> ReadS Op
$creadsPrec :: Int -> ReadS Op
Read, Op -> Op -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Op -> Op -> Bool
$c/= :: Op -> Op -> Bool
== :: Op -> Op -> Bool
$c== :: Op -> Op -> Bool
Eq, Eq Op
Op -> Op -> Bool
Op -> Op -> Ordering
Op -> Op -> Op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Op -> Op -> Op
$cmin :: Op -> Op -> Op
max :: Op -> Op -> Op
$cmax :: Op -> Op -> Op
>= :: Op -> Op -> Bool
$c>= :: Op -> Op -> Bool
> :: Op -> Op -> Bool
$c> :: Op -> Op -> Bool
<= :: Op -> Op -> Bool
$c<= :: Op -> Op -> Bool
< :: Op -> Op -> Bool
$c< :: Op -> Op -> Bool
compare :: Op -> Op -> Ordering
$ccompare :: Op -> Op -> Ordering
Ord, Int -> Op
Op -> Int
Op -> [Op]
Op -> Op
Op -> Op -> [Op]
Op -> Op -> Op -> [Op]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Op -> Op -> Op -> [Op]
$cenumFromThenTo :: Op -> Op -> Op -> [Op]
enumFromTo :: Op -> Op -> [Op]
$cenumFromTo :: Op -> Op -> [Op]
enumFromThen :: Op -> Op -> [Op]
$cenumFromThen :: Op -> Op -> [Op]
enumFrom :: Op -> [Op]
$cenumFrom :: Op -> [Op]
fromEnum :: Op -> Int
$cfromEnum :: Op -> Int
toEnum :: Int -> Op
$ctoEnum :: Int -> Op
pred :: Op -> Op
$cpred :: Op -> Op
succ :: Op -> Op
$csucc :: Op -> Op
Enum)

-- | Supported functions
data Function =
    Id
  | Abs
  | Sin
  | Cos
  | Tan
  | Sinh
  | Cosh
  | Tanh
  | ASin
  | ACos
  | ATan
  | ASinh
  | ACosh
  | ATanh
  | Sqrt
  | Cbrt
  | Square
  | Log
  | Exp
     deriving (Int -> Function -> ShowS
[Function] -> ShowS
Function -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Function] -> ShowS
$cshowList :: [Function] -> ShowS
show :: Function -> String
$cshow :: Function -> String
showsPrec :: Int -> Function -> ShowS
$cshowsPrec :: Int -> Function -> ShowS
Show, ReadPrec [Function]
ReadPrec Function
Int -> ReadS Function
ReadS [Function]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Function]
$creadListPrec :: ReadPrec [Function]
readPrec :: ReadPrec Function
$creadPrec :: ReadPrec Function
readList :: ReadS [Function]
$creadList :: ReadS [Function]
readsPrec :: Int -> ReadS Function
$creadsPrec :: Int -> ReadS Function
Read, Function -> Function -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Function -> Function -> Bool
$c/= :: Function -> Function -> Bool
== :: Function -> Function -> Bool
$c== :: Function -> Function -> Bool
Eq, Eq Function
Function -> Function -> Bool
Function -> Function -> Ordering
Function -> Function -> Function
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Function -> Function -> Function
$cmin :: Function -> Function -> Function
max :: Function -> Function -> Function
$cmax :: Function -> Function -> Function
>= :: Function -> Function -> Bool
$c>= :: Function -> Function -> Bool
> :: Function -> Function -> Bool
$c> :: Function -> Function -> Bool
<= :: Function -> Function -> Bool
$c<= :: Function -> Function -> Bool
< :: Function -> Function -> Bool
$c< :: Function -> Function -> Bool
compare :: Function -> Function -> Ordering
$ccompare :: Function -> Function -> Ordering
Ord, Int -> Function
Function -> Int
Function -> [Function]
Function -> Function
Function -> Function -> [Function]
Function -> Function -> Function -> [Function]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Function -> Function -> Function -> [Function]
$cenumFromThenTo :: Function -> Function -> Function -> [Function]
enumFromTo :: Function -> Function -> [Function]
$cenumFromTo :: Function -> Function -> [Function]
enumFromThen :: Function -> Function -> [Function]
$cenumFromThen :: Function -> Function -> [Function]
enumFrom :: Function -> [Function]
$cenumFrom :: Function -> [Function]
fromEnum :: Function -> Int
$cfromEnum :: Function -> Int
toEnum :: Int -> Function
$ctoEnum :: Int -> Function
pred :: Function -> Function
$cpred :: Function -> Function
succ :: Function -> Function
$csucc :: Function -> Function
Enum)

-- | create a tree with a single node representing a variable
var :: Int -> Fix SRTree
var :: Int -> Fix SRTree
var Int
ix = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Var Int
ix)

-- | create a tree with a single node representing a parameter
param :: Int -> Fix SRTree
param :: Int -> Fix SRTree
param Int
ix = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Param Int
ix)

instance Num (Fix SRTree) where
  Fix (Const Double
0) + :: Fix SRTree -> Fix SRTree -> Fix SRTree
+ Fix SRTree
r = Fix SRTree
r
  Fix SRTree
l + Fix (Const Double
0) = Fix SRTree
l
  Fix (Const Double
c1) + Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1 forall a. Num a => a -> a -> a
+ Double
c2
  Fix SRTree
l + Fix SRTree
r                   = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Add Fix SRTree
l Fix SRTree
r
  {-# INLINE (+) #-}

  Fix SRTree
l - :: Fix SRTree -> Fix SRTree -> Fix SRTree
- Fix (Const Double
0) = Fix SRTree
l
  Fix (Const Double
0) - Fix SRTree
r = forall a. Num a => a -> a
negate Fix SRTree
r
  Fix (Const Double
c1) - Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1 forall a. Num a => a -> a -> a
- Double
c2
  Fix SRTree
l - Fix SRTree
r                   = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Fix SRTree
l Fix SRTree
r
  {-# INLINE (-) #-}

  Fix (Const Double
0) * :: Fix SRTree -> Fix SRTree -> Fix SRTree
* Fix SRTree
_ = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
  Fix SRTree
_ * Fix (Const Double
0) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
  Fix (Const Double
1) * Fix SRTree
r = Fix SRTree
r
  Fix SRTree
l * Fix (Const Double
1) = Fix SRTree
l
  Fix (Const Double
c1) * Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1 forall a. Num a => a -> a -> a
* Double
c2
  Fix SRTree
l * Fix SRTree
r                   = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Fix SRTree
l Fix SRTree
r
  {-# INLINE (*) #-}

  abs :: Fix SRTree -> Fix SRTree
abs = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Abs
  {-# INLINE abs #-}

  negate :: Fix SRTree -> Fix SRTree
negate (Fix (Const Double
x)) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const (forall a. Num a => a -> a
negate Double
x)
  negate Fix SRTree
t         = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const (-Double
1)) forall a. Num a => a -> a -> a
* Fix SRTree
t
  {-# INLINE negate #-}

  signum :: Fix SRTree -> Fix SRTree
signum Fix SRTree
t    = case Fix SRTree
t of
                  Fix (Const Double
x) -> forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
signum Double
x
                  Fix SRTree
_       -> forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
  fromInteger :: Integer -> Fix SRTree
fromInteger Integer
x = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const (forall a. Num a => Integer -> a
fromInteger Integer
x)
  {-# INLINE fromInteger #-}

instance Fractional (Fix SRTree) where
  Fix SRTree
l / :: Fix SRTree -> Fix SRTree -> Fix SRTree
/ Fix (Const Double
1) = Fix SRTree
l
  Fix (Const Double
c1) / Fix (Const Double
c2) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall a b. (a -> b) -> a -> b
$ Double
c1forall a. Fractional a => a -> a -> a
/Double
c2
  Fix SRTree
l / Fix SRTree
r                   = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Div Fix SRTree
l Fix SRTree
r
  {-# INLINE (/) #-}

  fromRational :: Rational -> Fix SRTree
fromRational = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => Rational -> a
fromRational
  {-# INLINE fromRational #-}

instance Floating (Fix SRTree) where
  pi :: Fix SRTree
pi      = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const  forall a. Floating a => a
pi
  {-# INLINE pi #-}
  exp :: Fix SRTree -> Fix SRTree
exp     = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Exp
  {-# INLINE exp #-}
  log :: Fix SRTree -> Fix SRTree
log     = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Log
  {-# INLINE log #-}
  sqrt :: Fix SRTree -> Fix SRTree
sqrt    = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sqrt
  {-# INLINE sqrt #-}
  sin :: Fix SRTree -> Fix SRTree
sin     = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sin
  {-# INLINE sin #-}
  cos :: Fix SRTree -> Fix SRTree
cos     = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cos
  {-# INLINE cos #-}
  tan :: Fix SRTree -> Fix SRTree
tan     = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tan
  {-# INLINE tan #-}
  asin :: Fix SRTree -> Fix SRTree
asin    = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASin
  {-# INLINE asin #-}
  acos :: Fix SRTree -> Fix SRTree
acos    = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACos
  {-# INLINE acos #-}
  atan :: Fix SRTree -> Fix SRTree
atan    = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATan
  {-# INLINE atan #-}
  sinh :: Fix SRTree -> Fix SRTree
sinh    = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sinh
  {-# INLINE sinh #-}
  cosh :: Fix SRTree -> Fix SRTree
cosh    = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cosh
  {-# INLINE cosh #-}
  tanh :: Fix SRTree -> Fix SRTree
tanh    = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tanh
  {-# INLINE tanh #-}
  asinh :: Fix SRTree -> Fix SRTree
asinh   = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASinh
  {-# INLINE asinh #-}
  acosh :: Fix SRTree -> Fix SRTree
acosh   = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACosh
  {-# INLINE acosh #-}
  atanh :: Fix SRTree -> Fix SRTree
atanh   = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATanh
  {-# INLINE atanh #-}

  Fix SRTree
l ** :: Fix SRTree -> Fix SRTree -> Fix SRTree
** Fix (Const Double
1) = Fix SRTree
l
  Fix SRTree
l ** Fix (Const Double
0) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
1)
  Fix SRTree
l ** Fix SRTree
r  = forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Power Fix SRTree
l Fix SRTree
r
  {-# INLINE (**) #-}

  logBase :: Fix SRTree -> Fix SRTree -> Fix SRTree
logBase Fix SRTree
l (Fix (Const Double
1)) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
0)
  logBase Fix SRTree
l Fix SRTree
r = forall a. Floating a => a -> a
log Fix SRTree
l forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
log Fix SRTree
r
  {-# INLINE logBase #-}

-- | Arity of the current node
arity :: Fix SRTree -> Int
arity :: Fix SRTree -> Int
arity = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a} {val}. Num a => SRTree val -> a
alg
  where
    alg :: SRTree val -> a
alg Var {}      = a
0
    alg Param {}    = a
0
    alg Const {}    = a
0
    alg Uni {}      = a
1
    alg Bin {}      = a
2
{-# INLINE arity #-}

-- | Get the children of a node. Returns an empty list in case of a leaf node.
getChildren :: Fix SRTree -> [Fix SRTree]
getChildren :: Fix SRTree -> [Fix SRTree]
getChildren (Fix (Var {})) = []
getChildren (Fix (Param {})) = []
getChildren (Fix (Const {})) = []
getChildren (Fix (Uni Function
_ Fix SRTree
t)) = [Fix SRTree
t]
getChildren (Fix (Bin Op
_ Fix SRTree
l Fix SRTree
r)) = [Fix SRTree
l, Fix SRTree
r]
{-# INLINE getChildren #-}

-- | Count the number of nodes in a tree.
countNodes :: Fix SRTree -> Int
countNodes :: Fix SRTree -> Int
countNodes = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var {} = a
1
      alg Param {} = a
1
      alg Const {} = a
1
      alg (Uni Function
_ a
t) = a
1 forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
1 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countNodes #-}

-- | Count the number of `Var` nodes
countVarNodes :: Fix SRTree -> Int
countVarNodes :: Fix SRTree -> Int
countVarNodes = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var {} = a
1
      alg Param {} = a
0
      alg Const {} = a
0
      alg (Uni Function
_ a
t) = a
0 forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
0 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countVarNodes #-}

-- | Count the number of `Param` nodes
countParams :: Fix SRTree -> Int
countParams :: Fix SRTree -> Int
countParams = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var {} = a
0
      alg Param {} = a
1
      alg Const {} = a
0
      alg (Uni Function
_ a
t) = a
0 forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
0 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countParams #-}

-- | Count the number of const nodes
countConsts :: Fix SRTree -> Int
countConsts :: Fix SRTree -> Int
countConsts = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var {} = a
0
      alg Param {} = a
0
      alg Const {} = a
1
      alg (Uni Function
_ a
t) = a
0 forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
0 forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countConsts #-}

-- | Count the occurrences of variable indexed as `ix`
countOccurrences :: Int -> Fix SRTree -> Int
countOccurrences :: Int -> Fix SRTree -> Int
countOccurrences Int
ix = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata forall {a}. Num a => SRTree [a] -> [a]
alg
  where
      alg :: SRTree [a] -> [a]
alg (Var Int
iy) = [a
1 | Int
ix forall a. Eq a => a -> a -> Bool
== Int
iy]
      alg Param {} = []
      alg Const {} = []
      alg (Uni Function
_ [a]
t) = [a]
t
      alg (Bin Op
_ [a]
l [a]
r) = [a]
l forall a. Semigroup a => a -> a -> a
<> [a]
r
{-# INLINE countOccurrences #-}

-- | Evaluates the tree given a vector of variable values, a vector of parameter values and a function that takes a Double and change to whatever type the variables have. This is useful when working with datasets of many values per variables.
evalTree :: (Num a, Floating a) => V.Vector a -> V.Vector Double -> (Double -> a) -> Fix SRTree -> a
evalTree :: forall a.
(Num a, Floating a) =>
Vector a -> Vector Double -> (Double -> a) -> Fix SRTree -> a
evalTree Vector a
xss Vector Double
params Double -> a
f = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg (Var Int
ix) = Vector a
xss forall a. Vector a -> Int -> a
! Int
ix
      alg (Param Int
ix) = Double -> a
f forall a b. (a -> b) -> a -> b
$ Vector Double
params forall a. Vector a -> Int -> a
! Int
ix
      alg (Const Double
c) = Double -> a
f Double
c
      alg (Uni Function
g a
t) = forall a. Floating a => Function -> a -> a
evalFun Function
g a
t
      alg (Bin Op
op a
l a
r) = forall a. Floating a => Op -> a -> a -> a
evalOp Op
op a
l a
r
{-# INLINE evalTree #-}

evalOp :: Floating a => Op -> a -> a -> a
evalOp :: forall a. Floating a => Op -> a -> a -> a
evalOp Op
Add = forall a. Num a => a -> a -> a
(+)
evalOp Op
Sub = (-)
evalOp Op
Mul = forall a. Num a => a -> a -> a
(*)
evalOp Op
Div = forall a. Fractional a => a -> a -> a
(/)
evalOp Op
Power = forall a. Floating a => a -> a -> a
(**)
{-# INLINE evalOp #-}

evalFun :: Floating a => Function -> a -> a
evalFun :: forall a. Floating a => Function -> a -> a
evalFun Function
Id = forall a. a -> a
id
evalFun Function
Abs = forall a. Num a => a -> a
abs
evalFun Function
Sin = forall a. Floating a => a -> a
sin
evalFun Function
Cos = forall a. Floating a => a -> a
cos
evalFun Function
Tan = forall a. Floating a => a -> a
tan
evalFun Function
Sinh = forall a. Floating a => a -> a
sinh
evalFun Function
Cosh = forall a. Floating a => a -> a
cosh
evalFun Function
Tanh = forall a. Floating a => a -> a
tanh
evalFun Function
ASin = forall a. Floating a => a -> a
asin
evalFun Function
ACos = forall a. Floating a => a -> a
acos
evalFun Function
ATan = forall a. Floating a => a -> a
atan
evalFun Function
ASinh = forall a. Floating a => a -> a
asinh
evalFun Function
ACosh = forall a. Floating a => a -> a
acosh
evalFun Function
ATanh = forall a. Floating a => a -> a
atanh
evalFun Function
Sqrt = forall a. Floating a => a -> a
sqrt
evalFun Function
Cbrt = forall a. Floating a => a -> a
cbrt
evalFun Function
Square = (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
evalFun Function
Log = forall a. Floating a => a -> a
log
evalFun Function
Exp = forall a. Floating a => a -> a
exp
{-# INLINE evalFun #-}

-- | Cubic root
cbrt :: Floating val => val -> val
cbrt :: forall a. Floating a => a -> a
cbrt val
x = forall a. Num a => a -> a
signum val
x forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
abs val
x forall a. Floating a => a -> a -> a
** (val
1forall a. Fractional a => a -> a -> a
/val
3)
{-# INLINE cbrt #-}

-- | Returns the inverse of a function. This is a partial function.
inverseFunc :: Function -> Function
inverseFunc :: Function -> Function
inverseFunc Function
Id     = Function
Id
inverseFunc Function
Sin    = Function
ASin
inverseFunc Function
Cos    = Function
ACos
inverseFunc Function
Tan    = Function
ATan
inverseFunc Function
Tanh   = Function
ATanh
inverseFunc Function
ASin   = Function
Sin
inverseFunc Function
ACos   = Function
Cos
inverseFunc Function
ATan   = Function
Tan
inverseFunc Function
ATanh  = Function
Tanh
inverseFunc Function
Sqrt   = Function
Square
inverseFunc Function
Square = Function
Sqrt
inverseFunc Function
Log    = Function
Exp
inverseFunc Function
Exp    = Function
Log
inverseFunc Function
x      = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Function
x forall a. [a] -> [a] -> [a]
++ String
" has no support for inverse function"
{-# INLINE inverseFunc #-}

-- | Creates the symbolic partial derivative of a tree by variable `dx` (if `p` is `False`)
-- or parameter `dx` (if `p` is `True`).
deriveBy :: Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy :: Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy Bool
p Int
dx = forall a b. (a, b) -> a
fst (forall (f :: * -> *) a b.
Functor f =>
(f (a, b) -> a) -> (f (a, b) -> b) -> (Fix f -> a, Fix f -> b)
mutu forall {b}. Floating b => SRTree (b, b) -> b
alg1 forall {a}. SRTree (a, Fix SRTree) -> Fix SRTree
alg2)
  where
      alg1 :: SRTree (b, b) -> b
alg1 (Var Int
ix) = if Bool -> Bool
not Bool
p Bool -> Bool -> Bool
&& Int
ix forall a. Eq a => a -> a -> Bool
== Int
dx then b
1 else b
0
      alg1 (Param Int
ix) = if Bool
p Bool -> Bool -> Bool
&& Int
ix forall a. Eq a => a -> a -> Bool
== Int
dx then b
1 else b
0
      alg1 (Const Double
_) = b
0
      alg1 (Uni Function
f (b, b)
t) = forall a. Floating a => Function -> a -> a
derivative Function
f (forall a b. (a, b) -> b
snd (b, b)
t) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
t
      alg1 (Bin Op
Add (b, b)
l (b, b)
r) = forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> a
fst (b, b)
r
      alg1 (Bin Op
Sub (b, b)
l (b, b)
r) = forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
- forall a b. (a, b) -> a
fst (b, b)
r
      alg1 (Bin Op
Mul (b, b)
l (b, b)
r) = forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> b
snd (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
r
      alg1 (Bin Op
Div (b, b)
l (b, b)
r) = (forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
- forall a b. (a, b) -> b
snd (b, b)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
r) forall a. Fractional a => a -> a -> a
/ forall a b. (a, b) -> b
snd (b, b)
r forall a. Floating a => a -> a -> a
** b
2
      alg1 (Bin Op
Power (b, b)
l (b, b)
r) = forall a b. (a, b) -> b
snd (b, b)
l forall a. Floating a => a -> a -> a
** (forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
- b
1) forall a. Num a => a -> a -> a
* (forall a b. (a, b) -> b
snd (b, b)
r forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
l forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> b
snd (b, b)
l forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log (forall a b. (a, b) -> b
snd (b, b)
l) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (b, b)
r)

      alg2 :: SRTree (a, Fix SRTree) -> Fix SRTree
alg2 (Var Int
ix) = Int -> Fix SRTree
var Int
ix
      alg2 (Param Int
ix) = Int -> Fix SRTree
param Int
ix
      alg2 (Const Double
c) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
c)
      alg2 (Uni Function
f (a, Fix SRTree)
t) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Function -> val -> SRTree val
Uni Function
f forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd (a, Fix SRTree)
t)
      alg2 (Bin Op
f (a, Fix SRTree)
l (a, Fix SRTree)
r) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f (forall a b. (a, b) -> b
snd (a, Fix SRTree)
l) (forall a b. (a, b) -> b
snd (a, Fix SRTree)
r))

newtype Tape a = Tape { forall a. Tape a -> [a]
untape :: [a] } deriving (Int -> Tape a -> ShowS
forall a. Show a => Int -> Tape a -> ShowS
forall a. Show a => [Tape a] -> ShowS
forall a. Show a => Tape a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tape a] -> ShowS
$cshowList :: forall a. Show a => [Tape a] -> ShowS
show :: Tape a -> String
$cshow :: forall a. Show a => Tape a -> String
showsPrec :: Int -> Tape a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Tape a -> ShowS
Show, forall a b. a -> Tape b -> Tape a
forall a b. (a -> b) -> Tape a -> Tape b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Tape b -> Tape a
$c<$ :: forall a b. a -> Tape b -> Tape a
fmap :: forall a b. (a -> b) -> Tape a -> Tape b
$cfmap :: forall a b. (a -> b) -> Tape a -> Tape b
Functor)

instance Num a => Num (Tape a) where
  (Tape [a]
x) + :: Tape a -> Tape a -> Tape a
+ (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(+) [a]
x [a]
y
  (Tape [a]
x) - :: Tape a -> Tape a -> Tape a
- (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) [a]
x [a]
y
  (Tape [a]
x) * :: Tape a -> Tape a -> Tape a
* (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [a]
x [a]
y
  abs :: Tape a -> Tape a
abs (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => a -> a
abs [a]
x)
  signum :: Tape a -> Tape a
signum (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => a -> a
signum [a]
x)
  fromInteger :: Integer -> Tape a
fromInteger Integer
x = forall a. [a] -> Tape a
Tape [forall a. Num a => Integer -> a
fromInteger Integer
x]
  negate :: Tape a -> Tape a
negate (Tape [a]
x) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
*(-a
1)) [a]
x
instance Floating a => Floating (Tape a) where
  pi :: Tape a
pi = forall a. [a] -> Tape a
Tape [forall a. Floating a => a
pi]
  exp :: Tape a -> Tape a
exp (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
exp [a]
x)
  log :: Tape a -> Tape a
log (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
log [a]
x)
  sqrt :: Tape a -> Tape a
sqrt (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
sqrt [a]
x)
  sin :: Tape a -> Tape a
sin (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
sin [a]
x)
  cos :: Tape a -> Tape a
cos (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
cos [a]
x)
  tan :: Tape a -> Tape a
tan (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
tan [a]
x)
  asin :: Tape a -> Tape a
asin (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
asin [a]
x)
  acos :: Tape a -> Tape a
acos (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
acos [a]
x)
  atan :: Tape a -> Tape a
atan (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
atan [a]
x)
  sinh :: Tape a -> Tape a
sinh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
sinh [a]
x)
  cosh :: Tape a -> Tape a
cosh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
cosh [a]
x)
  tanh :: Tape a -> Tape a
tanh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
tanh [a]
x)
  asinh :: Tape a -> Tape a
asinh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
asinh [a]
x)
  acosh :: Tape a -> Tape a
acosh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
acosh [a]
x)
  atanh :: Tape a -> Tape a
atanh (Tape [a]
x) = forall a. [a] -> Tape a
Tape (forall a b. (a -> b) -> [a] -> [b]
map forall a. Floating a => a -> a
atanh [a]
x)
  (Tape [a]
x) ** :: Tape a -> Tape a -> Tape a
** (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Floating a => a -> a -> a
(**) [a]
x [a]
y
instance Fractional a => Fractional (Tape a) where
  fromRational :: Rational -> Tape a
fromRational Rational
x = forall a. [a] -> Tape a
Tape [forall a. Fractional a => Rational -> a
fromRational Rational
x]
  (Tape [a]
x) / :: Tape a -> Tape a -> Tape a
/ (Tape [a]
y) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Fractional a => a -> a -> a
(/) [a]
x [a]
y
  recip :: Tape a -> Tape a
recip (Tape [a]
x) = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Fractional a => a -> a
recip [a]
x

-- | Calculates the numerical derivative of a tree using forward mode
-- provided a vector of variable values `xss`, a vector of parameter values `theta` and
-- a function that changes a Double value to the type of the variable values.
forwardMode :: (Show a, Num a, Floating a) => V.Vector a -> V.Vector Double -> (Double -> a) -> Fix SRTree -> [a]
forwardMode :: forall a.
(Show a, Num a, Floating a) =>
Vector a -> Vector Double -> (Double -> a) -> Fix SRTree -> [a]
forwardMode Vector a
xss Vector Double
theta Double -> a
f = forall a. Tape a -> [a]
untape forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst (forall (f :: * -> *) a b.
Functor f =>
(f (a, b) -> a) -> (f (a, b) -> b) -> (Fix f -> a, Fix f -> b)
mutu SRTree (Tape a, Tape a) -> Tape a
alg1 forall {a}. SRTree (a, Tape a) -> Tape a
alg2)
  where
      n :: Int
n = forall a. Vector a -> Int
V.length Vector Double
theta
      repMat :: a -> Tape a
repMat a
v = forall a. [a] -> Tape a
Tape forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate Int
n a
v
      zeroes :: Tape a
zeroes = forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f Double
0
      twos :: Tape a
twos  = forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f Double
2
      tapeXs :: [Tape a]
tapeXs = [forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Vector a
xss forall a. Vector a -> Int -> a
! Int
ix | Int
ix <- [Int
0 .. forall a. Vector a -> Int
V.length Vector a
xss forall a. Num a => a -> a -> a
- Int
1]]
      tapeTheta :: [Tape a]
tapeTheta = [forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f (Vector Double
theta forall a. Vector a -> Int -> a
! Int
ix) | Int
ix <- [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
1]]
      paramVec :: [Tape a]
paramVec = [ forall a. [a] -> Tape a
Tape [if Int
ixforall a. Eq a => a -> a -> Bool
==Int
iy then Double -> a
f Double
1 else Double -> a
f Double
0 | Int
iy <- [Int
0 .. Int
nforall a. Num a => a -> a -> a
-Int
1]] | Int
ix <- [Int
0 .. Int
nforall a. Num a => a -> a -> a
-Int
1] ]

      alg1 :: SRTree (Tape a, Tape a) -> Tape a
alg1 (Var Int
ix)        = Tape a
zeroes
      alg1 (Param Int
ix)      = [Tape a]
paramVec forall a. [a] -> Int -> a
!! Int
ix
      alg1 (Const Double
_)       = Tape a
zeroes
      alg1 (Uni Function
f (Tape a, Tape a)
t)       = forall a. Floating a => Function -> a -> a
derivative Function
f (forall a b. (a, b) -> b
snd (Tape a, Tape a)
t) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
t
      alg1 (Bin Op
Add (Tape a, Tape a)
l (Tape a, Tape a)
r)   = forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
+ forall a b. (a, b) -> a
fst (Tape a, Tape a)
r
      alg1 (Bin Op
Sub (Tape a, Tape a)
l (Tape a, Tape a)
r)   = forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
- forall a b. (a, b) -> a
fst (Tape a, Tape a)
r
      alg1 (Bin Op
Mul (Tape a, Tape a)
l (Tape a, Tape a)
r)   = (forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (Tape a, Tape a)
r) forall a. Num a => a -> a -> a
+ (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
r)
      alg1 (Bin Op
Div (Tape a, Tape a)
l (Tape a, Tape a)
r)   = ((forall a b. (a, b) -> a
fst (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (Tape a, Tape a)
r) forall a. Num a => a -> a -> a
- (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
r)) forall a. Fractional a => a -> a -> a
/ forall a b. (a, b) -> b
snd (Tape a, Tape a)
r forall a. Floating a => a -> a -> a
** Tape a
twos
      alg1 (Bin Op
Power (Tape a, Tape a)
l (Tape a, Tape a)
r) = forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Floating a => a -> a -> a
** (forall a b. (a, b) -> b
snd (Tape a, Tape a)
r forall a. Num a => a -> a -> a
- Tape a
1) forall a. Num a => a -> a -> a
* ((forall a b. (a, b) -> b
snd (Tape a, Tape a)
r forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
l) forall a. Num a => a -> a -> a
+ (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log (forall a b. (a, b) -> b
snd (Tape a, Tape a)
l) forall a. Num a => a -> a -> a
* forall a b. (a, b) -> a
fst (Tape a, Tape a)
r))

      alg2 :: SRTree (a, Tape a) -> Tape a
alg2 (Var Int
ix)     = [Tape a]
tapeXs forall a. [a] -> Int -> a
!! Int
ix
      alg2 (Param Int
ix)   = [Tape a]
tapeTheta forall a. [a] -> Int -> a
!! Int
ix
      alg2 (Const Double
c)    = forall {a}. a -> Tape a
repMat forall a b. (a -> b) -> a -> b
$ Double -> a
f Double
c
      alg2 (Uni Function
g (a, Tape a)
t)    = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Floating a => Function -> a -> a
evalFun Function
g) (forall a b. (a, b) -> b
snd (a, Tape a)
t)
      alg2 (Bin Op
op (a, Tape a)
l (a, Tape a)
r) = forall a. Floating a => Op -> a -> a -> a
evalOp Op
op (forall a b. (a, b) -> b
snd (a, Tape a)
l) (forall a b. (a, b) -> b
snd (a, Tape a)
r)

-- | The function `gradParams` calculates the numerical gradient of the tree and evaluates the tree at the same time. It assumes that each parameter has a unique occurrence in the expression. This should be significantly faster than `forwardMode`.
gradParams  :: (Show a, Num a, Floating a) => V.Vector a -> V.Vector Double -> (Double -> a) -> Fix SRTree -> (a, [a])
gradParams :: forall a.
(Show a, Num a, Floating a) =>
Vector a
-> Vector Double -> (Double -> a) -> Fix SRTree -> (a, [a])
gradParams Vector a
xss Vector Double
theta Double -> a
f = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (a, [a]) -> (a, [a])
alg
  where
      n :: Int
n = forall a. Vector a -> Int
V.length Vector Double
theta

      alg :: SRTree (a, [a]) -> (a, [a])
alg (Var Int
ix)        = (Vector a
xss forall a. Vector a -> Int -> a
! Int
ix, [])
      alg (Param Int
ix)      = (Double -> a
f forall a b. (a -> b) -> a -> b
$ Vector Double
theta forall a. Vector a -> Int -> a
! Int
ix, [a
1])
      alg (Const Double
c)       = (Double -> a
f Double
c, [])
      alg (Uni Function
f (a
v, [a]
gs)) = let v' :: a
v' = forall a. Floating a => Function -> a -> a
evalFun Function
f a
v in (a
v', forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* forall a. Floating a => Function -> a -> a
derivative Function
f a
v) [a]
gs)
      alg (Bin Op
Add (a
v1, [a]
l) (a
v2, [a]
r)) = (a
v1forall a. Num a => a -> a -> a
+a
v2, [a]
l forall a. [a] -> [a] -> [a]
++ [a]
r)
      alg (Bin Op
Sub (a
v1, [a]
l) (a
v2, [a]
r)) = (a
v1forall a. Num a => a -> a -> a
-a
v2, [a]
l forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => a -> a
negate [a]
r)
      alg (Bin Op
Mul (a
v1, [a]
l) (a
v2, [a]
r)) = (a
v1forall a. Num a => a -> a -> a
*a
v2, forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
*a
v2) [a]
l forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
*a
v1) [a]
r)
      alg (Bin Op
Div (a
v1, [a]
l) (a
v2, [a]
r)) = (a
v1forall a. Fractional a => a -> a -> a
/a
v2, forall a b. (a -> b) -> [a] -> [b]
map (forall a. Fractional a => a -> a -> a
/a
v2) [a]
l forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map ((forall a. Fractional a => a -> a -> a
/a
v2forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Num a => a -> a -> a
*a
v1) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate) [a]
r)
      alg (Bin Op
Power (a
v1, [a]
l) (a
v2, [a]
r)) = (a
v1 forall a. Floating a => a -> a -> a
** a
v2, forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* (a
v1 forall a. Floating a => a -> a -> a
** (a
v2 forall a. Num a => a -> a -> a
- a
1))) (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
*a
v2) [a]
l forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map ((forall a. Num a => a -> a -> a
*a
v1)forall b c a. (b -> c) -> (a -> b) -> a -> c
.(forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
v1)) [a]
r))


derivative :: Floating a => Function -> a -> a
derivative :: forall a. Floating a => Function -> a -> a
derivative Function
Id      = forall a b. a -> b -> a
const a
1
derivative Function
Abs     = \a
x -> a
x forall a. Fractional a => a -> a -> a
/ forall a. Num a => a -> a
abs a
x
derivative Function
Sin     = forall a. Floating a => a -> a
cos
derivative Function
Cos     = forall a. Num a => a -> a
negateforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a. Floating a => a -> a
sin
derivative Function
Tan     = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Floating a => a -> a -> a
**a
2.0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
cos
derivative Function
Sinh    = forall a. Floating a => a -> a
cosh
derivative Function
Cosh    = forall a. Floating a => a -> a
sinh
derivative Function
Tanh    = (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Floating a => a -> a -> a
**a
2.0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
tanh
derivative Function
ASin    = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ACos    = forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ATan    = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
+) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ASinh   = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
+) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
ACosh   = \a
x -> a
1 forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
-a
1) forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
+a
1))
derivative Function
ATanh   = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
1forall a. Num a => a -> a -> a
-) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
Sqrt    = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
2forall a. Num a => a -> a -> a
*) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt
derivative Function
Cbrt    = forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
3forall a. Num a => a -> a -> a
*) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
cbrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
derivative Function
Square  = (a
2forall a. Num a => a -> a -> a
*)
derivative Function
Exp     = forall a. Floating a => a -> a
exp
derivative Function
Log     = forall a. Fractional a => a -> a
recip
{-# INLINE derivative #-}

-- | Symbolic derivative by a variable
deriveByVar :: Int -> Fix SRTree -> Fix SRTree
deriveByVar :: Int -> Fix SRTree -> Fix SRTree
deriveByVar = Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy Bool
False

-- | Symbolic derivative by a parameter
deriveByParam :: Int -> Fix SRTree -> Fix SRTree
deriveByParam :: Int -> Fix SRTree -> Fix SRTree
deriveByParam = Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy Bool
True

-- | Relabel the parameters incrementaly starting from 0
relabelParams :: Fix SRTree -> Fix SRTree
relabelParams :: Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
t = forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM forall {f :: * -> *} {a}.
Applicative f =>
SRTree (f a) -> f (SRTree a)
lTor SRTree (Fix SRTree) -> State Int (Fix SRTree)
alg Fix SRTree
t forall s a. State s a -> s -> a
`evalState` Int
0
  where
      lTor :: SRTree (f a) -> f (SRTree a)
lTor (Uni Function
f f a
mt) = forall val. Function -> val -> SRTree val
Uni Function
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
mt;
      lTor (Bin Op
f f a
ml f a
mr) = forall val. Op -> val -> val -> SRTree val
Bin Op
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
ml forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f a
mr
      lTor (Var Int
ix) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall val. Int -> SRTree val
Var Int
ix)
      lTor (Param Int
ix) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall val. Int -> SRTree val
Param Int
ix)
      lTor (Const Double
c) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall val. Double -> SRTree val
Const Double
c)

      alg :: SRTree (Fix SRTree) -> State Int (Fix SRTree)
      alg :: SRTree (Fix SRTree) -> State Int (Fix SRTree)
alg (Var Int
ix) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int -> Fix SRTree
var Int
ix
      alg (Param Int
ix) = do Int
iy <- forall s (m :: * -> *). MonadState s m => m s
get; forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
+Int
1); forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Fix SRTree
param Int
iy)
      alg (Const Double
c) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const Double
c
      alg (Uni Function
f Fix SRTree
t) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)
      alg (Bin Op
f Fix SRTree
l Fix SRTree
r) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f Fix SRTree
l Fix SRTree
r)

-- | Change constant values to a parameter, returning the changed tree and a list
-- of parameter values
constsToParam :: Fix SRTree -> (Fix SRTree, [Double])
constsToParam :: Fix SRTree -> (Fix SRTree, [Double])
constsToParam = forall {t} {a} {b}. (t -> a) -> (t, b) -> (a, b)
first Fix SRTree -> Fix SRTree
relabelParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg
  where
      first :: (t -> a) -> (t, b) -> (a, b)
first t -> a
f (t
x, b
y) = (t -> a
f t
x, b
y)

      alg :: SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg (Var Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Var Int
ix, [])
      alg (Param Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
ix, [Double
1.0])
      alg (Const Double
c) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
0, [Double
c])
      alg (Uni Function
f (Fix SRTree, [Double])
t) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Function -> val -> SRTree val
Uni Function
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
t), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
t)
      alg (Bin Op
f (Fix SRTree, [Double])
l (Fix SRTree, [Double])
r) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
l) (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
r)), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
l forall a. Semigroup a => a -> a -> a
<> forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
r)

-- | Same as `constsToParam` but does not change constant values that
-- can be converted to integer without loss of precision
floatConstsToParam :: Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam :: Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam = forall {t} {a} {b}. (t -> a) -> (t, b) -> (a, b)
first Fix SRTree -> Fix SRTree
relabelParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg
  where
      first :: (t -> a) -> (t, b) -> (a, b)
first t -> a
f (t
x, b
y) = (t -> a
f t
x, b
y)

      alg :: SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg (Var Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Var Int
ix, [])
      alg (Param Int
ix) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
ix, [Double
1.0])
      alg (Const Double
c) = if forall a b. (RealFrac a, Integral b) => a -> b
floor Double
c forall a. Eq a => a -> a -> Bool
== forall a b. (RealFrac a, Integral b) => a -> b
ceiling Double
c then (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const Double
c, []) else (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Int -> SRTree val
Param Int
0, [Double
c])
      alg (Uni Function
f (Fix SRTree, [Double])
t) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix forall a b. (a -> b) -> a -> b
$ forall val. Function -> val -> SRTree val
Uni Function
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
t), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
t)
      alg (Bin Op
f (Fix SRTree, [Double])
l (Fix SRTree, [Double])
r) = (forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
f (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
l) (forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
r)), forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
l forall a. Semigroup a => a -> a -> a
<> forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
r)