module Boltzmann.Species where
import Control.Applicative
import Control.Monad
import Data.Bifunctor
import Data.Coerce
import Data.Function
import Data.Foldable
import Data.List
import Data.Maybe
import Data.Vector ( Vector )
import qualified Data.Vector as V
import qualified Numeric.AD as AD
import Boltzmann.Data.Common
import Boltzmann.Data.Types
import Boltzmann.Solver
class Embed f m where
emap :: (m a -> m b) -> f a -> f b
embed :: m a -> f a
class (Alternative f, Num (Scalar f)) => Module f where
type Scalar f :: *
scalar :: Scalar f -> f ()
scalar x = x <.> pure ()
(<.>) :: Scalar f -> f a -> f a
x <.> f = scalar x *> f
infixr 3 <.>
type Endo a = a -> a
data System f a c = System
{ dim :: Int
, sys' :: f () -> Vector (f a) -> (Vector (f a), c)
} deriving (Functor)
sys :: System f a c -> f () -> Vector (f a) -> Vector (f a)
sys = (fmap . fmap . fmap) fst sys'
newtype ConstModule r a = ConstModule { unConstModule :: r }
instance Functor (ConstModule r) where
fmap _ (ConstModule r) = ConstModule r
instance Num r => Embed (ConstModule r) m where
emap _ (ConstModule r) = ConstModule r
embed _ = ConstModule 1
instance Num r => Applicative (ConstModule r) where
pure _ = ConstModule 1
ConstModule x <*> ConstModule y = ConstModule (x * y)
instance Num r => Alternative (ConstModule r) where
empty = ConstModule 0
ConstModule x <|> ConstModule y = ConstModule (x + y)
instance Num r => Module (ConstModule r) where
type Scalar (ConstModule r) = r
scalar = ConstModule
x <.> ConstModule r = ConstModule (x * r)
solve
:: forall b c
. (forall a. Num a => System (ConstModule a) b c)
-> Double -> Maybe (Vector Double)
solve s x = fixedPoint defSolveArgs phi' (V.replicate (dim s') 0)
where
phi' :: forall a. (AD.Mode a, AD.Scalar a ~ Double) => Endo (Vector a)
phi' = coerce (sys s (scalar (AD.auto x)) :: Endo (Vector (ConstModule a b)))
s' :: System (ConstModule Int) b c
s' = s
sizedGenerator
:: forall b c m
. MonadRandomLike m
=> (forall f. (Module f, Embed f m) => System (Pointiful f) b c)
-> Int
-> Int
-> Maybe Double
-> m b
sizedGenerator s i k size' = fst (sfix s' x oracle) V.! j
where
(x, oracle) = solveSized s i k size'
s' = point (k + 1) s
j = i * (k + 2) + k
solveSized
:: forall b c
. (forall a. Num a => System (Pointiful (ConstModule a)) b c)
-> Int
-> Int
-> Maybe Double
-> (Double, Vector Double)
solveSized s i k size' =
fmap fromJust (search (solve s') (checkSize size'))
where
s' :: forall a. Num a => System (ConstModule a) b c
s' = point (k + 1) s
j = i * (k + 2) + k
j' = i * (k + 2) + k + 1
checkSize _ (Just ys) | V.any (< 0) ys = False
checkSize (Just size) (Just ys) = size >= ys V.! j' / ys V.! j
checkSize Nothing (Just _) = True
checkSize _ Nothing = False
newtype Weighted m a = Weighted [(Double, m a)]
weighted :: Double -> m a -> Weighted m a
weighted x a = Weighted [(x, a)]
runWeighted :: MonadRandomLike m => Weighted m a -> (Double, m a)
runWeighted (Weighted [a]) = a
runWeighted (Weighted as) = (sum (fmap fst as), frequencyWith doubleR as)
instance Functor m => Functor (Weighted m) where
fmap f (Weighted as) = Weighted ((fmap . fmap . fmap) f as)
instance MonadRandomLike m => Embed (Weighted m) m where
emap f = Weighted . (: []) . fmap f . runWeighted
embed m = Weighted [(1, m)]
instance MonadRandomLike m => Applicative (Weighted m) where
pure a = Weighted [(1, pure a)]
f' <*> a' = Weighted [(u * v, f <*> a)]
where
(u, f) = runWeighted f'
(v, a) = runWeighted a'
instance MonadRandomLike m => Alternative (Weighted m) where
empty = Weighted []
Weighted as <|> Weighted bs = Weighted (as ++ bs)
instance MonadRandomLike m => Module (Weighted m) where
type Scalar (Weighted m) = Double
scalar x = Weighted [(x, pure ())]
x <.> Weighted as = Weighted (fmap (first (x *)) as)
sfix
:: MonadRandomLike m
=> System (Weighted m) b c -> Double -> Vector Double -> (Vector (m b), c)
sfix s x oracle =
fix $
(first . fmap) (snd . runWeighted) .
sys' s (scalar x) .
V.zipWith weighted oracle .
fst
data Pointiful f a = Pointiful [f a] | Zero (f a)
instance Functor f => Functor (Pointiful f) where
fmap f (Pointiful v) = Pointiful ((fmap . fmap) f v)
fmap f (Zero x) = Zero (fmap f x)
instance Embed f m => Embed (Pointiful f) m where
emap f (Pointiful v) = Pointiful ((fmap . emap) f v)
emap f (Zero x) = Zero (emap f x)
embed = Zero . embed
instance Module f => Applicative (Pointiful f) where
pure a = Zero (pure a)
Zero f <*> Zero x = Zero (f <*> x)
Zero f <*> Pointiful xs = Pointiful (fmap (f <*>) xs)
Pointiful fs <*> Zero x = Pointiful (fmap (<*> x) fs)
Pointiful fs <*> Pointiful xs = Pointiful (convolute fs xs)
where
convolute fs xs = zipWith3 sumOfProducts [0 ..] (inits' fs) (inits' xs)
inits' = tail . inits
sumOfProducts k f x = asum (zipWith3 (times k) [0 ..] f (reverse x))
times k k1 f x = fromInteger (binomial k k1) <.> f <*> x
instance Module f => Alternative (Pointiful f) where
empty = Zero empty
Pointiful xs <|> Pointiful ys = Pointiful (zipWith (<|>) xs ys)
Pointiful (x : xs) <|> Zero y = Pointiful ((x <|> y) : xs)
Zero x <|> Pointiful (y : ys) = Pointiful ((x <|> y) : ys)
Zero x <|> Zero y = Zero (x <|> y)
Pointiful [] <|> m = m
m <|> Pointiful [] = m
instance Module f => Module (Pointiful f) where
type Scalar (Pointiful f) = Scalar f
scalar = Zero . scalar
unPointiful :: Alternative f => Pointiful f a -> [f a]
unPointiful (Pointiful as) = as
unPointiful (Zero a) = a : repeat empty
point :: Module f => Int -> System (Pointiful f) b c -> System f b c
point k s = System ((k + 1) * dim s) $ \x ->
first flatten . sys' s (Pointiful (repeat x)) . resize
where
flatten = join . fmap (V.fromList . take (k + 1) . unPointiful)
resize v = V.generate (dim s) $ \i ->
Pointiful [v V.! j | j <- [i * (k + 1) .. i * (k + 1) + k]]