module Numeric.Neural.Model
( ParamFun(..)
, Component(..)
, weightsLens
, activate
, Model(..)
, model
, modelR
, modelError
, descent
, StdModel
, mkStdModel
) where
import Control.Arrow
import Control.Category
import Data.Profunctor
import Data.MyPrelude
import Prelude hiding (id, (.))
import Data.Utils.Analytic
import Data.Utils.Arrow
import Data.Utils.Statistics (mean)
import Data.Utils.Traversable
newtype ParamFun t a b = ParamFun { runPF :: a -> t Analytic -> b }
instance Category (ParamFun t) where
id = arr id
ParamFun f . ParamFun g = ParamFun $ \x ts -> f (g x ts) ts
instance Arrow (ParamFun t) where
arr f = ParamFun (\x _ -> f x)
first (ParamFun f) = ParamFun $ \(x, y) ts -> (f x ts, y)
instance ArrowChoice (ParamFun t) where
left (ParamFun f) = ParamFun $ \ex ts -> case ex of
Left x -> Left (f x ts)
Right y -> Right y
instance ArrowConvolve (ParamFun t) where
convolve (ParamFun f) = ParamFun $ \xs ts -> flip f ts <$> xs
instance Functor (ParamFun t a) where fmap = fmapArr
instance Applicative (ParamFun t a) where pure = pureArr; (<*>) = apArr
instance Profunctor (ParamFun t) where dimap = dimapArr
data Component a b = forall t. (Traversable t, Applicative t) => Component
{ weights :: t Double
, compute :: ParamFun t a b
, initR :: forall m. MonadRandom m => m (t Double)
}
weightsLens :: Lens' (Component a b) [Double]
weightsLens = lens (\(Component ws _ _) -> toList ws)
(\(Component _ c i) ws -> let Just ws' = fromList ws in Component ws' c i)
activate :: Component a b -> a -> b
activate (Component ws f _) x = runPF f x $ fromDouble <$> ws
data Empty a = Empty deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)
instance Applicative Empty where
pure = const Empty
Empty <*> Empty = Empty
data Pair s t a = Pair (s a) (t a) deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)
instance (Applicative s, Applicative t) => Applicative (Pair s t) where
pure x = Pair (pure x) (pure x)
Pair f g <*> Pair x y = Pair (f <*> x) (g <*> y)
instance Category Component where
id = arr id
Component ws c i . Component ws' c' i' = Component
{ weights = Pair ws ws'
, compute = ParamFun $ \x (Pair zs zs') -> runPF c (runPF c' x zs') zs
, initR = Pair <$> i <*> i'
}
instance Arrow Component where
arr f = Component
{ weights = Empty
, compute = arr f
, initR = return Empty
}
first (Component ws c i) = Component
{ weights = ws
, compute = first c
, initR = i
}
instance ArrowChoice Component where
left (Component ws c i) = Component ws (left c) i
instance ArrowConvolve Component where
convolve (Component ws c i) = Component ws (convolve c) i
instance Functor (Component a) where fmap = fmapArr
instance Applicative (Component a) where pure = pureArr; (<*>) = apArr
instance Profunctor Component where dimap = dimapArr
data Model :: (* -> *) -> (* -> *) -> * -> * -> * -> * where
Model :: (Functor f, Functor g)
=> Component (f Analytic) (g Analytic)
-> (a -> (f Double, g Analytic -> Analytic))
-> (b -> f Double)
-> (g Double -> c)
-> Model f g a b c
instance Profunctor (Model f g a) where
dimap m n (Model c e i o) = Model c e (i . m) (n . o)
model :: Model f g a b c -> b -> c
model (Model c _ i o) = activate $ i ^>> fmap fromDouble ^>> c >>^ fmap (fromJust . fromAnalytic) >>^ o
modelR :: MonadRandom m => Model f g a b c -> m (Model f g a b c)
modelR (Model c e i o) = case c of
Component _ f r -> do
ws <- r
return $ Model (Component ws f r) e i o
errFun :: (Functor f, Foldable h, Traversable t)
=> (a -> (f Double, g Analytic -> Analytic))
-> h a
-> ParamFun t (f Analytic) (g Analytic)
-> (t Analytic -> Analytic)
errFun e xs f = runPF f' xs where
f' = toList ^>> convolve f'' >>^ mean
f'' = proc x -> do
let (x', h) = e x
x'' = fromDouble <$> x'
y <- f -< x''
returnA -< h y
modelError :: Foldable h => Model f g a b c -> h a -> Double
modelError (Model c e _ _) xs = case c of
Component ws f _ -> let f' = errFun e xs f
f'' = fromJust . fromAnalytic . f' . fmap fromDouble
in f'' ws
descent :: (Foldable h)
=> Model f g a b c
-> Double
-> h a
-> (Double, Model f g a b c)
descent (Model c e i o) eta xs = case c of
Component ws f r ->
let f' = errFun e xs f
(err, ws') = gradient (\w dw -> w eta * dw) f' ws
c' = Component ws' f r
m = Model c' e i o
in (err, m)
type StdModel f g b c = Model f g (b, c) b c
mkStdModel :: (Functor f, Functor g)
=> Component (f Analytic) (g Analytic)
-> (c -> g Analytic -> Analytic)
-> (b -> f Double)
-> (g Double -> c)
-> StdModel f g b c
mkStdModel c e i o = Model c e' i o where
e' (x, y) = (i x, e y)