module Simulation.Aivika.IO.Dynamics.Memo () where
import Control.Monad
import Control.Monad.Trans
import Data.Array.IO.Safe
import Data.Array.MArray.Safe
import Data.IORef
import Simulation.Aivika.Trans.Internal.Specs
import Simulation.Aivika.Trans.Internal.Parameter
import Simulation.Aivika.Trans.Internal.Simulation
import Simulation.Aivika.Trans.Internal.Dynamics
import Simulation.Aivika.Trans.Dynamics.Memo
import Simulation.Aivika.Trans.Dynamics.Extra
import Simulation.Aivika.Trans.Array
instance MonadMemo IO where
memoDynamics (Dynamics m) =
Simulation $ \r ->
do let sc = runSpecs r
(phl, phu) = integPhaseBnds sc
(nl, nu) = integIterationBnds sc
arr <- liftIO $ newIOArray_ ((phl, nl), (phu, nu))
nref <- liftIO $ newIORef 0
phref <- liftIO $ newIORef 0
let r p =
do let n = pointIteration p
ph = pointPhase p
loop n' ph' =
if (n' > n) || ((n' == n) && (ph' > ph))
then
liftIO $ readArray arr (ph, n)
else
let p' = p { pointIteration = n', pointPhase = ph',
pointTime = basicTime sc n' ph' }
in do a <- m p'
a `seq` liftIO $ writeArray arr (ph', n') a
if ph' >= phu
then do liftIO $ writeIORef phref 0
liftIO $ writeIORef nref (n' + 1)
loop (n' + 1) 0
else do liftIO $ writeIORef phref (ph' + 1)
loop n' (ph' + 1)
n' <- liftIO $ readIORef nref
ph' <- liftIO $ readIORef phref
loop n' ph'
return $ interpolateDynamics $ Dynamics r
memo0Dynamics (Dynamics m) =
Simulation $ \r ->
do let sc = runSpecs r
bnds = integIterationBnds sc
arr <- liftIO $ newIOArray_ bnds
nref <- liftIO $ newIORef 0
let r p =
do let sc = pointSpecs p
n = pointIteration p
loop n' =
if n' > n
then
liftIO $ readArray arr n
else
let p' = p { pointIteration = n', pointPhase = 0,
pointTime = basicTime sc n' 0 }
in do a <- m p'
a `seq` liftIO $ writeArray arr n' a
liftIO $ writeIORef nref (n' + 1)
loop (n' + 1)
n' <- liftIO $ readIORef nref
loop n'
return $ discreteDynamics $ Dynamics r
iterateDynamics (Dynamics m) =
Simulation $ \r ->
do let sc = runSpecs r
nref <- liftIO $ newIORef 0
let r p =
do let sc = pointSpecs p
n = pointIteration p
loop n' =
unless (n' > n) $
let p' = p { pointIteration = n', pointPhase = 0,
pointTime = basicTime sc n' 0 }
in do a <- m p'
a `seq` liftIO $ writeIORef nref (n' + 1)
loop (n' + 1)
n' <- liftIO $ readIORef nref
loop n'
return $ discreteDynamics $ Dynamics r