{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-| Module : Grenade.Layers.Concat Description : Concatenation layer Copyright : (c) Huw Campbell, 2016-2017 License : BSD2 Stability : experimental This module provides the concatenation layer, which runs two chilld layers in parallel and combines their outputs. -} module Grenade.Layers.Concat ( Concat (..) ) where import Data.Serialize import Data.Singletons import GHC.TypeLits import Grenade.Core import Numeric.LinearAlgebra.Static ( row, (===), splitRows, unrow, (#), split, R ) -- | A Concatentating Layer. -- -- This layer shares it's input state between two sublayers, and concatenates their output. -- -- With Networks able to be Layers, this allows for very expressive composition of complex Networks. -- -- The Concat layer has a few instances, which allow one to flexibly "bash" together the outputs. -- -- Two 1D vectors, can go to a 2D shape with 2 rows if their lengths are identical. -- Any 2 1D vectors can also become a longer 1D Vector. -- -- 3D images become 3D images with more channels. The sizes must be the same, one can use Pad -- and Crop layers to ensure this is the case. data Concat :: Shape -> * -> Shape -> * -> * where Concat :: x -> y -> Concat m x n y instance (Show x, Show y) => Show (Concat m x n y) where show (Concat x y) = "Concat\n" ++ show x ++ "\n" ++ show y -- | Run two layers in parallel, combining their outputs. instance (UpdateLayer x, UpdateLayer y) => UpdateLayer (Concat m x n y) where type Gradient (Concat m x n y) = (Gradient x, Gradient y) runUpdate lr (Concat x y) (x', y') = Concat (runUpdate lr x x') (runUpdate lr y y') createRandom = Concat <$> createRandom <*> createRandom instance ( SingI i , Layer x i ('D1 o) , Layer y i ('D1 o) ) => Layer (Concat ('D1 o) x ('D1 o) y) i ('D2 2 o) where type Tape (Concat ('D1 o) x ('D1 o) y) i ('D2 2 o) = (Tape x i ('D1 o), Tape y i ('D1 o)) runForwards (Concat x y) input = let (xT, xOut :: S ('D1 o)) = runForwards x input (yT, yOut :: S ('D1 o)) = runForwards y input in case (xOut, yOut) of (S1D xOut', S1D yOut') -> ((xT, yT), S2D (row xOut' === row yOut')) runBackwards (Concat x y) (xTape, yTape) (S2D o) = let (ox, oy) = splitRows o (x', xB :: S i) = runBackwards x xTape (S1D $ unrow ox) (y', yB :: S i) = runBackwards y yTape (S1D $ unrow oy) in ((x', y'), xB + yB) instance ( SingI i , Layer x i ('D1 m) , Layer y i ('D1 n) , KnownNat o , KnownNat m , KnownNat n , o ~ (m + n) , n ~ (o - m) , (m <=? o) ~ 'True ) => Layer (Concat ('D1 m) x ('D1 n) y) i ('D1 o) where type Tape (Concat ('D1 m) x ('D1 n) y) i ('D1 o) = (Tape x i ('D1 m), Tape y i ('D1 n)) runForwards (Concat x y) input = let (xT, xOut :: S ('D1 m)) = runForwards x input (yT, yOut :: S ('D1 n)) = runForwards y input in case (xOut, yOut) of (S1D xOut', S1D yOut') -> ((xT, yT), S1D (xOut' # yOut')) runBackwards (Concat x y) (xTape, yTape) (S1D o) = let (ox :: R m , oy :: R n) = split o (x', xB :: S i) = runBackwards x xTape (S1D ox) (y', yB :: S i) = runBackwards y yTape (S1D oy) in ((x', y'), xB + yB) -- | Concat 3D shapes, increasing the number of channels. instance ( SingI i , Layer x i ('D3 rows cols m) , Layer y i ('D3 rows cols n) , KnownNat (rows * n) , KnownNat (rows * m) , KnownNat (rows * o) , KnownNat o , KnownNat m , KnownNat n , ((rows * m) + (rows * n)) ~ (rows * o) , ((rows * o) - (rows * m)) ~ (rows * n) , ((rows * m) <=? (rows * o)) ~ 'True ) => Layer (Concat ('D3 rows cols m) x ('D3 rows cols n) y) i ('D3 rows cols o) where type Tape (Concat ('D3 rows cols m) x ('D3 rows cols n) y) i ('D3 rows cols o) = (Tape x i ('D3 rows cols m), Tape y i ('D3 rows cols n)) runForwards (Concat x y) input = let (xT, xOut :: S ('D3 rows cols m)) = runForwards x input (yT, yOut :: S ('D3 rows cols n)) = runForwards y input in case (xOut, yOut) of (S3D xOut', S3D yOut') -> ((xT, yT), S3D (xOut' === yOut')) runBackwards (Concat x y) (xTape, yTape) (S3D o) = let (ox, oy) = splitRows o (x', xB :: S i) = runBackwards x xTape (S3D ox :: S ('D3 rows cols m)) (y', yB :: S i) = runBackwards y yTape (S3D oy :: S ('D3 rows cols n)) in ((x', y'), xB + yB) instance (Serialize a, Serialize b) => Serialize (Concat sa a sb b) where put (Concat a b) = put a *> put b get = Concat <$> get <*> get