module Numeric.SGD
( SgdArgs (..)
, sgdArgsDefault
, Para
, sgd
, module Numeric.SGD.Grad
, module Numeric.SGD.Dataset
) where
import Control.Monad (forM_)
import qualified System.Random as R
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Control.Monad.Primitive as Prim
import Numeric.SGD.Grad
import Numeric.SGD.Dataset
data SgdArgs = SgdArgs
{
batchSize :: Int
, regVar :: Double
, iterNum :: Double
, gain0 :: Double
, tau :: Double }
sgdArgsDefault :: SgdArgs
sgdArgsDefault = SgdArgs
{ batchSize = 30
, regVar = 10
, iterNum = 10
, gain0 = 1
, tau = 5 }
type Para = U.Vector Double
type MVect = UM.MVector (Prim.PrimState IO) Double
sgd
:: SgdArgs
-> (Para -> Int -> IO ())
-> (Para -> x -> Grad)
-> Dataset x
-> Para
-> IO Para
sgd SgdArgs{..} notify mkGrad dataset x0 = do
u <- UM.new (U.length x0)
doIt u 0 (R.mkStdGen 0) =<< U.thaw x0
where
gain k = (gain0 * tau) / (tau + done k)
done k
= fromIntegral (k * batchSize)
/ fromIntegral (size dataset)
doIt u k stdGen x
| done k > iterNum = do
frozen <- U.unsafeFreeze x
notify frozen k
return frozen
| otherwise = do
(batch, stdGen') <- sample stdGen batchSize dataset
frozen <- U.unsafeFreeze x
notify frozen k
let grad = parUnions (map (mkGrad frozen) batch)
addUp grad u
scale (gain k) u
x' <- U.unsafeThaw frozen
apply u x'
doIt u (k+1) stdGen' x'
addUp :: Grad -> MVect -> IO ()
addUp grad v = do
UM.set v 0
forM_ (toList grad) $ \(i, x) -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (x + y)
scale :: Double -> MVect -> IO ()
scale c v = do
forM_ [0 .. UM.length v 1] $ \i -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (c * y)
apply :: MVect -> MVect -> IO ()
apply w v = do
forM_ [0 .. UM.length v 1] $ \i -> do
x <- UM.unsafeRead v i
y <- UM.unsafeRead w i
UM.unsafeWrite v i (x + y)