{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
---- |
---- Copyright   :  (c) Edward Kmett 2010-2021
---- License     :  BSD3
---- Maintainer  :  ekmett@gmail.com
---- Stability   :  experimental
---- Portability :  GHC only
----
---- Unsafe and often partial combinators intended for internal usage.
----
---- Handle with care.
-------------------------------------------------------------------------------

module Numeric.AD.Internal.Forward.Double
  ( ForwardDouble(..)
  , bundle
  , unbundle
  , apply
  , bind
  , bind'
  , bindWith
  , bindWith'
  , transposeWith
  ) where

import Data.Foldable (toList)
import Data.Traversable (mapAccumL)
import Control.Monad (join)
import Data.Number.Erf
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode

data ForwardDouble = ForwardDouble { ForwardDouble -> Double
primal, ForwardDouble -> Double
tangent :: {-# UNPACK #-} !Double }
  deriving (ReadPrec [ForwardDouble]
ReadPrec ForwardDouble
Int -> ReadS ForwardDouble
ReadS [ForwardDouble]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [ForwardDouble]
$creadListPrec :: ReadPrec [ForwardDouble]
readPrec :: ReadPrec ForwardDouble
$creadPrec :: ReadPrec ForwardDouble
readList :: ReadS [ForwardDouble]
$creadList :: ReadS [ForwardDouble]
readsPrec :: Int -> ReadS ForwardDouble
$creadsPrec :: Int -> ReadS ForwardDouble
Read, Int -> ForwardDouble -> ShowS
[ForwardDouble] -> ShowS
ForwardDouble -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ForwardDouble] -> ShowS
$cshowList :: [ForwardDouble] -> ShowS
show :: ForwardDouble -> String
$cshow :: ForwardDouble -> String
showsPrec :: Int -> ForwardDouble -> ShowS
$cshowsPrec :: Int -> ForwardDouble -> ShowS
Show)

unbundle :: ForwardDouble -> (Double, Double)
unbundle :: ForwardDouble -> (Double, Double)
unbundle (ForwardDouble Double
a Double
da) = (Double
a, Double
da)
{-# INLINE unbundle #-}

bundle :: Double -> Double -> ForwardDouble
bundle :: Double -> Double -> ForwardDouble
bundle = Double -> Double -> ForwardDouble
ForwardDouble
{-# INLINE bundle #-}

apply :: (ForwardDouble -> b) -> Double -> b
apply :: forall b. (ForwardDouble -> b) -> Double -> b
apply ForwardDouble -> b
f Double
a = ForwardDouble -> b
f (Double -> Double -> ForwardDouble
bundle Double
a Double
1)
{-# INLINE apply #-}

instance Mode ForwardDouble where
  type Scalar ForwardDouble = Double

  auto :: Scalar ForwardDouble -> ForwardDouble
auto = forall a b c. (a -> b -> c) -> b -> a -> c
flip Double -> Double -> ForwardDouble
ForwardDouble Double
0

  zero :: ForwardDouble
zero = Double -> Double -> ForwardDouble
ForwardDouble Double
0 Double
0

  isKnownZero :: ForwardDouble -> Bool
isKnownZero (ForwardDouble Double
0 Double
0) = Bool
True
  isKnownZero ForwardDouble
_ = Bool
False

  asKnownConstant :: ForwardDouble -> Maybe (Scalar ForwardDouble)
asKnownConstant (ForwardDouble Double
x Double
0) = forall a. a -> Maybe a
Just Double
x
  asKnownConstant ForwardDouble
_ = forall a. Maybe a
Nothing
 
  isKnownConstant :: ForwardDouble -> Bool
isKnownConstant (ForwardDouble Double
_ Double
0) = Bool
True
  isKnownConstant ForwardDouble
_ = Bool
False

  Scalar ForwardDouble
a *^ :: Scalar ForwardDouble -> ForwardDouble -> ForwardDouble
*^ ForwardDouble Double
b Double
db = Double -> Double -> ForwardDouble
ForwardDouble (Scalar ForwardDouble
a forall a. Num a => a -> a -> a
* Double
b) (Scalar ForwardDouble
a forall a. Num a => a -> a -> a
* Double
db)
  ForwardDouble Double
a Double
da ^* :: ForwardDouble -> Scalar ForwardDouble -> ForwardDouble
^* Scalar ForwardDouble
b = Double -> Double -> ForwardDouble
ForwardDouble (Double
a forall a. Num a => a -> a -> a
* Scalar ForwardDouble
b) (Double
da forall a. Num a => a -> a -> a
* Scalar ForwardDouble
b)
  ForwardDouble Double
a Double
da ^/ :: Fractional (Scalar ForwardDouble) =>
ForwardDouble -> Scalar ForwardDouble -> ForwardDouble
^/ Scalar ForwardDouble
b = Double -> Double -> ForwardDouble
ForwardDouble (Double
a forall a. Fractional a => a -> a -> a
/ Scalar ForwardDouble
b) (Double
da forall a. Fractional a => a -> a -> a
/ Scalar ForwardDouble
b)

(<+>) :: ForwardDouble -> ForwardDouble -> ForwardDouble
ForwardDouble Double
a Double
da <+> :: ForwardDouble -> ForwardDouble -> ForwardDouble
<+> ForwardDouble Double
b Double
db = Double -> Double -> ForwardDouble
ForwardDouble (Double
a forall a. Num a => a -> a -> a
+ Double
b) (Double
da forall a. Num a => a -> a -> a
+ Double
db)

instance Jacobian ForwardDouble where
  type D ForwardDouble = Id Double

  unary :: (Scalar ForwardDouble -> Scalar ForwardDouble)
-> D ForwardDouble -> ForwardDouble -> ForwardDouble
unary Scalar ForwardDouble -> Scalar ForwardDouble
f (Id Double
dadb) (ForwardDouble Double
b Double
db) = Double -> Double -> ForwardDouble
ForwardDouble (Scalar ForwardDouble -> Scalar ForwardDouble
f Double
b) (Double
dadb forall a. Num a => a -> a -> a
* Double
db)

  lift1 :: (Scalar ForwardDouble -> Scalar ForwardDouble)
-> (D ForwardDouble -> D ForwardDouble)
-> ForwardDouble
-> ForwardDouble
lift1 Scalar ForwardDouble -> Scalar ForwardDouble
f D ForwardDouble -> D ForwardDouble
df (ForwardDouble Double
b Double
db) = Double -> Double -> ForwardDouble
ForwardDouble (Scalar ForwardDouble -> Scalar ForwardDouble
f Double
b) (Double
dadb forall a. Num a => a -> a -> a
* Double
db) where
    Id Double
dadb = D ForwardDouble -> D ForwardDouble
df (forall a. a -> Id a
Id Double
b)

  lift1_ :: (Scalar ForwardDouble -> Scalar ForwardDouble)
-> (D ForwardDouble -> D ForwardDouble -> D ForwardDouble)
-> ForwardDouble
-> ForwardDouble
lift1_ Scalar ForwardDouble -> Scalar ForwardDouble
f D ForwardDouble -> D ForwardDouble -> D ForwardDouble
df (ForwardDouble Double
b Double
db) = Double -> Double -> ForwardDouble
ForwardDouble Scalar ForwardDouble
a Double
da where
    a :: Scalar ForwardDouble
a = Scalar ForwardDouble -> Scalar ForwardDouble
f Double
b
    Id Double
da = D ForwardDouble -> D ForwardDouble -> D ForwardDouble
df (forall a. a -> Id a
Id Scalar ForwardDouble
a) (forall a. a -> Id a
Id Double
b) forall t. Mode t => t -> Scalar t -> t
^* Double
db

  binary :: (Scalar ForwardDouble
 -> Scalar ForwardDouble -> Scalar ForwardDouble)
-> D ForwardDouble
-> D ForwardDouble
-> ForwardDouble
-> ForwardDouble
-> ForwardDouble
binary Scalar ForwardDouble
-> Scalar ForwardDouble -> Scalar ForwardDouble
f (Id Double
dadb) (Id Double
dadc) (ForwardDouble Double
b Double
db) (ForwardDouble Double
c Double
dc) = Double -> Double -> ForwardDouble
ForwardDouble (Scalar ForwardDouble
-> Scalar ForwardDouble -> Scalar ForwardDouble
f Double
b Double
c) forall a b. (a -> b) -> a -> b
$ Double
dadb forall a. Num a => a -> a -> a
* Double
db forall a. Num a => a -> a -> a
+ Double
dc forall a. Num a => a -> a -> a
* Double
dadc

  lift2 :: (Scalar ForwardDouble
 -> Scalar ForwardDouble -> Scalar ForwardDouble)
-> (D ForwardDouble
    -> D ForwardDouble -> (D ForwardDouble, D ForwardDouble))
-> ForwardDouble
-> ForwardDouble
-> ForwardDouble
lift2 Scalar ForwardDouble
-> Scalar ForwardDouble -> Scalar ForwardDouble
f D ForwardDouble
-> D ForwardDouble -> (D ForwardDouble, D ForwardDouble)
df (ForwardDouble Double
b Double
db) (ForwardDouble Double
c Double
dc) = Double -> Double -> ForwardDouble
ForwardDouble Scalar ForwardDouble
a Double
da where
    a :: Scalar ForwardDouble
a = Scalar ForwardDouble
-> Scalar ForwardDouble -> Scalar ForwardDouble
f Double
b Double
c
    (Id Double
dadb, Id Double
dadc) = D ForwardDouble
-> D ForwardDouble -> (D ForwardDouble, D ForwardDouble)
df (forall a. a -> Id a
Id Double
b) (forall a. a -> Id a
Id Double
c)
    da :: Double
da = Double
dadb forall a. Num a => a -> a -> a
* Double
db forall a. Num a => a -> a -> a
+ Double
dc forall a. Num a => a -> a -> a
* Double
dadc

  lift2_ :: (Scalar ForwardDouble
 -> Scalar ForwardDouble -> Scalar ForwardDouble)
-> (D ForwardDouble
    -> D ForwardDouble
    -> D ForwardDouble
    -> (D ForwardDouble, D ForwardDouble))
-> ForwardDouble
-> ForwardDouble
-> ForwardDouble
lift2_ Scalar ForwardDouble
-> Scalar ForwardDouble -> Scalar ForwardDouble
f D ForwardDouble
-> D ForwardDouble
-> D ForwardDouble
-> (D ForwardDouble, D ForwardDouble)
df (ForwardDouble Double
b Double
db) (ForwardDouble Double
c Double
dc) = Double -> Double -> ForwardDouble
ForwardDouble Scalar ForwardDouble
a Double
da where
    a :: Scalar ForwardDouble
a = Scalar ForwardDouble
-> Scalar ForwardDouble -> Scalar ForwardDouble
f Double
b Double
c
    (Id Double
dadb, Id Double
dadc) = D ForwardDouble
-> D ForwardDouble
-> D ForwardDouble
-> (D ForwardDouble, D ForwardDouble)
df (forall a. a -> Id a
Id Scalar ForwardDouble
a) (forall a. a -> Id a
Id Double
b) (forall a. a -> Id a
Id Double
c)
    da :: Double
da = Double
dadb forall a. Num a => a -> a -> a
* Double
db forall a. Num a => a -> a -> a
+ Double
dc forall a. Num a => a -> a -> a
* Double
dadc

#define HEAD ForwardDouble
#define BODY1(x)
#define BODY2(x,y)
#define NO_Bounded
#include "instances.h"

bind :: Traversable f => (f ForwardDouble -> b) -> f Double -> f b
bind :: forall (f :: * -> *) b.
Traversable f =>
(f ForwardDouble -> b) -> f Double -> f b
bind f ForwardDouble -> b
f f Double
as = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Int -> Double -> (Int, b)
outer (Int
0 :: Int) f Double
as where
  outer :: Int -> Double -> (Int, b)
outer !Int
i Double
_ = (Int
i forall a. Num a => a -> a -> a
+ Int
1, f ForwardDouble -> b
f forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (forall {a}. (Num a, Eq a) => a -> a -> Double -> (a, ForwardDouble)
inner Int
i) Int
0 f Double
as)
  inner :: a -> a -> Double -> (a, ForwardDouble)
inner !a
i !a
j Double
a = (a
j forall a. Num a => a -> a -> a
+ a
1, if a
i forall a. Eq a => a -> a -> Bool
== a
j then Double -> Double -> ForwardDouble
bundle Double
a Double
1 else forall t. Mode t => Scalar t -> t
auto Double
a)

bind' :: Traversable f => (f ForwardDouble -> b) -> f Double -> (b, f b)
bind' :: forall (f :: * -> *) b.
Traversable f =>
(f ForwardDouble -> b) -> f Double -> (b, f b)
bind' f ForwardDouble -> b
f f Double
as = forall {a} {a} {b}. ((a, a), b) -> (a, b)
dropIx forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int, b) -> Double -> ((Int, b), b)
outer (Int
0 :: Int, b
b0) f Double
as where
  outer :: (Int, b) -> Double -> ((Int, b), b)
outer (!Int
i, b
_) Double
_ = let b :: b
b = f ForwardDouble -> b
f forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (forall {a}. (Num a, Eq a) => a -> a -> Double -> (a, ForwardDouble)
inner Int
i) (Int
0 :: Int) f Double
as in ((Int
i forall a. Num a => a -> a -> a
+ Int
1, b
b), b
b)
  inner :: a -> a -> Double -> (a, ForwardDouble)
inner !a
i !a
j Double
a = (a
j forall a. Num a => a -> a -> a
+ a
1, if a
i forall a. Eq a => a -> a -> Bool
== a
j then Double -> Double -> ForwardDouble
bundle Double
a Double
1 else forall t. Mode t => Scalar t -> t
auto Double
a)
  b0 :: b
b0 = f ForwardDouble -> b
f (forall t. Mode t => Scalar t -> t
auto forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Double
as)
  dropIx :: ((a, a), b) -> (a, b)
dropIx ((a
_,a
b),b
bs) = (a
b,b
bs)

bindWith :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> f c
bindWith :: forall (f :: * -> *) b c.
Traversable f =>
(Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> f c
bindWith Double -> b -> c
g f ForwardDouble -> b
f f Double
as = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Int -> Double -> (Int, c)
outer (Int
0 :: Int) f Double
as where
  outer :: Int -> Double -> (Int, c)
outer !Int
i Double
a = (Int
i forall a. Num a => a -> a -> a
+ Int
1, Double -> b -> c
g Double
a forall a b. (a -> b) -> a -> b
$ f ForwardDouble -> b
f forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (forall {a}. (Num a, Eq a) => a -> a -> Double -> (a, ForwardDouble)
inner Int
i) Int
0 f Double
as)
  inner :: a -> a -> Double -> (a, ForwardDouble)
inner !a
i !a
j Double
a = (a
j forall a. Num a => a -> a -> a
+ a
1, if a
i forall a. Eq a => a -> a -> Bool
== a
j then Double -> Double -> ForwardDouble
bundle Double
a Double
1 else forall t. Mode t => Scalar t -> t
auto Double
a)

bindWith' :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> (b, f c)
bindWith' :: forall (f :: * -> *) b c.
Traversable f =>
(Double -> b -> c)
-> (f ForwardDouble -> b) -> f Double -> (b, f c)
bindWith' Double -> b -> c
g f ForwardDouble -> b
f f Double
as = forall {a} {a} {b}. ((a, a), b) -> (a, b)
dropIx forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int, b) -> Double -> ((Int, b), c)
outer (Int
0 :: Int, b
b0) f Double
as where
  outer :: (Int, b) -> Double -> ((Int, b), c)
outer (!Int
i, b
_) Double
a = let b :: b
b = f ForwardDouble -> b
f forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (forall {a}. (Num a, Eq a) => a -> a -> Double -> (a, ForwardDouble)
inner Int
i) (Int
0 :: Int) f Double
as in ((Int
i forall a. Num a => a -> a -> a
+ Int
1, b
b), Double -> b -> c
g Double
a b
b)
  inner :: a -> a -> Double -> (a, ForwardDouble)
inner !a
i !a
j Double
a = (a
j forall a. Num a => a -> a -> a
+ a
1, if a
i forall a. Eq a => a -> a -> Bool
== a
j then Double -> Double -> ForwardDouble
bundle Double
a Double
1 else forall t. Mode t => Scalar t -> t
auto Double
a)
  b0 :: b
b0 = f ForwardDouble -> b
f (forall t. Mode t => Scalar t -> t
auto forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Double
as)
  dropIx :: ((a, a), b) -> (a, b)
dropIx ((a
_,a
b),b
bs) = (a
b,b
bs)

transposeWith :: (Functor f, Foldable f, Traversable g) => (b -> f a -> c) -> f (g a) -> g b -> g c
transposeWith :: forall (f :: * -> *) (g :: * -> *) b a c.
(Functor f, Foldable f, Traversable g) =>
(b -> f a -> c) -> f (g a) -> g b -> g c
transposeWith b -> f a -> c
f f (g a)
as = forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL f [a] -> b -> (f [a], c)
go f [a]
xss0 where
  go :: f [a] -> b -> (f [a], c)
go f [a]
xss b
b = (forall a. [a] -> [a]
tail forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f [a]
xss, b -> f a -> c
f b
b (forall a. [a] -> a
head forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f [a]
xss))
  xss0 :: f [a]
xss0 = forall (t :: * -> *) a. Foldable t => t a -> [a]
toList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (g a)
as

mul :: ForwardDouble -> ForwardDouble -> ForwardDouble
mul :: ForwardDouble -> ForwardDouble -> ForwardDouble
mul = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 forall a. Num a => a -> a -> a
(*) (\D ForwardDouble
x D ForwardDouble
y -> (D ForwardDouble
y, D ForwardDouble
x))