{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-| This module declares the classing LSTM layer data type. -}
module TensorSafe.Layers.LSTM where

import           Data.Kind               (Type)
import           Data.Map
import           Data.Proxy
import           GHC.TypeLits

import           TensorSafe.Compile.Expr
import           TensorSafe.Layer


-- | A LSTM layer with a number of units and a option to return the original sequences.
data LSTM :: Nat -> Bool -> Type where
    LSTM :: LSTM units returnSequences
    deriving Show

instance (KnownNat units) => Layer (LSTM units b) where
    layer = LSTM
    compile _ _ =
        let units = show $ natVal (Proxy :: Proxy units)
            returnSequences = show $ (Proxy :: Proxy returnSequences)
        in
            CNLayer DLSTM (fromList [
                ("units", units),
                ("returnSequences", returnSequences)
            ])