{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module InfBackprop.Common
(
Backprop (MkBackprop),
call,
forward,
backward,
StartBackprop,
startBackprop,
forwardBackward,
numba,
numbaN,
derivative,
derivativeN,
BackpropFunc,
const,
pureBackprop,
)
where
import Control.Arrow (Kleisli (Kleisli))
import Control.CatBifunctor (CatBiFunctor, first, (***))
import Control.Category (Category, id, (.), (>>>))
import GHC.Natural (Natural)
import IsomorphismClass (IsomorphicTo)
import IsomorphismClass.Extra ()
import IsomorphismClass.Isomorphism (Isomorphism, iso)
import NumHask (one, zero)
import NumHask.Algebra.Additive (Additive)
import NumHask.Algebra.Ring (Distributive)
import NumHask.Extra ()
import Prelude (Monad, flip, fromIntegral, iterate, pure, (!!), ($))
import qualified Prelude as P
data Backprop cat input output = forall cache.
MkBackprop
{
forall (cat :: * -> * -> *) input output.
Backprop cat input output -> cat input output
call :: cat input output,
()
forward :: Backprop cat input (output, cache),
()
backward :: Backprop cat (output, cache) input
}
composition' ::
forall cat x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y ->
Backprop cat y z ->
Backprop cat x z
composition' :: forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
composition'
(MkBackprop cat x y
callF (Backprop cat x (y, cache)
forwardF :: Backprop cat x (y, hF)) (Backprop cat (y, cache) x
backwardF :: Backprop cat (y, hF) x))
(MkBackprop cat y z
callG (Backprop cat y (z, cache)
forwardG :: Backprop cat y (z, hG)) (Backprop cat (z, cache) y
backwardG :: Backprop cat (z, hG) y)) =
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat x z
call_ Backprop cat x (z, (cache, cache))
forward_ Backprop cat (z, (cache, cache)) x
backward_
where
call_ :: cat x z
call_ :: cat x z
call_ = cat x y
callF forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> cat y z
callG
forward_ :: Backprop cat x (z, (hG, hF))
forward_ :: Backprop cat x (z, (cache, cache))
forward_ =
(Backprop cat x (y, cache)
forwardF forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat y (z, cache)
forwardG) forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((z, hG), hF) (z, (hG, hF)))
backward_ :: Backprop cat (z, (hG, hF)) x
backward_ :: Backprop cat (z, (cache, cache)) x
backward_ =
(forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat (z, (hG, hF)) ((z, hG), hF)) forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat (z, cache) y
backwardG forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` Backprop cat (y, cache) x
backwardF
iso' ::
forall cat x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso' :: forall (cat :: * -> * -> *) x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso' = forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat x y
call_ (Backprop cat x (y, ())
forward_ :: Backprop cat x (y, ())) (Backprop cat (y, ()) x
backward_ :: Backprop cat (y, ()) x)
where
call_ :: cat x y
call_ :: cat x y
call_ = forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso
forward_ :: Backprop cat x (y, ())
forward_ :: Backprop cat x (y, ())
forward_ = (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat x y) forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat y (y, ()))
backward_ :: Backprop cat (y, ()) x
backward_ :: Backprop cat (y, ()) x
backward_ = (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat (y, ()) y) forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat y x)
instance
(Isomorphism cat, CatBiFunctor (,) cat) =>
Category (Backprop cat)
where
id :: forall a. Backprop cat a a
id = forall (cat :: * -> * -> *) x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso'
. :: forall b c a.
Backprop cat b c -> Backprop cat a b -> Backprop cat a c
(.) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
composition'
instance
(Isomorphism cat, CatBiFunctor (,) cat) =>
Isomorphism (Backprop cat)
where
iso :: forall a b. IsomorphicTo a b => Backprop cat a b
iso = forall (cat :: * -> * -> *) x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso'
instance
(Isomorphism cat, CatBiFunctor (,) cat) =>
CatBiFunctor (,) (Backprop cat)
where
*** :: forall a1 b1 a2 b2.
Backprop cat a1 b1
-> Backprop cat a2 b2 -> Backprop cat (a1, a2) (b1, b2)
(***)
(MkBackprop cat a1 b1
call1 (Backprop cat a1 (b1, cache)
forward1 :: Backprop cat x1 (y1, h1)) (Backprop cat (b1, cache) a1
backward1 :: Backprop cat (y1, h1) x1))
(MkBackprop cat a2 b2
call2 (Backprop cat a2 (b2, cache)
forward2 :: Backprop cat x2 (y2, h2)) (Backprop cat (b2, cache) a2
backward2 :: Backprop cat (y2, h2) x2)) =
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat (a1, a2) (b1, b2)
call12 Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forward12 Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
backward12
where
call12 :: cat (x1, x2) (y1, y2)
call12 :: cat (a1, a2) (b1, b2)
call12 = cat a1 b1
call1 forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** cat a2 b2
call2
forward12 :: Backprop cat (x1, x2) ((y1, y2), (h1, h2))
forward12 :: Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forward12 = Backprop cat a1 (b1, cache)
forward1 forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** Backprop cat a2 (b2, cache)
forward2 forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((y1, h1), (y2, h2)) ((y1, y2), (h1, h2)))
backward12 :: Backprop cat ((y1, y2), (h1, h2)) (x1, x2)
backward12 :: Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
backward12 = (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((y1, y2), (h1, h2)) ((y1, h1), (y2, h2))) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat (b1, cache) a1
backward1 forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** Backprop cat (b2, cache) a2
backward2
forwardBackward ::
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat y y ->
Backprop cat x y ->
Backprop cat x x
forwardBackward :: forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forwardBackward Backprop cat y y
dy (MkBackprop cat x y
_ Backprop cat x (y, cache)
forward_ Backprop cat (y, cache) x
backward_) = Backprop cat x (y, cache)
forward_ forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat y y
dy forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat (y, cache) x
backward_
class Distributive x => StartBackprop cat x where
startBackprop :: Backprop cat x x
numba ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y ->
Backprop cat x x
numba :: forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> Backprop cat x x
numba = forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forwardBackward forall (cat :: * -> * -> *) x.
StartBackprop cat x =>
Backprop cat x x
startBackprop
numbaN ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural ->
Backprop cat x x ->
Backprop cat x x
numbaN :: forall (cat :: * -> * -> *) x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural -> Backprop cat x x -> Backprop cat x x
numbaN Natural
n Backprop cat x x
f = forall a. (a -> a) -> a -> [a]
iterate forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> Backprop cat x x
numba Backprop cat x x
f forall a. [a] -> Int -> a
!! forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n
derivative ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y ->
cat x x
derivative :: forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> cat x x
derivative = forall (cat :: * -> * -> *) input output.
Backprop cat input output -> cat input output
call forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> Backprop cat x x
numba
derivativeN ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural ->
Backprop cat x x ->
cat x x
derivativeN :: forall (cat :: * -> * -> *) x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural -> Backprop cat x x -> cat x x
derivativeN Natural
n = forall (cat :: * -> * -> *) input output.
Backprop cat input output -> cat input output
call forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (cat :: * -> * -> *) x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural -> Backprop cat x x -> Backprop cat x x
numbaN Natural
n
type BackpropFunc = Backprop (->)
instance forall x. (Distributive x) => StartBackprop (->) x where
startBackprop :: Backprop (->) x x
startBackprop = forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const forall a. Multiplicative a => a
one
const ::
forall c x.
(Additive c, Additive x) =>
c ->
BackpropFunc x c
const :: forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const c
c = forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop x -> c
call' BackpropFunc x (c, ())
forward' BackpropFunc (c, ()) x
backward'
where
call' :: x -> c
call' :: x -> c
call' = forall a b. a -> b -> a
P.const c
c
forward' :: BackpropFunc x (c, ())
forward' :: BackpropFunc x (c, ())
forward' = forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const c
c forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: BackpropFunc c (c, ()))
backward' :: BackpropFunc (c, ()) x
backward' :: BackpropFunc (c, ()) x
backward' = (forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: BackpropFunc (c, ()) c) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const forall a. Additive a => a
zero
pureBackprop :: forall a b m. Monad m => Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop :: forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop
( MkBackprop
(a -> b
call'' :: a -> b)
(Backprop (->) a (b, cache)
forward'' :: Backprop (->) a (b, c))
(Backprop (->) (b, cache) a
backward'' :: Backprop (->) (b, c) a)
) =
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop Kleisli m a b
call' Backprop (Kleisli m) a (b, cache)
forward' Backprop (Kleisli m) (b, cache) a
backward'
where
call' :: Kleisli m a b
call' :: Kleisli m a b
call' = forall (m :: * -> *) a b. (a -> m b) -> Kleisli m a b
Kleisli forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
call''
forward' :: Backprop (Kleisli m) a (b, c)
forward' :: Backprop (Kleisli m) a (b, cache)
forward' = forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) a (b, cache)
forward''
backward' :: Backprop (Kleisli m) (b, c) a
backward' :: Backprop (Kleisli m) (b, cache) a
backward' = forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) (b, cache) a
backward''
instance (Distributive x, Monad m) => StartBackprop (Kleisli m) x where
startBackprop :: Backprop (Kleisli m) x x
startBackprop = forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop forall (cat :: * -> * -> *) x.
StartBackprop cat x =>
Backprop cat x x
startBackprop