{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} -- | Vector with size in the type module Haskus.Format.Binary.Vector ( Vector (..) , vectorBuffer , take , drop , index , fromList , fromFilledList , fromFilledListZ , toList , replicate , concat ) where import Prelude hiding (replicate, head, last, tail, init, map, length, drop, take, concat) import System.IO.Unsafe (unsafePerformIO) import qualified Haskus.Utils.List as List import Haskus.Utils.Types import Haskus.Utils.HList import Haskus.Format.Binary.Storable import Haskus.Format.Binary.Ptr import Haskus.Format.Binary.Buffer -- | Vector with type-checked size data Vector (n :: Nat) a = Vector Buffer instance (Storable a, Show a, KnownNat n) => Show (Vector n a) where show v = "fromList " ++ show (toList v) -- | Return the buffer backing the vector vectorBuffer :: Vector n a -> Buffer vectorBuffer (Vector b) = b -- | Offset of the i-th element in a stored vector type family ElemOffset a i n where ElemOffset a i n = IfNat (i+1 <=? n) (i * (SizeOf a)) (TypeError ('Text "Invalid vector index: " ':<>: 'ShowType i ':$$: 'Text "Vector size: " ':<>: 'ShowType n)) instance forall a n. ( KnownNat (SizeOf a * n) ) => StaticStorable (Vector n a) where type SizeOf (Vector n a) = SizeOf a * n type Alignment (Vector n a) = Alignment a staticPeekIO ptr = Vector <$> bufferPackPtr (natValue @(SizeOf a * n)) (castPtr ptr) staticPokeIO ptr (Vector b) = bufferPoke ptr b instance forall a n. ( KnownNat n , Storable a ) => Storable (Vector n a) where sizeOf _ = natValue @n * sizeOfT @a alignment _ = alignmentT @a peekIO ptr = Vector <$> bufferPackPtr (sizeOfT' @(Vector n a)) (castPtr ptr) pokeIO ptr (Vector b) = bufferPoke ptr b -- | Yield the first n elements take :: forall n m a. ( KnownNat (SizeOf a * n) ) => Vector (m+n) a -> Vector n a {-# INLINE take #-} take (Vector b) = Vector (bufferTake (natValue @(SizeOf a * n)) b) -- | Drop the first n elements drop :: forall n m a. ( KnownNat (SizeOf a * n) ) => Vector (m+n) a -> Vector m a {-# INLINE drop #-} drop (Vector b) = Vector (bufferDrop (natValue @(SizeOf a * n)) b) -- | /O(1)/ Index safely into the vector using a type level index. index :: forall i a n. ( KnownNat (ElemOffset a i n) , Storable a ) => Vector n a -> a {-# INLINE index #-} index (Vector b) = bufferPeekStorableAt b (natValue @(ElemOffset a i n)) -- | Convert a list into a vector if the number of elements matches fromList :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => [a] -> Maybe (Vector n a) {-# INLINE fromList #-} fromList v | n' /= n = Nothing | n' == 0 = Just $ Vector $ emptyBuffer | otherwise = Just $ Vector $ bufferPackStorableList v where n' = natValue' @n n = fromIntegral (List.length v) -- | Take at most n element from the list, then use z fromFilledList :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => a -> [a] -> Vector n a {-# INLINE fromFilledList #-} fromFilledList z v = Vector $ bufferPackStorableList v' where v' = List.take (natValue @n) (v ++ repeat z) -- | Take at most (n-1) element from the list, then use z fromFilledListZ :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => a -> [a] -> Vector n a {-# INLINE fromFilledListZ #-} fromFilledListZ z v = fromFilledList z v' where v' = List.take (natValue @n - 1) v -- | Convert a vector into a list toList :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => Vector n a -> [a] {-# INLINE toList #-} toList (Vector b) | n == 0 = [] | otherwise = fmap (bufferPeekStorableAt b . (sza*)) [0..n-1] where n = natValue @n sza = sizeOfT' @a -- | Create a vector by replicating a value replicate :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => a -> Vector n a {-# INLINE replicate #-} replicate v = fromFilledList v [] data StoreVector = StoreVector -- Store a vector at the right offset instance forall n v a r. ( v ~ Vector n a , r ~ IO (Ptr a) , KnownNat n , KnownNat (SizeOf a) , StaticStorable a , Storable a ) => Apply StoreVector (v, IO (Ptr a)) r where apply _ (v, getP) = do p <- getP let vsz = natValue @n p' = p `indexPtr'` (-1 * vsz * sizeOfT @a) poke (castPtr p') v return p' type family WholeSize fs :: Nat where WholeSize '[] = 0 WholeSize (Vector n s ': xs) = n + WholeSize xs -- | Concat several vectors into a single one concat :: forall l (n :: Nat) a . ( n ~ WholeSize l , KnownNat n , Storable a , StaticStorable a , HFoldr StoreVector (IO (Ptr a)) l (IO (Ptr a)) ) => HList l -> Vector n a concat vs = unsafePerformIO $ do let sz = sizeOfT @a * natValue @n p <- mallocBytes (fromIntegral sz) :: IO (Ptr ()) _ <- hFoldr StoreVector (return (castPtr p `indexPtr'` sz) :: IO (Ptr a)) vs :: IO (Ptr a) Vector <$> bufferUnsafePackPtr (fromIntegral sz) p