{-# LANGUAGE FlexibleInstances #-}

module HaskellWorks.Data.Simd.Comparison.Stock
  ( CmpEqWord8s(..)
  ) where

import Data.Word
import HaskellWorks.Data.AtIndex
import HaskellWorks.Data.Bits.BitWise
import HaskellWorks.Data.Simd.Internal.Bits
import HaskellWorks.Data.Simd.Internal.Broadword

import qualified Data.ByteString                    as BS
import qualified Data.Vector.Storable               as DVS
import qualified HaskellWorks.Data.ByteString       as BS
import qualified HaskellWorks.Data.Simd.ChunkString as CS
import qualified HaskellWorks.Data.Vector.AsVector8 as V

class CmpEqWord8s a where
  cmpEqWord8s :: Word8 -> a -> a

instance CmpEqWord8s (DVS.Vector Word8) where
  cmpEqWord8s :: Word8 -> Vector Word8 -> Vector Word8
cmpEqWord8s Word8
w8 Vector Word8
v = forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast (forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 (forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word8
v :: DVS.Vector Word64))

instance CmpEqWord8s (DVS.Vector Word64) where
  cmpEqWord8s :: Word8 -> Vector Word64 -> Vector Word64
cmpEqWord8s Word8
w8 Vector Word64
v = forall a. Storable a => Int -> (Vector a -> a) -> Vector a
DVS.constructN ((forall a. Storable a => Vector a -> Int
DVS.length Vector Word64
v forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8) Vector Word64 -> Word64
go
    where iw :: Word64
iw = forall a. FillWord64 a => a -> Word64
fillWord64 Word8
w8
          go :: DVS.Vector Word64 -> Word64
          go :: Vector Word64 -> Word64
go Vector Word64
u = let ui :: Position
ui = forall v. Length v => v -> Position
end Vector Word64
u in
            if Position
ui forall a. Num a => a -> a -> a
* Position
8 forall a. Num a => a -> a -> a
+ Position
8 forall a. Ord a => a -> a -> Bool
< forall v. Length v => v -> Position
end Vector Word64
v
              then  let vi :: Position
vi  = Position
ui forall a. Num a => a -> a -> a
* Position
8
                        w0 :: Word64
w0  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
0)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w1 :: Word64
w1  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
1)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w2 :: Word64
w2  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
2)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w3 :: Word64
w3  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
3)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w4 :: Word64
w4  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
4)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w5 :: Word64
w5  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
5)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w6 :: Word64
w6  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
6)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w7 :: Word64
w7  = Word64 -> Word64
testWord8s ((Vector Word64
v forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi forall a. Num a => a -> a -> a
+ Position
7)) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w :: Word64
w   = (Word64
w7 forall a. Shift a => a -> Word64 -> a
.<. Word64
56) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w6 forall a. Shift a => a -> Word64 -> a
.<. Word64
48) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w5 forall a. Shift a => a -> Word64 -> a
.<. Word64
40) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w4 forall a. Shift a => a -> Word64 -> a
.<. Word64
32) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w3 forall a. Shift a => a -> Word64 -> a
.<. Word64
24) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w2 forall a. Shift a => a -> Word64 -> a
.<. Word64
16) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w1 forall a. Shift a => a -> Word64 -> a
.<.  Word64
8) forall a. BitWise a => a -> a -> a
.|.
                                Word64
w0
                    in forall a. BitWise a => a -> a
comp Word64
w
              else  let vi :: Position
vi  = Position
ui forall a. Num a => a -> a -> a
* Position
8
                        w0 :: Word64
w0  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
0) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w1 :: Word64
w1  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
1) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w2 :: Word64
w2  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
2) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w3 :: Word64
w3  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
3) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w4 :: Word64
w4  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
4) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w5 :: Word64
w5  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
5) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w6 :: Word64
w6  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
6) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w7 :: Word64
w7  = Word64 -> Word64
testWord8s (forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Word64
0 Vector Word64
v (Position
vi forall a. Num a => a -> a -> a
+ Position
7) forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w :: Word64
w   = (Word64
w7 forall a. Shift a => a -> Word64 -> a
.<. Word64
56) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w6 forall a. Shift a => a -> Word64 -> a
.<. Word64
48) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w5 forall a. Shift a => a -> Word64 -> a
.<. Word64
40) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w4 forall a. Shift a => a -> Word64 -> a
.<. Word64
32) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w3 forall a. Shift a => a -> Word64 -> a
.<. Word64
24) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w2 forall a. Shift a => a -> Word64 -> a
.<. Word64
16) forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w1 forall a. Shift a => a -> Word64 -> a
.<.  Word64
8) forall a. BitWise a => a -> a -> a
.|.
                              Word64
w0
                    in forall a. BitWise a => a -> a
comp Word64
w
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word64] where
  cmpEqWord8s :: Word8 -> [Vector Word64] -> [Vector Word64]
cmpEqWord8s Word8
w8 [Vector Word64]
vs = forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word64]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word8] where
  cmpEqWord8s :: Word8 -> [Vector Word8] -> [Vector Word8]
cmpEqWord8s Word8
w8 [Vector Word8]
vs = forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word8]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [BS.ByteString] where
  cmpEqWord8s :: Word8 -> [ByteString] -> [ByteString]
cmpEqWord8s Word8
w8 [ByteString]
vs = forall a. ToByteString a => a -> ByteString
BS.toByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. AsVector8 a => a -> Vector Word8
V.asVector8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s CS.ChunkString where
  cmpEqWord8s :: Word8 -> ChunkString -> ChunkString
cmpEqWord8s Word8
w8 = forall a. ToChunkString a => a -> ChunkString
CS.toChunkString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToByteStrings a => a -> [ByteString]
BS.toByteStrings
  {-# INLINE cmpEqWord8s #-}