{-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} module Synthesizer.LLVM.Filter.NonRecursive ( convolve, convolvePacked, ) where import qualified Synthesizer.LLVM.CausalParameterized.ProcessPrivate as CausalP import qualified Synthesizer.LLVM.CausalParameterized.RingBuffer as RingBuffer import qualified Synthesizer.LLVM.Frame.SerialVector as Serial import qualified Synthesizer.LLVM.Storable.Vector as SVU import qualified Data.StorableVector as SV import qualified LLVM.DSL.Parameter as Param import qualified LLVM.Extra.Storable as Storable import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Control as C import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Tuple as Tuple import qualified LLVM.Core as LLVM import LLVM.Core (Value, valueOf, CodeGenFunction, IsSized, SizeOf) import qualified Type.Data.Num.Decimal as TypeNum import Type.Data.Num.Decimal.Number ((:*:)) import Foreign.ForeignPtr (touchForeignPtr) import Foreign.Ptr (Ptr) import Data.Word (Word) import Control.Arrow ((<<<), (&&&)) import Control.Monad (liftM2) import qualified Algebra.IntegralDomain as Integral import NumericPrelude.Numeric import NumericPrelude.Base {- This is a brute-force implementation. No Karatsuba, No Toom-Cook, No Fourier. -} convolve :: (Storable.C a, Tuple.ValueOf a ~ al, Memory.C al, A.PseudoRing al) => Param.T p (SV.Vector a) -> CausalP.T p al al convolve mask = let len = fmap SV.length mask in CausalP.zipWith scalarProduct (fmap (fromIntegral :: Int -> Word) len) <<< RingBuffer.trackConst A.zero len &&& provideMask mask convolvePacked :: (TypeNum.Positive n, TypeNum.Positive (n :*: asize), Storable.C a, Tuple.ValueOf a ~ Value al, LLVM.IsArithmetic al, LLVM.IsPrimitive al, IsSized al, SizeOf al ~ asize) => Param.T p (SV.Vector a) -> CausalP.T p (Serial.Value n al) (Serial.Value n al) convolvePacked mask = Serial.withSize $ \vectorSize -> let len = fmap SV.length mask in CausalP.zipWith scalarProductPacked (fmap (fromIntegral :: Int -> Word) len) <<< RingBuffer.trackConst A.zero (fmap (flip Integral.divUp vectorSize) len) &&& provideMask mask provideMask :: (Storable.C a) => Param.T p (SV.Vector a) -> CausalP.T p x (Value (Ptr a)) provideMask mask = CausalP.Cons (\p () _x () -> return (p,())) (return ()) return (const $ const $ return ()) (\p -> let (fp,ptr,_l) = SVU.unsafeToPointers $ Param.get mask p in return (fp, (ptr, ()))) -- keep the foreign ptr alive touchForeignPtr scalarProduct :: (Storable.C a, Tuple.ValueOf a ~ al, Memory.C al, A.PseudoRing al) => Value Word -> RingBuffer.T al -> Value (Ptr a) -> CodeGenFunction r al scalarProduct n rb mask = fmap snd $ Storable.arrayLoop n mask (A.zero, A.zero) $ \ptr (k, s) -> do a <- RingBuffer.index k rb b <- Storable.load ptr liftM2 (,) (A.inc k) (A.add s =<< A.mul a b) _scalarProduct :: (Storable.C a, IsSized a, Tuple.ValueOf a ~ Value a, LLVM.IsArithmetic a) => Value Word -> RingBuffer.T (Value a) -> Value (Ptr a) -> CodeGenFunction r (Value a) _scalarProduct = scalarProduct scalarProductPacked :: (Storable.C a, Tuple.ValueOf a ~ Value al, LLVM.IsArithmetic al, LLVM.IsPrimitive al, IsSized al, SizeOf al ~ asize, TypeNum.Positive n, TypeNum.Positive (n :*: asize)) => Value Word -> RingBuffer.T (Serial.Value n al) -> Value (Ptr a) -> CodeGenFunction r (Serial.Value n al) scalarProductPacked n0 rb mask0 = do (ax, rx) <- readSerialStart rb bx <- Storable.load mask0 sx <- A.scale bx ax n1 <- A.dec n0 mask1 <- Storable.incrementPtr mask0 fmap snd $ Storable.arrayLoop n1 mask1 (rx, sx) $ \ptr (r1, s1) -> do (a,r2) <- readSerialNext rb r1 b <- Storable.load ptr fmap ((,) r2) (A.add s1 =<< A.scale b a) type Iterator n a = ((Serial.Value n a, {- I would like to use Serial.Iterator, but we need to read in reversed order, that is, from high to low indices. -} Serial.Value n a, Value Word), Value Word) readSerialStart :: (LLVM.IsPrimitive a, IsSized a, SizeOf a ~ asize, TypeNum.Positive n, TypeNum.Positive (n :*: asize)) => RingBuffer.T (Serial.Value n a) -> CodeGenFunction r (Serial.Value n a, Iterator n a) readSerialStart rb = do a <- RingBuffer.index A.zero rb return (a, ((a, Tuple.undef, A.zero), A.zero)) readSerialNext :: (LLVM.IsPrimitive a, IsSized a, SizeOf a ~ asize, TypeNum.Positive n, TypeNum.Positive (n :*: asize)) => RingBuffer.T (Serial.Value n a) -> Iterator n a -> CodeGenFunction r (Serial.Value n a, Iterator n a) readSerialNext rb ((a0,r0,j0), k0) = do vectorEnd <- A.cmp LLVM.CmpEQ j0 A.zero ((r1,j1), k1) <- C.ifThen vectorEnd ((r0,j0), k0) $ do k <- A.inc k0 r <- RingBuffer.index k rb return ((r, valueOf (fromIntegral $ Serial.size r :: Word)), k) j2 <- A.dec j1 (ai,r2) <- Serial.shiftUp Tuple.undef r1 (_, a1) <- Serial.shiftUp ai a0 return (a1, ((a1,r2,j2), k1))