{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE EmptyCase             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TupleSections         #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeInType            #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE ViewPatterns          #-}
{-# OPTIONS_HADDOCK not-home       #-}

-- |
-- Module      : Numeric.Backprop.Internal
-- Copyright   : (c) Justin Le 2023
-- License     : BSD3
--
-- Maintainer  : justin@jle.im
-- Stability   : experimental
-- Portability : non-portable
--
-- Provides the types and instances used for the graph
-- building/back-propagation for the library.

module Numeric.Backprop.Internal (
    BVar
  , W
  , backpropWithN, evalBPN
  , constVar
  , liftOp, liftOp1, liftOp2, liftOp3
  , viewVar, setVar, sequenceVar, collectVar, previewVar, toListOfVar
  , coerceVar
  -- * Func wrappers
  , ZeroFunc(..), zfNum, zeroFunc
  , AddFunc(..), afNum, addFunc
  , OneFunc(..), ofNum, oneFunc
  -- * Debug
  , debugSTN
  , debugIR
  ) where

import           Control.DeepSeq
import           Control.Exception
import           Control.Monad
import           Control.Monad.ST
import           Control.Monad.Trans.State
import           Data.Bifunctor
import           Data.Coerce
import           Data.Foldable
import           Data.Function
import           Data.Functor.Identity
import           Data.IORef
import           Data.Kind
import           Data.Maybe
import           Data.Monoid hiding        (Any(..))
import           Data.Proxy
import           Data.Reflection
import           Data.Type.Util
import           Data.Typeable
import           Data.Vinyl.Core
import           GHC.Exts                  (Any)
import           GHC.Generics              as G
import           Lens.Micro
import           Lens.Micro.Extras
import           Numeric.Backprop.Class
import           Numeric.Backprop.Op
import           System.IO.Unsafe
import           Unsafe.Coerce
import qualified Data.Vector               as V
import qualified Data.Vector.Mutable       as MV
import qualified Data.Vinyl.Recursive      as VR
import qualified Data.Vinyl.XRec           as X

-- | "Zero out" all components of a value.  For scalar values, this should
-- just be @'const' 0@.  For vectors and matrices, this should set all
-- components to zero, the additive identity.
--
-- Should be idempotent: Applying the function twice is the same as
-- applying it just once.
--
-- Each type should ideally only have one 'ZeroFunc'.  This coherence
-- constraint is given by the typeclass 'Backprop'.
--
-- @since 0.2.0.0
newtype ZeroFunc a = ZF { forall a. ZeroFunc a -> a -> a
runZF :: a -> a }

-- | Add together two values of a type.  To combine contributions of
-- gradients, so should ideally be information-preserving.
--
-- See laws for 'Backprop' for the laws this should be expected to
-- preserve.  Namely, it should be commutative and associative, with an
-- identity for a valid 'ZeroFunc'.
--
-- Each type should ideally only have one 'AddFunc'.  This coherence
-- constraint is given by the typeclass 'Backprop'.
--
-- @since 0.2.0.0
newtype AddFunc  a = AF { forall a. AddFunc a -> a -> a -> a
runAF :: a -> a -> a }

-- | "One" all components of a value.  For scalar values, this should
-- just be @'const' 1@.  For vectors and matrices, this should set all
-- components to one, the multiplicative identity.
--
-- Should be idempotent: Applying the function twice is the same as
-- applying it just once.
--
-- Each type should ideally only have one 'OneFunc'.  This coherence
-- constraint is given by the typeclass 'Backprop'.
--
-- @since 0.2.0.0
newtype OneFunc  a = OF { forall a. OneFunc a -> a -> a
runOF :: a -> a }

-- | If a type has a 'Num' instance, this is the canonical 'ZeroFunc'.
--
-- @since 0.2.0.0
zfNum :: Num a => ZeroFunc a
zfNum :: forall a. Num a => ZeroFunc a
zfNum = forall a. (a -> a) -> ZeroFunc a
ZF (forall a b. a -> b -> a
const a
0)
{-# INLINE zfNum #-}

-- | If a type has a 'Num' instance, this is the canonical 'AddFunc'.
--
-- @since 0.2.0.0
afNum :: Num a => AddFunc a
afNum :: forall a. Num a => AddFunc a
afNum = forall a. (a -> a -> a) -> AddFunc a
AF forall a. Num a => a -> a -> a
(+)
{-# INLINE afNum #-}

-- | If a type has a 'Num' instance, this is the canonical 'OneFunc'.
--
-- @since 0.2.0.0
ofNum :: Num a => OneFunc a
ofNum :: forall a. Num a => OneFunc a
ofNum = forall a. (a -> a) -> OneFunc a
OF (forall a b. a -> b -> a
const a
1)
{-# INLINE ofNum #-}

-- | A @'BVar' s a@ is a value of type @a@ that can be "backpropagated".
--
-- Functions referring to 'BVar's are tracked by the library and can be
-- automatically differentiated to get their gradients and results.
--
-- For simple numeric values, you can use its 'Num', 'Fractional', and
-- 'Floating' instances to manipulate them as if they were the numbers they
-- represent.
--
-- If @a@ contains items, the items can be accessed and extracted using
-- lenses. A @'Lens'' b a@ can be used to access an @a@ inside a @b@, using
-- '^^.' ('Numeric.Backprop.viewVar'):
--
-- @
-- ('^.')  ::        a -> 'Lens'' a b ->        b
-- ('^^.') :: 'BVar' s a -> 'Lens'' a b -> 'BVar' s b
-- @
--
-- There is also '^^?' ('Numeric.Backprop.previewVar'), to use a 'Prism''
-- or 'Traversal'' to extract a target that may or may not be present
-- (which can implement pattern matching), '^^..'
-- ('Numeric.Backprop.toListOfVar') to use a 'Traversal'' to extract /all/
-- targets inside a 'BVar', and '.~~' ('setVar') to set and update values
-- inside a 'BVar'.
--
-- If you have control over your data type definitions, you can also use
-- 'Numeric.Backprop.splitBV' and 'Numeric.Backprop.joinBV' to manipulate
-- data types by easily extracting fields out of a 'BVar' of data types and
-- creating 'BVar's of data types out of 'BVar's of their fields.  See
-- "Numeric.Backprop#hkd" for a tutorial on this use pattern.
--
-- For more complex operations, libraries can provide functions on 'BVar's
-- using 'Numeric.Backprop.liftOp' and related functions.  This is how you
-- can create primitive functions that users can use to manipulate your
-- library's values.  See
-- <https://backprop.jle.im/08-equipping-your-library.html> for a detailed
-- guide.
--
-- For example, the /hmatrix/ library has a matrix-vector multiplication
-- function, @#> :: L m n -> R n -> L m@.
--
-- A library could instead provide a function @#> :: 'BVar' (L m n) -> BVar
-- (R n) -> BVar (R m)@, which the user can then use to manipulate their
-- 'BVar's of @L m n@s and @R n@s, etc.
--
-- See "Numeric.Backprop#liftops" and documentation for
-- 'Numeric.Backprop.liftOp' for more information.
--
data BVar s a = BV { forall s a. BVar s a -> BRef s
_bvRef :: !(BRef s)
                   , forall s a. BVar s a -> a
_bvVal :: !a
                   }

-- | @since 0.1.5.1
deriving instance Typeable (BVar s a)

-- | @since 0.2.6.3
instance X.IsoHKD (BVar s) a

data BRef (s :: Type) = BRInp !Int
                      | BRIx !Int
                      | BRC
  deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall s x. Rep (BRef s) x -> BRef s
forall s x. BRef s -> Rep (BRef s) x
$cto :: forall s x. Rep (BRef s) x -> BRef s
$cfrom :: forall s x. BRef s -> Rep (BRef s) x
Generic, Int -> BRef s -> ShowS
forall s. Int -> BRef s -> ShowS
forall s. [BRef s] -> ShowS
forall s. BRef s -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BRef s] -> ShowS
$cshowList :: forall s. [BRef s] -> ShowS
show :: BRef s -> String
$cshow :: forall s. BRef s -> String
showsPrec :: Int -> BRef s -> ShowS
$cshowsPrec :: forall s. Int -> BRef s -> ShowS
Show)

instance NFData (BRef s)

-- | This will force the value inside, as well.
instance NFData a => NFData (BVar s a) where
    rnf :: BVar s a -> ()
rnf (BV BRef s
r a
v) = forall a. NFData a => a -> a
force BRef s
r seq :: forall a b. a -> b -> b
`seq` forall a. NFData a => a -> a
force a
v seq :: forall a b. a -> b -> b
`seq` ()

-- | Project out a constant value if the 'BVar' refers to one.
bvConst :: BVar s a -> Maybe a
bvConst :: forall s a. BVar s a -> Maybe a
bvConst (BV BRef s
BRC !a
x) = forall a. a -> Maybe a
Just a
x
bvConst BVar s a
_           = forall a. Maybe a
Nothing
{-# INLINE bvConst #-}

forceBVar :: BVar s a -> ()
forceBVar :: forall s a. BVar s a -> ()
forceBVar (BV BRef s
r !a
_) = forall a. NFData a => a -> a
force BRef s
r seq :: forall a b. a -> b -> b
`seq` ()
{-# INLINE forceBVar #-}

data InpRef :: Type -> Type where
    IR :: { ()
_irIx    :: !(BVar s b)
          , ()
_irAdd   :: !(a -> b -> b)
          , ()
_irEmbed :: !(a -> b)
          }
       -> InpRef a

forceInpRef :: InpRef a -> ()
forceInpRef :: forall a. InpRef a -> ()
forceInpRef (IR BVar s b
v !a -> b -> b
_ !a -> b
_) = forall s a. BVar s a -> ()
forceBVar BVar s b
v seq :: forall a b. a -> b -> b
`seq` ()
{-# INLINE forceInpRef #-}

-- | Debugging string for an 'InpRef'.
debugIR :: InpRef a -> String
debugIR :: forall a. InpRef a -> String
debugIR IR{BVar s b
a -> b
a -> b -> b
_irEmbed :: a -> b
_irAdd :: a -> b -> b
_irIx :: BVar s b
_irEmbed :: ()
_irAdd :: ()
_irIx :: ()
..} = forall a. Show a => a -> String
show (forall s a. BVar s a -> BRef s
_bvRef BVar s b
_irIx)

data TapeNode :: Type -> Type where
    TN :: { ()
_tnInputs :: !(Rec InpRef as)
          , ()
_tnGrad   :: !(a -> Rec Identity as)
          }
       -> TapeNode a

forceTapeNode :: TapeNode a -> ()
forceTapeNode :: forall a. TapeNode a -> ()
forceTapeNode (TN Rec InpRef as
inps !a -> Rec Identity as
_) = forall {u} (f :: u -> *) m (rs :: [u]).
Monoid m =>
(forall (x :: u). f x -> m) -> Rec f rs -> m
VR.rfoldMap forall a. InpRef a -> ()
forceInpRef Rec InpRef as
inps seq :: forall a b. a -> b -> b
`seq` ()
{-# INLINE forceTapeNode #-}

data SomeTapeNode :: Type where
    STN :: { ()
_stnNode :: !(TapeNode a)
           }
        -> SomeTapeNode

forceSomeTapeNode :: SomeTapeNode -> ()
forceSomeTapeNode :: SomeTapeNode -> ()
forceSomeTapeNode (STN TapeNode a
n) = forall a. TapeNode a -> ()
forceTapeNode TapeNode a
n

-- | Debugging string for a 'SomeTapeMode'.
debugSTN :: SomeTapeNode -> String
debugSTN :: SomeTapeNode -> String
debugSTN (STN TN{Rec InpRef as
a -> Rec Identity as
_tnGrad :: a -> Rec Identity as
_tnInputs :: Rec InpRef as
_tnGrad :: ()
_tnInputs :: ()
..}) = forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {u} (f :: u -> *) m (rs :: [u]).
Monoid m =>
(forall (x :: u). f x -> m) -> Rec f rs -> m
VR.rfoldMap ((forall a. a -> [a] -> [a]
:[]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. InpRef a -> String
debugIR) forall a b. (a -> b) -> a -> b
$ Rec InpRef as
_tnInputs

-- | An ephemeral Wengert Tape in the environment.  Used internally to
-- track of the computational graph of variables.
--
-- For the end user, one can just imagine @'Reifies' s 'W'@ as a required
-- constraint on @s@ that allows backpropagation to work.
newtype W = W { W -> IORef (Int, [SomeTapeNode])
wRef :: IORef (Int, [SomeTapeNode]) }

initWengert :: IO W
initWengert :: IO W
initWengert = IORef (Int, [SomeTapeNode]) -> W
W forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef (Int
0,[])
{-# INLINE initWengert #-}

insertNode
    :: TapeNode a
    -> a                    -- ^ val
    -> W
    -> IO (BVar s a)
insertNode :: forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode a
tn !a
x !W
w = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((forall s a. BRef s -> a -> BVar s a
`BV` a
x) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Int -> BRef s
BRIx) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' (W -> IORef (Int, [SomeTapeNode])
wRef W
w) forall a b. (a -> b) -> a -> b
$ \(!Int
n,![SomeTapeNode]
t) ->
    let n' :: Int
n' = Int
n forall a. Num a => a -> a -> a
+ Int
1
        t' :: [SomeTapeNode]
t' = forall s. TapeNode s -> SomeTapeNode
STN TapeNode a
tn forall a. a -> [a] -> [a]
: [SomeTapeNode]
t
    in  forall a. TapeNode a -> ()
forceTapeNode TapeNode a
tn seq :: forall a b. a -> b -> b
`seq` Int
n' seq :: forall a b. a -> b -> b
`seq` [SomeTapeNode]
t' seq :: forall a b. a -> b -> b
`seq` ((Int
n', [SomeTapeNode]
t'), Int
n)
{-# INLINE insertNode #-}

-- | Lift a value into a 'BVar' representing a constant value.
--
-- This value will not be considered an input, and its gradients will not
-- be backpropagated.
constVar :: a -> BVar s a
constVar :: forall a s. a -> BVar s a
constVar = forall s a. BRef s -> a -> BVar s a
BV forall s. BRef s
BRC
{-# INLINE constVar #-}

liftOp_
    :: forall s as b. Reifies s W
    => Rec AddFunc as
    -> Op as b
    -> Rec (BVar s) as
    -> IO (BVar s b)
liftOp_ :: forall s (as :: [*]) b.
Reifies s W =>
Rec AddFunc as -> Op as b -> Rec (BVar s) as -> IO (BVar s b)
liftOp_ Rec AddFunc as
afs Op as b
o !Rec (BVar s) as
vs = case forall {u} (h :: * -> *) (f :: u -> *) (g :: u -> *) (rs :: [u]).
Applicative h =>
(forall (x :: u). f x -> h (g x)) -> Rec f rs -> h (Rec g rs)
rtraverse (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s a. BVar s a -> Maybe a
bvConst) Rec (BVar s) as
vs of
    Just Rec Identity as
xs -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a s. a -> BVar s a
constVar (forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op as b
o Rec Identity as
xs)
    Maybe (Rec Identity as)
Nothing -> forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode b
tn b
y (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
  where
    (b
y,b -> Rec Identity as
g) = forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op as b
o (forall {u} (f :: u -> *) (g :: u -> *) (rs :: [u]).
(forall (x :: u). f x -> g x) -> Rec f rs -> Rec g rs
VR.rmap (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s a. BVar s a -> a
_bvVal) Rec (BVar s) as
vs)
    tn :: TapeNode b
tn = TN { _tnInputs :: Rec InpRef as
_tnInputs = forall {k} (f :: k -> *) (g :: k -> *) (h :: k -> *).
(forall (x :: k). f x -> g x -> h x)
-> forall (xs :: [k]). Rec f xs -> Rec g xs -> Rec h xs
VR.rzipWith forall a. AddFunc a -> BVar s a -> InpRef a
go Rec AddFunc as
afs Rec (BVar s) as
vs
            , _tnGrad :: b -> Rec Identity as
_tnGrad   = b -> Rec Identity as
g
            }
    go :: forall a. AddFunc a -> BVar s a -> InpRef a
    go :: forall a. AddFunc a -> BVar s a -> InpRef a
go AddFunc a
af !BVar s a
v = forall s a. BVar s a -> ()
forceBVar BVar s a
v seq :: forall a b. a -> b -> b
`seq` forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) forall a. a -> a
id
    {-# INLINE go #-}
{-# INLINE liftOp_ #-}

-- | 'Numeric.Backprop.liftOp', but with explicit 'add' and 'zero'.
liftOp
    :: forall as b s. Reifies s W
    => Rec AddFunc as
    -> Op as b
    -> Rec (BVar s) as
    -> BVar s b
liftOp :: forall (as :: [*]) b s.
Reifies s W =>
Rec AddFunc as -> Op as b -> Rec (BVar s) as -> BVar s b
liftOp Rec AddFunc as
afs Op as b
o !Rec (BVar s) as
vs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall s (as :: [*]) b.
Reifies s W =>
Rec AddFunc as -> Op as b -> Rec (BVar s) as -> IO (BVar s b)
liftOp_ Rec AddFunc as
afs Op as b
o Rec (BVar s) as
vs
{-# INLINE liftOp #-}

liftOp1_
    :: forall a b s. Reifies s W
    => AddFunc a
    -> Op '[a] b
    -> BVar s a
    -> IO (BVar s b)
liftOp1_ :: forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> IO (BVar s b)
liftOp1_ AddFunc a
_  Op '[a] b
o (forall s a. BVar s a -> Maybe a
bvConst->Just a
x) = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s. a -> BVar s a
constVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op '[a] b
o forall a b. (a -> b) -> a -> b
$ (forall a. a -> Identity a
Identity a
x forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil)
liftOp1_ AddFunc a
af Op '[a] b
o BVar s a
v = forall s a. BVar s a -> ()
forceBVar BVar s a
v seq :: forall a b. a -> b -> b
`seq` forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode b
tn b
y (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
  where
    (b
y,b -> Rec Identity '[a]
g) = forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op '[a] b
o (forall a. a -> Identity a
Identity (forall s a. BVar s a -> a
_bvVal BVar s a
v) forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil)
    tn :: TapeNode b
tn = TN { _tnInputs :: Rec InpRef '[a]
_tnInputs = forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) forall a. a -> a
id forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
            , _tnGrad :: b -> Rec Identity '[a]
_tnGrad   = b -> Rec Identity '[a]
g
            }
{-# INLINE liftOp1_ #-}

-- | 'Numeric.Backprop.liftOp1', but with explicit 'add' and 'zero'.
liftOp1
    :: forall a b s. Reifies s W
    => AddFunc a
    -> Op '[a] b
    -> BVar s a
    -> BVar s b
liftOp1 :: forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 AddFunc a
af Op '[a] b
o !BVar s a
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> IO (BVar s b)
liftOp1_ AddFunc a
af Op '[a] b
o BVar s a
v
{-# INLINE liftOp1 #-}

liftOp2_
    :: forall a b c s. Reifies s W
    => AddFunc a
    -> AddFunc b
    -> Op '[a,b] c
    -> BVar s a
    -> BVar s b
    -> IO (BVar s c)
liftOp2_ :: forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> Op '[a, b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
liftOp2_ AddFunc a
_ AddFunc b
_ Op '[a, b] c
o (forall s a. BVar s a -> Maybe a
bvConst->Just a
x) (forall s a. BVar s a -> Maybe a
bvConst->Just b
y)
    = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s. a -> BVar s a
constVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op '[a, b] c
o forall a b. (a -> b) -> a -> b
$ forall a. a -> Identity a
Identity a
x forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall a. a -> Identity a
Identity b
y forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
liftOp2_ AddFunc a
afa AddFunc b
afb Op '[a, b] c
o BVar s a
v BVar s b
u = forall s a. BVar s a -> ()
forceBVar BVar s a
v
                   seq :: forall a b. a -> b -> b
`seq` forall s a. BVar s a -> ()
forceBVar BVar s b
u
                   seq :: forall a b. a -> b -> b
`seq` forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode c
tn c
y (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
  where
    (c
y,c -> Rec Identity '[a, b]
g) = forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op '[a, b] c
o forall a b. (a -> b) -> a -> b
$ forall a. a -> Identity a
Identity (forall s a. BVar s a -> a
_bvVal BVar s a
v)
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall a. a -> Identity a
Identity (forall s a. BVar s a -> a
_bvVal BVar s b
u)
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
    tn :: TapeNode c
tn = TN { _tnInputs :: Rec InpRef '[a, b]
_tnInputs = forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
afa) forall a. a -> a
id forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
u (forall a. AddFunc a -> a -> a -> a
runAF AddFunc b
afb) forall a. a -> a
id forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
            , _tnGrad :: c -> Rec Identity '[a, b]
_tnGrad   = c -> Rec Identity '[a, b]
g
            }
{-# INLINE liftOp2_ #-}

-- | 'Numeric.Backprop.liftOp2', but with explicit 'add' and 'zero'.
liftOp2
    :: forall a b c s. Reifies s W
    => AddFunc a
    -> AddFunc b
    -> Op '[a,b] c
    -> BVar s a
    -> BVar s b
    -> BVar s c
liftOp2 :: forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 AddFunc a
afa AddFunc b
afb Op '[a, b] c
o !BVar s a
v !BVar s b
u = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> Op '[a, b] c
-> BVar s a
-> BVar s b
-> IO (BVar s c)
liftOp2_ AddFunc a
afa AddFunc b
afb Op '[a, b] c
o BVar s a
v BVar s b
u
{-# INLINE liftOp2 #-}

liftOp3_
    :: forall a b c d s. Reifies s W
    => AddFunc a
    -> AddFunc b
    -> AddFunc c
    -> Op '[a,b,c] d
    -> BVar s a
    -> BVar s b
    -> BVar s c
    -> IO (BVar s d)
liftOp3_ :: forall a b c d s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a, b, c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
liftOp3_ AddFunc a
_ AddFunc b
_ AddFunc c
_ Op '[a, b, c] d
o (forall s a. BVar s a -> Maybe a
bvConst->Just a
x) (forall s a. BVar s a -> Maybe a
bvConst->Just b
y) (forall s a. BVar s a -> Maybe a
bvConst->Just c
z)
    = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s. a -> BVar s a
constVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (as :: [*]) a. Op as a -> Rec Identity as -> a
evalOp Op '[a, b, c] d
o forall a b. (a -> b) -> a -> b
$ forall a. a -> Identity a
Identity a
x
                                  forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall a. a -> Identity a
Identity b
y
                                  forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall a. a -> Identity a
Identity c
z
                                  forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
liftOp3_ AddFunc a
afa AddFunc b
afb AddFunc c
afc Op '[a, b, c] d
o BVar s a
v BVar s b
u BVar s c
w = forall s a. BVar s a -> ()
forceBVar BVar s a
v
                         seq :: forall a b. a -> b -> b
`seq` forall s a. BVar s a -> ()
forceBVar BVar s b
u
                         seq :: forall a b. a -> b -> b
`seq` forall s a. BVar s a -> ()
forceBVar BVar s c
w
                         seq :: forall a b. a -> b -> b
`seq` forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode d
tn d
y (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
  where
    (d
y, d -> Rec Identity '[a, b, c]
g) = forall (as :: [*]) a.
Op as a -> Rec Identity as -> (a, a -> Rec Identity as)
runOpWith Op '[a, b, c] d
o forall a b. (a -> b) -> a -> b
$ forall a. a -> Identity a
Identity (forall s a. BVar s a -> a
_bvVal BVar s a
v)
                        forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall a. a -> Identity a
Identity (forall s a. BVar s a -> a
_bvVal BVar s b
u)
                        forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall a. a -> Identity a
Identity (forall s a. BVar s a -> a
_bvVal BVar s c
w)
                        forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
    tn :: TapeNode d
tn = TN { _tnInputs :: Rec InpRef '[a, b, c]
_tnInputs = forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
afa) forall a. a -> a
id
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
u (forall a. AddFunc a -> a -> a -> a
runAF AddFunc b
afb) forall a. a -> a
id
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s c
w (forall a. AddFunc a -> a -> a -> a
runAF AddFunc c
afc) forall a. a -> a
id
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
            , _tnGrad :: d -> Rec Identity '[a, b, c]
_tnGrad   = d -> Rec Identity '[a, b, c]
g
            }
{-# INLINE liftOp3_ #-}

-- | 'Numeric.Backprop.liftOp3', but with explicit 'add' and 'zero'.
liftOp3
    :: forall a b c d s. Reifies s W
    => AddFunc a
    -> AddFunc b
    -> AddFunc c
    -> Op '[a,b,c] d
    -> BVar s a
    -> BVar s b
    -> BVar s c
    -> BVar s d
liftOp3 :: forall a b c d s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a, b, c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> BVar s d
liftOp3 AddFunc a
afa AddFunc b
afb AddFunc c
afc Op '[a, b, c] d
o !BVar s a
v !BVar s b
u !BVar s c
w = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a b c d s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> AddFunc c
-> Op '[a, b, c] d
-> BVar s a
-> BVar s b
-> BVar s c
-> IO (BVar s d)
liftOp3_ AddFunc a
afa AddFunc b
afb AddFunc c
afc Op '[a, b, c] d
o BVar s a
v BVar s b
u BVar s c
w
{-# INLINE liftOp3 #-}

-- TODO: can we get the zero and add func from the bvar?
viewVar_
    :: forall a b s. Reifies s W
    => AddFunc a
    -> ZeroFunc b
    -> Lens' b a
    -> BVar s b
    -> IO (BVar s a)
viewVar_ :: forall a b s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Lens' b a -> BVar s b -> IO (BVar s a)
viewVar_ AddFunc a
af ZeroFunc b
z Lens' b a
l BVar s b
v = forall s a. BVar s a -> ()
forceBVar BVar s b
v seq :: forall a b. a -> b -> b
`seq` forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode a
tn a
y (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
  where
    x :: b
x = forall s a. BVar s a -> a
_bvVal BVar s b
v
    y :: a
y = b
x forall s a. s -> Getting a s a -> a
^. Lens' b a
l
    tn :: TapeNode a
tn = TN { _tnInputs :: Rec InpRef '[a]
_tnInputs = forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
v (forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over Lens' b a
l forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) (\a
g -> forall s t a b. ASetter s t a b -> b -> s -> t
set Lens' b a
l a
g (forall a. ZeroFunc a -> a -> a
runZF ZeroFunc b
z b
x))
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
            , _tnGrad :: a -> Rec Identity '[a]
_tnGrad   = (forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Identity a
Identity
            }
{-# INLINE viewVar_ #-}

-- | 'Numeric.Backprop.viewVar', but with explicit 'add' and 'zero'.
viewVar
    :: forall a b s. Reifies s W
    => AddFunc a
    -> ZeroFunc b
    -> Lens' b a
    -> BVar s b
    -> BVar s a
viewVar :: forall a b s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Lens' b a -> BVar s b -> BVar s a
viewVar AddFunc a
af ZeroFunc b
z Lens' b a
l !BVar s b
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a b s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Lens' b a -> BVar s b -> IO (BVar s a)
viewVar_ AddFunc a
af ZeroFunc b
z Lens' b a
l BVar s b
v
{-# INLINE viewVar #-}

-- TODO: can zero and add func be gotten from the input bvars?
setVar_
    :: forall a b s. Reifies s W
    => AddFunc a
    -> AddFunc b
    -> ZeroFunc a
    -> Lens' b a
    -> BVar s a
    -> BVar s b
    -> IO (BVar s b)
setVar_ :: forall a b s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
setVar_ AddFunc a
afa AddFunc b
afb ZeroFunc a
za Lens' b a
l BVar s a
w BVar s b
v = forall s a. BVar s a -> ()
forceBVar BVar s b
v
                     seq :: forall a b. a -> b -> b
`seq` forall s a. BVar s a -> ()
forceBVar BVar s a
w
                     seq :: forall a b. a -> b -> b
`seq` forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode b
tn b
y (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
  where
    y :: b
y = forall s a. BVar s a -> a
_bvVal BVar s b
v forall a b. a -> (a -> b) -> b
& Lens' b a
l forall s t a b. ASetter s t a b -> b -> s -> t
.~ forall s a. BVar s a -> a
_bvVal BVar s a
w
    tn :: TapeNode b
tn = TN { _tnInputs :: Rec InpRef '[a, b]
_tnInputs = forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
w (forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
afa) forall a. a -> a
id
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
v (forall a. AddFunc a -> a -> a -> a
runAF AddFunc b
afb) forall a. a -> a
id
                       forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
            , _tnGrad :: b -> Rec Identity '[a, b]
_tnGrad   = \b
d -> let (a
dw,b
dv) = Lens' b a
l (\a
x -> (a
x, forall a. ZeroFunc a -> a -> a
runZF ZeroFunc a
za a
x)) b
d
                                in  forall a. a -> Identity a
Identity a
dw forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall a. a -> Identity a
Identity b
dv forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
            }
{-# INLINE setVar_ #-}

-- | 'Numeric.Backprop.setVar', but with explicit 'add' and 'zero'.
setVar
    :: forall a b s. Reifies s W
    => AddFunc a
    -> AddFunc b
    -> ZeroFunc a
    -> Lens' b a
    -> BVar s a
    -> BVar s b
    -> BVar s b
setVar :: forall a b s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> BVar s b
setVar AddFunc a
afa AddFunc b
afb ZeroFunc a
za Lens' b a
l !BVar s a
w !BVar s b
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a b s.
Reifies s W =>
AddFunc a
-> AddFunc b
-> ZeroFunc a
-> Lens' b a
-> BVar s a
-> BVar s b
-> IO (BVar s b)
setVar_ AddFunc a
afa AddFunc b
afb ZeroFunc a
za Lens' b a
l BVar s a
w BVar s b
v
{-# INLINE setVar #-}

-- | 'Numeric.Backprop.sequenceVar', but with explicit 'add' and 'zero'.
sequenceVar
    :: forall t a s. (Reifies s W, Traversable t)
    => AddFunc a
    -> ZeroFunc a
    -> BVar s (t a)
    -> t (BVar s a)
sequenceVar :: forall (t :: * -> *) a s.
(Reifies s W, Traversable t) =>
AddFunc a -> ZeroFunc a -> BVar s (t a) -> t (BVar s a)
sequenceVar AddFunc a
af ZeroFunc a
z !BVar s (t a)
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af (forall a. (a -> a) -> ZeroFunc a
ZF (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. ZeroFunc a -> a -> a
runZF ZeroFunc a
z))) forall a. a -> a
id forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse BVar s (t a)
v
{-# INLINE sequenceVar #-}

-- TODO: can add funcs and zeros be had from bvars and Functor instance?
collectVar_
    :: forall t a s. (Reifies s W, Foldable t, Functor t)
    => AddFunc a
    -> ZeroFunc a
    -> t (BVar s a)
    -> IO (BVar s (t a))
collectVar_ :: forall (t :: * -> *) a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a -> ZeroFunc a -> t (BVar s a) -> IO (BVar s (t a))
collectVar_ AddFunc a
af ZeroFunc a
z !t (BVar s a)
vs = forall {k} (f :: k -> *) (a :: k) r.
[f a] -> (forall (n :: Nat). VecT n f a -> r) -> r
withVec (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t (BVar s a)
vs) forall a b. (a -> b) -> a -> b
$ \(VecT n (BVar s) a
vVec :: VecT n (BVar s) a) -> do
    let tn :: TapeNode (t a)
        tn :: TapeNode (t a)
tn = TN
          { _tnInputs :: Rec InpRef (Replicate n a)
_tnInputs = forall {k} (n :: Nat) (f :: k -> *) (a :: k).
VecT n f a -> Rec f (Replicate n a)
vecToRec (forall {k} (n :: Nat) (f :: k -> *) (g :: k -> *) (a :: k).
(f a -> g a) -> VecT n f a -> VecT n g a
vmap (\BVar s a
v -> forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s a
v (forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af) forall a. a -> a
id) VecT n (BVar s) a
vVec)
          , _tnGrad :: t a -> Rec Identity (Replicate n a)
_tnGrad   = forall {k} (n :: Nat) (f :: k -> *) (a :: k).
VecT n f a -> Rec f (Replicate n a)
vecToRec
                      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (a :: k1) b (c :: k2) (f :: k1 -> *)
       (g :: k2 -> *) (n :: Nat).
(f a -> Maybe b -> g c) -> VecT n f a -> [b] -> VecT n g c
zipVecList (\BVar s a
v -> forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a -> a
fromMaybe (forall a. ZeroFunc a -> a -> a
runZF ZeroFunc a
z (forall s a. BVar s a -> a
_bvVal BVar s a
v))) VecT n (BVar s) a
vVec
                      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
          }
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall a. a -> IO a
evaluate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s a. BVar s a -> ()
forceBVar) t (BVar s a)
vs
    forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode (t a)
tn (forall s a. BVar s a -> a
_bvVal forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t (BVar s a)
vs) (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
{-# INLINE collectVar_ #-}

-- | 'Numeric.Backprop.collectVar', but with explicit 'add' and 'zero'.
collectVar
    :: forall t a s. (Reifies s W, Foldable t, Functor t)
    => AddFunc a
    -> ZeroFunc a
    -> t (BVar s a)
    -> BVar s (t a)
collectVar :: forall (t :: * -> *) a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a -> ZeroFunc a -> t (BVar s a) -> BVar s (t a)
collectVar AddFunc a
af ZeroFunc a
z !t (BVar s a)
vs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a s.
(Reifies s W, Foldable t, Functor t) =>
AddFunc a -> ZeroFunc a -> t (BVar s a) -> IO (BVar s (t a))
collectVar_ AddFunc a
af ZeroFunc a
z t (BVar s a)
vs
{-# INLINE collectVar #-}

traverseVar'
    :: forall b a f s. (Reifies s W, Traversable f)
    => AddFunc a
    -> ZeroFunc b
    -> (b -> f a)
    -> Traversal' b a
    -> BVar s b
    -> IO (f (BVar s a))
traverseVar' :: forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af ZeroFunc b
z b -> f a
f Traversal' b a
t BVar s b
v = forall s a. BVar s a -> ()
forceBVar BVar s b
v
                    seq :: forall a b. a -> b -> b
`seq` forall (t :: * -> *) a b (f :: * -> *).
(Traversable t, Monad f) =>
(Int -> a -> f b) -> t a -> f (t b)
itraverse Int -> a -> IO (BVar s a)
go (b -> f a
f b
x)
  where
    x :: b
x = forall s a. BVar s a -> a
_bvVal BVar s b
v
    go :: Int -> a -> IO (BVar s a)
    go :: Int -> a -> IO (BVar s a)
go Int
i a
y = forall a s. TapeNode a -> a -> W -> IO (BVar s a)
insertNode TapeNode a
tn a
y (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy @s))
      where
        tn :: TapeNode a
tn = TN { _tnInputs :: Rec InpRef '[a]
_tnInputs = forall s b a. BVar s b -> (a -> b -> b) -> (a -> b) -> InpRef a
IR BVar s b
v (forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over (forall b a. Traversal' b a -> Int -> Lens' b a
ixt Traversal' b a
t Int
i) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. AddFunc a -> a -> a -> a
runAF AddFunc a
af)
                                   (\a
g -> forall s t a b. ASetter s t a b -> b -> s -> t
set (forall b a. Traversal' b a -> Int -> Lens' b a
ixt Traversal' b a
t Int
i) a
g (forall a. ZeroFunc a -> a -> a
runZF ZeroFunc b
z b
x))
                           forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil
                , _tnGrad :: a -> Rec Identity '[a]
_tnGrad   = (forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {u} (a :: u -> *). Rec a '[]
RNil) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Identity a
Identity
                }
    {-# INLINE go #-}
{-# INLINE traverseVar' #-}

-- | 'Numeric.Backprop.previewVar', but with explicit 'add' and 'zero'.
previewVar
    :: forall b a s. Reifies s W
    => AddFunc a
    -> ZeroFunc b
    -> Traversal' b a
    -> BVar s b
    -> Maybe (BVar s a)
previewVar :: forall b a s.
Reifies s W =>
AddFunc a
-> ZeroFunc b -> Traversal' b a -> BVar s b -> Maybe (BVar s a)
previewVar AddFunc a
af ZeroFunc b
z Traversal' b a
t !BVar s b
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af ZeroFunc b
z (forall a s. Getting (First a) s a -> s -> Maybe a
preview Traversal' b a
t) Traversal' b a
t BVar s b
v
{-# INLINE previewVar #-}

-- | 'Numeric.Backprop.toListOfVar', but with explicit 'add' and 'zero'.
toListOfVar
    :: forall b a s. Reifies s W
    => AddFunc a
    -> ZeroFunc b
    -> Traversal' b a
    -> BVar s b
    -> [BVar s a]
toListOfVar :: forall b a s.
Reifies s W =>
AddFunc a -> ZeroFunc b -> Traversal' b a -> BVar s b -> [BVar s a]
toListOfVar AddFunc a
af ZeroFunc b
z Traversal' b a
t !BVar s b
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall b a (f :: * -> *) s.
(Reifies s W, Traversable f) =>
AddFunc a
-> ZeroFunc b
-> (b -> f a)
-> Traversal' b a
-> BVar s b
-> IO (f (BVar s a))
traverseVar' AddFunc a
af ZeroFunc b
z (forall a s. Getting (Endo [a]) s a -> s -> [a]
toListOf Traversal' b a
t) Traversal' b a
t BVar s b
v
{-# INLINE toListOfVar #-}

-- | Coerce a 'BVar' contents.  Useful for things like newtype wrappers.
--
-- @since 0.1.5.2
coerceVar
    :: Coercible a b
    => BVar s a
    -> BVar s b
coerceVar :: forall a b s. Coercible a b => BVar s a -> BVar s b
coerceVar v :: BVar s a
v@(BV BRef s
r a
x) = forall s a. BVar s a -> ()
forceBVar BVar s a
v seq :: forall a b. a -> b -> b
`seq` forall s a. BRef s -> a -> BVar s a
BV BRef s
r (coerce :: forall a b. Coercible a b => a -> b
coerce a
x)

data Runner s = R { forall s. Runner s -> MVector s (Maybe Any)
_rDelta  :: !(MV.MVector s (Maybe Any))
                  , forall s. Runner s -> MVector s (Maybe Any)
_rInputs :: !(MV.MVector s (Maybe Any))
                  }

initRunner
    :: (Int, [SomeTapeNode])
    -> (Int, [Maybe Any])
    -> ST s (Runner s)
initRunner :: forall s.
(Int, [SomeTapeNode]) -> (Int, [Maybe Any]) -> ST s (Runner s)
initRunner (Int
n, [SomeTapeNode]
stns) (Int
nx,[Maybe Any]
xs) = do
    MVector s (Maybe Any)
delts <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new Int
n
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
nforall a. Num a => a -> a -> a
-Int
1,Int
nforall a. Num a => a -> a -> a
-Int
2..] [SomeTapeNode]
stns) forall a b. (a -> b) -> a -> b
$ \(Int
i, STN (TN{} :: TapeNode c)) ->
      forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (Maybe Any)
delts Int
i forall a b. (a -> b) -> a -> b
$ forall a b. a -> b
unsafeCoerce (forall a. Maybe a
Nothing @c)
    MVector s (Maybe Any)
inps <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new Int
nx
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Maybe Any]
xs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ \Int
i Maybe Any
z ->
      forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (Maybe Any)
inps Int
i Maybe Any
z
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s.
MVector s (Maybe Any) -> MVector s (Maybe Any) -> Runner s
R MVector s (Maybe Any)
delts MVector s (Maybe Any)
inps
{-# INLINE initRunner #-}

gradRunner
    :: forall b s. ()
    => b                        -- ^ one
    -> Runner s
    -> (Int, [SomeTapeNode])
    -> ST s ()
gradRunner :: forall b s. b -> Runner s -> (Int, [SomeTapeNode]) -> ST s ()
gradRunner b
o R{MVector s (Maybe Any)
_rInputs :: MVector s (Maybe Any)
_rDelta :: MVector s (Maybe Any)
_rInputs :: forall s. Runner s -> MVector s (Maybe Any)
_rDelta :: forall s. Runner s -> MVector s (Maybe Any)
..} (Int
n,[SomeTapeNode]
stns) = do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n forall a. Ord a => a -> a -> Bool
> Int
0) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s (Maybe Any)
_rDelta (Int
n forall a. Num a => a -> a -> a
- Int
1) (forall a b. a -> b
unsafeCoerce (forall a. a -> Maybe a
Just b
o))
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> SomeTapeNode -> ST s ()
go [Int
nforall a. Num a => a -> a -> a
-Int
1,Int
nforall a. Num a => a -> a -> a
-Int
2..] [SomeTapeNode]
stns
  where
    go :: Int -> SomeTapeNode -> ST s ()
    go :: Int -> SomeTapeNode -> ST s ()
go Int
i (STN (TN{Rec InpRef as
a -> Rec Identity as
_tnGrad :: a -> Rec Identity as
_tnInputs :: Rec InpRef as
_tnGrad :: ()
_tnInputs :: ()
..} :: TapeNode c)) = do
      Maybe Any
delt <- forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s (Maybe Any)
_rDelta Int
i
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe Any
delt forall a b. (a -> b) -> a -> b
$ \Any
d -> do
        let gs :: Rec Identity as
gs = a -> Rec Identity as
_tnGrad (forall a b. a -> b
unsafeCoerce Any
d)
        forall {u} (h :: * -> *) (f :: u -> *) (g :: u -> *) (as :: [u]).
Applicative h =>
(forall (a :: u). f a -> g a -> h ())
-> Rec f as -> Rec g as -> h ()
rzipWithM_ forall x. InpRef x -> Identity x -> ST s ()
propagate Rec InpRef as
_tnInputs Rec Identity as
gs
    {-# INLINE go #-}
    propagate :: forall x. InpRef x -> Identity x -> ST s ()
    propagate :: forall x. InpRef x -> Identity x -> ST s ()
propagate (IR BVar s b
v x -> b -> b
(+*) x -> b
e) (Identity x
d) = case forall s a. BVar s a -> BRef s
_bvRef BVar s b
v of
      BRInp Int
i -> forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s (Maybe Any)
_rInputs) Int
i forall a b. (a -> b) -> a -> b
$
        forall a b. a -> b
unsafeCoerce forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> (a -> b -> b) -> (a -> b) -> Maybe b -> Maybe b
bumpMaybe x
d x -> b -> b
(+*) x -> b
e forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b
unsafeCoerce
      BRIx Int
i -> forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s (Maybe Any)
_rDelta) Int
i forall a b. (a -> b) -> a -> b
$
        forall a b. a -> b
unsafeCoerce forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> (a -> b -> b) -> (a -> b) -> Maybe b -> Maybe b
bumpMaybe x
d x -> b -> b
(+*) x -> b
e forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b
unsafeCoerce
      BRef s
BRC     -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    {-# INLINE propagate #-}
{-# INLINE gradRunner #-}

bumpMaybe
    :: a                -- ^ val
    -> (a -> b -> b)    -- ^ add
    -> (a -> b)         -- ^ embed
    -> Maybe b
    -> Maybe b
bumpMaybe :: forall a b. a -> (a -> b -> b) -> (a -> b) -> Maybe b -> Maybe b
bumpMaybe a
x a -> b -> b
(+*) a -> b
e = \case
    Maybe b
Nothing -> forall a. a -> Maybe a
Just (a -> b
e a
x)
    Just b
y  -> forall a. a -> Maybe a
Just (a
x a -> b -> b
+* b
y)
{-# INLINE bumpMaybe #-}

seqEither :: Either a (b, [SomeTapeNode]) -> Either a (b, [SomeTapeNode])
seqEither :: forall a b.
Either a (b, [SomeTapeNode]) -> Either a (b, [SomeTapeNode])
seqEither e :: Either a (b, [SomeTapeNode])
e@(Left !a
_)                                    = Either a (b, [SomeTapeNode])
e
seqEither e :: Either a (b, [SomeTapeNode])
e@(Right (!b
_,forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SomeTapeNode -> ()
forceSomeTapeNode->(!()
_))) = Either a (b, [SomeTapeNode])
e
{-# INLINE seqEither #-}

-- | 'Numeric.Backprop.backpropWithN', but with explicit 'zero' and 'one'.
--
-- Note that argument order changed in v0.2.4.
--
-- @since 0.2.0.0
backpropWithN
    :: forall as b. ()
    => Rec ZeroFunc as
    -> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
    -> Rec Identity as
    -> (b, b -> Rec Identity as)
backpropWithN :: forall (as :: [*]) b.
Rec ZeroFunc as
-> (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as
-> (b, b -> Rec Identity as)
backpropWithN Rec ZeroFunc as
zfs forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f !Rec Identity as
xs = (b
y, b -> Rec Identity as
g')
  where
    !(forall a b.
Either a (b, [SomeTapeNode]) -> Either a (b, [SomeTapeNode])
seqEither->(!Either Int (Int, [SomeTapeNode])
tp0),!b
y) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f Rec Identity as
xs
    g' :: b -> Rec Identity as
    g' :: b -> Rec Identity as
g' = case Either Int (Int, [SomeTapeNode])
tp0 of
      Left Int
i   -> Int -> b -> Rec Identity as
setInput Int
i
      Right (Int, [SomeTapeNode])
tp -> (Int, [SomeTapeNode]) -> b -> Rec Identity as
g (Int, [SomeTapeNode])
tp
    {-# INLINE g' #-}
    g :: (Int, [SomeTapeNode]) -> b -> Rec Identity as
    g :: (Int, [SomeTapeNode]) -> b -> Rec Identity as
g (Int, [SomeTapeNode])
tp b
o = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
        Runner s
r <- forall s.
(Int, [SomeTapeNode]) -> (Int, [Maybe Any]) -> ST s (Runner s)
initRunner (Int, [SomeTapeNode])
tp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. Sum a -> a
getSum (forall a. Endo a -> a -> a
`appEndo` [])
                           forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {u} (f :: u -> *) m (rs :: [u]).
Monoid m =>
(forall (x :: u). f x -> m) -> Rec f rs -> m
VR.rfoldMap forall a. Identity a -> (Sum Int, Endo [Maybe Any])
go     -- TODO: use strict tuple?
                           forall a b. (a -> b) -> a -> b
$ Rec Identity as
xs
        forall b s. b -> Runner s -> (Int, [SomeTapeNode]) -> ST s ()
gradRunner b
o Runner s
r (Int, [SomeTapeNode])
tp
        [Maybe Any]
delts <- forall (t :: * -> *) a. Foldable t => t a -> [a]
toList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze (forall s. Runner s -> MVector s (Maybe Any)
_rInputs Runner s
r)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a -> a
fromMaybe (forall a. String -> a
internalError String
"backpropN") forall a b. (a -> b) -> a -> b
$
          forall {u} (f :: u -> *) (g :: u -> *) (as :: [u]) c.
(forall (a :: u). f a -> c -> g a)
-> Rec f as -> [c] -> Maybe (Rec g as)
fillRec (\Identity a
z -> forall b a. b -> (a -> b) -> Maybe a -> b
maybe Identity a
z (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b
unsafeCoerce))
            (forall {k} (f :: k -> *) (g :: k -> *) (h :: k -> *).
(forall (x :: k). f x -> g x -> h x)
-> forall (xs :: [k]). Rec f xs -> Rec g xs -> Rec h xs
VR.rzipWith (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ZeroFunc a -> a -> a
runZF) Rec ZeroFunc as
zfs Rec Identity as
xs)
            [Maybe Any]
delts
      where
        go :: forall a. Identity a -> (Sum Int, Endo [Maybe Any])
        go :: forall a. Identity a -> (Sum Int, Endo [Maybe Any])
go Identity a
_ = (Sum Int
1, forall a. (a -> a) -> Endo a
Endo (forall a b. a -> b
unsafeCoerce (forall a. Maybe a
Nothing @a) forall a. a -> [a] -> [a]
:))
        {-# INLINE go #-}
    setInput :: Int -> b -> Rec Identity as
    setInput :: Int -> b -> Rec Identity as
setInput !Int
i !b
x = forall (bs :: [*]).
Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go Rec ZeroFunc as
zfs Rec Identity as
xs Int
0
      where
        go :: Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
        go :: forall (bs :: [*]).
Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go = \case
          Rec ZeroFunc bs
RNil    -> \Rec Identity bs
_ Int
_ -> forall {u} (a :: u -> *). Rec a '[]
RNil
          ZeroFunc r
z :& Rec ZeroFunc rs
zs -> \case
            Identity r
q :& Rec Identity rs
qs -> \(!Int
j) ->
              if Int
j forall a. Eq a => a -> a -> Bool
== Int
i
                then forall a. a -> Identity a
Identity (forall a b. a -> b
unsafeCoerce b
x) forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall {k} (f :: k -> *) (g :: k -> *) (h :: k -> *).
(forall (x :: k). f x -> g x -> h x)
-> forall (xs :: [k]). Rec f xs -> Rec g xs -> Rec h xs
VR.rzipWith coerce :: forall a b. Coercible a b => a -> b
coerce Rec ZeroFunc rs
zs Rec Identity rs
qs
                else coerce :: forall a b. Coercible a b => a -> b
coerce ZeroFunc r
z Identity r
q forall {u} (a :: u -> *) (r :: u) (rs :: [u]).
a r -> Rec a rs -> Rec a (r : rs)
:& forall (bs :: [*]).
Rec ZeroFunc bs -> Rec Identity bs -> Int -> Rec Identity bs
go Rec ZeroFunc rs
zs Rec Identity rs
qs (Int
j forall a. Num a => a -> a -> a
+ Int
1)
    {-# INLINE setInput #-}
{-# INLINE backpropWithN #-}

-- | 'evalBP' generalized to multiple inputs of different types.  See
-- documentation for 'Numeric.Backprop.backpropN' for more details.
evalBPN
    :: forall as b. ()
    => (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
    -> Rec Identity as
    -> b
evalBPN :: forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> b
evalBPN forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f = forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f
{-# INLINE evalBPN #-}

fillWengert
    :: forall as b. ()
    => (forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
    -> Rec Identity as
    -> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert :: forall (as :: [*]) b.
(forall s. Reifies s W => Rec (BVar s) as -> BVar s b)
-> Rec Identity as -> IO (Either Int (Int, [SomeTapeNode]), b)
fillWengert forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f Rec Identity as
xs = do
    W
w <- IO W
initWengert
    (Maybe Int
i, b
o) <- forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify W
w forall a b. (a -> b) -> a -> b
$ \(Proxy s
Proxy :: Proxy s) -> do
      let oVar :: BVar s b
oVar = forall s. Reifies s W => Rec (BVar s) as -> BVar s b
f (forall s. Rec (BVar s) as
inpRec @s)
      forall a. a -> IO a
evaluate (forall s a. BVar s a -> ()
forceBVar BVar s b
oVar)
      let isInput :: Maybe Int
isInput = case forall s a. BVar s a -> BRef s
_bvRef BVar s b
oVar of
            BRInp Int
i -> forall a. a -> Maybe a
Just Int
i
            BRef s
_       -> forall a. Maybe a
Nothing
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int
isInput, forall s a. BVar s a -> a
_bvVal BVar s b
oVar)
    Either Int (Int, [SomeTapeNode])
t <- case Maybe Int
i of
      Maybe Int
Nothing -> forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef (W -> IORef (Int, [SomeTapeNode])
wRef W
w)
      Just Int
i' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left Int
i'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Int (Int, [SomeTapeNode])
t, b
o)
  where
    inpRec :: forall s. Rec (BVar s) as
    inpRec :: forall s. Rec (BVar s) as
inpRec = forall s a. State s a -> s -> a
evalState (forall {u} (h :: * -> *) (f :: u -> *) (g :: u -> *) (rs :: [u]).
Applicative h =>
(forall (x :: u). f x -> h (g x)) -> Rec f rs -> h (Rec g rs)
rtraverse (forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Int -> (BVar s a, Int)
go forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity) Rec Identity as
xs) Int
0
      where
        go :: a -> Int -> (BVar s a, Int)
        go :: forall a. a -> Int -> (BVar s a, Int)
go a
x Int
i = (forall s a. BRef s -> a -> BVar s a
BV (forall s. Int -> BRef s
BRInp Int
i) a
x, Int
i forall a. Num a => a -> a -> a
+ Int
1)
        {-# INLINE go #-}
    {-# INLINE inpRec #-}
{-# INLINE fillWengert #-}


instance (Num a, Reifies s W) => Num (BVar s a) where
    + :: BVar s a -> BVar s a -> BVar s a
(+)         = forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 forall a. Num a => AddFunc a
afNum forall a. Num a => AddFunc a
afNum forall a. Num a => Op '[a, a] a
(+.)
    {-# INLINE (+) #-}
    (-)         = forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 forall a. Num a => AddFunc a
afNum forall a. Num a => AddFunc a
afNum forall a. Num a => Op '[a, a] a
(-.)
    {-# INLINE (-) #-}
    * :: BVar s a -> BVar s a -> BVar s a
(*)         = forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 forall a. Num a => AddFunc a
afNum forall a. Num a => AddFunc a
afNum forall a. Num a => Op '[a, a] a
(*.)
    {-# INLINE (*) #-}
    negate :: BVar s a -> BVar s a
negate      = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Num a => Op '[a] a
negateOp
    {-# INLINE negate #-}
    signum :: BVar s a -> BVar s a
signum      = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Num a => Op '[a] a
signumOp
    {-# INLINE signum #-}
    abs :: BVar s a -> BVar s a
abs         = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Num a => Op '[a] a
absOp
    {-# INLINE abs #-}
    fromInteger :: Integer -> BVar s a
fromInteger = forall a s. a -> BVar s a
constVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger
    {-# INLINE fromInteger #-}

instance (Fractional a, Reifies s W) => Fractional (BVar s a) where
    / :: BVar s a -> BVar s a -> BVar s a
(/)          = forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 forall a. Num a => AddFunc a
afNum forall a. Num a => AddFunc a
afNum forall a. Fractional a => Op '[a, a] a
(/.)
    {-# INLINE (/) #-}
    recip :: BVar s a -> BVar s a
recip        = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Fractional a => Op '[a] a
recipOp
    {-# INLINE recip #-}
    fromRational :: Rational -> BVar s a
fromRational = forall a s. a -> BVar s a
constVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => Rational -> a
fromRational
    {-# INLINE fromRational #-}

instance (Floating a, Reifies s W) => Floating (BVar s a) where
    pi :: BVar s a
pi      = forall a s. a -> BVar s a
constVar forall a. Floating a => a
pi
    {-# INLINE pi #-}
    exp :: BVar s a -> BVar s a
exp     = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
expOp
    {-# INLINE exp #-}
    log :: BVar s a -> BVar s a
log     = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
logOp
    {-# INLINE log #-}
    sqrt :: BVar s a -> BVar s a
sqrt    = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
sqrtOp
    {-# INLINE sqrt #-}
    ** :: BVar s a -> BVar s a -> BVar s a
(**)    = forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 forall a. Num a => AddFunc a
afNum forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a, a] a
(**.)
    {-# INLINE (**) #-}
    logBase :: BVar s a -> BVar s a -> BVar s a
logBase = forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 forall a. Num a => AddFunc a
afNum forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a, a] a
logBaseOp
    {-# INLINE logBase #-}
    sin :: BVar s a -> BVar s a
sin     = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
sinOp
    {-# INLINE sin #-}
    cos :: BVar s a -> BVar s a
cos     = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
cosOp
    {-# INLINE cos #-}
    tan :: BVar s a -> BVar s a
tan     = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
tanOp
    {-# INLINE tan  #-}
    asin :: BVar s a -> BVar s a
asin    = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
asinOp
    {-# INLINE asin #-}
    acos :: BVar s a -> BVar s a
acos    = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
acosOp
    {-# INLINE acos #-}
    atan :: BVar s a -> BVar s a
atan    = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
atanOp
    {-# INLINE atan #-}
    sinh :: BVar s a -> BVar s a
sinh    = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
sinhOp
    {-# INLINE sinh #-}
    cosh :: BVar s a -> BVar s a
cosh    = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
coshOp
    {-# INLINE cosh #-}
    tanh :: BVar s a -> BVar s a
tanh    = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
tanhOp
    {-# INLINE tanh #-}
    asinh :: BVar s a -> BVar s a
asinh   = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
asinhOp
    {-# INLINE asinh #-}
    acosh :: BVar s a -> BVar s a
acosh   = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
acoshOp
    {-# INLINE acosh #-}
    atanh :: BVar s a -> BVar s a
atanh   = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Num a => AddFunc a
afNum forall a. Floating a => Op '[a] a
atanhOp
    {-# INLINE atanh #-}

-- | Compares the values inside the 'BVar'.
--
-- @since 0.1.5.0
instance Eq a => Eq (BVar s a) where
    == :: BVar s a -> BVar s a -> Bool
(==) = forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall s a. BVar s a -> a
_bvVal
    /= :: BVar s a -> BVar s a -> Bool
(/=) = forall a. Eq a => a -> a -> Bool
(/=) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall s a. BVar s a -> a
_bvVal

-- | Compares the values inside the 'BVar'.
--
-- @since 0.1.5.0
instance Ord a => Ord (BVar s a) where
    compare :: BVar s a -> BVar s a -> Ordering
compare = forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall s a. BVar s a -> a
_bvVal
    < :: BVar s a -> BVar s a -> Bool
(<)     = forall a. Ord a => a -> a -> Bool
(<)     forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall s a. BVar s a -> a
_bvVal
    <= :: BVar s a -> BVar s a -> Bool
(<=)    = forall a. Ord a => a -> a -> Bool
(<=)    forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall s a. BVar s a -> a
_bvVal
    > :: BVar s a -> BVar s a -> Bool
(>)     = forall a. Ord a => a -> a -> Bool
(>)     forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall s a. BVar s a -> a
_bvVal
    >= :: BVar s a -> BVar s a -> Bool
(>=)    = forall a. Ord a => a -> a -> Bool
(>=)    forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall s a. BVar s a -> a
_bvVal

-- Some utility functions to get around a lens dependency
itraverse
    :: forall t a b f. (Traversable t, Monad f)
    => (Int -> a -> f b) -> t a -> f (t b)
itraverse :: forall (t :: * -> *) a b (f :: * -> *).
(Traversable t, Monad f) =>
(Int -> a -> f b) -> t a -> f (t b)
itraverse Int -> a -> f b
f t a
xs = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int -> f (b, Int)
go) t a
xs) Int
0
  where
    go :: a -> Int -> f (b, Int)
    go :: a -> Int -> f (b, Int)
go a
x Int
i = (,Int
iforall a. Num a => a -> a -> a
+Int
1) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> a -> f b
f Int
i a
x
{-# INLINE itraverse #-}

ixi :: Int -> Lens' [a] a
ixi :: forall a. Int -> Lens' [a] a
ixi Int
_ a -> f a
_ []     = forall a. String -> a
internalError String
"ixi"
ixi Int
0 a -> f a
f (a
x:[a]
xs) = (forall a. a -> [a] -> [a]
:[a]
xs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f a
f a
x
ixi Int
n a -> f a
f (a
x:[a]
xs) = (a
xforall a. a -> [a] -> [a]
:)  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Int -> Lens' [a] a
ixi (Int
n forall a. Num a => a -> a -> a
- Int
1) a -> f a
f [a]
xs
{-# INLINE ixi #-}

ixt :: forall b a. Traversal' b a -> Int -> Lens' b a
ixt :: forall b a. Traversal' b a -> Int -> Lens' b a
ixt Traversal' b a
t Int
i a -> f a
f b
xs = [a] -> b
stuff forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Int -> Lens' [a] a
ixi Int
i a -> f a
f [a]
contents
  where
    contents :: [a]
contents = b
xs forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. Traversal' b a
t
    stuff :: [a] -> b
stuff    = forall s a. State s a -> s -> a
evalState (forall (f :: * -> *) s t a b.
LensLike f s t a b -> LensLike f s t a b
traverseOf Traversal' b a
t (forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const [a] -> (a, [a])
go) b
xs)
      where
        go :: [a] -> (a,  [a])
        go :: [a] -> (a, [a])
go []     = forall a. String -> a
internalError String
"ixt"
        go (a
y:[a]
ys) = (a
y, [a]
ys)
{-# INLINE ixt #-}

-- | @since 0.2.2.0
instance (Backprop a, Reifies s W) => Backprop (BVar s a) where
    zero :: BVar s a -> BVar s a
zero = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Backprop a => AddFunc a
addFunc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> (b, b -> a)) -> Op '[a] b
op1 forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Backprop a => a -> a
zero a
x, forall a. Backprop a => a -> a
zero)
    {-# INLINE zero #-}
    add :: BVar s a -> BVar s a -> BVar s a
add  = forall a b c s.
Reifies s W =>
AddFunc a
-> AddFunc b -> Op '[a, b] c -> BVar s a -> BVar s b -> BVar s c
liftOp2 forall a. Backprop a => AddFunc a
addFunc forall a. Backprop a => AddFunc a
addFunc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> (c, c -> (a, b))) -> Op '[a, b] c
op2 forall a b. (a -> b) -> a -> b
$ \a
x a
y ->
        ( forall a. Backprop a => a -> a -> a
add a
x a
y
        , \a
d -> (a
d, a
d)
        )
    {-# INLINE add #-}
    one :: BVar s a -> BVar s a
one  = forall a b s.
Reifies s W =>
AddFunc a -> Op '[a] b -> BVar s a -> BVar s b
liftOp1 forall a. Backprop a => AddFunc a
addFunc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> (b, b -> a)) -> Op '[a] b
op1 forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Backprop a => a -> a
one  a
x, forall a. Backprop a => a -> a
zero)
    {-# INLINE one #-}

-- | The canonical 'ZeroFunc' for instances of 'Backprop'.
--
-- @since 0.2.0.0
zeroFunc :: Backprop a => ZeroFunc a
zeroFunc :: forall a. Backprop a => ZeroFunc a
zeroFunc = forall a. (a -> a) -> ZeroFunc a
ZF forall a. Backprop a => a -> a
zero
{-# INLINE zeroFunc #-}

-- | The canonical 'AddFunc' for instances of 'Backprop'.
--
-- @since 0.2.0.0
addFunc :: Backprop a => AddFunc a
addFunc :: forall a. Backprop a => AddFunc a
addFunc = forall a. (a -> a -> a) -> AddFunc a
AF forall a. Backprop a => a -> a -> a
add
{-# INLINE addFunc #-}

-- | The canonical 'OneFunc' for instances of 'Backprop'.
--
-- @since 0.2.0.0
oneFunc :: Backprop a => OneFunc a
oneFunc :: forall a. Backprop a => OneFunc a
oneFunc = forall a. (a -> a) -> OneFunc a
OF forall a. Backprop a => a -> a
one
{-# INLINE oneFunc #-}

internalError :: String -> a
internalError :: forall a. String -> a
internalError String
m = forall a. String -> a
errorWithoutStackTrace forall a b. (a -> b) -> a -> b
$
    String
"Numeric.Backprop.Internal." forall a. [a] -> [a] -> [a]
++ String
m forall a. [a] -> [a] -> [a]
++ String
": unexpected shape involved in gradient computation"