-- | Applicative interface to define recursive structures and derive Boltzmann
-- samplers.
--
-- Given the recursive structure of the types, and how to combine generators,
-- the library takes care of computing the oracles and setting the right
-- distributions.

{-# LANGUAGE FlexibleContexts, FlexibleInstances, GADTs, RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveFunctor, DeriveGeneric, ImplicitParams #-}
{-# LANGUAGE RecordWildCards, DeriveDataTypeable #-}
{-# LANGUAGE TypeFamilies, MultiParamTypeClasses #-}

module Boltzmann.Species where

import Control.Applicative
import Control.Monad
import Data.Bifunctor
import Data.Coerce
import Data.Function
import Data.Foldable
import Data.List
import Data.Maybe
import Data.Vector ( Vector )
import qualified Data.Vector as V
import qualified Numeric.AD as AD

import Boltzmann.Data.Common
import Boltzmann.Data.Types
import Boltzmann.Solver

class Embed f m where
  emap :: (m a -> m b) -> f a -> f b
  -- | A natural transformation between @f@ and @m@?
  embed :: m a -> f a

-- | 'Applicative' defines a product, 'Alternative' defines an addition,
-- with scalar multiplication we get a module.
--
-- This typeclass allows to directly tweak weights in the oracle by
-- chosen factors.
class (Alternative f, Num (Scalar f)) => Module f where
  type Scalar f :: *

  -- | Scalar embedding.
  scalar :: Scalar f -> f ()
  scalar x = x <.> pure ()

  -- | Scalar multiplication.
  (<.>) :: Scalar f -> f a -> f a
  x <.> f = scalar x *> f

infixr 3 <.>

type Endo a = a -> a

data System f a c = System
  { dim :: Int
  , sys' :: f () -> Vector (f a) -> (Vector (f a), c)
  } deriving (Functor)

sys :: System f a c -> f () -> Vector (f a) -> Vector (f a)
sys = (fmap . fmap . fmap) fst sys'

newtype ConstModule r a = ConstModule { unConstModule :: r }

instance Functor (ConstModule r) where
  fmap _ (ConstModule r) = ConstModule r

instance Num r => Embed (ConstModule r) m where
  emap _ (ConstModule r) = ConstModule r
  embed _ = ConstModule 1

instance Num r => Applicative (ConstModule r) where
  pure _ = ConstModule 1
  ConstModule x <*> ConstModule y = ConstModule (x * y)

instance Num r => Alternative (ConstModule r) where
  empty = ConstModule 0
  ConstModule x <|> ConstModule y = ConstModule (x + y)

instance Num r => Module (ConstModule r) where
  type Scalar (ConstModule r) = r
  scalar = ConstModule
  x <.> ConstModule r = ConstModule (x * r)

solve
  :: forall b c
  . (forall a. Num a => System (ConstModule a) b c)
  -> Double -> Maybe (Vector Double)
solve s x = fixedPoint defSolveArgs phi' (V.replicate (dim s') 0)
  where
    phi' :: forall a. (AD.Mode a, AD.Scalar a ~ Double) => Endo (Vector a)
    phi' = coerce (sys s (scalar (AD.auto x)) :: Endo (Vector (ConstModule a b)))
    -- Arbitrary instantiation to get its dimension.
    s' :: System (ConstModule Int) b c
    s' = s

sizedGenerator
  :: forall b c m
  . MonadRandomLike m
  => (forall f. (Module f, Embed f m) => System (Pointiful f) b c)
  -> Int  -- ^ Index of type
  -> Int  -- ^ Points
  -> Maybe Double  -- ^ Expected size (or singular sampler)
  -> m b
sizedGenerator s i k size' = fst (sfix s' x oracle) V.! j
  where
    (x, oracle) = solveSized s i k size'
    s' = point (k + 1) s
    j = i * (k + 2) + k

solveSized
  :: forall b c
  . (forall a. Num a => System (Pointiful (ConstModule a)) b c)
  -> Int  -- ^ Index of type
  -> Int  -- ^ Points
  -> Maybe Double  -- ^ Expected size (or singular sampler)
  -> (Double, Vector Double)
solveSized s i k size' =
  fmap fromJust (search (solve s') (checkSize size'))
  where
    s' :: forall a. Num a => System (ConstModule a) b c
    s' = point (k + 1) s
    j = i * (k + 2) + k
    j' = i * (k + 2) + k + 1
    checkSize _ (Just ys) | V.any (< 0) ys = False
    checkSize (Just size) (Just ys) = size >= ys V.! j' / ys V.! j
    checkSize Nothing (Just _) = True
    checkSize _ Nothing = False

newtype Weighted m a = Weighted [(Double, m a)]

weighted :: Double -> m a -> Weighted m a
weighted x a = Weighted [(x, a)]

runWeighted :: MonadRandomLike m => Weighted m a -> (Double, m a)
runWeighted (Weighted [a]) = a
runWeighted (Weighted as) = (sum (fmap fst as), frequencyWith doubleR as)

instance Functor m => Functor (Weighted m) where
  fmap f (Weighted as) = Weighted ((fmap . fmap . fmap) f as)

instance MonadRandomLike m => Embed (Weighted m) m where
  emap f = Weighted . (: []) . fmap f . runWeighted
  embed m = Weighted [(1, m)]

instance MonadRandomLike m => Applicative (Weighted m) where
  pure a = Weighted [(1, pure a)]
  f' <*> a' = Weighted [(u * v, f <*> a)]
    where
      (u, f) = runWeighted f'
      (v, a) = runWeighted a'

instance MonadRandomLike m => Alternative (Weighted m) where
  empty = Weighted []
  Weighted as <|> Weighted bs = Weighted (as ++ bs)

instance MonadRandomLike m => Module (Weighted m) where
  type Scalar (Weighted m) = Double
  scalar x = Weighted [(x, pure ())]
  x <.> Weighted as = Weighted (fmap (first (x *)) as)

sfix
  :: MonadRandomLike m
  => System (Weighted m) b c -> Double -> Vector Double -> (Vector (m b), c)
sfix s x oracle =
  fix $
    (first . fmap) (snd . runWeighted) .
    sys' s (scalar x) .
    V.zipWith weighted oracle .
    fst

data Pointiful f a = Pointiful [f a] | Zero (f a)

instance Functor f => Functor (Pointiful f) where
  fmap f (Pointiful v) = Pointiful ((fmap . fmap) f v)
  fmap f (Zero x) = Zero (fmap f x)

instance Embed f m => Embed (Pointiful f) m where
  emap f (Pointiful v) = Pointiful ((fmap . emap) f v)
  emap f (Zero x) = Zero (emap f x)
  embed = Zero . embed

instance Module f => Applicative (Pointiful f) where
  pure a = Zero (pure a)
  Zero f <*> Zero x = Zero (f <*> x)
  Zero f <*> Pointiful xs = Pointiful (fmap (f <*>) xs)
  Pointiful fs <*> Zero x = Pointiful (fmap (<*> x) fs)
  Pointiful fs <*> Pointiful xs = Pointiful (convolute fs xs)
    where
      convolute fs xs = zipWith3 sumOfProducts [0 ..] (inits' fs) (inits' xs)
      inits' = tail . inits
      sumOfProducts k f x = asum (zipWith3 (times k) [0 ..] f (reverse x))
      times k k1 f x = fromInteger (binomial k k1) <.> f <*> x

instance Module f => Alternative (Pointiful f) where
  empty = Zero empty
  Pointiful xs <|> Pointiful ys = Pointiful (zipWith (<|>) xs ys)
  Pointiful (x : xs) <|> Zero y = Pointiful ((x <|> y) : xs)
  Zero x <|> Pointiful (y : ys) = Pointiful ((x <|> y) : ys)
  Zero x <|> Zero y = Zero (x <|> y)
  Pointiful [] <|> m = m
  m <|> Pointiful [] = m

instance Module f => Module (Pointiful f) where
  type Scalar (Pointiful f) = Scalar f
  scalar = Zero . scalar

unPointiful :: Alternative f => Pointiful f a -> [f a]
unPointiful (Pointiful as) = as
unPointiful (Zero a) = a : repeat empty

point :: Module f => Int -> System (Pointiful f) b c -> System f b c
point k s = System ((k + 1) * dim s) $ \x ->
  first flatten . sys' s (Pointiful (repeat x)) . resize
  where
    flatten = join . fmap (V.fromList . take (k + 1) . unPointiful)
    resize v = V.generate (dim s) $ \i ->
      Pointiful [v V.! j | j <- [i * (k + 1) .. i * (k + 1) + k]]