{-# LANGUAGE FlexibleInstances #-}

module HaskellWorks.Simd.Cli.Comparison
  ( CmpEqWord8s(..)
  ) where

import Data.ByteString                as BS
import Data.Word
import HaskellWorks.Data.AtIndex
import HaskellWorks.Data.Bits.BitWise
import HaskellWorks.Data.Positioning

import qualified Data.ByteString.Lazy         as LBS
import qualified Data.Vector.Storable         as DVS
import qualified HaskellWorks.Data.ByteString as BS
import qualified HaskellWorks.Data.Length     as HW

class CmpEqWord8s a where
  cmpEqWord8s :: Word8 -> a -> a

instance CmpEqWord8s BS.ByteString where
  cmpEqWord8s :: Word8 -> ByteString -> ByteString
cmpEqWord8s Word8
w8 ByteString
bs = Vector Word8 -> ByteString
forall a. ToByteString a => a -> ByteString
BS.toByteString (Vector Word8 -> ByteString) -> Vector Word8 -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> (Vector Word8 -> Word8) -> Vector Word8
forall a. Storable a => Int -> (Vector a -> a) -> Vector a
DVS.constructN ((ByteString -> Int
BS.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8) (Word8 -> Count -> Vector Word8 -> Word8
go Word8
0 Count
0)
    where go :: Word8 -> Count -> DVS.Vector Word8 -> Word8
          go :: Word8 -> Count -> Vector Word8 -> Word8
go Word8
w Count
n Vector Word8
u = case Vector Word8 -> Count
forall v. Length v => v -> Count
HW.length Vector Word8
u of
            Count
ui -> case Count
ui Count -> Count -> Count
forall a. Num a => a -> a -> a
* Count
8 of
              Count
bsi -> if Count
bsi Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
8 Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
<= ByteString -> Count
forall v. Length v => v -> Count
HW.length ByteString
bs
                then Word8 -> Count -> Vector Word8 -> Word8
goFast Word8
w Count
n Vector Word8
u
                else Word8 -> Count -> Vector Word8 -> Word8
goSafe Word8
w Count
n Vector Word8
u
          goFast :: Word8 -> Count -> DVS.Vector Word8 -> Word8
          goFast :: Word8 -> Count -> Vector Word8 -> Word8
goFast Word8
w Count
n Vector Word8
u = if Count
n Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
< Count
8
            then case Vector Word8 -> Count
forall v. Length v => v -> Count
HW.length Vector Word8
u of
              Count
ui -> case Count
ui Count -> Count -> Count
forall a. Num a => a -> a -> a
* Count
8 of
                Count
bsi -> case Count
bsi Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
n of
                  Count
wi -> if Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
bs ByteString -> Position -> Elem ByteString
forall v. AtIndex v => v -> Position -> Elem v
!!! Count -> Position
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
wi
                    then Word8 -> Count -> Vector Word8 -> Word8
goFast (Word8
w Word8 -> Word8 -> Word8
forall a. BitWise a => a -> a -> a
.|. (Word8
1 Word8 -> Count -> Word8
forall a. Shift a => a -> Count -> a
.<. Count
n)) (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1) Vector Word8
u
                    else Word8 -> Count -> Vector Word8 -> Word8
goFast  Word8
w                (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1) Vector Word8
u
            else Word8
w
          goSafe :: Word8 -> Count -> DVS.Vector Word8 -> Word8
          goSafe :: Word8 -> Count -> Vector Word8 -> Word8
goSafe Word8
w Count
n Vector Word8
u = if Count
n Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
< Count
8
            then case Vector Word8 -> Count
forall v. Length v => v -> Count
HW.length Vector Word8
u of
              Count
ui -> case Count
ui Count -> Count -> Count
forall a. Num a => a -> a -> a
* Count
8 of
                Count
bsi -> case Count
bsi Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
n of
                  Count
wi -> if Count
wi Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
< Count
bsLen
                    then if Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
bs ByteString -> Position -> Elem ByteString
forall v. AtIndex v => v -> Position -> Elem v
!!! Count -> Position
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
wi
                      then Word8 -> Count -> Vector Word8 -> Word8
goSafe (Word8
w Word8 -> Word8 -> Word8
forall a. BitWise a => a -> a -> a
.|. (Word8
1 Word8 -> Count -> Word8
forall a. Shift a => a -> Count -> a
.<. Count
n)) (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1) Vector Word8
u
                      else Word8 -> Count -> Vector Word8 -> Word8
goSafe  Word8
w                (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1) Vector Word8
u
                    else Word8
w
            else Word8
w
          bsLen :: Count
bsLen = ByteString -> Count
forall v. Length v => v -> Count
HW.length ByteString
bs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [BS.ByteString] where
  cmpEqWord8s :: Word8 -> [ByteString] -> [ByteString]
cmpEqWord8s Word8
w8 [ByteString]
vs = Word8 -> ByteString -> ByteString
forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s LBS.ByteString where
  cmpEqWord8s :: Word8 -> ByteString -> ByteString
cmpEqWord8s Word8
w8 = [ByteString] -> ByteString
LBS.fromChunks ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> [ByteString] -> [ByteString]
forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
LBS.toChunks
  {-# INLINE cmpEqWord8s #-}