module Numeric.AD.Internal.Kahn
( Kahn(..)
, Tape(..)
, partials
, partialArray
, partialMap
, derivative
, derivative'
, vgrad, vgrad'
, Grad(..)
, bind
, unbind
, unbindMap
, unbindWith
, unbindMapWithDefault
, primal
, var
, varId
) where
#if __GLASGOW_HASKELL__ < 710
import Prelude hiding (mapM)
import Control.Applicative (Applicative(..),(<$>))
import Data.Traversable (Traversable, mapM)
#endif
import Control.Monad.ST
import Control.Monad hiding (mapM)
import Control.Monad.Trans.State
import Data.List (foldl')
import Data.Array.ST
import Data.Array
import Data.IntMap (IntMap, fromListWith, findWithDefault)
import Data.Graph (Vertex, transposeG, Graph)
import Data.Number.Erf
import Data.Reify (reifyGraph, MuRef(..))
import qualified Data.Reify.Graph as Reified
import System.IO.Unsafe (unsafePerformIO)
import Data.Data (Data)
import Data.Typeable (Typeable)
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
data Tape a t
= Zero
| Lift !a
| Var !a !Int
| Binary !a a a t t
| Unary !a a t
deriving (Show, Data, Typeable)
newtype Kahn a = Kahn (Tape a (Kahn a)) deriving (Show, Typeable)
instance MuRef (Kahn a) where
type DeRef (Kahn a) = Tape a
mapDeRef _ (Kahn Zero) = pure Zero
mapDeRef _ (Kahn (Lift a)) = pure (Lift a)
mapDeRef _ (Kahn (Var a v)) = pure (Var a v)
mapDeRef f (Kahn (Binary a dadb dadc b c)) = Binary a dadb dadc <$> f b <*> f c
mapDeRef f (Kahn (Unary a dadb b)) = Unary a dadb <$> f b
instance Num a => Mode (Kahn a) where
type Scalar (Kahn a) = a
isKnownZero (Kahn Zero) = True
isKnownZero _ = False
isKnownConstant (Kahn Zero) = True
isKnownConstant (Kahn (Lift _)) = True
isKnownConstant _ = False
auto a = Kahn (Lift a)
zero = Kahn Zero
a *^ b = lift1 (a *) (\_ -> auto a) b
a ^* b = lift1 (* b) (\_ -> auto b) a
a ^/ b = lift1 (/ b) (\_ -> auto (recip b)) a
(<+>) :: Num a => Kahn a -> Kahn a -> Kahn a
(<+>) = binary (+) 1 1
(<**>) :: Floating a => Kahn a -> Kahn a -> Kahn a
Kahn Zero <**> y = auto (0 ** primal y)
_ <**> Kahn Zero = auto 1
x <**> Kahn (Lift y) = lift1 (**y) (\z -> y *^ z ** Id (y1)) x
x <**> y = lift2_ (**) (\z xi yi -> (yi * xi ** (yi 1), z * log xi)) x y
primal :: Num a => Kahn a -> a
primal (Kahn Zero) = 0
primal (Kahn (Lift a)) = a
primal (Kahn (Var a _)) = a
primal (Kahn (Binary a _ _ _ _)) = a
primal (Kahn (Unary a _ _)) = a
instance Num a => Jacobian (Kahn a) where
type D (Kahn a) = Id a
unary f _ (Kahn Zero) = Kahn (Lift (f 0))
unary f _ (Kahn (Lift a)) = Kahn (Lift (f a))
unary f (Id dadb) b = Kahn (Unary (f (primal b)) dadb b)
lift1 f df b = unary f (df (Id pb)) b where
pb = primal b
lift1_ f df b = unary (const a) (df (Id a) (Id pb)) b where
pb = primal b
a = f pb
binary f _ _ (Kahn Zero) (Kahn Zero) = Kahn (Lift (f 0 0))
binary f _ _ (Kahn Zero) (Kahn (Lift c)) = Kahn (Lift (f 0 c))
binary f _ _ (Kahn (Lift b)) (Kahn Zero) = Kahn (Lift (f b 0))
binary f _ _ (Kahn (Lift b)) (Kahn (Lift c)) = Kahn (Lift (f b c))
binary f _ (Id dadc) (Kahn Zero) c = Kahn (Unary (f 0 (primal c)) dadc c)
binary f _ (Id dadc) (Kahn (Lift b)) c = Kahn (Unary (f b (primal c)) dadc c)
binary f (Id dadb) _ b (Kahn Zero) = Kahn (Unary (f (primal b) 0) dadb b)
binary f (Id dadb) _ b (Kahn (Lift c)) = Kahn (Unary (f (primal b) c) dadb b)
binary f (Id dadb) (Id dadc) b c = Kahn (Binary (f (primal b) (primal c)) dadb dadc b c)
lift2 f df b c = binary f dadb dadc b c where
(dadb, dadc) = df (Id (primal b)) (Id (primal c))
lift2_ f df b c = binary (\_ _ -> a) dadb dadc b c where
pb = primal b
pc = primal c
a = f pb pc
(dadb, dadc) = df (Id a) (Id pb) (Id pc)
mul :: Num a => Kahn a -> Kahn a -> Kahn a
mul = lift2 (*) (\x y -> (y, x))
#define HEAD Kahn a
#include <instances.h>
derivative :: Num a => Kahn a -> a
derivative = sum . map snd . partials
derivative' :: Num a => Kahn a -> (a, a)
derivative' r = (primal r, derivative r)
backPropagate :: Num a => (Vertex -> (Tape a Int, Int, [Int])) -> STArray s Int a -> Vertex -> ST s ()
backPropagate vmap ss v = case node of
Unary _ g b -> do
da <- readArray ss i
db <- readArray ss b
writeArray ss b (db + g*da)
Binary _ gb gc b c -> do
da <- readArray ss i
db <- readArray ss b
writeArray ss b (db + gb*da)
dc <- readArray ss c
writeArray ss c (dc + gc*da)
_ -> return ()
where
(node, i, _) = vmap v
topSortAcyclic :: Graph -> [Vertex]
topSortAcyclic g = reverse $ runST $ do
del <- newArray (bounds g) False :: ST s (STUArray s Int Bool)
let tg = transposeG g
starters = [ n | (n, []) <- assocs tg ]
loop [] rs = return rs
loop (n:ns) rs = do
writeArray del n True
let add [] = return ns
add (m:ms) = do
b <- ok (tg!m)
ms' <- add ms
return $ if b then m : ms' else ms'
ok [] = return True
ok (x:xs) = do b <- readArray del x; if b then ok xs else return False
ns' <- add (g!n)
loop ns' (n : rs)
loop starters []
partials :: forall a. Num a => Kahn a -> [(Int, a)]
partials tape = [ let v = sensitivities ! ix in seq v (ident, v) | (ix, Var _ ident) <- xs ] where
Reified.Graph xs start = unsafePerformIO $ reifyGraph tape
g = array xsBounds [ (i, successors t) | (i, t) <- xs ]
vertexMap = array xsBounds xs
vmap i = (vertexMap ! i, i, [])
xsBounds = sbounds xs
sensitivities = runSTArray $ do
ss <- newArray xsBounds 0
writeArray ss start 1
forM_ (topSortAcyclic g) $
backPropagate vmap ss
return ss
sbounds ((a,_):as) = foldl' (\(lo,hi) (b,_) -> let lo' = min lo b; hi' = max hi b in lo' `seq` hi' `seq` (lo', hi')) (a,a) as
sbounds _ = undefined
successors :: Tape a t -> [t]
successors (Unary _ _ b) = [b]
successors (Binary _ _ _ b c) = [b,c]
successors _ = []
partialArray :: Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray vbounds tape = accumArray (+) 0 vbounds (partials tape)
partialMap :: Num a => Kahn a -> IntMap a
partialMap = fromListWith (+) . partials
class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
pack :: i -> [Kahn a] -> Kahn a
unpack :: ([a] -> [a]) -> o
unpack' :: ([a] -> (a, [a])) -> o'
instance Num a => Grad (Kahn a) [a] (a, [a]) a where
pack i _ = i
unpack f = f []
unpack' f = f []
instance Grad i o o' a => Grad (Kahn a -> i) (a -> o) (a -> o') a where
pack f (a:as) = pack (f a) as
pack _ [] = error "Grad.pack: logic error"
unpack f a = unpack (f . (a:))
unpack' f a = unpack' (f . (a:))
vgrad :: Grad i o o' a => i -> o
vgrad i = unpack (unsafeGrad (pack i)) where
unsafeGrad f as = unbind vs (partialArray bds $ f vs) where
(vs,bds) = bind as
vgrad' :: Grad i o o' a => i -> o'
vgrad' i = unpack' (unsafeGrad' (pack i)) where
unsafeGrad' f as = (primal r, unbind vs (partialArray bds r)) where
r = f vs
(vs,bds) = bind as
var :: a -> Int -> Kahn a
var a v = Kahn (Var a v)
varId :: Kahn a -> Int
varId (Kahn (Var _ v)) = v
varId _ = error "varId: not a Var"
bind :: Traversable f => f a -> (f (Kahn a), (Int,Int))
bind xs = (r,(0,hi)) where
(r,hi) = runState (mapM freshVar xs) 0
freshVar a = state $ \s -> let s' = s + 1 in s' `seq` (var a s, s')
unbind :: Functor f => f (Kahn a) -> Array Int a -> f a
unbind xs ys = fmap (\v -> ys ! varId v) xs
unbindWith :: (Functor f, Num a) => (a -> b -> c) -> f (Kahn a) -> Array Int b -> f c
unbindWith f xs ys = fmap (\v -> f (primal v) (ys ! varId v)) xs
unbindMap :: (Functor f, Num a) => f (Kahn a) -> IntMap a -> f a
unbindMap xs ys = fmap (\v -> findWithDefault 0 (varId v) ys) xs
unbindMapWithDefault :: (Functor f, Num a) => b -> (a -> b -> c) -> f (Kahn a) -> IntMap b -> f c
unbindMapWithDefault z f xs ys = fmap (\v -> f (primal v) $ findWithDefault z (varId v) ys) xs