{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module NN.DSL(module NN.DSL) where

import           Gen.Caffe.AccuracyParameter           as AP
import           Gen.Caffe.ConvolutionParameter        as CP
import           Gen.Caffe.DataParameter               as DP
import           Gen.Caffe.DropoutParameter            as DP
import           Gen.Caffe.FillerParameter             as FP
import           Gen.Caffe.InnerProductParameter       as IP
import           Gen.Caffe.LayerParameter              as LP
import           Gen.Caffe.LRNParameter                as LRN
import           Gen.Caffe.NetStateRule                as NS
import           Gen.Caffe.ParamSpec                   as PS
import           Gen.Caffe.Phase                       as P
import           Gen.Caffe.PoolingParameter            as PP
import           Gen.Caffe.PoolingParameter.PoolMethod as PP
import           Gen.Caffe.TransformationParameter     as TP

import           Control.Lens
import           Data.Maybe
import           Data.Sequence
import           Text.ProtocolBuffers                  as P

import           NN.Graph

type Net = Gr LayerParameter ()
type AnnotatedNet a = Gr (LayerParameter, a) ()
type NetBuilder = G LayerParameter ()

data LayerTy = Data
             | Pool
             | Concat
             | Conv
             | IP
             | LRN
             | ReLU
             | Dropout
             | Accuracy
             | SoftmaxWithLoss
               deriving (Show, Eq, Enum)

-- Manually implement for exhausiveness checking + Caffe
-- idiosyncracies
asCaffe :: LayerTy -> String
asCaffe Data = "Data"
asCaffe Concat = "Concat"
asCaffe Pool = "Pooling"
asCaffe Conv = "Convolution"
asCaffe IP = "InnerProduct"
asCaffe LRN = "LRN"
asCaffe ReLU = "ReLU"
asCaffe Dropout = "Dropout"
asCaffe Accuracy = "Accuracy"
asCaffe SoftmaxWithLoss = "SoftmaxWithLoss"

toCaffe :: String -> Maybe LayerTy
toCaffe "Data" = Just Data
toCaffe "Concat" = Just Concat
toCaffe "Pooling" = Just Pool
toCaffe "Convolution" = Just Conv
toCaffe "InnerProduct" = Just IP
toCaffe "LRN" = Just LRN
toCaffe "ReLU" = Just ReLU
toCaffe "Dropout" = Just Dropout
toCaffe "Accuracy" = Just Accuracy
toCaffe "SoftmaxWithLoss" = Just SoftmaxWithLoss
toCaffe _ = Nothing

s = P.fromString

def :: Default a => a
def = P.defaultValue

ty type'' = LP._type' ?~ s (asCaffe type'')

layerTy :: LayerParameter -> LayerTy
layerTy l = fromJust (LP.type' l) & toString & toCaffe & fromJust

phase' phase'' = LP._include <>~ singleton (def & _phase ?~ phase'')

param' v = _param .~ fromList v

-- Data
setF outer f n = set (outer . _Just . f) (Just n)
source' source'' = setF _data_param DP._source (s source'')
cropSize' = setF _transform_param TP._crop_size
meanFile' meanFile'' = setF _transform_param TP._mean_file (s meanFile'')
mirror' = setF _transform_param TP._mirror
batchSize' = setF _data_param DP._batch_size
backend' =  setF _data_param DP._backend

-- Convolution
setConv = setF _convolution_param
numOutputC' = setConv CP._num_output
kernelSizeC' = setConv CP._kernel_size
padC' = setConv CP._pad
groupC' = setConv CP._group
strideC' = setConv CP._stride
biasFillerC' = setConv CP._bias_filler
weightFillerC' = setConv CP._weight_filler

-- Pooling
setPool = setF _pooling_param
pool' = setPool PP._pool
sizeP' = setPool PP._kernel_size
strideP' = setPool PP._stride
padP' = setPool PP._pad

-- Inner Product
setIP = setF _inner_product_param
weightFillerIP' = setIP IP._weight_filler
numOutputIP' = setIP IP._num_output
biasFillerIP' = setIP IP._bias_filler

-- LRN
setLRN = setF _lrn_param
localSize' = setLRN LRN._local_size
alphaLRN' = setLRN LRN._alpha
betaLRN' = setLRN LRN._beta

-- Fillers
constant value' = def & FP._type' ?~ s "constant" & _value ?~ value'
gaussian std' = def & FP._type' ?~ s "gaussian" & _std ?~ std'
xavier std' = def & FP._type' ?~ s "xavier" & _std ?~ std'
zero = constant 0.0

-- Multipler
lrMult' value' = _lr_mult ?~ value'
decayMult' value' = _decay_mult ?~ value'

-- Simple Layers
accuracy k' = def & ty Accuracy & phase' TEST & _accuracy_param ?~ (def & AP._top_k ?~ k')
softmax = def & ty SoftmaxWithLoss
dropout ratio = def & ty Dropout & _dropout_param ?~ (def & _dropout_ratio ?~ ratio)
relu = def & ty ReLU
conv = def & ty Conv & _convolution_param ?~ def
ip n = def & ty IP & _inner_product_param ?~ def & numOutputIP' n
data' = def & ty Data & _transform_param ?~ def & _data_param ?~ def
maxPool = def & ty Pool & _pooling_param ?~ def & pool' MAX
avgPool = def & ty Pool & _pooling_param ?~ def & pool' AVE
lrn = def & ty LRN & _lrn_param ?~ def
concat' = def & ty Concat