{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-| Module : Grenade.Core.Network Description : Merging layer for parallel network composition Copyright : (c) Huw Campbell, 2016-2017 License : BSD2 Stability : experimental -} module Grenade.Layers.Merge ( Merge (..) ) where import Data.Serialize import Data.Singletons import Grenade.Core -- | A Merging layer. -- -- Similar to Concat layer, except sums the activations instead of creating a larger -- shape. data Merge :: * -> * -> * where Merge :: x -> y -> Merge x y instance (Show x, Show y) => Show (Merge x y) where show (Merge x y) = "Merge\n" ++ show x ++ "\n" ++ show y -- | Run two layers in parallel, combining their outputs. -- This just kind of "smooshes" the weights together. instance (UpdateLayer x, UpdateLayer y) => UpdateLayer (Merge x y) where type Gradient (Merge x y) = (Gradient x, Gradient y) runUpdate lr (Merge x y) (x', y') = Merge (runUpdate lr x x') (runUpdate lr y y') createRandom = Merge <$> createRandom <*> createRandom -- | Combine the outputs and the inputs, summing the output shape instance (SingI i, SingI o, Layer x i o, Layer y i o) => Layer (Merge x y) i o where type Tape (Merge x y) i o = (Tape x i o, Tape y i o) runForwards (Merge x y) input = let (xT, xOut) = runForwards x input (yT, yOut) = runForwards y input in ((xT, yT), xOut + yOut) runBackwards (Merge x y) (xTape, yTape) o = let (x', xB) = runBackwards x xTape o (y', yB) = runBackwards y yTape o in ((x', y'), xB + yB) instance (Serialize a, Serialize b) => Serialize (Merge a b) where put (Merge a b) = put a *> put b get = Merge <$> get <*> get