{-# OPTIONS_HADDOCK show-extensions #-}

{-# LANGUAGE Arrows #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-|
Module      : Neural.Model
Description : "neural" components and models
Copyright   : (c) Lars Brünjes, 2016
License     : MIT
Maintainer  : brunjlar@gmail.com
Stability   : experimental
Portability : portable

This module defines /parameterized functions/, /components/ and /models/. 
The parameterized functions are instances of the 'Arrow' and 'ArrowChoice' typeclasses, whereas
'Component's behave like 'Arrow's with choice over a different base category 
(the category 'Diff' of differentiable functions).
Both parameterized functions and components can be combined easily and flexibly. 

/Models/ contain a component, can measure their error with regard to samples and can be trained by gradient descent/
backpropagation.
-}

module Numeric.Neural.Model
    ( ParamFun(..)
    , Component(..)
    , _weights
    , activate
    , _component
    , Pair(..)
    , FEither(..)
    , Convolve(..)
    , cArr
    , cFirst
    , cLeft
    , cConvolve
    , Model(..)
    , model
    , modelR
    , modelError
    , descent
    , StdModel
    , mkStdModel
    ) where

import Control.Applicative    
import Control.Arrow
import Control.Category
import Control.Monad.Par            (runPar)
import Control.Monad.Par.Combinator (parMapReduceRange, InclusiveRange(..))
import Data.Profunctor
import Data.MyPrelude
import Prelude                      hiding (id, (.))
import Data.Utils.Analytic
import Data.Utils.Arrow
import Data.Utils.Statistics        (mean)
import Data.Utils.Traversable

-- | The type @'ParamFun' t a b@ describes parameterized functions from @a@ to @b@, where the
--   parameters are of type @t s@.
--   When such components are composed, they all share the /same/ parameters.
--
newtype ParamFun s t a b = ParamFun { runPF :: a -> t s -> b }

instance Category (ParamFun s t) where

    id = arr id

    ParamFun f . ParamFun g = ParamFun $ \x ts -> f (g x ts) ts

instance Arrow (ParamFun s t) where

    arr f = ParamFun (\x _ -> f x)

    first (ParamFun f) = ParamFun $ \(x, y) ts -> (f x ts, y)

instance ArrowChoice (ParamFun s t) where

    left (ParamFun f) = ParamFun $ \ex ts -> case ex of
        Left x  -> Left (f x ts)
        Right y -> Right y

instance ArrowConvolve (ParamFun s t) where

    convolve (ParamFun f) = ParamFun $ \xs ts -> flip f ts <$> xs

instance Functor (ParamFun s t a) where fmap = fmapArr

instance Applicative (ParamFun s t a) where pure = pureArr; (<*>) = apArr

instance Profunctor (ParamFun s t) where dimap  = dimapArr

-- | A @'Component' f g@ is a parameterized differentiable function @f Double -> g Double@.
--   In contrast to 'ParamFun', when components are composed, parameters are not shared. 
--   Each component carries its own collection of parameters instead.
--
data Component f g = forall t. (Traversable t, Applicative t, NFData (t Double)) => Component
    { weights :: t Double                                         -- ^ the specific parameter values
    , compute :: forall s. Analytic s => ParamFun s t (f s) (g s) -- ^ the encapsulated parameterized function
    , initR   :: forall m. MonadRandom m => m (t Double)          -- ^ randomly sets the parameters
    }

-- | A 'Lens'' to get or set the weights of a component.
--   The shape of the parameter collection is hidden by existential quantification,
--   so this lens has to use simple generic lists.
--
_weights:: Lens' (Component f g) [Double]
_weights = lens (\(Component ws _ _)    -> toList ws)
                (\(Component _  c i) ws -> let Just ws' = fromList ws in Component ws' c i)

-- | Activates a component, i.e. applies it to the specified input, using the current parameter values.
--
activate :: Component f g -> f Double -> g Double
activate (Component ws f _) xs = runPF f xs ws 

data Empty a = Empty deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)

instance Applicative Empty where

    pure = const Empty

    Empty <*> Empty = Empty

instance NFData (Empty a) where

    rnf Empty = ()

-- | The analogue for pairs in the category of functors.
--
data Pair s t a = Pair (s a) (t a) deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)

instance (NFData (s a), NFData (t a)) => NFData (Pair s t a) where

    rnf (Pair xs ys) = rnf xs `seq` rnf ys `seq` ()

instance (Applicative s, Applicative t) => Applicative (Pair s t) where

    pure x = Pair (pure x) (pure x)

    Pair f g <*> Pair x y = Pair (f <*> x) (g <*> y)

instance Category Component where

    id = cArr id

    Component ws c i . Component ws' c' i' = Component
        { weights = Pair ws ws'
        , compute = ParamFun $ \x (Pair zs zs') -> runPF c (runPF c' x zs') zs 
        , initR   = Pair <$> i <*> i'
        }

-- | The analogue of 'Control.Arrow.arr' for 'Component's.
--
cArr :: Diff f g -> Component f g
cArr (Diff f) = Component
    { weights = Empty
    , compute = arr f
    , initR   = return Empty
    }

-- | The analogue of 'Control.Arrow.first' for 'Component's.
--
cFirst :: Component f g -> Component (Pair f h) (Pair g h)
cFirst (Component ws c i) = Component
    { weights = ws
    , compute = ParamFun $ \(Pair xs ys) ws' -> Pair (runPF c xs ws') ys
    , initR   = i
    }

-- | The analogue for 'Either' in the category of functors.
--
data FEither f g a = FLeft (f a) | FRight (g a)
    deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)

-- | The analogue of 'Control.Arrow.left' for 'Component's.
--
cLeft :: Component f g -> Component (FEither f h) (FEither g h)
cLeft (Component ws c i) = Component
    { weights = ws
    , compute = ParamFun $ \es ws' -> case es of
        FLeft xs  -> FLeft $ runPF c xs ws'
        FRight ys -> FRight ys
    , initR   = i
    }

-- | Composition of functors.
--
data Convolve f g a = Convolve (f (g a))
    deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)

-- | The analogue of 'convolve' for 'Component's.
--
cConvolve :: Functor h => Component f g -> Component (Convolve h f) (Convolve h g) 
cConvolve (Component ws c i) = Component
    { weights = ws
    , compute = ParamFun $ \(Convolve xss) ws' -> Convolve $ flip (runPF c) ws' <$> xss
    , initR   = i
    }

instance NFData (Component f g) where

    rnf (Component ws _ _) = rnf ws


-- | A @'Model' f g a b c@ wraps a @'Component' f g@
--   and models functions @b -> c@ with "samples" (for model error determination)
--   of type @a@.
--
data Model :: (* -> *) -> (* -> *) -> * -> * -> * -> * where

    Model :: (Functor f, Functor g)
             => Component f g
             -> (a -> (f Double, Diff g Identity))
             -> (b -> f Double)
             -> (g Double -> c)
             -> Model f g a b c

instance Profunctor (Model f g a) where

    dimap m n (Model c e i o) = Model c e (i . m) (n . o)

instance NFData (Model f g a b c) where

    rnf (Model c _ _ _) = rnf c

-- | A 'Lens' for accessing the component embedded in a model.
--
_component :: Lens' (Model f g a b c) (Component f g)
_component = lens (\(Model c _ _ _) -> c)
                  (\(Model _ e i o) c -> Model c e i o)

-- | Computes the modelled function.
model :: Model f g a b c -> b -> c
model (Model c _ i o) = o . activate c . i

-- | Generates a model with randomly initialized weights. All other properties are copied from the provided model. 
modelR :: MonadRandom m => Model f g a b c -> m (Model f g a b c)
modelR (Model c e i o) = case c of
    Component _ f r -> do
        ws <- r
        return $ Model (Component ws f r) e i o

errFun :: forall f t a g. (Functor f, Traversable t)
          => (a -> (f Double, Diff g Identity)) 
          -> a
          -> (forall s. Analytic s => ParamFun s t (f s) (g s))
          -> Diff t Identity
errFun e x f = Diff $ runPF f' x where

    f' :: forall s. Analytic s => ParamFun s t a (Identity s)
    f' = proc z -> do
        let (x', Diff h) = e z
            x''          = fromDouble <$> x'
        y <- f -< x''
        returnA -< h y

modelError' :: Model f g a b c -> a -> Double
modelError' (Model c e _ _) x = case c of
    Component ws f _ -> let f' = errFun e x f
                        in  runIdentity $ runDiff f' ws

-- | Calculates the avarage model error for a "mini-batch" of samples.
--
modelError :: Foldable h => Model f g a b c -> h a -> Double
modelError m xs = mean $ modelError' m <$> toList xs

-- | Performs one step of gradient descent/ backpropagation on the model,
descent :: (Foldable h)
           => Model f g a b c           -- ^ the model whose error should be decreased 
           -> Double                    -- ^ the learning rate
           -> h a                       -- ^ a mini-batch of samples
           -> (Double, Model f g a b c) -- ^ returns the average sample error and the improved model
descent (Model c e i o) eta xs = case c of
    Component ws f r ->
        let xs'                       = toList xs
            l                         = length xs'
            l'                        = fromIntegral l
            scale                     = eta / l'
            q j                       = do
                                            let x          = xs' !! j
                                                (err', g') = gradWith' (\_ dw -> scale * dw) (errFun e x f) ws
                                            return (err' / l', g')
            s (err', g') (err'', g'') = return (err' + err'', (+) <$> g' <*> g'')
            (err, ws')                = runPar $ parMapReduceRange (InclusiveRange 0 $ pred l) q s (0, pure 0)
            ws''                      = (-) <$> ws <*> ws'
            c'                        = Component ws'' f r
            m                         = Model c' e i o
        in  (err, m)

-- | A type abbreviation for the most common type of models, where samples are just input-output tuples.
type StdModel f g b c = Model f g (b, c) b c

-- | Creates a 'StdModel', using the simplifying assumtion that the error can be computed from the expected
--   output allone.
--
mkStdModel :: (Functor f, Functor g) 
              => Component f g
              -> (c -> Diff g Identity)
              -> (b -> f Double)
              -> (g Double -> c)
              -> StdModel f g b c
mkStdModel c e i o = Model c e' i o where

    e' (x, y) = (i x, e y)