module Grenade.Layers.Merge (
Merge (..)
) where
import Data.Serialize
import Data.Singletons
import Grenade.Core
data Merge :: * -> * -> * where
Merge :: x -> y -> Merge x y
instance (Show x, Show y) => Show (Merge x y) where
show (Merge x y) = "Merge\n" ++ show x ++ "\n" ++ show y
instance (UpdateLayer x, UpdateLayer y) => UpdateLayer (Merge x y) where
type Gradient (Merge x y) = (Gradient x, Gradient y)
runUpdate lr (Merge x y) (x', y') = Merge (runUpdate lr x x') (runUpdate lr y y')
createRandom = Merge <$> createRandom <*> createRandom
instance (SingI i, SingI o, Layer x i o, Layer y i o) => Layer (Merge x y) i o where
type Tape (Merge x y) i o = (Tape x i o, Tape y i o)
runForwards (Merge x y) input =
let (xT, xOut) = runForwards x input
(yT, yOut) = runForwards y input
in ((xT, yT), xOut + yOut)
runBackwards (Merge x y) (xTape, yTape) o =
let (x', xB) = runBackwards x xTape o
(y', yB) = runBackwards y yTape o
in ((x', y'), xB + yB)
instance (Serialize a, Serialize b) => Serialize (Merge a b) where
put (Merge a b) = put a *> put b
get = Merge <$> get <*> get