module Math.Integrators.StormerVerlet
  ( integrateV,
    stormerVerlet2H,
    Integrator,
  )
where

import Control.Lens
import Control.Monad.Primitive
import Data.Vector (Vector, (!))
import Data.Vector qualified as V
import Data.Vector.Mutable
import Linear (V2 (..))

-- | Integrator function
-- -   \Phi [h] |->  y_0 -> y_1
type Integrator a =
  -- | Step
  Double ->
  -- | Initial value
  a ->
  -- | Next value
  a

-- | Störmer-Verlet integration scheme for systems of the form
-- \(\mathbb{H}(p,q) = T(p) + V(q)\)
stormerVerlet2H ::
  (Applicative f, Num (f a), Fractional a) =>
  -- | Step size
  a ->
  -- | \(\frac{\partial H}{\partial q}\)
  (f a -> f a) ->
  -- | \(\frac{\partial H}{\partial p}\)
  (f a -> f a) ->
  -- | Current \((p, q)\) as a 2-dimensional vector
  V2 (f a) ->
  -- | New \((p, q)\) as a 2-dimensional vector
  V2 (f a)
stormerVerlet2H :: forall (f :: * -> *) a.
(Applicative f, Num (f a), Fractional a) =>
a -> (f a -> f a) -> (f a -> f a) -> V2 (f a) -> V2 (f a)
stormerVerlet2H a
hh f a -> f a
nablaQ f a -> f a
nablaP V2 (f a)
prev =
  f a -> f a -> V2 (f a)
forall a. a -> a -> V2 a
V2 f a
qNew f a
pNew
  where
    h2 :: a
h2 = a
hh a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
2
    hhs :: f a
hhs = a -> f a
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
hh
    hh2s :: f a
hh2s = a -> f a
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
h2
    qsPrev :: f a
qsPrev = V2 (f a)
prev V2 (f a) -> Getting (f a) (V2 (f a)) (f a) -> f a
forall s a. s -> Getting a s a -> a
^. Getting (f a) (V2 (f a)) (f a)
forall s t a b. Field1 s t a b => Lens s t a b
Lens (V2 (f a)) (V2 (f a)) (f a) (f a)
_1
    psPrev :: f a
psPrev = V2 (f a)
prev V2 (f a) -> Getting (f a) (V2 (f a)) (f a) -> f a
forall s a. s -> Getting a s a -> a
^. Getting (f a) (V2 (f a)) (f a)
forall s t a b. Field2 s t a b => Lens s t a b
Lens (V2 (f a)) (V2 (f a)) (f a) (f a)
_2
    pp2 :: f a
pp2 = f a
psPrev f a -> f a -> f a
forall a. Num a => a -> a -> a
- f a
hh2s f a -> f a -> f a
forall a. Num a => a -> a -> a
* f a -> f a
nablaQ f a
qsPrev
    qNew :: f a
qNew = f a
qsPrev f a -> f a -> f a
forall a. Num a => a -> a -> a
+ f a
hhs f a -> f a -> f a
forall a. Num a => a -> a -> a
* f a -> f a
nablaP f a
pp2
    pNew :: f a
pNew = f a
pp2 f a -> f a -> f a
forall a. Num a => a -> a -> a
- f a
hh2s f a -> f a -> f a
forall a. Num a => a -> a -> a
* f a -> f a
nablaQ f a
qNew

-- |
-- Integrate ODE equation using fixed steps set by a vector, and returns a vector
-- of solutions corrensdonded to times that was requested.
-- It takes Vector of time points as a parameter and returns a vector of results
integrateV ::
  PrimMonad m =>
  -- | Internal integrator
  Integrator a ->
  -- | initial  value
  a ->
  -- | vector of time points
  Vector Double ->
  -- | vector of solution
  m (Vector a)
integrateV :: forall (m :: * -> *) a.
PrimMonad m =>
Integrator a -> a -> Vector Double -> m (Vector a)
integrateV Integrator a
integrator a
initial Vector Double
times = do
  MVector (PrimState m) a
out <- Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
new (Vector Double -> Int
forall a. Vector a -> Int
V.length Vector Double
times)
  MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
write MVector (PrimState m) a
out Int
0 a
initial
  a -> Int -> MVector (PrimState m) a -> m ()
forall {m :: * -> *}.
PrimMonad m =>
a -> Int -> MVector (PrimState m) a -> m ()
compute a
initial Int
1 MVector (PrimState m) a
out
  MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector (PrimState m) a
out
  where
    compute :: a -> Int -> MVector (PrimState m) a -> m ()
compute a
y Int
i MVector (PrimState m) a
out
      | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Double -> Int
forall a. Vector a -> Int
V.length Vector Double
times = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = do
          let h :: Double
h = (Vector Double
times Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
! Int
i) Double -> Double -> Double
forall a. Num a => a -> a -> a
- (Vector Double
times Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
              y' :: a
y' = Integrator a
integrator Double
h a
y
          MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
write MVector (PrimState m) a
out Int
i a
y'
          a -> Int -> MVector (PrimState m) a -> m ()
compute a
y' (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) MVector (PrimState m) a
out