{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.EMD.Internal.Spline (
Spline, SplineEnd(..)
, makeSpline
, sampleSpline
) where
import Data.Finite
import Data.Maybe
import Data.Proxy
import Data.Type.Equality
import GHC.Generics (Generic)
import GHC.TypeLits.Compare
import GHC.TypeNats
import Numeric.EMD.Internal.Tridiagonal
import qualified Data.Binary as Bi
import qualified Data.Map as M
import qualified Data.Vector.Sized as SV
data SplineEnd a
= SENotAKnot
| SENatural
| SEClamped a a
deriving (Show, Eq, Ord, Generic)
instance Bi.Binary a => Bi.Binary (SplineEnd a)
data SplineCoef a = SC { _scAlpha :: !a
, _scBeta :: !a
, _scGamma0 :: !a
, _scGamma1 :: !a
, _scDelta :: !a
}
deriving Show
data Spline a = Spline { splineHead :: !(a, SplineCoef a)
, splineTail :: !(M.Map a (SplineCoef a))
}
runSplineCoef
:: Fractional a
=> a
-> SplineCoef a
-> a
-> a
runSplineCoef x0 (SC α β γ0 γ1 δ) x = q * γ0
+ t * γ1
+ t * q * (q * α + t * β)
where
t = (x - x0) / δ
q = 1 - t
sampleSpline
:: (Fractional a, Ord a)
=> Spline a
-> a
-> a
sampleSpline Spline{..} x = case x `M.lookupLE` splineTail of
Nothing ->
let (x0, sc) = splineHead
in runSplineCoef x0 sc x
Just (x0, sc) -> runSplineCoef x0 sc x
makeSpline
:: forall a. (Ord a, Fractional a)
=> SplineEnd a
-> M.Map a a
-> Maybe (Spline a)
makeSpline se ps = SV.withSizedList (M.toList ps) $ \(xsys :: SV.Vector n (a, a)) -> do
Refl <- Proxy @1 `isLE` Proxy @n
Refl <- Proxy @2 `isLE` Proxy @n
let xs, ys :: SV.Vector n a
(xs, ys) = SV.unzip xsys
dxs, dys :: SV.Vector (n - 1) a
dxs = SV.tail xs - SV.init xs
rdxs :: SV.Vector (n - 1) a
rdxs = recip dxs
rdxssq :: SV.Vector (n - 1) a
rdxssq = rdxs * rdxs
dys = SV.tail ys - SV.init ys
dydxssq = dys * rdxssq
mainDiag :: SV.Vector (n - 2) a
mainDiag = SV.zipWith (\rdx0 rdx1 -> 2 * ( rdx0 + rdx1 ))
(SV.init rdxs)
(SV.tail rdxs)
lowerDiag :: SV.Vector (n - 2) a
lowerDiag = SV.take rdxs
upperDiag :: SV.Vector (n - 2) a
upperDiag = SV.tail rdxs
rhs :: SV.Vector (n - 2) a
rhs = SV.zipWith (\dydxsq0 dydxsq1 -> 3 * (dydxsq0 + dydxsq1))
(SV.init dydxssq)
(SV.tail dydxssq)
EE{..} = case se of
SENotAKnot -> notAKnot rdxs rdxssq dydxssq
SENatural -> natural rdxs dydxssq
SEClamped c0 c1 -> clamped c0 c1
solution = fromMaybe (errorWithoutStackTrace "Numeric.EMD.Internal.Spline.makeSpline: Splining coefficient matrix is singular") $
solveTridiagonal ( lowerDiag `SV.snoc` eeLower1)
(eeMain0 `SV.cons` mainDiag `SV.snoc` eeMain1 )
(eeUpper0 `SV.cons` upperDiag )
(eeRhs0 `SV.cons` rhs `SV.snoc` eeRhs1 )
as :: SV.Vector (n - 1) a
as = SV.zipWith3 (\k dx dy -> k * dx - dy) (SV.init solution) dxs dys
bs :: SV.Vector (n - 1) a
bs = SV.zipWith3 (\k dx dy -> - k * dx + dy) (SV.tail solution) dxs dys
coefs :: SV.Vector (n - 1) (a, SplineCoef a)
coefs = SV.zipWith6 (\x α β γ0 γ1 δ -> (x, SC α β γ0 γ1 δ))
(SV.init xs) as bs (SV.init ys) (SV.tail ys) dxs
pure Spline
{ splineHead = SV.head coefs
, splineTail = M.fromAscList . SV.toList . SV.tail $ coefs
}
data EndEqn a = EE { eeMain0 :: !a
, eeUpper0 :: !a
, eeLower1 :: !a
, eeMain1 :: !a
, eeRhs0 :: !a
, eeRhs1 :: !a
}
natural
:: (KnownNat n, Num a)
=> SV.Vector (n + 1) a
-> SV.Vector (n + 1) a
-> EndEqn a
natural rdxs dydxssq = EE
{ eeMain0 = 2 * (rdxs `SV.index` minBound)
, eeUpper0 = rdxs `SV.index` minBound
, eeLower1 = rdxs `SV.index` maxBound
, eeMain1 = 2 * (rdxs `SV.index` maxBound)
, eeRhs0 = 3 * (dydxssq `SV.index` minBound)
, eeRhs1 = 3 * (dydxssq `SV.index` maxBound)
}
notAKnot
:: (KnownNat n, Num a)
=> SV.Vector (n + 1) a
-> SV.Vector (n + 1) a
-> SV.Vector (n + 1) a
-> EndEqn a
notAKnot rdxs rdxssq dydxssq = EE
{ eeMain0 = rdxssq `SV.index` minBound + rdx12Upper
, eeUpper0 = rdxssq `SV.index` minBound
+ rdxssq `SV.index` shift minBound
+ 2 * rdx12Upper
, eeLower1 = - (rdxssq `SV.index` weaken maxBound)
- (rdxssq `SV.index` maxBound)
- 2 * rdx12Lower
, eeMain1 = - rdxssq `SV.index` maxBound - rdx12Lower
, eeRhs0 = 2 * (dydxssq `SV.index` minBound) * (rdxs `SV.index` minBound)
+ 3 * (dydxssq `SV.index` minBound) * (rdxs `SV.index` shift minBound)
+ (dydxssq `SV.index` shift minBound) * (rdxs `SV.index` shift minBound)
, eeRhs1 = - (dydxssq `SV.index` weaken maxBound) * (rdxs `SV.index` weaken maxBound)
- 3 * (dydxssq `SV.index` maxBound) * (rdxs `SV.index` weaken maxBound)
- 2 * (dydxssq `SV.index` maxBound) * (rdxs `SV.index` maxBound)
}
where
rdx12Upper = rdxs `SV.index` minBound * rdxs `SV.index` shift minBound
rdx12Lower = rdxs `SV.index` maxBound * rdxs `SV.index` weaken maxBound
clamped
:: Num a
=> a
-> a
-> EndEqn a
clamped c0 c1 = EE
{ eeMain0 = 1
, eeUpper0 = 0
, eeLower1 = 0
, eeMain1 = 1
, eeRhs0 = c0
, eeRhs1 = c1
}