module Numeric.AD.Internal.Sparse
( Index(..)
, emptyIndex
, addToIndex
, indices
, Sparse(..)
, apply
, vars
, d, d', ds
, skeleton
, spartial
, partial
, vgrad
, vgrad'
, vgrads
, Grad(..)
, Grads(..)
) where
import Prelude hiding (lookup)
import Control.Applicative hiding ((<**>))
import Control.Comonad.Cofree
import Control.Monad (join)
import Data.Data
import Data.IntMap (IntMap, mapWithKey, unionWith, findWithDefault, toAscList, singleton, insertWith, lookup)
import qualified Data.IntMap as IntMap
import Data.Number.Erf
import Data.Traversable
import Data.Typeable ()
import Numeric.AD.Internal.Combinators
import Numeric.AD.Jacobian
import Numeric.AD.Mode
newtype Index = Index (IntMap Int)
emptyIndex :: Index
emptyIndex = Index IntMap.empty
addToIndex :: Int -> Index -> Index
addToIndex k (Index m) = Index (insertWith (+) k 1 m)
indices :: Index -> [Int]
indices (Index as) = uncurry (flip replicate) `concatMap` toAscList as
data Sparse a
= Sparse !a (IntMap (Sparse a))
| Zero
deriving (Show, Data, Typeable)
dropMap :: Int -> IntMap a -> IntMap a
dropMap n = snd . IntMap.split (n 1)
times :: Num a => Sparse a -> Int -> Sparse a -> Sparse a
times Zero _ _ = Zero
times _ _ Zero = Zero
times (Sparse a as) n (Sparse b bs) = Sparse (a * b) $
unionWith (+)
(fmap (^* b) (dropMap n as))
(fmap (a *^) (dropMap n bs))
vars :: (Traversable f, Num a) => f a -> f (Sparse a)
vars = snd . mapAccumL var 0 where
var !n a = (n + 1, Sparse a $ singleton n $ auto 1)
apply :: (Traversable f, Num a) => (f (Sparse a) -> b) -> f a -> b
apply f = f . vars
skeleton :: Traversable f => f a -> f Int
skeleton = snd . mapAccumL (\ !n _ -> (n + 1, n)) 0
d :: (Traversable f, Num a) => f b -> Sparse a -> f a
d fs (Zero) = 0 <$ fs
d fs (Sparse _ da) = snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs
d' :: (Traversable f, Num a) => f a -> Sparse a -> (a, f a)
d' fs Zero = (0, 0 <$ fs)
d' fs (Sparse a da) = (a, snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs)
ds :: (Traversable f, Num a) => f b -> Sparse a -> Cofree f a
ds fs Zero = r where r = 0 :< (r <$ fs)
ds fs (as@(Sparse a _)) = a :< (go emptyIndex <$> fns) where
fns = skeleton fs
go ix i = partial (indices ix') as :< (go ix' <$> fns) where
ix' = addToIndex i ix
partial :: Num a => [Int] -> Sparse a -> a
partial [] (Sparse a _) = a
partial (n:ns) (Sparse _ da) = partial ns $ findWithDefault (auto 0) n da
partial _ Zero = 0
spartial :: Num a => [Int] -> Sparse a -> Maybe a
spartial [] (Sparse a _) = Just a
spartial (n:ns) (Sparse _ da) = do
a' <- lookup n da
spartial ns a'
spartial _ Zero = Nothing
primal :: Num a => Sparse a -> a
primal (Sparse a _) = a
primal Zero = 0
(<**>) :: Floating a => Sparse a -> Sparse a -> Sparse a
Zero <**> y = auto (0 ** primal y)
_ <**> Zero = auto 1
x <**> y@(Sparse b bs)
| IntMap.null bs = lift1 (**b) (\z -> b *^ z <**> Sparse (b1) IntMap.empty) x
| otherwise = lift2_ (**) (\z xi yi -> (yi * z / xi, z * log xi)) x y
instance Num a => Mode (Sparse a) where
type Scalar (Sparse a) = a
auto a = Sparse a IntMap.empty
zero = Zero
Zero ^* _ = Zero
Sparse a as ^* b = Sparse (a * b) $ fmap (^* b) as
_ *^ Zero = Zero
a *^ Sparse b bs = Sparse (a * b) $ fmap (a *^) bs
Zero ^/ _ = Zero
Sparse a as ^/ b = Sparse (a / b) $ fmap (^/ b) as
infixr 6 <+>
(<+>) :: Num a => Sparse a -> Sparse a -> Sparse a
Zero <+> a = a
a <+> Zero = a
Sparse a as <+> Sparse b bs = Sparse (a + b) $ unionWith (<+>) as bs
instance Num a => Jacobian (Sparse a) where
type D (Sparse a) = Sparse a
unary f _ Zero = auto (f 0)
unary f dadb (Sparse pb bs) = Sparse (f pb) $ mapWithKey (times dadb) bs
lift1 f _ Zero = auto (f 0)
lift1 f df b@(Sparse pb bs) = Sparse (f pb) $ mapWithKey (times (df b)) bs
lift1_ f _ Zero = auto (f 0)
lift1_ f df b@(Sparse pb bs) = a where
a = Sparse (f pb) $ mapWithKey (times (df a b)) bs
binary f _ _ Zero Zero = auto (f 0 0)
binary f _ dadc Zero (Sparse pc dc) = Sparse (f 0 pc) $ mapWithKey (times dadc) dc
binary f dadb _ (Sparse pb db) Zero = Sparse (f pb 0 ) $ mapWithKey (times dadb) db
binary f dadb dadc (Sparse pb db) (Sparse pc dc) = Sparse (f pb pc) $
unionWith (<+>)
(mapWithKey (times dadb) db)
(mapWithKey (times dadc) dc)
lift2 f _ Zero Zero = auto (f 0 0)
lift2 f df Zero c@(Sparse pc dc) = Sparse (f 0 pc) $ mapWithKey (times dadc) dc where dadc = snd (df zero c)
lift2 f df b@(Sparse pb db) Zero = Sparse (f pb 0) $ mapWithKey (times dadb) db where dadb = fst (df b zero)
lift2 f df b@(Sparse pb db) c@(Sparse pc dc) = Sparse (f pb pc) da where
(dadb, dadc) = df b c
da = unionWith (<+>)
(mapWithKey (times dadb) db)
(mapWithKey (times dadc) dc)
lift2_ f _ Zero Zero = auto (f 0 0)
lift2_ f df b@(Sparse pb db) Zero = a where a = Sparse (f pb 0) (mapWithKey (times (fst (df a b zero))) db)
lift2_ f df Zero c@(Sparse pc dc) = a where a = Sparse (f 0 pc) (mapWithKey (times (snd (df a zero c))) dc)
lift2_ f df b@(Sparse pb db) c@(Sparse pc dc) = a where
(dadb, dadc) = df a b c
a = Sparse (f pb pc) da
da = unionWith (<+>)
(mapWithKey (times dadb) db)
(mapWithKey (times dadc) dc)
#define HEAD Sparse a
#include "instances.h"
class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
pack :: i -> [Sparse a] -> Sparse a
unpack :: ([a] -> [a]) -> o
unpack' :: ([a] -> (a, [a])) -> o'
instance Num a => Grad (Sparse a) [a] (a, [a]) a where
pack i _ = i
unpack f = f []
unpack' f = f []
instance Grad i o o' a => Grad (Sparse a -> i) (a -> o) (a -> o') a where
pack f (a:as) = pack (f a) as
pack _ [] = error "Grad.pack: logic error"
unpack f a = unpack (f . (a:))
unpack' f a = unpack' (f . (a:))
vgrad :: Grad i o o' a => i -> o
vgrad i = unpack (unsafeGrad (pack i)) where
unsafeGrad f as = d as $ apply f as
vgrad' :: Grad i o o' a => i -> o'
vgrad' i = unpack' (unsafeGrad' (pack i)) where
unsafeGrad' f as = d' as $ apply f as
class Num a => Grads i o a | i -> a o, o -> a i where
packs :: i -> [Sparse a] -> Sparse a
unpacks :: ([a] -> Cofree [] a) -> o
instance Num a => Grads (Sparse a) (Cofree [] a) a where
packs i _ = i
unpacks f = f []
instance Grads i o a => Grads (Sparse a -> i) (a -> o) a where
packs f (a:as) = packs (f a) as
packs _ [] = error "Grad.pack: logic error"
unpacks f a = unpacks (f . (a:))
vgrads :: Grads i o a => i -> o
vgrads i = unpacks (unsafeGrads (packs i)) where
unsafeGrads f as = ds as $ apply f as