module NN.DSL(module NN.DSL) where
import Gen.Caffe.AccuracyParameter as AP
import Gen.Caffe.ConvolutionParameter as CP
import Gen.Caffe.DataParameter as DP
import Gen.Caffe.DropoutParameter as DP
import Gen.Caffe.FillerParameter as FP
import Gen.Caffe.InnerProductParameter as IP
import Gen.Caffe.LayerParameter as LP
import Gen.Caffe.LRNParameter as LRN
import Gen.Caffe.NetStateRule as NS
import Gen.Caffe.ParamSpec as PS
import Gen.Caffe.Phase as P
import Gen.Caffe.PoolingParameter as PP
import Gen.Caffe.PoolingParameter.PoolMethod as PP
import Gen.Caffe.TransformationParameter as TP
import Control.Lens
import Data.Maybe
import Data.Sequence
import Text.ProtocolBuffers as P
import NN.Graph
type Net = Gr LayerParameter ()
type AnnotatedNet a = Gr (LayerParameter, a) ()
type NetBuilder = G LayerParameter ()
data LayerTy = Data
| Pool
| Concat
| Conv
| IP
| LRN
| ReLU
| Dropout
| Accuracy
| SoftmaxWithLoss
deriving (Show, Eq, Enum)
asCaffe :: LayerTy -> String
asCaffe Data = "Data"
asCaffe Concat = "Concat"
asCaffe Pool = "Pooling"
asCaffe Conv = "Convolution"
asCaffe IP = "InnerProduct"
asCaffe LRN = "LRN"
asCaffe ReLU = "ReLU"
asCaffe Dropout = "Dropout"
asCaffe Accuracy = "Accuracy"
asCaffe SoftmaxWithLoss = "SoftmaxWithLoss"
toCaffe :: String -> Maybe LayerTy
toCaffe "Data" = Just Data
toCaffe "Concat" = Just Concat
toCaffe "Pooling" = Just Pool
toCaffe "Convolution" = Just Conv
toCaffe "InnerProduct" = Just IP
toCaffe "LRN" = Just LRN
toCaffe "ReLU" = Just ReLU
toCaffe "Dropout" = Just Dropout
toCaffe "Accuracy" = Just Accuracy
toCaffe "SoftmaxWithLoss" = Just SoftmaxWithLoss
toCaffe _ = Nothing
s = P.fromString
def :: Default a => a
def = P.defaultValue
ty type'' = LP._type' ?~ s (asCaffe type'')
layerTy :: LayerParameter -> LayerTy
layerTy l = fromJust (LP.type' l) & toString & toCaffe & fromJust
phase' phase'' = LP._include <>~ singleton (def & _phase ?~ phase'')
param' v = _param .~ fromList v
setF outer f n = set (outer . _Just . f) (Just n)
source' source'' = setF _data_param DP._source (s source'')
cropSize' = setF _transform_param TP._crop_size
meanFile' meanFile'' = setF _transform_param TP._mean_file (s meanFile'')
mirror' = setF _transform_param TP._mirror
batchSize' = setF _data_param DP._batch_size
backend' = setF _data_param DP._backend
setConv = setF _convolution_param
numOutputC' = setConv CP._num_output
kernelSizeC' = setConv CP._kernel_size
padC' = setConv CP._pad
groupC' = setConv CP._group
strideC' = setConv CP._stride
biasFillerC' = setConv CP._bias_filler
weightFillerC' = setConv CP._weight_filler
setPool = setF _pooling_param
pool' = setPool PP._pool
sizeP' = setPool PP._kernel_size
strideP' = setPool PP._stride
padP' = setPool PP._pad
setIP = setF _inner_product_param
weightFillerIP' = setIP IP._weight_filler
numOutputIP' = setIP IP._num_output
biasFillerIP' = setIP IP._bias_filler
setLRN = setF _lrn_param
localSize' = setLRN LRN._local_size
alphaLRN' = setLRN LRN._alpha
betaLRN' = setLRN LRN._beta
constant value' = def & FP._type' ?~ s "constant" & _value ?~ value'
gaussian std' = def & FP._type' ?~ s "gaussian" & _std ?~ std'
xavier std' = def & FP._type' ?~ s "xavier" & _std ?~ std'
zero = constant 0.0
lrMult' value' = _lr_mult ?~ value'
decayMult' value' = _decay_mult ?~ value'
accuracy k' = def & ty Accuracy & phase' TEST & _accuracy_param ?~ (def & AP._top_k ?~ k')
softmax = def & ty SoftmaxWithLoss
dropout ratio = def & ty Dropout & _dropout_param ?~ (def & _dropout_ratio ?~ ratio)
relu = def & ty ReLU
conv = def & ty Conv & _convolution_param ?~ def
ip n = def & ty IP & _inner_product_param ?~ def & numOutputIP' n
data' = def & ty Data & _transform_param ?~ def & _data_param ?~ def
maxPool = def & ty Pool & _pooling_param ?~ def & pool' MAX
avgPool = def & ty Pool & _pooling_param ?~ def & pool' AVE
lrn = def & ty LRN & _lrn_param ?~ def
concat' = def & ty Concat