{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.Convolution.Adhoc (
   Transform2,
   karatsuba,
   cyclic,
   complex,
   ) where

import Data.Array.Accelerate.Convolution.Private (Transform2, indexPad, )
import Data.Array.Accelerate.Fourier.Private (Transform, )

import qualified Data.Array.Accelerate.Utility.Sliced1 as Sliced1
import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Lift.Acc as Acc
import Data.Array.Accelerate.Utility.Lift.Exp (expr)
import Data.Array.Accelerate.Utility.Lift.Acc (acc)

import qualified Data.Array.Accelerate.Data.Complex as Complex
import Data.Array.Accelerate.Data.Complex (Complex((:+)), )

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate
          (Exp, Acc, Array,
           Z(Z), (:.)((:.)), Any(Any), All(All), Slice, Shape, (!), )


{- |
Both arrays must have the same size.
-}
karatsuba ::
   (Shape sh, Slice sh, A.Num a) =>
   Transform2 (sh :. Int) a
karatsuba x y =
   flip A.slice (A.lift $ Any :. (0::Int) :. All)
   .
   A.afst
   .
   A.awhile
      (\arrs -> A.unit $ (Sliced.length $ A.asnd arrs) A.> 1)
      (Acc.modify (acc, acc) $
       \(z, lens) ->
          (karatsubaGo (lens ! A.index1 0) (2*(lens ! A.index1 1)-1) z,
           Sliced.tail lens))
   .
   (Acc.modify ((acc, acc), acc) $
    \((x0,y0), lens) -> (A.zipWith (*) x0 y0, lens))
   .
   A.awhile
      (\arrs -> A.unit $ (Sliced.length $ A.afst $ A.afst arrs) A.> 1)
      (Acc.modify ((acc, acc), acc) $
       \((x0,y0), lens) ->
          let (x1,y1) = karatsubaReorder (x0,y0)
          in  ((x1,y1), Sliced.consExp (Sliced.length x1) lens))
   $
   A.lift
      ((A.replicate (A.lift $ Any :. (1::Int) :. All) x,
        A.replicate (A.lift $ Any :. (1::Int) :. All) y),
       A.fill (A.constant $ Z:.1) (Sliced.length x))

karatsubaReorder ::
   (Shape sh, Slice sh, A.Num a) =>
   (Acc (Array (sh :. Int :. Int) a),
    Acc (Array (sh :. Int :. Int) a)) ->
   (Acc (Array (sh :. Int :. Int) a),
    Acc (Array (sh :. Int :. Int) a))
karatsubaReorder (x,y) =
   let len2 = - div (- Sliced.length x) 2
       xl = Sliced.take len2 x
       yl = Sliced.take len2 y
       xr = Sliced.pad 0 len2 $ Sliced.drop len2 x
       yr = Sliced.pad 0 len2 $ Sliced.drop len2 y
   in  (Sliced1.append3 xl (A.zipWith (+) xl xr) xr,
        Sliced1.append3 yl (A.zipWith (+) yl yr) yr)

karatsubaGo ::
   (Shape sh, Slice sh, A.Num a) =>
   Exp Int ->
   Exp Int ->
   Transform (sh :. Int :. Int) a
karatsubaGo xlen zlen zmerged =
   let (sh:.n:._m) = Exp.unlift (expr:.expr:.expr) $ A.shape zmerged
       n3 = div n 3
       zl = Sliced1.take n3 zmerged
       zm = Sliced1.drop n3 zmerged
       zr = Sliced1.drop (2*n3) zmerged
       zc = A.zipWith (-) zm $ A.zipWith (+) zl zr
   in  A.generate (A.lift $ sh :. n3 :. zlen) $
       Exp.modify (expr:.expr) $
       \(ix:.k) ->
          indexPad (ix:.k)        zl +
          indexPad (ix:.k-xlen)   zc +
          indexPad (ix:.k-xlen*2) zr


{- |
Turn an ordinary convolution into a cyclic convolution of the same length.
-}
cyclic ::
   (Shape sh, Slice sh, A.Num a) =>
   Transform2 (sh :. Int) a ->
   Transform2 (sh :. Int) a
cyclic conv x y =
   let z = conv x y
       len = Sliced.length x
   in  A.zipWith (+) z $ Sliced.pad 0 len $ Sliced.drop len z


{- |
Turn a real-valued convolution into a complex-valued convolution.
Can be removed when we get @instance IsNum (Complex a)@.
-}
complex, _complex ::
   (Shape sh, Slice sh, A.Num a) =>
   Transform2 (sh :. Int) a ->
   Transform2 (sh :. Int) (Complex a)
complex conv x y =
   let xr = A.map Complex.real x; xi = A.map Complex.imag x
       yr = A.map Complex.real y; yi = A.map Complex.imag y
       xm = A.zipWith (+) xr xi
       ym = A.zipWith (+) yr yi
       xryr = conv xr yr
       xiyi = conv xi yi
       xmym = conv xm ym
   in  A.zipWith
          (Exp.modify2 expr expr (:+))
          (A.zipWith (-) xryr xiyi)
          (A.zipWith (-) xmym $ A.zipWith (+) xryr xiyi)

_complex conv x y =
   let xr = A.map Complex.real x; xi = A.map Complex.imag x
       yr = A.map Complex.real y; yi = A.map Complex.imag y
   in  A.zipWith
          (Exp.modify2 expr expr (:+))
          (A.zipWith (-) (conv xr yr) (conv xi yi))
          (A.zipWith (+) (conv xr yi) (conv xi yr))