module LambdaSound.Convolution
  (  Kernel (..),
    convolveSamples,
    convolvePercentage,
    convolveDuration,
  )
where

import LambdaSound.Sound

import Data.Massiv.Array qualified as M
import Data.Coerce (coerce)

-- | A Kernel for convolution
data Kernel p = Kernel
  { forall p. Kernel p -> p -> Float
coefficients :: p -> Float,
    forall p. Kernel p -> p
size :: p,
    forall p. Kernel p -> p
offset :: p
  }

convolve :: (Int -> Kernel Int) -> Sound d Pulse -> Sound d Pulse
convolve :: forall (d :: SoundDuration).
(Int -> Kernel Int) -> Sound d Pulse -> Sound d Pulse
convolve Int -> Kernel Int
makeKernel = (Vector S Pulse -> Vector DW Pulse)
-> Sound d Pulse -> Sound d Pulse
forall r (d :: SoundDuration).
Load r Int Pulse =>
(Vector S Pulse -> Vector r Pulse)
-> Sound d Pulse -> Sound d Pulse
modifyWholeSound ((Vector S Pulse -> Vector DW Pulse)
 -> Sound d Pulse -> Sound d Pulse)
-> (Vector S Pulse -> Vector DW Pulse)
-> Sound d Pulse
-> Sound d Pulse
forall a b. (a -> b) -> a -> b
$ \Vector S Pulse
wholeSound ->
  let (Kernel Int -> Float
coefficients Int
size Int
offset) = Int -> Kernel Int
makeKernel Int
n
      n :: Int
n = Sz Int -> Int
forall ix. Sz ix -> ix
M.unSz (Sz Int -> Int) -> Sz Int -> Int
forall a b. (a -> b) -> a -> b
$ Vector S Pulse -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Vector S Pulse
wholeSound
      stencil :: Stencil Int Pulse Pulse
stencil = Sz Int
-> Int -> ((Int -> Pulse) -> Pulse) -> Stencil Int Pulse Pulse
forall ix e a.
Index ix =>
Sz ix -> ix -> ((ix -> e) -> a) -> Stencil ix e a
M.makeStencil (Int -> Sz Int
M.Sz1 Int
size) Int
offset (((Int -> Pulse) -> Pulse) -> Stencil Int Pulse Pulse)
-> ((Int -> Pulse) -> Pulse) -> Stencil Int Pulse Pulse
forall a b. (a -> b) -> a -> b
$ \Int -> Pulse
getV ->
        Array D Int Pulse -> Pulse
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Pulse -> Pulse) -> Array D Int Pulse -> Pulse
forall a b. (a -> b) -> a -> b
$ (Int -> Pulse -> Pulse) -> Vector S Pulse -> Array D Int Pulse
forall r ix e a.
(Index ix, Source r e) =>
(ix -> e -> a) -> Array r ix e -> Array D ix a
M.imap (\Int
i -> Pulse -> Pulse -> Pulse
forall a. Num a => a -> a -> a
(*) (Pulse -> Pulse -> Pulse) -> Pulse -> Pulse -> Pulse
forall a b. (a -> b) -> a -> b
$ Int -> Pulse
getV (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
offset)) Vector S Pulse
computedCoefficients
      computedCoefficients :: Vector S Pulse
computedCoefficients =
        forall r ix e r'.
(Manifest r e, Load r' ix e) =>
Array r' ix e -> Array r ix e
M.compute @M.S (Array D Int Pulse -> Vector S Pulse)
-> Array D Int Pulse -> Vector S Pulse
forall a b. (a -> b) -> a -> b
$
          if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1
            then Pulse -> Array D Int Pulse
forall r ix e. Load r ix e => e -> Array r ix e
M.singleton Pulse
0.5
            else Comp -> Sz Int -> (Int -> Pulse) -> Array D Int Pulse
forall e. Comp -> Sz Int -> (Int -> e) -> Vector D e
M.generate Comp
M.Seq (Int -> Sz Int
M.Sz1 Int
size) ((Int -> Pulse) -> Array D Int Pulse)
-> (Int -> Pulse) -> Array D Int Pulse
forall a b. (a -> b) -> a -> b
$ \Int
i ->
              forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @_ @Pulse (Float -> Pulse) -> Float -> Pulse
forall a b. (a -> b) -> a -> b
$ Int -> Float
coefficients Int
i
   in Border Pulse
-> Stencil Int Pulse Pulse -> Vector S Pulse -> Vector DW Pulse
forall ix r e a.
(Index ix, Manifest r e) =>
Border e -> Stencil ix e a -> Array r ix e -> Array DW ix a
M.mapStencil Border Pulse
forall e. Border e
M.Reflect Stencil Int Pulse Pulse
stencil Vector S Pulse
wholeSound

-- | Convolve a 'Sound' where the 'Kernel' size is
-- determined by 'Percentage's of the sound.
convolvePercentage :: Kernel Percentage -> Sound d Pulse -> Sound d Pulse
convolvePercentage :: forall (d :: SoundDuration).
Kernel Percentage -> Sound d Pulse -> Sound d Pulse
convolvePercentage (Kernel Percentage -> Float
coefficients Percentage
sizeP Percentage
offsetP) = (Int -> Kernel Int) -> Sound d Pulse -> Sound d Pulse
forall (d :: SoundDuration).
(Int -> Kernel Int) -> Sound d Pulse -> Sound d Pulse
convolve ((Int -> Kernel Int) -> Sound d Pulse -> Sound d Pulse)
-> (Int -> Kernel Int) -> Sound d Pulse -> Sound d Pulse
forall a b. (a -> b) -> a -> b
$ \Int
n ->
  let size :: Int
size = Percentage -> Int
forall b. Integral b => Percentage -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Percentage -> Int) -> Percentage -> Int
forall a b. (a -> b) -> a -> b
$ Percentage
sizeP Percentage -> Percentage -> Percentage
forall a. Num a => a -> a -> a
* Int -> Percentage
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
   in Kernel
        { $sel:coefficients:Kernel :: Int -> Float
coefficients = \Int
i -> Percentage -> Float
coefficients (Int -> Percentage
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Percentage -> Percentage -> Percentage
forall a. Fractional a => a -> a -> a
/ Int -> Percentage
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)),
          $sel:size:Kernel :: Int
size = Int
size,
          $sel:offset:Kernel :: Int
offset = Percentage -> Int
forall b. Integral b => Percentage -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Percentage -> Int) -> Percentage -> Int
forall a b. (a -> b) -> a -> b
$ Percentage
offsetP Percentage -> Percentage -> Percentage
forall a. Num a => a -> a -> a
* Int -> Percentage
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
        }

-- | Convolve a 'Sound' where the 'Kernel' size is
-- determined by a 'Duration'.
convolveDuration :: Kernel Duration -> Sound T Pulse -> Sound T Pulse
convolveDuration :: Kernel Duration -> Sound 'T Pulse -> Sound 'T Pulse
convolveDuration (Kernel Duration -> Float
coefficients Duration
sizeD Duration
offsetD) sound :: Sound 'T Pulse
sound@(TimedSound Duration
d ComputeSound Pulse
_) =
  Kernel Percentage -> Sound 'T Pulse -> Sound 'T Pulse
forall (d :: SoundDuration).
Kernel Percentage -> Sound d Pulse -> Sound d Pulse
convolvePercentage
    ((Percentage -> Float)
-> Percentage -> Percentage -> Kernel Percentage
forall p. (p -> Float) -> p -> p -> Kernel p
Kernel (Duration -> Float
coefficients (Duration -> Float)
-> (Percentage -> Duration) -> Percentage -> Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Duration -> Duration -> Duration
forall a. Num a => a -> a -> a
* Duration
d) (Duration -> Duration)
-> (Percentage -> Duration) -> Percentage -> Duration
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Percentage -> Duration
forall a b. Coercible a b => a -> b
coerce) (Duration -> Percentage
forall a b. Coercible a b => a -> b
coerce (Duration -> Percentage) -> Duration -> Percentage
forall a b. (a -> b) -> a -> b
$ Duration
sizeD Duration -> Duration -> Duration
forall a. Fractional a => a -> a -> a
/ Duration
d) (Duration -> Percentage
forall a b. Coercible a b => a -> b
coerce (Duration -> Percentage) -> Duration -> Percentage
forall a b. (a -> b) -> a -> b
$ Duration
offsetD Duration -> Duration -> Duration
forall a. Fractional a => a -> a -> a
/ Duration
d))
    Sound 'T Pulse
sound

-- | Convolve a 'Sound' where the 'Kernel' size is
-- determined by the amount of samples. You have to keep in mind
-- that different sample rates will result in a different number of samples
-- for the same sound.
convolveSamples :: Kernel Int -> Sound T Pulse -> Sound T Pulse
convolveSamples :: Kernel Int -> Sound 'T Pulse -> Sound 'T Pulse
convolveSamples Kernel Int
kernel = (Int -> Kernel Int) -> Sound 'T Pulse -> Sound 'T Pulse
forall (d :: SoundDuration).
(Int -> Kernel Int) -> Sound d Pulse -> Sound d Pulse
convolve (Kernel Int -> Int -> Kernel Int
forall a b. a -> b -> a
const Kernel Int
kernel)