{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Numeric.EMD.Internal.Tridiagonal (
solveTridiagonal
) where
import Control.Applicative.Backwards
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Data.Finite
import Data.Foldable
import GHC.TypeNats
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable.Sized as SMVG
import qualified Data.Vector.Generic.Sized as SVG
solveTridiagonal
:: forall v n a. (VG.Vector v a, KnownNat n, Fractional a, Eq a)
=> SVG.Vector v (n + 1) a
-> SVG.Vector v (n + 2) a
-> SVG.Vector v (n + 1) a
-> SVG.Vector v (n + 2) a
-> Maybe (SVG.Vector v (n + 2) a)
solveTridiagonal as bs cs ds = runST $ runMaybeT $ do
guard $ SVG.head bs /= 0
cs' <- makeCs
mxs <- lift $ SVG.thaw ds
makeDs cs' mxs
forwards . for_ (consecFinites @(n + 1)) $ \(i0, i1) -> Backwards $ do
x1 <- lift $ SMVG.read mxs i1
let sbr = cs' `SVG.index` i0 * x1
lift $ SMVG.modify mxs (subtract sbr) (weaken i0)
lift $ SVG.freeze mxs
where
makeCs :: MaybeT (ST s) (SVG.Vector v (n + 1) a)
makeCs = do
mcs <- lift $ SVG.thaw cs
lift $ SMVG.modify mcs (/ SVG.head bs) minBound
for_ (consecFinites @n) $ \(i0, i1) -> do
c0 <- lift $ SMVG.read mcs (weaken i0)
let dvr = bs `SVG.index` weaken i1
- as `SVG.index` weaken i0 * c0
guard $ dvr /= 0
lift $ SMVG.modify mcs (/ dvr) i1
lift $ SVG.freeze mcs
makeDs
:: SVG.Vector v (n + 1) a
-> SVG.MVector (VG.Mutable v) (n + 2) s a
-> MaybeT (ST s) ()
makeDs cs' mds = do
lift $ SMVG.modify mds (/ SVG.head bs) minBound
for_ (consecFinites @(n + 1)) $ \(i0, i1) -> do
let c0 = cs' `SVG.index` i0
d0 <- lift $ SMVG.read mds (weaken i0)
let sbr = as `SVG.index` i0 * d0
dvr = bs `SVG.index` i1
- as `SVG.index` i0 * c0
guard $ dvr /= 0
lift $ SMVG.modify mds ((/ dvr) . subtract sbr) i1
consecFinites :: KnownNat n => [(Finite n, Finite (n + 1))]
consecFinites = zip finites (tail finites)