{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Primitive.Vec (
Vec(..),
Vec2, pattern Vec2,
Vec3, pattern Vec3,
Vec4, pattern Vec4,
Vec8, pattern Vec8,
Vec16, pattern Vec16,
listOfVec,
liftVec,
) where
import Control.Monad.ST
import Data.Primitive.ByteArray
import Data.Primitive.Types
import Data.Text.Prettyprint.Doc
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import GHC.Base ( isTrue# )
import GHC.Int
import GHC.Prim
import GHC.TypeLits
import GHC.Word
data Vec (n :: Nat) a = Vec ByteArray#
type role Vec nominal representational
instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where
show = vec . listOfVec
where
vec :: [a] -> String
vec = show
. group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", "
. map viaShow
listOfVec :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a]
listOfVec (Vec ba#) = go 0#
where
go :: Int# -> [a]
go i# | isTrue# (i# <# n#) = indexByteArray# ba# i# : go (i# +# 1#)
| otherwise = []
!(I# n#) = fromIntegral (natVal' (proxy# :: Proxy# n))
instance Eq (Vec n a) where
Vec ba1# == Vec ba2# = ByteArray ba1# == ByteArray ba2#
type Vec2 a = Vec 2 a
type Vec3 a = Vec 3 a
type Vec4 a = Vec 4 a
type Vec8 a = Vec 8 a
type Vec16 a = Vec 16 a
pattern Vec2 :: Prim a => a -> a -> Vec2 a
pattern Vec2 a b <- (unpackVec2 -> (a,b))
where Vec2 = packVec2
{-# COMPLETE Vec2 #-}
pattern Vec3 :: Prim a => a -> a -> a -> Vec3 a
pattern Vec3 a b c <- (unpackVec3 -> (a,b,c))
where Vec3 = packVec3
{-# COMPLETE Vec3 #-}
pattern Vec4 :: Prim a => a -> a -> a -> a -> Vec4 a
pattern Vec4 a b c d <- (unpackVec4 -> (a,b,c,d))
where Vec4 = packVec4
{-# COMPLETE Vec4 #-}
pattern Vec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a
pattern Vec8 a b c d e f g h <- (unpackVec8 -> (a,b,c,d,e,f,g,h))
where Vec8 = packVec8
{-# COMPLETE Vec8 #-}
pattern Vec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> Vec16 a
pattern Vec16 a b c d e f g h i j k l m n o p <- (unpackVec16 -> (a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p))
where Vec16 = packVec16
{-# COMPLETE Vec16 #-}
unpackVec2 :: Prim a => Vec2 a -> (a,a)
unpackVec2 (Vec ba#) =
( indexByteArray# ba# 0#
, indexByteArray# ba# 1#
)
unpackVec3 :: Prim a => Vec3 a -> (a,a,a)
unpackVec3 (Vec ba#) =
( indexByteArray# ba# 0#
, indexByteArray# ba# 1#
, indexByteArray# ba# 2#
)
unpackVec4 :: Prim a => Vec4 a -> (a,a,a,a)
unpackVec4 (Vec ba#) =
( indexByteArray# ba# 0#
, indexByteArray# ba# 1#
, indexByteArray# ba# 2#
, indexByteArray# ba# 3#
)
unpackVec8 :: Prim a => Vec8 a -> (a,a,a,a,a,a,a,a)
unpackVec8 (Vec ba#) =
( indexByteArray# ba# 0#
, indexByteArray# ba# 1#
, indexByteArray# ba# 2#
, indexByteArray# ba# 3#
, indexByteArray# ba# 4#
, indexByteArray# ba# 5#
, indexByteArray# ba# 6#
, indexByteArray# ba# 7#
)
unpackVec16 :: Prim a => Vec16 a -> (a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a)
unpackVec16 (Vec ba#) =
( indexByteArray# ba# 0#
, indexByteArray# ba# 1#
, indexByteArray# ba# 2#
, indexByteArray# ba# 3#
, indexByteArray# ba# 4#
, indexByteArray# ba# 5#
, indexByteArray# ba# 6#
, indexByteArray# ba# 7#
, indexByteArray# ba# 8#
, indexByteArray# ba# 9#
, indexByteArray# ba# 10#
, indexByteArray# ba# 11#
, indexByteArray# ba# 12#
, indexByteArray# ba# 13#
, indexByteArray# ba# 14#
, indexByteArray# ba# 15#
)
packVec2 :: Prim a => a -> a -> Vec2 a
packVec2 a b = runST $ do
mba <- newByteArray (2 * sizeOf a)
writeByteArray mba 0 a
writeByteArray mba 1 b
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#
packVec3 :: Prim a => a -> a -> a -> Vec3 a
packVec3 a b c = runST $ do
mba <- newByteArray (3 * sizeOf a)
writeByteArray mba 0 a
writeByteArray mba 1 b
writeByteArray mba 2 c
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#
packVec4 :: Prim a => a -> a -> a -> a -> Vec4 a
packVec4 a b c d = runST $ do
mba <- newByteArray (4 * sizeOf a)
writeByteArray mba 0 a
writeByteArray mba 1 b
writeByteArray mba 2 c
writeByteArray mba 3 d
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#
packVec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a
packVec8 a b c d e f g h = runST $ do
mba <- newByteArray (8 * sizeOf a)
writeByteArray mba 0 a
writeByteArray mba 1 b
writeByteArray mba 2 c
writeByteArray mba 3 d
writeByteArray mba 4 e
writeByteArray mba 5 f
writeByteArray mba 6 g
writeByteArray mba 7 h
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#
packVec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> Vec16 a
packVec16 a b c d e f g h i j k l m n o p = runST $ do
mba <- newByteArray (16 * sizeOf a)
writeByteArray mba 0 a
writeByteArray mba 1 b
writeByteArray mba 2 c
writeByteArray mba 3 d
writeByteArray mba 4 e
writeByteArray mba 5 f
writeByteArray mba 6 g
writeByteArray mba 7 h
writeByteArray mba 8 i
writeByteArray mba 9 j
writeByteArray mba 10 k
writeByteArray mba 11 l
writeByteArray mba 12 m
writeByteArray mba 13 n
writeByteArray mba 14 o
writeByteArray mba 15 p
ByteArray ba# <- unsafeFreezeByteArray mba
return $! Vec ba#
liftVec :: Vec n a -> Q (TExp (Vec n a))
liftVec (Vec ba#)
= unsafeTExpCoerce
[| runST $ \s ->
case newByteArray# $(liftInt# n#) s of { (# s1, mba# #) ->
case copyAddrToByteArray# $(litE (StringPrimL bytes)) mba# 0# $(liftInt# n#) s1 of { s2 ->
case unsafeFreezeByteArray# mba# s2 of { (# s3, ba'# #) ->
(# s3, Vec ba'# #)
}}}
|]
where
bytes :: [Word8]
bytes = go 0#
where
go i# | isTrue# (i# <# n#) = W8# (indexWord8Array# ba# i#) : go (i# +# 1#)
| otherwise = []
n# = sizeofByteArray# ba#
liftInt# :: Int# -> ExpQ
liftInt# i# = litE (IntPrimL (toInteger (I# i#)))