{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Synthesizer.LLVM.Filter.NonRecursive (
   convolve,
   convolvePacked,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.ProcessPrivate as CausalP
import qualified Synthesizer.LLVM.CausalParameterized.RingBuffer as RingBuffer
import qualified Synthesizer.LLVM.Parameter as Param
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial

import qualified Synthesizer.LLVM.Storable.Vector as SVU
import qualified Data.StorableVector as SV

import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import LLVM.Extra.Class (undefTuple, )

import qualified LLVM.Core as LLVM
import LLVM.Core (Value, valueOf, CodeGenFunction, IsSized, SizeOf, )

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal.Number ((:*:), )

import Foreign.ForeignPtr (touchForeignPtr, )
import Foreign.Storable (Storable, )
import Foreign.Ptr (Ptr, )
import Data.Word (Word32, )

import Control.Arrow ((<<<), (&&&), )
import Control.Monad (liftM2, )

import qualified Algebra.IntegralDomain as Integral

import NumericPrelude.Numeric
import NumericPrelude.Base



{-
This is a brute-force implementation.
No Karatsuba, No Toom-Cook, No Fourier.
-}
convolve ::
   (Storable a,
    Class.MakeValueTuple a, Class.ValueTuple a ~ al,
    Memory.C al, A.PseudoRing al) =>
   Param.T p (SV.Vector a) -> CausalP.T p al al
convolve mask =
   let len = fmap SV.length mask
   in  CausalP.zipWith scalarProduct (fmap (fromIntegral :: Int -> Word32) len)
       <<<
       RingBuffer.trackConst A.zero len &&& provideMask mask

convolvePacked ::
   (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am,
    LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize,
    TypeNum.Positive n, TypeNum.Positive (n :*: amsize),
    Class.MakeValueTuple a, Class.ValueTuple a ~ al, Memory.Struct al ~ am,
    Storable a, Memory.C al,
    LLVM.IsArithmetic a) =>
   Param.T p (SV.Vector a) ->
   CausalP.T p (Serial.Value n a) (Serial.Value n a)
convolvePacked mask =
   Serial.withSize $ \vectorSize ->
      let len = fmap SV.length mask
      in  CausalP.zipWith scalarProductPacked
             (fmap (fromIntegral :: Int -> Word32) len)
          <<<
          RingBuffer.trackConst A.zero
             (fmap (flip Integral.divUp vectorSize) len)
          &&&
          provideMask mask

provideMask ::
   (Storable a,
    Class.MakeValueTuple a, Class.ValueTuple a ~ al,
    Memory.C al, Memory.Struct al ~ am) =>
   Param.T p (SV.Vector a) -> CausalP.T p x (Value (Ptr am))
provideMask mask =
   CausalP.Cons
      (\p () _x () -> return (p,()))
      (return ())
      return
      (const $ const $ return ())
      (\p ->
         let (fp,ptr,_l) = SVU.unsafeToPointers $ Param.get mask p
         in  return (fp, (ptr, ())))
      -- keep the foreign ptr alive
      touchForeignPtr


scalarProduct ::
   (Memory.C a, Memory.Struct a ~ am,
    A.PseudoRing a) =>
   Value Word32 ->
   RingBuffer.T a -> Value (Ptr am) ->
   CodeGenFunction r a
scalarProduct n rb mask =
   fmap snd $
   C.arrayLoop n mask (A.zero, A.zero) $ \ptr (k, s) -> do
      a <- RingBuffer.index k rb
      b <- Memory.load ptr
      liftM2 (,) (A.inc k) (A.add s =<< A.mul a b)

_scalarProduct ::
   (Memory.FirstClass a, Memory.Stored a ~ am, IsSized am,
    LLVM.IsArithmetic a) =>
   Value Word32 ->
   RingBuffer.T (Value a) -> Value (Ptr am) ->
   CodeGenFunction r (Value a)
_scalarProduct = scalarProduct


scalarProductPacked ::
   (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am,
    LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize,
    TypeNum.Positive n, TypeNum.Positive (n :*: amsize),
    LLVM.IsArithmetic a) =>
   Value Word32 ->
   RingBuffer.T (Serial.Value n a) -> Value (Ptr am) ->
   CodeGenFunction r (Serial.Value n a)
scalarProductPacked n0 rb mask0 = do
   (ax, rx) <- readSerialStart rb
   bx <- Memory.load mask0
   sx <- A.scale bx ax
   n1 <- A.dec n0
   mask1 <- A.advanceArrayElementPtr mask0
   fmap snd $ C.arrayLoop n1 mask1 (rx, sx) $
         \ptr (r1, s1) -> do
      (a,r2) <- readSerialNext rb r1
      b <- Memory.load ptr
      fmap ((,) r2) (A.add s1 =<< A.scale b a)


type
   Iterator n a =
      ((Serial.Value n a,
        {-
        I would like to use Serial.Iterator,
        but we need to read in reversed order,
        that is, from high to low indices.
        -}
        Serial.Value n a,
        Value Word32),
       Value Word32)

readSerialStart ::
   (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am,
    LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize,
    TypeNum.Positive n, TypeNum.Positive (n :*: amsize)) =>
   RingBuffer.T (Serial.Value n a) ->
   CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialStart rb = do
   a <- RingBuffer.index A.zero rb
   return (a, ((a, undefTuple, A.zero), A.zero))

readSerialNext ::
   (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am,
    LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize,
    TypeNum.Positive n, TypeNum.Positive (n :*: amsize)) =>
   RingBuffer.T (Serial.Value n a) ->
   Iterator n a ->
   CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialNext rb ((a0,r0,j0), k0) = do
   vectorEnd <- A.cmp LLVM.CmpEQ j0 A.zero
   ((r1,j1), k1) <-
      C.ifThen vectorEnd ((r0,j0), k0) $ do
         k <- A.inc k0
         r <- RingBuffer.index k rb
         return ((r, valueOf (fromIntegral $ Serial.size r :: Word32)), k)
   j2 <- A.dec j1
   (ai,r2) <- Serial.shiftUp undefTuple r1
   (_, a1) <- Serial.shiftUp ai a0
   return (a1, ((a1,r2,j2), k1))