{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-| This module declares the 2D convolutional layer data type. -}
module TensorSafe.Layers.Conv2D where

import           Data.Kind               (Type)
import           Data.Map
import           Data.Proxy
import           GHC.TypeLits

import           TensorSafe.Compile.Expr
import           TensorSafe.Layer


-- | A 2D Convolutional layer
data Conv2D :: Nat -> Nat -> Nat -> Nat -> Nat -> Nat -> Type where
    Conv2D :: Conv2D channels filters kernelRows kernelColumns strideRows strideColumns
    deriving Show

instance ( KnownNat channels
         , KnownNat filters
         , KnownNat kernelRows
         , KnownNat kernelColumns
         , KnownNat strideRows
         , KnownNat strideColumns
         ) => Layer (Conv2D channels filters kernelRows kernelColumns strideRows strideColumns) where
    layer = Conv2D
    compile _ inputShape =
        let filters = natVal (Proxy :: Proxy filters)
            kernelRows = natVal (Proxy :: Proxy kernelRows)
            kernelColumns = natVal (Proxy :: Proxy kernelColumns)
            strideRows = natVal (Proxy :: Proxy strideRows)
            strideColumns = natVal (Proxy :: Proxy strideColumns)

            initialParams = case inputShape of
                Just shape -> fromList [("inputShape", shape)]
                Nothing    -> empty
            params = union initialParams (fromList [
                    ("kernelSize", show [kernelRows, kernelColumns]),
                    ("filters", show filters),
                    ("strides", show [strideRows, strideColumns])
                ])
        in
            CNLayer DConv2D params