module NN.Backend.Torch.Codegen where
import Control.Applicative
import Control.Lens hiding (assign)
import Control.Monad.State.Strict
import Gen.Caffe.LayerParameter as LP
import Language.Lua.PrettyPrinter
import Language.Lua.Syntax
import Text.Printf
import NN.Backend.Torch.Lua
import NN.Backend.Torch.Torch
data TorchState = TorchState {
_statements :: [Stat],
_sequential :: Maybe String,
_criteria :: [String],
_count :: Int
}
makeLenses ''TorchState
newtype Torch a = Torch { _unTorch :: State TorchState a }
deriving (Functor, Applicative, Monad, MonadState TorchState)
initialize :: Torch ()
initialize = do
seq' <- fresh "seq"
sequential ?= seq'
statements <>= [require "nn"]
statements <>= [assign seq' $ torchExp (TorchModule "nn" "Sequential" [])]
where
require module' = funCall "require" [toLua $ L module']
fresh :: String -> Torch String
fresh prefix = do
c <- use count
count += 1
return $ printf "%s%d" prefix c
insertModule :: Module Exp -> Torch ()
insertModule (Criterion exp') = do
name' <- fresh "criterion"
criteria <>= [name']
statements <>= [assign name' exp']
insertModule (Inner exp') = do
Just seq' <- use sequential
statements <>= [methCall seq' "add" [exp']]
finalize :: Torch Block
finalize = do
Just seq' <- use sequential
criteria' <- use criteria
statements' <- use statements
return $ Block statements' (Just $ return' <$> seq':criteria')
runTorch :: [LayerParameter] -> Torch Block
runTorch layers = do
initialize
forM_ exps insertModule
finalize
where
exps = concatMap torchExps layers
torchExps lp = (torchExp <$>) <$> torchModules lp
lower :: [LayerParameter] -> Block
lower layers = (evalState . _unTorch) (runTorch layers) emptyTorch
where
emptyTorch = TorchState [] Nothing [] 0
codegen :: Block -> String
codegen block = pprint block & renderPretty 0.4 150 & displayS & \f -> f ""