{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Numeric.EMD (
emd
, emdTrace
, emd'
, iemd
, EMD(..)
, EMDOpts(..), defaultEO
, BoundaryHandler(..)
, Sifter
, defaultSifter
, SplineEnd(..)
, sift, SiftResult(..)
, envelopes
) where
import Control.DeepSeq
import Control.Monad.IO.Class
import Data.Default.Class
import Data.Functor.Identity
import Data.List
import GHC.Generics (Generic)
import GHC.TypeNats
import Numeric.EMD.Internal
import Numeric.EMD.Internal.Spline
import Numeric.EMD.Sift
import Text.Printf
import qualified Data.Binary as Bi
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Sized as SVG
defaultEO :: (VG.Vector v a, Fractional a, Ord a) => EMDOpts v n a
defaultEO = EO { eoSifter = defaultSifter
, eoSplineEnd = SENatural
, eoBoundaryHandler = Just BHSymmetric
}
instance (VG.Vector v a, Fractional a, Ord a) => Default (EMDOpts v n a) where
def = defaultEO
data EMD v n a = EMD { emdIMFs :: ![SVG.Vector v n a]
, emdResidual :: !(SVG.Vector v n a)
}
deriving (Show, Generic, Eq, Ord)
instance NFData (v a) => NFData (EMD v n a)
instance (VG.Vector v a, KnownNat n, Bi.Binary (v a)) => Bi.Binary (EMD v n a) where
put EMD{..} = Bi.put (SVG.fromSized <$> emdIMFs)
*> Bi.put (SVG.fromSized emdResidual)
get = do
Just emdIMFs <- traverse SVG.toSized <$> Bi.get
Just emdResidual <- SVG.toSized <$> Bi.get
pure EMD{..}
emd :: (VG.Vector v a, KnownNat n, Floating a, Ord a)
=> EMDOpts v (n + 1) a
-> SVG.Vector v (n + 1) a
-> EMD v (n + 1) a
emd eo = runIdentity . emd' (const (pure ())) eo
emdTrace
:: (VG.Vector v a, KnownNat n, Floating a, Ord a, MonadIO m)
=> EMDOpts v (n + 1) a
-> SVG.Vector v (n + 1) a
-> m (EMD v (n + 1) a)
emdTrace = emd' $ \case
SRResidual _ -> liftIO $ putStrLn "Residual found."
SRIMF _ i -> liftIO $ printf "IMF found (%d sifts)\n" i
emd'
:: (VG.Vector v a, KnownNat n, Floating a, Ord a, Applicative m)
=> (SiftResult v (n + 1) a -> m r)
-> EMDOpts v (n + 1) a
-> SVG.Vector v (n + 1) a
-> m (EMD v (n + 1) a)
emd' cb eo = go id
where
go !imfs !v = cb res *> case res of
SRResidual r -> pure $ EMD (imfs []) r
SRIMF v' _ -> go (imfs . (v':)) (v - v')
where
res = sift eo v
iemd
:: (VG.Vector v a, Num a)
=> EMD v n a
-> SVG.Vector v n a
iemd EMD{..} = foldl' (SVG.zipWith (+)) emdResidual emdIMFs