module TypedFlow.Layers.Core
(
DenseP(..), dense, (#),
DropProb(..), mkDropout, mkDropouts,
EmbeddingP(..), embedding,
ConvP(..), conv, maxPool2D)
where
import Prelude hiding (tanh,Num(..),Floating(..),floor)
import qualified Prelude
import GHC.TypeLits
import TypedFlow.TF
import TypedFlow.Types
import Control.Monad.State (gets)
import Data.Monoid ((<>))
data DenseP t a b = DenseP {denseWeights :: Tensor '[a,b] (Flt t)
,denseBiases :: Tensor '[b] (Flt t)}
newtype EmbeddingP numObjects embeddingSize t = EmbeddingP (Tensor '[numObjects, embeddingSize] ('Typ 'Float t))
instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => KnownTensors (EmbeddingP numObjects embeddingSize b) where
travTensor f s (EmbeddingP p) = EmbeddingP <$> travTensor f s p
instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize b) where
defaultInitializer = EmbeddingP (randomUniform (0.05) 0.05)
embedding :: ∀ embeddingSize numObjects batchSize t.
EmbeddingP numObjects embeddingSize t -> Tensor '[batchSize] Int32 -> Tensor '[embeddingSize,batchSize] ('Typ 'Float t)
embedding (EmbeddingP param) input = gather @ '[embeddingSize] (transpose param) input
instance (KnownNat a, KnownNat b, KnownBits t) => KnownTensors (DenseP t a b) where
travTensor f s (DenseP x y) = DenseP <$> travTensor f (s<>"_w") x <*> travTensor f (s<>"_bias") y
instance (KnownNat n, KnownNat m, KnownBits b) => ParamWithDefault (DenseP b n m) where
defaultInitializer = DenseP glorotUniform (truncatedNormal 0.1)
(#), dense :: ∀m n batchSize t. DenseP t n m -> Tensor '[n, batchSize] (Flt t) -> Tensor '[m, batchSize] (Flt t)
(DenseP weightMatrix bias) # v = (weightMatrix ∙ v) + bias
dense = (#)
data DropProb = DropProb Float
mkDropout :: forall s t. KnownShape s => KnownBits t => DropProb -> Gen (Tensor s ('Typ 'Float t) -> Tensor s ('Typ 'Float t))
mkDropout (DropProb dropProb) = do
let keepProb = 1.0 Prelude.- dropProb
isTraining <- gets genTrainingPlaceholder
mask <- assign (if_ isTraining
(floor (randomUniform keepProb (1 Prelude.+ keepProb)) ⊘ constant keepProb)
ones)
return (mask ⊙)
newtype EndoTensor t s = EndoTensor (Tensor s t -> Tensor s t)
mkDropouts :: KnownBits t => KnownLen shapes => All KnownShape shapes => DropProb -> Gen (HTV ('Typ 'Float t) shapes -> HTV ('Typ 'Float t) shapes)
mkDropouts d = appEndoTensor <$> mkDropouts' shapeSList where
mkDropouts' :: forall shapes t. KnownBits t => All KnownShape shapes =>
SList shapes -> Gen (NP (EndoTensor ('Typ 'Float t)) shapes)
mkDropouts' LZ = return Unit
mkDropouts' (LS _ rest) = do
x <- mkDropout d
xs <- mkDropouts' rest
return (EndoTensor x :* xs)
appEndoTensor :: NP (EndoTensor t) s -> HTV t s -> HTV t s
appEndoTensor Unit Unit = Unit
appEndoTensor (EndoTensor f :* fs) (F x :* xs) = F (f x) :* appEndoTensor fs xs
data ConvP t outChannels inChannels filterSpatialShape
= ConvP (T ('[outChannels,inChannels] ++ filterSpatialShape) ('Typ 'Float t)) (T '[outChannels] ('Typ 'Float t))
instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownBits t) =>
ParamWithDefault (ConvP t outChannels inChannels filterSpatialShape) where
defaultInitializer = prodHomo @filterSpatialShape @'[outChannels] $
knownAppend @filterSpatialShape @'[outChannels] $
ConvP (transposeN' (reshape i)) (constant 0.1)
where i :: T '[inChannels,Product filterSpatialShape* outChannels] (Flt t)
i = knownProduct @filterSpatialShape glorotUniform
instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownBits t) =>
KnownTensors (ConvP t outChannels inChannels filterSpatialShape) where
travTensor f s (ConvP x y) = ConvP <$> travTensor f (s<>"_filters") x <*> travTensor f (s <> "_biases") y
conv :: forall outChannels filterSpatialShape inChannels s t.
((1 + Length filterSpatialShape) ~ Length s,
Length filterSpatialShape <= 3,
KnownLen filterSpatialShape) =>
ConvP t outChannels inChannels filterSpatialShape ->
T ('[inChannels] ++ s) ('Typ 'Float t) -> (T ('[outChannels] ++ s) ('Typ 'Float t))
conv (ConvP filters bias) input = convolution input filters + bias
maxPool2D :: forall stridex (stridey::Nat) batch height width channels t.
(KnownNat stridex, KnownNat stridey) =>
T '[channels,width*stridex,height*stridex,batch] (Flt t) -> T '[channels,width,height,batch] (Flt t)
maxPool2D (T value) = T (funcall "tf.nn.max_pool" [value
,showShape @'[1,stridex,stridey,1]
,showShape @'[1,stridex,stridey,1]
,named "padding" (str "SAME") ])