{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
#if MIN_VERSION_GLASGOW_HASKELL (8,6,0,0)
{-# LANGUAGE NoStarIsType #-}
#endif
module Haskus.Binary.Vector
( Vector (..)
, vectorBuffer
, vectorReverse
, take
, drop
, index
, fromList
, fromFilledList
, fromFilledListZ
, toList
, replicate
, concat
, zipWith
)
where
import Prelude hiding ( replicate, head, last
, tail, init, map, length, drop, take, concat
, zipWith )
import System.IO.Unsafe (unsafePerformIO)
import qualified Haskus.Utils.List as List
import Haskus.Utils.Types
import Haskus.Utils.HList
import Haskus.Utils.Maybe
import Haskus.Utils.Flow
import Haskus.Binary.Storable
import Haskus.Binary.Buffer
import Haskus.Binary.Bits
import Foreign.Ptr
import Foreign.Marshal.Alloc (mallocBytes)
data Vector (n :: Nat) a = Vector Buffer
instance (Storable a, Show a, KnownNat n) => Show (Vector n a) where
show v = "fromList " ++ show (toList v)
vectorBuffer :: Vector n a -> Buffer
vectorBuffer (Vector b) = b
vectorReverse :: (KnownNat n, Storable a) => Vector n a -> Vector n a
vectorReverse = fromJust . fromList . reverse . toList
type family ElemOffset a i n where
ElemOffset a i n = Assert (i+1 <=? n)
(i * (SizeOf a))
(('Text "Invalid vector index: " ':<>: 'ShowType i
':$$: 'Text "Vector size: " ':<>: 'ShowType n))
instance forall a n.
( KnownNat (SizeOf a * n)
) => StaticStorable (Vector n a) where
type SizeOf (Vector n a) = SizeOf a * n
type Alignment (Vector n a) = Alignment a
staticPeekIO ptr =
Vector <$> bufferPackPtr (natValue @(SizeOf a * n)) (castPtr ptr)
staticPokeIO ptr (Vector b) = bufferPoke ptr b
instance forall a n.
( KnownNat n
, Storable a
) => Storable (Vector n a) where
sizeOf _ = natValue @n * sizeOfT @a
alignment _ = alignmentT @a
peekIO ptr =
Vector <$> bufferPackPtr (sizeOfT' @(Vector n a)) (castPtr ptr)
pokeIO ptr (Vector b) = bufferPoke ptr b
take :: forall n m a.
( KnownNat (SizeOf a * n)
) => Vector (m+n) a -> Vector n a
{-# INLINABLE take #-}
take (Vector b) = Vector (bufferTake (natValue @(SizeOf a * n)) b)
drop :: forall n m a.
( KnownNat (SizeOf a * n)
) => Vector (m+n) a -> Vector m a
{-# INLINABLE drop #-}
drop (Vector b) = Vector (bufferDrop (natValue @(SizeOf a * n)) b)
index :: forall i a n.
( KnownNat (ElemOffset a i n)
, Storable a
) => Vector n a -> a
{-# INLINABLE index #-}
index (Vector b) = bufferPeekStorableAt b (natValue @(ElemOffset a i n))
fromList :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => [a] -> Maybe (Vector n a)
{-# INLINABLE fromList #-}
fromList v
| n' /= n = Nothing
| n' == 0 = Just $ Vector $ emptyBuffer
| otherwise = Just $ Vector $ bufferPackStorableList v
where
n' = natValue' @n
n = fromIntegral (List.length v)
fromFilledList :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => a -> [a] -> Vector n a
{-# INLINABLE fromFilledList #-}
fromFilledList z v = Vector $ bufferPackStorableList v'
where
v' = List.take (natValue @n) (v ++ repeat z)
fromFilledListZ :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => a -> [a] -> Vector n a
{-# INLINABLE fromFilledListZ #-}
fromFilledListZ z v = fromFilledList z v'
where
v' = List.take (natValue @n - 1) v
toList :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => Vector n a -> [a]
{-# INLINABLE toList #-}
toList (Vector b)
| n == 0 = []
| otherwise = fmap (bufferPeekStorableAt b . (sza*)) [0..n-1]
where
n = natValue @n
sza = sizeOfT' @a
replicate :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => a -> Vector n a
{-# INLINABLE replicate #-}
replicate v = fromFilledList v []
data StoreVector = StoreVector
instance forall n v a r.
( v ~ Vector n a
, r ~ IO (Ptr a)
, KnownNat n
, KnownNat (SizeOf a)
, StaticStorable a
, Storable a
) => Apply StoreVector (v, IO (Ptr a)) r where
apply _ (v, getP) = do
p <- getP
let
vsz = natValue @n
p' = p `plusPtr` (-1 * vsz * fromIntegral (sizeOfT @a))
poke (castPtr p') v
return p'
type family WholeSize fs :: Nat where
WholeSize '[] = 0
WholeSize (Vector n s ': xs) = n + WholeSize xs
concat :: forall l (n :: Nat) a .
( n ~ WholeSize l
, KnownNat n
, Storable a
, StaticStorable a
, HFoldr StoreVector (IO (Ptr a)) l (IO (Ptr a))
)
=> HList l -> Vector n a
concat vs = unsafePerformIO $ do
let sz = sizeOfT @a * natValue @n
p <- mallocBytes (fromIntegral sz) :: IO (Ptr ())
_ <- hFoldr StoreVector (return (castPtr p `plusPtr` fromIntegral sz) :: IO (Ptr a)) vs :: IO (Ptr a)
Vector <$> bufferUnsafePackPtr (fromIntegral sz) p
zipWith ::
( KnownNat n
, Storable a
, Storable b
, Storable c
) => (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
zipWith f u v = fromJust . fromList <| List.zipWith f (toList u) (toList v)
map ::
( KnownNat n
, Storable a
, Storable b
) => (a -> b) -> Vector n a -> Vector n b
map f = fromJust . fromList . fmap f . toList
instance
( KnownNat n
, Storable a
, Eq a
)
=> Eq (Vector n a)
where
u == v = toList u == toList v
instance
( KnownNat n
, Bitwise a
, Storable a
) => Bitwise (Vector n a)
where
u .&. v = zipWith (.&.) u v
u .|. v = zipWith (.|.) u v
u `xor` v = zipWith xor u v
instance
( KnownNat (BitSize a)
, FiniteBits a
, KnownNat n
, Storable a
) => FiniteBits (Vector n a)
where
type BitSize (Vector n a) = n * BitSize a
zeroBits = fromJust (fromList (List.replicate (natValue @n) zeroBits))
oneBits = fromJust (fromList (List.replicate (natValue @n) oneBits))
complement u = map complement u
countLeadingZeros = go 0 . toList
where
go !n [] = n
go !n (x:xs) = let c = countLeadingZeros x
in if c == natValue @(BitSize a)
then go (n+c) xs
else n+c
countTrailingZeros = go 0 . reverse . toList
where
go !n [] = n
go !n (x:xs) = let c = countTrailingZeros x
in if c == natValue @(BitSize a)
then go (n+c) xs
else n+c
instance
( Storable a
, ShiftableBits a
, Bitwise a
, FiniteBits a
, KnownNat (BitSize a)
, KnownNat (n * BitSize a)
, KnownNat n
) => ShiftableBits (Vector n a)
where
shiftL u c = uncheckedShiftL u (c `mod` natValue @(BitSize (Vector n a)))
shiftR u c = uncheckedShiftR u (c `mod` natValue @(BitSize (Vector n a)))
uncheckedShiftL u c =
let n = natValue @n
sa = natValue @(BitSize a)
go _ 0 _ = []
go 0 k xs = List.take k xs
go s k xs
| s >= sa = go (s-sa) k (List.tail xs)
| otherwise =
let (x:y:zs) = xs
in ((x `shiftL` s) .|. (y `shiftR` (sa-s))) : go s (k-1) (y:zs)
in fromJust (fromList (go c n (toList u ++ List.repeat zeroBits)))
uncheckedShiftR u c =
let n = natValue @n
sa = natValue @(BitSize a)
go _ 0 _ = []
go 0 k xs = List.take k (List.tail xs)
go s k xs
| s >= sa = zeroBits : go (s-sa) (k-1) xs
| otherwise =
let (x:y:zs) = xs
in ((x `shiftL` (sa-s)) .|. (y `shiftR` s)) : go s (k-1) (y:zs)
in fromJust (fromList (go c n (zeroBits : toList u)))
instance
( Storable a
, IndexableBits a
, FiniteBits a
, KnownNat (BitSize a)
, KnownNat n
, Bitwise a
) => IndexableBits (Vector n a) where
popCount = sum . fmap popCount . toList
bit i = let n = natValue @n
sa = natValue @(BitSize a)
(f,r) = i `divMod` sa
toRep = fromIntegral (n - f - 1)
xs = List.replicate toRep zeroBits
++ [bit r]
++ List.replicate (fromIntegral f) zeroBits
in fromJust <| fromList <| if i >= n * sa
then List.replicate (fromIntegral n) zeroBits
else xs
testBit u i = let n = natValue @n
sa = natValue @(BitSize a)
(f,r) = i `divMod` sa
toDrop = fromIntegral (n - f - 1)
in if i >= n * sa
then False
else testBit (List.head (List.drop toDrop (toList u))) r
instance
( Storable a
, Bits a
, KnownNat n
, KnownNat (n * BitSize a)
) => RotatableBits (Vector n a)