{-# LANGUAGE NoMonomorphismRestriction #-}

module Tutorial where

import qualified Prelude as P

import Feldspar
import Feldspar.Compiler
import Feldspar.Vector

-- This file contains the examples of vector programming
-- covered in the draft Feldspar Tutorial

type UInt = Data DefaultWord  -- Feldspar usigned int

type VInt = DVector DefaultInt


-- some functions for manipulating bits of indices
-- Remove those that are not used when tutorial is finshed

complN :: Data Index -> Data Index -> Data Index
complN k = (`xor` oneBitsN k)

oneBitsN :: Data Index -> Data Index
oneBitsN  = complement . zeroBitsN

zeroBitsN :: Data Index -> Data Index
zeroBitsN = (allOnes <<) 

allOnes :: Data Index
allOnes = complement 0

complTo :: Data Index -> Data Index -> Data Index
complTo k = (`xor` oneBitsTo k)

flipBit ::  Data Index -> Data Index -> Data Index
flipBit k = (`xor` (1<<k))

bitZero :: Data Index -> Data Index -> Data Bool
bitZero k i = (i .&. (1<<k)) == 0

bitOne :: Data Index -> Data Index -> Data Bool
bitOne k i = (i .&. (1<<k)) /= 0

lsbZero :: Data Index -> Data Bool
lsbZero i = (i .&. 1) == 0

zeroBitsTo :: Data Index -> Data Index
zeroBitsTo n = (zeroBitsN n) << 1

oneBitsTo :: Data Index -> Data Index
oneBitsTo = complement . zeroBitsTo

lsbsTo :: Data Index -> Data Index -> Data Index
lsbsTo k = (.&. oneBitsTo k)

lsbsN :: Data Index -> Data Index -> Data Index
lsbsN k i = i .&. oneBitsN k

lsb :: Data Index -> Data Index
lsb = (.&. 1)

-- lsb moves from position 0 to position k. Bits 1 to k shift one bit right
rotBitFrom0 :: Data Index -> Data Index -> Data Index
rotBitFrom0 k i = lefts .|. b .|. rights
  where
    r1s = complement l1s
    l1s = zeroBitsTo k  -- k+1 0s on right, rest 1s
    b = (i .&. 1) << k
    rights = (i .&. r1s) >> 1
    lefts = i .&. l1s 




v1 = value [[1,2],[3,4],[5,(6::Int32)]]

func :: Data Int32 -> Data Int32 -> Data Int32 -> Data Bool
func a b c = a + b == c


testParallel :: Data [Index]
testParallel = parallel 25 (\i -> 50-(2 * i))

setSize :: (Type a, Type b) => Length -> (Vector (Data a) -> Vector (Data b)) ->
                 Data [a] -> Vector (Data b)
setSize n f = f . unfreezeVector' n 



setSize'' :: (Type a1, Type a2, Type a) => Length -> 
             (Vector (Data a1) -> Vector (Data a2) -> Data a) ->
            Data [a1] -> Data [a2] -> Data a
setSize'' n f as bs = f (unfreezeVector' n as) (unfreezeVector' n bs)


scalarProduct1 a b = forLoop (min (length a) (length b)) 0 (\ix sum -> sum + a!ix * b!ix)

-- This function is actually built in as scalarProd
scalarProduct as bs = sum (zipWith (*) as bs)


sc = icompile (setSize'' 256 (scalarProduct :: VInt -> VInt -> Data DefaultInt))


sc' = icompile (scalarProduct :: VInt -> VInt -> Data DefaultInt)

countUp :: Data Length -> DVector Index
countUp n = indexed n id

countUp1 :: Data Length -> DVector Index
countUp1 n = map (+1) (countUp n)

countUpFrom :: Data Index -> Data Length -> DVector Index
countUpFrom m n = indexed n (+m)



countDown :: Data Length -> DVector Index
countDown n = reverse (countUp n)

ex1 = eval (countUp1 6)

ex2 = eval (countDown 6)


revmap0 :: DVector Float -> DVector Float
revmap0 xs = map (+1) (reverse xs)

revmap :: DVector Float -> DVector Float
revmap = map (+1) . reverse

-- revmap1 :: DVector a -> DVector a     This type is incorrect
revmap1 :: (Numeric a) => DVector a -> DVector a
revmap1 = map (+1) . reverse

ex3' = eval (revmap1 (vector [0..5]))    -- Missing type declaration

ex3 = eval (revmap1 (vector [0..5] :: DVector Int32))

ex4 f n = eval (f (vector [0..n] :: VInt))


cx1 = icompile (revmap1 :: VInt -> VInt)

cx2 = icompile (setSize 256 (revmap1 :: VInt -> VInt))


halveZip :: (Syntactic a) => (a -> a -> c) -> Vector a -> Vector c
halveZip f as = ms
  where
    (ls,rs) = splitAt halfl as
    ms = zipWith f ls rs
    l = length as
    halfl = div l 2




cx3 = icompile (setSize 256 (halveZip min :: VInt -> VInt))






propUniv2 _ _ = universal

mmin :: (P.Ord a, Type a) => Data a -> Data a -> Data a
mmin = function2 "min" propUniv2 P.min 

mmax :: (P.Ord a, Type a) => Data a -> Data a -> Data a
mmax = function2 "max" propUniv2 P.max

cx4 = icompile (setSize 256  (halveZip mmin :: VInt -> VInt))

cx4' = icompile (setSize 255  (halveZip mmin :: VInt -> VInt))

cx4'' = icompile (setSize 256 ((take 3 . halveZip mmin) :: VInt -> VInt))




both :: (Syntactic a) => (a -> a -> c) -> (a -> a -> c) -> Vector a -> Vector c
both f g as = fs ++ gs
  where
    fs = halveZip f as
    gs = halveZip g as


ex5 n = eval $ (both mmin mmax . reverse) (vector [1..n] :: VInt)

cx5 = icompile (setSize 256 (both mmin mmax :: VInt -> VInt))

cx6 = icompile (setSize 256 (mergeSegments . both mmin mmax :: VInt -> VInt))



premap :: (Data Index -> Data Index) -> Vector a -> Vector a
premap f (Indexed l ixf Empty) = indexed l (ixf . f)

swapOE1 :: (Syntactic a) => Vector a -> Vector a
swapOE1 v = indexed (length v) ixf
  where
    ixf i = condition (i `mod` 2 == 0) (v!(i+1)) (v!(i-1))

swapOE2 :: Vector a -> Vector a
swapOE2 = premap (\i -> condition (i `mod` 2 == 0) (i+1)(i-1))

swapOE3 :: Vector a -> Vector a
swapOE3 = premap (`xor` 1)



ex6 = eval (swapOE3 (vector [0..15] :: VInt))

ex7 = eval (swapOE3 (vector [1..17] :: VInt))

cx7'   = icompile $ setSize 256 (swapOE1 :: VInt -> VInt)

cx7''  = icompile $ setSize 256 (swapOE2 :: VInt -> VInt)

cx7''' = icompile $ setSize 256 (swapOE3 :: VInt -> VInt)



selEvenIx :: Vector a -> Vector a
selEvenIx as = take l (premap (*2) as)
  where
    l = (length as + 1) `div` 2

exer1a = eval (selEvenIx (vector [0..16] :: VInt))


cexer1 = icompile (setSize 256 (selEvenIx :: VInt -> VInt))


rev1 :: Vector a -> Vector a
rev1 v = premap (\i -> l-1-i) v
  where
    l = length v






ex8 = eval $ map (complN 4)(vector [0..15])

ex9 = eval $ map (complN 2)(vector [0..15])


revi :: Data Index -> Vector a -> Vector a
revi k = premap (complN k)

ex10 = eval (revi 2 (vector [0..15] :: VInt))

ex11 = eval (revi 2 (vector [0..31] :: VInt))



cx7 = icompile (setSize 256 (revi 4 . map (+1) :: VInt -> VInt))



ex12 = eval (fold (+) 0 (vector [0..15] :: VInt))




sumEven :: Vector UInt  -> UInt
sumEven = sum . map keepEven
  where
    keepEven i = condition (i `mod` 2 == 0) i 0

exer2a = eval (sumEven (vector [0..31] :: Vector UInt))

cexer2 = icompile sumEven


onCond :: Data Bool -> UInt -> UInt
onCond b m = m .&. (- (b2i b))

isEven i = i .&. 1 == 0

exer2b i = eval (onCond (isEven i) i)

sumEven1 :: Vector UInt  -> UInt
sumEven1 = sum . map keepEven1
  where
    keepEven1 i = onCond (isEven i) i

cexer2a = icompile sumEven1

pipe :: (Syntactic a) => (Data Index -> a -> a) -> Vector (Data Index) -> a -> a
pipe = flip . fold . flip 



fact :: UInt -> UInt
fact i = pipe f (countUp1 i) 1
  where
    f i = (* i)

fact1 :: UInt -> UInt
fact1 i = pipe f (2...i) 1
  where
    f i = (* i)

fact2 :: UInt -> UInt
fact2 i = pipe f (countUpFrom 2 (i-1)) i
  where
    f i = (* i)


fact3 :: UInt -> UInt
fact3 i = pipe f (map (+2) (countUp (i-1))) 1
  where
    f i = (* i)


fact4 :: UInt -> UInt
fact4 i = fold1 (*) (countUp1 i)


ex13' = eval (fact 5)



cx8 = icompile fact

cx8' = icompile fact1

cx8'' = icompile fact2



bitr :: Data Index -> Data Index -> Data Index
bitr n i = snd (pipe stage (countUp n) (i, i >> n))
  where
    stage _ (i,r) = (i>>1, (i .&. 1) .|. (r<<1))

bitRev :: Data Index -> Vector a -> Vector a
bitRev n = premap (bitr n)

cx9 = icompile (bitRev :: Data Index -> VInt -> VInt)

ex13''  = eval (bitRev 8 (vector [0..255] :: VInt))
ex13''' = eval (bitRevLog 3 (vector [0..255] :: VInt))
ex13''''= eval (bitRevH 8 (vector [0..255] :: VInt))


mergeBy :: Data Index -> Data Index -> Data Index -> Data Index
mergeBy m a b = (a .&. m) .|. (b .&. complement m)


-- if you know that the number of bits to be reversed is a power
-- of two, you can do this nice trick
-- swap adjacent bits, then swap pairs of bits etc.
-- The number of bits to be reversed is 2^n
bitrLog :: Data Index -> Data Index -> Data Index
bitrLog n i = snd (pipe stage (map (1<<) (countDown n))  (allOnes, i))
  where
    stage s (mask, v) = (mask', mergeBy mask' (v>>s) (v<<s))
      where
        mask' = (mask `xor` (mask << s)) .|. zeroBitsN (1 << n)

bitRevLog :: Data Index -> Vector a -> Vector a
bitRevLog n = premap (bitrLog n)

cx10 = icompile (bitRevLog :: Data Index -> VInt -> VInt)


composeN :: Index -> (a -> a) -> a -> a
composeN 0 f = id
composeN n f = (composeN (n-1) f) . f




-- Make this one an exercise!
composeList :: [ a -> a ] -> a -> a
composeList [] = id
composeList (f:fs) = composeList fs . f


-- same as bitr but now n is a Haskell level value
bitrH n i = snd (composeN n stage (i, i >> vn))
  where
    stage (i,r) = (i>>1, (i .&. 1) .|. (r<<1))
    vn = value n


bitRevH :: Index -> Vector a -> Vector a
bitRevH n = premap (bitrH n)

cx11 = icompile (bitRevH 8 :: VInt -> VInt)



bitrLogH :: Index -> Data Index -> Data Index
bitrLogH n i = snd (composeList fns (allOnes, i))
  where
    fns = [stage (1 << (value ix)) | ix <- P.reverse [0..n-1]]
    stage s (mask, v) = (mask', mergeBy mask' (v>>s) (v<<s))
      where
        mask' = (mask `xor` (mask << s)) .|. zeroBitsN (1 << (value n))

bitRevLogH :: Index -> Vector a -> Vector a
bitRevLogH n = premap (bitrLogH n)

cx12 = icompile (bitRevLogH 4 :: VInt -> VInt)





comb :: (Syntactic a) =>
        (t -> t -> a) -> (t -> t -> a)
         -> (Data Index -> Data Bool) -> (Data Index -> Data Index)
         -> Vector t 
         -> Vector a
comb f g c p (Indexed l ixf Empty) = indexed l ixf'
  where
    ixf' i = condition (c i) (f a b) (g a b)
      where
        a = ixf i
        b = ixf (p i)

apart :: (Syntactic a) =>
         (t -> t -> a) -> (t -> t -> a)
         -> Data Index
         -> Vector t
         -> Vector a
apart f g k = comb f g (bitZero k) (flipBit k)


ex13 k = eval (apart mmin mmax k (vector [7,6,5,4,3,2,1,0] :: VInt))

ex14 = [ ex13 (value i) | i <- [0..2] ]


batMerge :: (P.Ord a, Type a) => Data Index -> DVector a -> DVector a
batMerge n = pipe (apart mmin mmax) (countDown n)

halfRev :: (Type a) => Data Index -> DVector a -> DVector a
halfRev n = premap (\i -> (condition (bitZero n' i) i (complN n' i)))
  where
    n' = n-1

-- works on 2^n length sub-arrays
-- works on 2^n length sub-arrays
halfRev1 :: (Type a) => Data Index -> DVector a -> DVector a
halfRev1 n  = premap (\i -> i `xor`  (onCond (bitOne n' i) (oneBitsN n')))
  where
    n' = n-1

ex15 = [eval (halfRev1 (value k) (vector [0..15] :: VInt)) | k <- [1..4]]



merge :: (P.Ord a, Type a) => Data Index -> DVector a -> DVector a
merge n = batMerge n . halfRev1 n

sortV :: (P.Ord a, Type a) => Data Index -> DVector a -> DVector a
sortV n = pipe merge (countUp1 n)

ex16 k = eval (sortV k (vector [0,1,2,3,12,5,6,7,1,14,13,12,11,19,9,8] :: VInt))


cx13 = icompile (sortV :: Data Index -> VInt -> VInt)



-- sorter on each 2^n length sub-array of inputs, n > 0
-- inside the merger, one loop body is unwound to permit fusion with halfRev1
sort1 :: (P.Ord a, Type a) => Data Index -> DVector a -> DVector a
sort1 n = pipe merge (countUp1 n)
  where
    merge n = batMerge (n-1) . apart mmin mmax (n-1) . halfRev1 n

cx13' = icompile (sort1 :: Data Index -> VInt -> VInt)


fex = apart mmin mmax 0 . apart mmin mmax 1 . apart mmin mmax 2

cx14 = icompile (setSize 256 (fex :: VInt -> VInt))


fexforce :: (P.Ord a, Type a) => DVector a -> DVector a
fexforce = apart mmin mmax 0 . force . 
           apart mmin mmax 1 . force . 
           apart mmin mmax 2

cx15 = icompile (setSize 256 (fexforce :: VInt -> VInt))


fir1 :: Data Float -> Data Float -> DVector Float -> DVector Float
fir1 a0 a1 vec = map (\(x,y) -> a0*x + a1*y) $ zip vec (tail vec)

lowPass :: Data Float -> DVector Float -> DVector Float
lowPass x = fir1 x (1-x)

highPass :: Data Float -> DVector Float -> DVector Float
highPass x = fir1 x (x-1)

bandPass1 :: Data Float -> DVector Float -> DVector Float
bandPass1 x = highPass x . lowPass x

bandPass2 :: Data Float -> DVector Float -> DVector Float
bandPass2 x = highPass x . force . lowPass x


cx16 = icompile (setSize 256 (lowPass  0.5))

cx17 = icompile (setSize 256 (highPass 0.5))

cx18 = icompile (setSize 256 (bandPass1 0.5))

cx19 = icompile (setSize 256 (bandPass2 0.5))





riffle :: Data Index -> Vector a -> Vector a
riffle k = premap (rotBitFrom0 k)

bitRev1 :: Type a => Data Index -> Vector (Data a) -> Vector (Data a)
bitRev1 n = pipe riffle (countUp1 n)



ex17 = eval (riffle 3 (countUp 16))

cx20 = icompile (bitRev1 :: Data Index -> VInt -> VInt)



combx f g c p x (Indexed l ixf Empty) =  indexed l ixf'
      where
        ixf' i = condition (c i) (f ai pi xi) (g pi ai xi)
          where
            ai = ixf i
            pi = ixf (p i)
            xi = x i


pows2 :: Data Length -> DVector Index
pows2 k = indexed k (1<<)

-- 2^l input FFT. Applies to sub-parts of input vector
-- of length 2^l. Produces each of the results in bit reversed order.
-- There is currently no check that the input vector is at least of length 2^l
fft :: Data Index ->  DVector (Complex Float) -> DVector (Complex Float) 
fft l = pipe stage (countDown l)
  where
    stage k = combx f g (bitZero k) (`xor` p)  twid
      where
        p = 1<<k
        f a b _ = a + b
        g a b t = t * (a-b)
        twid i  = cis (-pi*(i2f (lsbsN k i)) / i2f p)



ex18' = eval ((bitRev 3 . fft 3) (testseq1 3))




-- 2^l input IFFT. Produces output in bit reversed order.
ifft :: Data Index ->  DVector (Complex Float) -> DVector (Complex Float) 
ifft l = map (/ (complex (i2f (2^l)) 0)) . pipe stage (countDown l)
  where
    stage k = combx f g (bitZero k) (`xor` p)  twid
      where
        p = 1<<k
        f a b _ = a + b
        g a b t = t * (a-b)
        twid i  = cis (pi*(i2f (lsbsN k i)) / i2f p)





cx21 = icompile fft

testseq :: Data DefaultWord -> DVector (Complex Float)
testseq n = mergeSegments (seq ++ reverse seq)
                    where seq = (indexed (2 ^ (n - 1)) (\i -> complex (i2f i) 0 ))

testseq1 :: Data DefaultWord -> DVector (Complex Float)
testseq1 n = indexed (2^n) (\i -> complex (i2f i) 0 )


ex18 k = eval ((bitRev k . ifft k  . bitRev k . fft k) (testseq k) :: DVector (Complex Float) )