module NN.Examples.MLPSweep where import Control.Applicative import Control.Concurrent import Control.Lens import Control.Monad import Data.Function import Data.List import Data.Word import GHC.IO.Handle import System.Exit import System.IO.Temp import System.Process import Text.Read import NN.Backend.Torch as Torch import NN.DSL import NN.Graph import NN.Passes -- A simple example of performing a parameter sweep over the number of -- hidden units in an MLP. parameterSweepMLP :: Int -> IO ([Word32], Maybe Float) parameterSweepMLP numWorkers = maximumBy (compare `on` snd) <$> parMapIO numWorkers candidates assess where mlp hiddenUnits = do _ <- sequential (concatMap (\n -> [ip n, relu]) hiddenUnits ++ [softmax]) return () candidates = [[i, j, k] | let xs = [10..15], i <- xs, j <- xs, k <- xs] assess experiment = do let Just torchCode = mlp experiment & parse & Torch.backend (file, handle) <- openTempFile "/tmp" "mlp.lua" hPutStr handle torchCode hClose handle (rc, stdout, _) <- readProcessWithExitCode "NN/Examples/scripts/run_mlp.lua" [file] "" return $ case rc of ExitSuccess -> readMaybe stdout _ -> Nothing parMapIO :: Int -> [a] -> (a -> IO b) -> IO [(a, b)] parMapIO n xs f = do jobs <- newChan results <- newChan forM_ [1..n] $ \_ -> forkIO $ worker jobs results forM_ xs (writeChan jobs) forM xs $ \_ -> readChan results where worker jobs results = forever $ do job <- readChan jobs result <- f job writeChan results (job, result)