-- | This module implements IIR filters.
--
-- See: http://shepazu.github.io/Audio-EQ-Cookbook/audio-eq-cookbook.html
module LambdaSound.Filter
  ( -- * Usage
    IIRParams (..),
    applyIIRFilter,

    -- * Design
    lowPassFilter,
    highPassFilter,
    bandPassFilter,
  )
where

import Control.Monad (forM_)
import Data.Coerce (coerce)
import Data.Massiv.Array qualified as M
import Data.Massiv.Array.Unsafe qualified as MU
import Data.Maybe (fromMaybe)
import LambdaSound.Sound
import DSP.Filter.Analog.Prototype

-- | IIRParams contains the filter coefficients for the forward and
-- feedback computation
data IIRParams = IIRParams
  { IIRParams -> Vector S Float
feedforward :: !(M.Vector M.S Float),
    IIRParams -> Vector S Float
feedback :: !(M.Vector M.S Float)
  }
  deriving (Int -> IIRParams -> ShowS
[IIRParams] -> ShowS
IIRParams -> String
(Int -> IIRParams -> ShowS)
-> (IIRParams -> String)
-> ([IIRParams] -> ShowS)
-> Show IIRParams
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IIRParams -> ShowS
showsPrec :: Int -> IIRParams -> ShowS
$cshow :: IIRParams -> String
show :: IIRParams -> String
$cshowList :: [IIRParams] -> ShowS
showList :: [IIRParams] -> ShowS
Show)

-- | A low-pass filter using cutoff frequency and resonance.
lowPassFilter :: Hz -> Float -> SamplingInfo -> IIRParams
lowPassFilter :: Hz -> Float -> SamplingInfo -> IIRParams
lowPassFilter Hz
freq Float
q SamplingInfo
si =
  Vector S Float -> Vector S Float -> IIRParams
IIRParams (Comp -> [Float] -> Vector S Float
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Float
b0, Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float -> Float
forall a. Floating a => a -> a
cos Float
w0, Float
b0]) (Comp -> [Float] -> Vector S Float
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float
a, -Float
2 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float -> Float
forall a. Floating a => a -> a
cos Float
w0, Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
a])
  where
    b0 :: Float
b0 = (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float -> Float
forall a. Floating a => a -> a
cos Float
w0) Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float
2
    w0 :: Float
w0 = Hz -> Hz -> Float
calcW0 SamplingInfo
si.sampleRate Hz
scaledFreq
    a :: Float
a = Float -> Float -> Float
calcAQ Float
w0 Float
q
    scaledFreq :: Hz
scaledFreq = Hz
freq Hz -> Hz -> Hz
forall a. Fractional a => a -> a -> a
/ (SamplingInfo
si.sampleRate Hz -> Hz -> Hz
forall a. Num a => a -> a -> a
* Float -> Hz
forall a b. Coercible a b => a -> b
coerce (SamplingInfo
si.period))

-- | A high-pass filter using cutoff frequency and resonance.
highPassFilter :: Hz -> Float -> SamplingInfo -> IIRParams
highPassFilter :: Hz -> Float -> SamplingInfo -> IIRParams
highPassFilter Hz
freq Float
q SamplingInfo
si =
  Vector S Float -> Vector S Float -> IIRParams
IIRParams (Comp -> [Float] -> Vector S Float
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Float
b0, -Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float -> Float
forall a. Floating a => a -> a
cos Float
w0, Float
b0]) (Comp -> [Float] -> Vector S Float
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float
a, -Float
2 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float -> Float
forall a. Floating a => a -> a
cos Float
w0, Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
a])
  where
    b0 :: Float
b0 = (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float -> Float
forall a. Floating a => a -> a
cos Float
w0) Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float
2
    w0 :: Float
w0 = Hz -> Hz -> Float
calcW0 SamplingInfo
si.sampleRate Hz
scaledFreq
    a :: Float
a = Float -> Float -> Float
calcAQ Float
w0 Float
q
    scaledFreq :: Hz
scaledFreq = Hz
freq Hz -> Hz -> Hz
forall a. Fractional a => a -> a -> a
/ (SamplingInfo
si.sampleRate Hz -> Hz -> Hz
forall a. Num a => a -> a -> a
* Float -> Hz
forall a b. Coercible a b => a -> b
coerce (SamplingInfo
si.period))

-- | A band pass filter using cutoff frequency and resonance.
bandPassFilter :: Hz -> Float -> SamplingInfo -> IIRParams
bandPassFilter :: Hz -> Float -> SamplingInfo -> IIRParams
bandPassFilter Hz
freq Float
q SamplingInfo
si =
  Vector S Float -> Vector S Float -> IIRParams
IIRParams (Comp -> [Float] -> Vector S Float
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Float
a, Float
0, -Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
a]) (Comp -> [Float] -> Vector S Float
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float
a, -Float
2 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float -> Float
forall a. Floating a => a -> a
cos Float
w0, Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
a])
  where
    w0 :: Float
w0 = Hz -> Hz -> Float
calcW0 SamplingInfo
si.sampleRate Hz
scaledFreq
    a :: Float
a = Float -> Float -> Float
calcAQ Float
w0 Float
q
    scaledFreq :: Hz
scaledFreq = Hz
freq Hz -> Hz -> Hz
forall a. Fractional a => a -> a -> a
/ (SamplingInfo
si.sampleRate Hz -> Hz -> Hz
forall a. Num a => a -> a -> a
* Float -> Hz
forall a b. Coercible a b => a -> b
coerce (SamplingInfo
si.period))

calcW0 :: Hz -> Hz -> Float
calcW0 :: Hz -> Hz -> Float
calcW0 Hz
sampleRate Hz
freq = Hz -> Float
forall a b. Coercible a b => a -> b
coerce (Hz -> Float) -> Hz -> Float
forall a b. (a -> b) -> a -> b
$ Hz
2 Hz -> Hz -> Hz
forall a. Num a => a -> a -> a
* Hz
forall a. Floating a => a
pi Hz -> Hz -> Hz
forall a. Num a => a -> a -> a
* Hz
freq Hz -> Hz -> Hz
forall a. Fractional a => a -> a -> a
/ Hz
sampleRate

calcAQ :: Float -> Float -> Float
calcAQ :: Float -> Float -> Float
calcAQ Float
_ Float
0 = Float
0
calcAQ Float
w0 Float
q = Float -> Float
forall a. Floating a => a -> a
sin Float
w0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ (Float
2 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
q)

-- | Applies the IIR filter defined by the 'IIRParams' to the sound.
applyIIRFilter :: (SamplingInfo -> IIRParams) -> Sound d Pulse -> Sound d Pulse
applyIIRFilter :: forall (d :: SoundDuration).
(SamplingInfo -> IIRParams) -> Sound d Pulse -> Sound d Pulse
applyIIRFilter SamplingInfo -> IIRParams
makeParams Sound d Pulse
sound = Sound d Pulse -> Sound 'I Pulse -> Sound d Pulse
forall (d :: SoundDuration) a (x :: SoundDuration) b.
Sound d a -> Sound x b -> Sound d b
adoptDuration Sound d Pulse
sound (Sound 'I Pulse -> Sound d Pulse)
-> Sound 'I Pulse -> Sound d Pulse
forall a b. (a -> b) -> a -> b
$ (SamplingInfo -> Sound d Pulse) -> Sound 'I Pulse
forall (d :: SoundDuration) a.
(SamplingInfo -> Sound d a) -> Sound 'I a
withSamplingInfo ((SamplingInfo -> Sound d Pulse) -> Sound 'I Pulse)
-> (SamplingInfo -> Sound d Pulse) -> Sound 'I Pulse
forall a b. (a -> b) -> a -> b
$ \SamplingInfo
si ->
  IIRParams -> Sound d Pulse -> Sound d Pulse
forall (d :: SoundDuration).
IIRParams -> Sound d Pulse -> Sound d Pulse
applyFilter (SamplingInfo -> IIRParams
makeParams SamplingInfo
si) Sound d Pulse
sound
  where
    applyFilter :: IIRParams -> Sound d Pulse -> Sound d Pulse
    applyFilter :: forall (d :: SoundDuration).
IIRParams -> Sound d Pulse -> Sound d Pulse
applyFilter (IIRParams Vector S Float
feedforward Vector S Float
feedback') =
      let (Pulse
currentCoefficient, Vector S Float
feedback) = (Float -> Pulse
forall a b. Coercible a b => a -> b
coerce (Float -> Pulse) -> Float -> Pulse
forall a b. (a -> b) -> a -> b
$ Float -> Vector S Float -> Int -> Float
forall ix r e.
(Index ix, Manifest r e) =>
e -> Array r ix e -> ix -> e
M.defaultIndex Float
1 Vector S Float
feedback' Int
0, Vector S Float -> Vector S Float
forall r e. Source r e => Vector r e -> Vector r e
M.tail Vector S Float
feedback')
       in (Vector S Pulse -> MVector RealWorld S Pulse -> ST RealWorld ())
-> Sound d Pulse -> Sound d Pulse
forall (d :: SoundDuration).
(Vector S Pulse -> MVector RealWorld S Pulse -> ST RealWorld ())
-> Sound d Pulse -> Sound d Pulse
modifyWholeSoundST ((Vector S Pulse -> MVector RealWorld S Pulse -> ST RealWorld ())
 -> Sound d Pulse -> Sound d Pulse)
-> (Vector S Pulse -> MVector RealWorld S Pulse -> ST RealWorld ())
-> Sound d Pulse
-> Sound d Pulse
forall a b. (a -> b) -> a -> b
$ \Vector S Pulse
source MVector RealWorld S Pulse
dest -> do
            [Int] -> (Int -> ST RealWorld ()) -> ST RealWorld ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int -> Int
forall a. Enum a => a -> a
pred (Sz Int -> Int
forall ix. Sz ix -> ix
M.unSz (Sz Int -> Int) -> Sz Int -> Int
forall a b. (a -> b) -> a -> b
$ MVector RealWorld S Pulse -> Sz Int
forall ix s. Index ix => MArray s S ix Pulse -> Sz ix
forall r e ix s.
(Manifest r e, Index ix) =>
MArray s r ix e -> Sz ix
M.sizeOfMArray MVector RealWorld S Pulse
dest)] ((Int -> ST RealWorld ()) -> ST RealWorld ())
-> (Int -> ST RealWorld ()) -> ST RealWorld ()
forall a b. (a -> b) -> a -> b
$ \Int
index -> do
              let sourceValues :: Array D Int Pulse
sourceValues = (Int -> Float -> Pulse) -> Vector S Float -> Array D Int Pulse
forall r ix e a.
(Index ix, Source r e) =>
(ix -> e -> a) -> Array r ix e -> Array D ix a
M.imap (\Int
i Float
v -> Float -> Pulse
forall a b. Coercible a b => a -> b
coerce Float
v Pulse -> Pulse -> Pulse
forall a. Num a => a -> a -> a
* Pulse -> Vector S Pulse -> Int -> Pulse
forall ix r e.
(Index ix, Manifest r e) =>
e -> Array r ix e -> ix -> e
M.defaultIndex Pulse
0 Vector S Pulse
source (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i)) Vector S Float
feedforward
              Vector S Pulse
recursiveValues <- forall r ix b r' a (m :: * -> *).
(Source r' a, Manifest r b, Index ix, PrimMonad m) =>
(ix -> a -> m b) -> Array r' ix a -> m (Array r ix b)
M.itraversePrim @M.S (\Int
i Float
v -> (Float -> Pulse
forall a b. Coercible a b => a -> b
coerce Float
v *) (Pulse -> Pulse) -> (Maybe Pulse -> Pulse) -> Maybe Pulse -> Pulse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pulse -> Maybe Pulse -> Pulse
forall a. a -> Maybe a -> a
fromMaybe Pulse
0 (Maybe Pulse -> Pulse)
-> ST RealWorld (Maybe Pulse) -> ST RealWorld Pulse
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MArray (PrimState (ST RealWorld)) S Int Pulse
-> Int -> ST RealWorld (Maybe Pulse)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m (Maybe e)
M.read MVector RealWorld S Pulse
MArray (PrimState (ST RealWorld)) S Int Pulse
dest (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int
forall a. Enum a => a -> a
succ Int
i)) Vector S Float
feedback

              let currentValue :: Pulse
currentValue = (Array D Int Pulse -> Pulse
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum Array D Int Pulse
sourceValues Pulse -> Pulse -> Pulse
forall a. Num a => a -> a -> a
- Vector S Pulse -> Pulse
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum Vector S Pulse
recursiveValues) Pulse -> Pulse -> Pulse
forall a. Fractional a => a -> a -> a
/ Pulse
currentCoefficient
              MArray (PrimState (ST RealWorld)) S Int Pulse
-> Int -> Pulse -> ST RealWorld ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
MU.unsafeWrite MVector RealWorld S Pulse
MArray (PrimState (ST RealWorld)) S Int Pulse
dest Int
index Pulse
currentValue