{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Unsafe and often partial combinators intended for internal usage.
--
-- Handle with care.
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Sparse
  ( Monomial(..)
  , emptyMonomial
  , addToMonomial
  , indices
  , Sparse(..)
  , apply
  , vars
  , d, d', ds
  , skeleton
  , spartial
  , partial
  , vgrad
  , vgrad'
  , vgrads
  , Grad(..)
  , Grads(..)
  , terms
  , primal
  ) where

import Prelude hiding (lookup)
import Control.Comonad.Cofree
import Control.Monad (join, guard)
import Data.Data
import Data.IntMap (IntMap, unionWith, findWithDefault, singleton, lookup)
import qualified Data.IntMap as IntMap
import Data.Number.Erf
import Data.Traversable
import Data.Typeable ()
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Sparse.Common
import Numeric.AD.Jacobian
import Numeric.AD.Mode

-- | We only store partials in sorted order, so the map contained in a partial
-- will only contain partials with equal or greater keys to that of the map in
-- which it was found. This should be key for efficiently computing sparse hessians.
-- there are only @n + k - 1@ choose @k@ distinct nth partial derivatives of a
-- function with k inputs.
data Sparse a
  = Sparse !a (IntMap (Sparse a))
  | Zero
  deriving (Int -> Sparse a -> ShowS
forall a. Show a => Int -> Sparse a -> ShowS
forall a. Show a => [Sparse a] -> ShowS
forall a. Show a => Sparse a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Sparse a] -> ShowS
$cshowList :: forall a. Show a => [Sparse a] -> ShowS
show :: Sparse a -> String
$cshow :: forall a. Show a => Sparse a -> String
showsPrec :: Int -> Sparse a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Sparse a -> ShowS
Show, Sparse a -> DataType
Sparse a -> Constr
forall {a}. Data a => Typeable (Sparse a)
forall a. Data a => Sparse a -> DataType
forall a. Data a => Sparse a -> Constr
forall a.
Data a =>
(forall b. Data b => b -> b) -> Sparse a -> Sparse a
forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Sparse a -> u
forall a u.
Data a =>
(forall d. Data d => d -> u) -> Sparse a -> [u]
forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Sparse a))
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
$cgmapMo :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
$cgmapMp :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
$cgmapM :: forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Sparse a -> u
$cgmapQi :: forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Sparse a -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> Sparse a -> [u]
$cgmapQ :: forall a u.
Data a =>
(forall d. Data d => d -> u) -> Sparse a -> [u]
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
$cgmapQr :: forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
$cgmapQl :: forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
gmapT :: (forall b. Data b => b -> b) -> Sparse a -> Sparse a
$cgmapT :: forall a.
Data a =>
(forall b. Data b => b -> b) -> Sparse a -> Sparse a
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Sparse a))
$cdataCast2 :: forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Sparse a))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
$cdataCast1 :: forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
dataTypeOf :: Sparse a -> DataType
$cdataTypeOf :: forall a. Data a => Sparse a -> DataType
toConstr :: Sparse a -> Constr
$ctoConstr :: forall a. Data a => Sparse a -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
$cgunfold :: forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
$cgfoldl :: forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
Data, Typeable)

vars :: (Traversable f, Num a) => f a -> f (Sparse a)
vars :: forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> f (Sparse a)
vars = forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL forall {a}. Num a => Int -> a -> (Int, Sparse a)
var Int
0 where
  var :: Int -> a -> (Int, Sparse a)
var !Int
n a
a = (Int
n forall a. Num a => a -> a -> a
+ Int
1, forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse a
a forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> IntMap a
singleton Int
n forall a b. (a -> b) -> a -> b
$ forall t. Mode t => Scalar t -> t
auto a
1)
{-# INLINE vars #-}

apply :: (Traversable f, Num a) => (f (Sparse a) -> b) -> f a -> b
apply :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> b
f = f (Sparse a) -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> f (Sparse a)
vars
{-# INLINE apply #-}

d :: (Traversable f, Num a) => f b -> Sparse a -> f a
d :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> f a
d f b
fs Sparse a
Zero = a
0 forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs
d f b
fs (Sparse a
_ IntMap (Sparse a)
da) = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\ !Int
n b
_ -> (Int
n forall a. Num a => a -> a -> a
+ Int
1, forall b a. b -> (a -> b) -> Maybe a -> b
maybe a
0 forall a. Num a => Sparse a -> a
primal forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap (Sparse a)
da)) Int
0 f b
fs
{-# INLINE d #-}

d' :: (Traversable f, Num a) => f a -> Sparse a -> (a, f a)
d' :: forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> Sparse a -> (a, f a)
d' f a
fs Sparse a
Zero = (a
0, a
0 forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f a
fs)
d' f a
fs (Sparse a
a IntMap (Sparse a)
da) = (a
a, forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\ !Int
n a
_ -> (Int
n forall a. Num a => a -> a -> a
+ Int
1, forall b a. b -> (a -> b) -> Maybe a -> b
maybe a
0 forall a. Num a => Sparse a -> a
primal forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap (Sparse a)
da)) Int
0 f a
fs)
{-# INLINE d' #-}

ds :: (Traversable f, Num a) => f b -> Sparse a -> Cofree f a
ds :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> Cofree f a
ds f b
fs Sparse a
Zero = Cofree f a
r where r :: Cofree f a
r = a
0 forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Cofree f a
r forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs)
ds f b
fs as :: Sparse a
as@(Sparse a
a IntMap (Sparse a)
_) = a
a forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Int -> Cofree f a
go Monomial
emptyMonomial forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Int
fns) where
  fns :: f Int
fns = forall (f :: * -> *) a. Traversable f => f a -> f Int
skeleton f b
fs
  -- go :: Monomial -> Int -> Cofree f a
  go :: Monomial -> Int -> Cofree f a
go Monomial
ix Int
i = forall a. Num a => [Int] -> Sparse a -> a
partial (Monomial -> [Int]
indices Monomial
ix') Sparse a
as forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Int -> Cofree f a
go Monomial
ix' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Int
fns) where
    ix' :: Monomial
ix' = Int -> Monomial -> Monomial
addToMonomial Int
i Monomial
ix
{-# INLINE ds #-}

partialS :: Num a => [Int] -> Sparse a -> Sparse a
partialS :: forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS []     Sparse a
a             = Sparse a
a
partialS (Int
n:[Int]
ns) (Sparse a
_ IntMap (Sparse a)
da) = forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a. a -> Int -> IntMap a -> a
findWithDefault forall a. Sparse a
Zero Int
n IntMap (Sparse a)
da
partialS [Int]
_      Sparse a
Zero          = forall a. Sparse a
Zero
{-# INLINE partialS #-}

partial :: Num a => [Int] -> Sparse a -> a
partial :: forall a. Num a => [Int] -> Sparse a -> a
partial []     (Sparse a
a IntMap (Sparse a)
_)  = a
a
partial (Int
n:[Int]
ns) (Sparse a
_ IntMap (Sparse a)
da) = forall a. Num a => [Int] -> Sparse a -> a
partial [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a. a -> Int -> IntMap a -> a
findWithDefault (forall t. Mode t => Scalar t -> t
auto a
0) Int
n IntMap (Sparse a)
da
partial [Int]
_      Sparse a
Zero          = a
0
{-# INLINE partial #-}

spartial :: Num a => [Int] -> Sparse a -> Maybe a
spartial :: forall a. Num a => [Int] -> Sparse a -> Maybe a
spartial [] (Sparse a
a IntMap (Sparse a)
_) = forall a. a -> Maybe a
Just a
a
spartial (Int
n:[Int]
ns) (Sparse a
_ IntMap (Sparse a)
da) = do
  Sparse a
a' <- forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap (Sparse a)
da
  forall a. Num a => [Int] -> Sparse a -> Maybe a
spartial [Int]
ns Sparse a
a'
spartial [Int]
_  Sparse a
Zero         = forall a. Maybe a
Nothing
{-# INLINE spartial #-}

primal :: Num a => Sparse a -> a
primal :: forall a. Num a => Sparse a -> a
primal (Sparse a
a IntMap (Sparse a)
_) = a
a
primal Sparse a
Zero = a
0

instance Num a => Mode (Sparse a) where
  type Scalar (Sparse a) = a
  auto :: Scalar (Sparse a) -> Sparse a
auto Scalar (Sparse a)
a = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse Scalar (Sparse a)
a forall a. IntMap a
IntMap.empty
  zero :: Sparse a
zero = forall a. Sparse a
Zero
  isKnownZero :: Sparse a -> Bool
isKnownZero Sparse a
Zero = Bool
True
  isKnownZero Sparse a
_ = Bool
False
  isKnownConstant :: Sparse a -> Bool
isKnownConstant Sparse a
Zero = Bool
True
  isKnownConstant (Sparse a
_ IntMap (Sparse a)
m) = forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap (Sparse a)
m
  asKnownConstant :: Sparse a -> Maybe (Scalar (Sparse a))
asKnownConstant Sparse a
Zero = forall a. a -> Maybe a
Just a
0
  asKnownConstant (Sparse a
a IntMap (Sparse a)
m) = a
a forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (f :: * -> *). Alternative f => Bool -> f ()
guard (forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap (Sparse a)
m)
  Sparse a
Zero        ^* :: Sparse a -> Scalar (Sparse a) -> Sparse a
^* Scalar (Sparse a)
_ = forall a. Sparse a
Zero
  Sparse a
a IntMap (Sparse a)
as ^* Scalar (Sparse a)
b = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (a
a forall a. Num a => a -> a -> a
* Scalar (Sparse a)
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall t. Mode t => t -> Scalar t -> t
^* Scalar (Sparse a)
b) IntMap (Sparse a)
as
  Scalar (Sparse a)
_ *^ :: Scalar (Sparse a) -> Sparse a -> Sparse a
*^ Sparse a
Zero        = forall a. Sparse a
Zero
  Scalar (Sparse a)
a *^ Sparse a
b IntMap (Sparse a)
bs = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a)
a forall a. Num a => a -> a -> a
* a
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (Sparse a)
a forall t. Mode t => Scalar t -> t -> t
*^) IntMap (Sparse a)
bs
  Sparse a
Zero        ^/ :: Fractional (Scalar (Sparse a)) =>
Sparse a -> Scalar (Sparse a) -> Sparse a
^/ Scalar (Sparse a)
_ = forall a. Sparse a
Zero
  Sparse a
a IntMap (Sparse a)
as ^/ Scalar (Sparse a)
b = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (a
a forall a. Fractional a => a -> a -> a
/ Scalar (Sparse a)
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall t. (Mode t, Fractional (Scalar t)) => t -> Scalar t -> t
^/ Scalar (Sparse a)
b) IntMap (Sparse a)
as

infixr 6 <+>

(<+>) :: Num a => Sparse a -> Sparse a -> Sparse a
Sparse a
Zero <+> :: forall a. Num a => Sparse a -> Sparse a -> Sparse a
<+> Sparse a
a = Sparse a
a
Sparse a
a <+> Sparse a
Zero = Sparse a
a
Sparse a
a IntMap (Sparse a)
as <+> Sparse a
b IntMap (Sparse a)
bs = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (a
a forall a. Num a => a -> a -> a
+ a
b) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>) IntMap (Sparse a)
as IntMap (Sparse a)
bs

-- The instances for Jacobian for Sparse and Tower are almost identical;
-- could easily be made exactly equal by small changes.
instance Num a => Jacobian (Sparse a) where
  type D (Sparse a) = Sparse a
  unary :: (Scalar (Sparse a) -> Scalar (Sparse a))
-> D (Sparse a) -> Sparse a -> Sparse a
unary Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
_ Sparse a
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a)
f a
0)
  unary Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
dadb (Sparse a
pb IntMap (Sparse a)
bs) = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* D (Sparse a)
dadb) IntMap (Sparse a)
bs

  lift1 :: (Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a) -> D (Sparse a)) -> Sparse a -> Sparse a
lift1 Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a)
_ Sparse a
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a)
f a
0)
  lift1 Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a)
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
bs) = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* D (Sparse a) -> D (Sparse a)
df Sparse a
b) IntMap (Sparse a)
bs

  lift1_ :: (Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a) -> D (Sparse a) -> D (Sparse a))
-> Sparse a
-> Sparse a
lift1_ Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> D (Sparse a)
_  Sparse a
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a)
f a
0)
  lift1_ Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> D (Sparse a)
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
bs) = Sparse a
a where
    a :: Sparse a
a = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a) -> D (Sparse a) -> D (Sparse a)
df Sparse a
a Sparse a
b forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
bs

  binary :: (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a))
-> D (Sparse a) -> D (Sparse a) -> Sparse a -> Sparse a -> Sparse a
binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
_    D (Sparse a)
_    Sparse a
Zero           Sparse a
Zero           = forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
0 a
0)
  binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
_    D (Sparse a)
dadc Sparse a
Zero           (Sparse a
pc IntMap (Sparse a)
dc) = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
0  a
pc) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
dadc forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc
  binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
dadb D (Sparse a)
_    (Sparse a
pb IntMap (Sparse a)
db) Sparse a
Zero           = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb a
0 ) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
dadb forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db
  binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
dadb D (Sparse a)
dadc (Sparse a
pb IntMap (Sparse a)
db) (Sparse a
pc IntMap (Sparse a)
dc) = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb a
pc) forall a b. (a -> b) -> a -> b
$
    forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>)  (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
dadb forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
dadc forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc)

  lift2 :: (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a)))
-> Sparse a
-> Sparse a
-> Sparse a
lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
_  Sparse a
Zero             Sparse a
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
0 a
0)
  lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
Zero c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
0 a
pc) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a
dadc forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc where dadc :: Sparse a
dadc = forall a b. (a, b) -> b
snd (D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df forall t. Mode t => t
zero Sparse a
c)
  lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) Sparse a
Zero = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb a
0) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* Sparse a
dadb) IntMap (Sparse a)
db where dadb :: Sparse a
dadb = forall a b. (a, b) -> a
fst (D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
b forall t. Mode t => t
zero)
  lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb a
pc) IntMap (Sparse a)
da where
    (D (Sparse a)
dadb, D (Sparse a)
dadc) = D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
b Sparse a
c
    da :: IntMap (Sparse a)
da = forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
dadb forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
dadc forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc)

  lift2_ :: (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a)
    -> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a)))
-> Sparse a
-> Sparse a
-> Sparse a
lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
_  Sparse a
Zero             Sparse a
Zero = forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
0 a
0)
  lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) Sparse a
Zero = Sparse a
a where a :: Sparse a
a = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb a
0) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a b. (a, b) -> a
fst (D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
a Sparse a
b forall t. Mode t => t
zero) forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db)
  lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
Zero c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = Sparse a
a where a :: Sparse a
a = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
0 a
pc) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall a. Num a => a -> a -> a
* forall a b. (a, b) -> b
snd (D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
a forall t. Mode t => t
zero Sparse a
c)) IntMap (Sparse a)
dc)
  lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = Sparse a
a where
    (Sparse a
dadb, Sparse a
dadc) = D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
a Sparse a
b Sparse a
c
    a :: Sparse a
a = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
pb a
pc) IntMap (Sparse a)
da
    da :: IntMap (Sparse a)
da = forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a
dadb forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db) (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a
dadc forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc)

#define HEAD (Sparse a)
#include "instances.h"

class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
  pack :: i -> [Sparse a] -> Sparse a
  unpack :: ([a] -> [a]) -> o
  unpack' :: ([a] -> (a, [a])) -> o'

instance Num a => Grad (Sparse a) [a] (a, [a]) a where
  pack :: Sparse a -> [Sparse a] -> Sparse a
pack Sparse a
i [Sparse a]
_ = Sparse a
i
  unpack :: ([a] -> [a]) -> [a]
unpack [a] -> [a]
f = [a] -> [a]
f []
  unpack' :: ([a] -> (a, [a])) -> (a, [a])
unpack' [a] -> (a, [a])
f = [a] -> (a, [a])
f []

instance Grad i o o' a => Grad (Sparse a -> i) (a -> o) (a -> o') a where
  pack :: (Sparse a -> i) -> [Sparse a] -> Sparse a
pack Sparse a -> i
f (Sparse a
a:[Sparse a]
as) = forall i o o' a. Grad i o o' a => i -> [Sparse a] -> Sparse a
pack (Sparse a -> i
f Sparse a
a) [Sparse a]
as
  pack Sparse a -> i
_ [] = forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpack :: ([a] -> [a]) -> a -> o
unpack [a] -> [a]
f a
a = forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack ([a] -> [a]
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aforall a. a -> [a] -> [a]
:))
  unpack' :: ([a] -> (a, [a])) -> a -> o'
unpack' [a] -> (a, [a])
f a
a = forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' ([a] -> (a, [a])
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aforall a. a -> [a] -> [a]
:))

vgrad :: Grad i o o' a => i -> o
vgrad :: forall i o o' a. Grad i o o' a => i -> o
vgrad i
i = forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack (forall {f :: * -> *} {a} {a}.
(Traversable f, Num a, Num a) =>
(f (Sparse a) -> Sparse a) -> f a -> f a
unsafeGrad (forall i o o' a. Grad i o o' a => i -> [Sparse a] -> Sparse a
pack i
i)) where
  unsafeGrad :: (f (Sparse a) -> Sparse a) -> f a -> f a
unsafeGrad f (Sparse a) -> Sparse a
f f a
as = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> f a
d f a
as forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> Sparse a
f f a
as
{-# INLINE vgrad #-}

vgrad' :: Grad i o o' a => i -> o'
vgrad' :: forall i o o' a. Grad i o o' a => i -> o'
vgrad' i
i = forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' (forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Sparse a) -> Sparse a) -> f a -> (a, f a)
unsafeGrad' (forall i o o' a. Grad i o o' a => i -> [Sparse a] -> Sparse a
pack i
i)) where
  unsafeGrad' :: (f (Sparse a) -> Sparse a) -> f a -> (a, f a)
unsafeGrad' f (Sparse a) -> Sparse a
f f a
as = forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> Sparse a -> (a, f a)
d' f a
as forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> Sparse a
f f a
as
{-# INLINE vgrad' #-}

class Num a => Grads i o a | i -> a o, o -> a i where
  packs :: i -> [Sparse a] -> Sparse a
  unpacks :: ([a] -> Cofree [] a) -> o

instance Num a => Grads (Sparse a) (Cofree [] a) a where
  packs :: Sparse a -> [Sparse a] -> Sparse a
packs Sparse a
i [Sparse a]
_ = Sparse a
i
  unpacks :: ([a] -> Cofree [] a) -> Cofree [] a
unpacks [a] -> Cofree [] a
f = [a] -> Cofree [] a
f []

instance Grads i o a => Grads (Sparse a -> i) (a -> o) a where
  packs :: (Sparse a -> i) -> [Sparse a] -> Sparse a
packs Sparse a -> i
f (Sparse a
a:[Sparse a]
as) = forall i o a. Grads i o a => i -> [Sparse a] -> Sparse a
packs (Sparse a -> i
f Sparse a
a) [Sparse a]
as
  packs Sparse a -> i
_ [] = forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpacks :: ([a] -> Cofree [] a) -> a -> o
unpacks [a] -> Cofree [] a
f a
a = forall i o a. Grads i o a => ([a] -> Cofree [] a) -> o
unpacks ([a] -> Cofree [] a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aforall a. a -> [a] -> [a]
:))

vgrads :: Grads i o a => i -> o
vgrads :: forall i o a. Grads i o a => i -> o
vgrads i
i = forall i o a. Grads i o a => ([a] -> Cofree [] a) -> o
unpacks (forall {f :: * -> *} {a} {a}.
(Traversable f, Num a, Num a) =>
(f (Sparse a) -> Sparse a) -> f a -> Cofree f a
unsafeGrads (forall i o a. Grads i o a => i -> [Sparse a] -> Sparse a
packs i
i)) where
  unsafeGrads :: (f (Sparse a) -> Sparse a) -> f a -> Cofree f a
unsafeGrads f (Sparse a) -> Sparse a
f f a
as = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> Cofree f a
ds f a
as forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> Sparse a
f f a
as
{-# INLINE vgrads #-}

isZero :: Sparse a -> Bool
isZero :: forall a. Sparse a -> Bool
isZero Sparse a
Zero = Bool
True
isZero Sparse a
_ = Bool
False

mul :: Num a => Sparse a -> Sparse a -> Sparse a
mul :: forall a. Num a => Sparse a -> Sparse a -> Sparse a
mul Sparse a
Zero Sparse a
_ = forall a. Sparse a
Zero
mul Sparse a
_ Sparse a
Zero = forall a. Sparse a
Zero
mul f :: Sparse a
f@(Sparse a
_ IntMap (Sparse a)
am) g :: Sparse a
g@(Sparse a
_ IntMap (Sparse a)
bm) = forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (forall a. Num a => Sparse a -> a
primal Sparse a
f forall a. Num a => a -> a -> a
* forall a. Num a => Sparse a -> a
primal Sparse a
g) (Int -> Monomial -> IntMap (Sparse a)
derivs Int
0 Monomial
emptyMonomial) where
  derivs :: Int -> Monomial -> IntMap (Sparse a)
derivs Int
v Monomial
mi = forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IntMap.unions (forall a b. (a -> b) -> [a] -> [b]
map Int -> IntMap (Sparse a)
fn [Int
v..Int
kMax]) where
    fn :: Int -> IntMap (Sparse a)
fn Int
w
      | forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
zs = forall a. IntMap a
IntMap.empty
      | Bool
otherwise = forall a. Int -> a -> IntMap a
IntMap.singleton Int
w (forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [a]
ds) (Int -> Monomial -> IntMap (Sparse a)
derivs Int
w Monomial
mi'))
      where
        mi' :: Monomial
mi' = Int -> Monomial -> Monomial
addToMonomial Int
w Monomial
mi
        ([Bool]
zs,[a]
ds) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall a b. (a -> b) -> [a] -> [b]
map (Integer, Monomial, Monomial) -> (Bool, a)
derVal (Monomial -> [(Integer, Monomial, Monomial)]
terms Monomial
mi'))
        derVal :: (Integer, Monomial, Monomial) -> (Bool, a)
derVal (Integer
bin,Monomial
mif,Monomial
mig) = (forall a. Sparse a -> Bool
isZero Sparse a
fder Bool -> Bool -> Bool
|| forall a. Sparse a -> Bool
isZero Sparse a
gder, forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
bin forall a. Num a => a -> a -> a
* forall a. Num a => Sparse a -> a
primal Sparse a
fder forall a. Num a => a -> a -> a
* forall a. Num a => Sparse a -> a
primal Sparse a
gder) where
          fder :: Sparse a
fder = forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS (Monomial -> [Int]
indices Monomial
mif) Sparse a
f
          gder :: Sparse a
gder = forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS (Monomial -> [Int]
indices Monomial
mig) Sparse a
g
  kMax :: Int
kMax = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) (forall a. IntMap a -> Maybe ((Int, a), IntMap a)
IntMap.maxViewWithKey IntMap (Sparse a)
am) forall a. Ord a => a -> a -> a
`max` forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) (forall a. IntMap a -> Maybe ((Int, a), IntMap a)
IntMap.maxViewWithKey IntMap (Sparse a)
bm)