{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module TensorSafe.Layers.Conv2D where
import Data.Kind (Type)
import Data.Map
import Data.Proxy
import GHC.TypeLits
import TensorSafe.Compile.Expr
import TensorSafe.Layer
data Conv2D :: Nat -> Nat -> Nat -> Nat -> Nat -> Nat -> Type where
Conv2D :: Conv2D channels filters kernelRows kernelColumns strideRows strideColumns
deriving Show
instance ( KnownNat channels
, KnownNat filters
, KnownNat kernelRows
, KnownNat kernelColumns
, KnownNat strideRows
, KnownNat strideColumns
) => Layer (Conv2D channels filters kernelRows kernelColumns strideRows strideColumns) where
layer = Conv2D
compile _ inputShape =
let filters = natVal (Proxy :: Proxy filters)
kernelRows = natVal (Proxy :: Proxy kernelRows)
kernelColumns = natVal (Proxy :: Proxy kernelColumns)
strideRows = natVal (Proxy :: Proxy strideRows)
strideColumns = natVal (Proxy :: Proxy strideColumns)
initialParams = case inputShape of
Just shape -> fromList [("inputShape", shape)]
Nothing -> empty
params = union initialParams (fromList [
("kernelSize", show [kernelRows, kernelColumns]),
("filters", show filters),
("strides", show [strideRows, strideColumns])
])
in
CNLayer DConv2D params