module LLVM.Extra.Storable.Private where
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value)
import qualified Type.Data.Num.Decimal as TypeNum
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.Reader as MR
import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative.HT as App
import qualified Control.Functor.HT as FuncHT
import Control.Monad (foldM, replicateM, replicateM_, (<=<))
import Control.Applicative (Applicative, pure)
import qualified Foreign.Storable.Record.Tuple as StoreTuple
import qualified Foreign.Storable as Store
import Foreign.Storable.FixedArray (roundUp)
import Foreign.Ptr (Ptr)
import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import Data.Orphans ()
import Data.Complex (Complex)
import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Bool8 (Bool8)
class
(Store.Storable a, Tuple.Value a,
Tuple.Phi (Tuple.ValueOf a), Tuple.Undefined (Tuple.ValueOf a)) =>
C a where
load :: Value (Ptr a) -> CodeGenFunction r (Tuple.ValueOf a)
store :: Tuple.ValueOf a -> Value (Ptr a) -> CodeGenFunction r ()
storeNext ::
(C a, Tuple.ValueOf a ~ al, Value (Ptr a) ~ ptr) =>
al -> ptr -> CodeGenFunction r ptr
storeNext a ptr = store a ptr >> incrementPtr ptr
modify ::
(C a, Tuple.ValueOf a ~ al) =>
(al -> CodeGenFunction r al) ->
Value (Ptr a) -> CodeGenFunction r ()
modify f ptr = flip store ptr =<< f =<< load ptr
loadMultiValue ::
(C a) => Value (Ptr a) -> CodeGenFunction r (MultiValue.T a)
loadMultiValue ptr = fmap MultiValue.Cons $ load ptr
storeMultiValue ::
(C a) => MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r ()
storeMultiValue (MultiValue.Cons a) ptr = store a ptr
storeNextMultiValue ::
(C a) => MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r (Value (Ptr a))
storeNextMultiValue (MultiValue.Cons a) ptr =
store a ptr >> incrementPtr ptr
modifyMultiValue ::
(C a) =>
(MultiValue.T a -> CodeGenFunction r (MultiValue.T a)) ->
Value (Ptr a) -> CodeGenFunction r ()
modifyMultiValue f ptr =
flip storeMultiValue ptr =<< f =<< loadMultiValue ptr
loadPrimitive ::
(LLVM.Storable a) => Value (Ptr a) -> CodeGenFunction r (Value a)
loadPrimitive ptr = LLVM.load =<< LLVM.bitcast ptr
storePrimitive ::
(LLVM.Storable a) => Value a -> Value (Ptr a) -> CodeGenFunction r ()
storePrimitive a ptr = LLVM.store a =<< LLVM.bitcast ptr
instance C Float where
load = loadPrimitive; store = storePrimitive
instance C Double where
load = loadPrimitive; store = storePrimitive
instance C Word where
load = loadPrimitive; store = storePrimitive
instance C Word8 where
load = loadPrimitive; store = storePrimitive
instance C Word16 where
load = loadPrimitive; store = storePrimitive
instance C Word32 where
load = loadPrimitive; store = storePrimitive
instance C Word64 where
load = loadPrimitive; store = storePrimitive
instance C Int where
load = loadPrimitive; store = storePrimitive
instance C Int8 where
load = loadPrimitive; store = storePrimitive
instance C Int16 where
load = loadPrimitive; store = storePrimitive
instance C Int32 where
load = loadPrimitive; store = storePrimitive
instance C Int64 where
load = loadPrimitive; store = storePrimitive
instance C Bool where
load ptr = do
bytePtr <- castToBytePtr ptr
bytes <-
flip MS.evalStateT bytePtr $
replicateM (Store.sizeOf (False :: Bool))
(MT.lift . LLVM.load =<< incPtrState)
let zero = LLVM.valueOf 0
mask <- foldM A.or zero bytes
A.cmp LLVM.CmpNE mask zero
store b ptr = do
bytePtr <- castToBytePtr ptr
byte <- LLVM.sext b
flip MS.evalStateT bytePtr $
replicateM_ (Store.sizeOf (False :: Bool))
(MT.lift . LLVM.store byte =<< incPtrState)
incPtrState :: MS.StateT BytePtr (CodeGenFunction r) BytePtr
incPtrState = update A.advanceArrayElementPtr
instance C Bool8 where
load ptr =
A.cmp LLVM.CmpNE (LLVM.valueOf 0) =<< LLVM.load =<< castToBytePtr ptr
store b ptr = do
byte <- LLVM.zext b
LLVM.store byte =<< castToBytePtr ptr
instance (C a) => C (Complex a) where
load = loadApplicative; store = storeFoldable
instance (Tuple tuple) => C (StoreTuple.Tuple tuple) where
load = loadTuple
store = storeTuple
class
(StoreTuple.Storable tuple, Tuple.Value tuple,
Tuple.Phi (Tuple.ValueOf tuple), Tuple.Undefined (Tuple.ValueOf tuple)) =>
Tuple tuple where
loadTuple ::
Value (Ptr (StoreTuple.Tuple tuple)) ->
CodeGenFunction r (Tuple.ValueOf tuple)
storeTuple ::
Tuple.ValueOf tuple ->
Value (Ptr (StoreTuple.Tuple tuple)) ->
CodeGenFunction r ()
instance (C a, C b) => Tuple (a,b) where
loadTuple ptr =
runElements ptr $
App.mapPair (loadElement, loadElement) $
FuncHT.unzip $ proxyFromElement3 ptr
storeTuple (a,b) ptr =
case FuncHT.unzip $ proxyFromElement3 ptr of
(pa,pb) -> runElements ptr $ storeElement pa a >> storeElement pb b
instance (C a, C b, C c) => Tuple (a,b,c) where
loadTuple ptr =
runElements ptr $
App.mapTriple (loadElement, loadElement, loadElement) $
FuncHT.unzip3 $ proxyFromElement3 ptr
storeTuple (a,b,c) ptr =
case FuncHT.unzip3 $ proxyFromElement3 ptr of
(pa,pb,pc) ->
runElements ptr $
storeElement pa a >> storeElement pb b >> storeElement pc c
runElements ::
Value (Ptr a) ->
MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) c ->
CodeGenFunction r c
runElements ptr act = do
bytePtr <- castToBytePtr ptr
flip MS.evalStateT 0 $ flip MR.runReaderT bytePtr act
loadElement ::
(C a) =>
LP.Proxy a ->
MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) (Tuple.ValueOf a)
loadElement proxy =
MT.lift . MT.lift . load =<< elementPtr proxy
storeElement ::
(C a) =>
LP.Proxy a -> Tuple.ValueOf a ->
MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) ()
storeElement proxy a =
MT.lift . MT.lift . store a =<< elementPtr proxy
elementPtr ::
(C a) =>
LP.Proxy a ->
MR.ReaderT BytePtr
(MS.StateT Int (CodeGenFunction r)) (LLVM.Value (Ptr a))
elementPtr proxy = do
ptr <- MR.ask
MT.lift $ do
offset <- elementOffset proxy
MT.lift $ castFromBytePtr =<< LLVM.getElementPtr ptr (offset, ())
elementOffset ::
(Monad m, Store.Storable a) => LP.Proxy a -> MS.StateT Int m Int
elementOffset proxy = do
let dummy = elementFromProxy proxy
MS.modify (roundUp $ Store.alignment dummy)
offset <- MS.get
MS.modify (+ Store.sizeOf dummy)
return offset
instance
(TypeNum.Positive n, Vector a, Tuple.VectorValue n a,
Tuple.Phi (Tuple.VectorValueOf n a)) =>
C (LLVM.Vector n a) where
load ptr =
assembleVector (proxyFromElement3 ptr) =<< loadApplicative ptr
store a ptr =
flip storeFoldable ptr
=<< disassembleVector (proxyFromElement3 ptr) a
class (C a) => Vector a where
assembleVector ::
(TypeNum.Positive n) =>
LP.Proxy a -> LLVM.Vector n (Tuple.ValueOf a) ->
CodeGenFunction r (Tuple.VectorValueOf n a)
disassembleVector ::
(TypeNum.Positive n) =>
LP.Proxy a -> Tuple.VectorValueOf n a ->
CodeGenFunction r (LLVM.Vector n (Tuple.ValueOf a))
assemblePrimitive ::
(TypeNum.Positive n, LLVM.IsPrimitive a) =>
LLVM.Vector n (Value a) -> CodeGenFunction r (Value (LLVM.Vector n a))
assemblePrimitive =
foldM
(\v (i,x) -> LLVM.insertelement v x (LLVM.valueOf i))
(LLVM.value LLVM.undef)
. zip [0..] . Fold.toList
disassemblePrimitive ::
(TypeNum.Positive n, LLVM.IsPrimitive a) =>
Value (LLVM.Vector n a) -> CodeGenFunction r (LLVM.Vector n (Value a))
disassemblePrimitive v =
Trav.mapM (LLVM.extractelement v . LLVM.valueOf) indices
indices :: (Applicative f, Trav.Traversable f) => f Word32
indices =
flip MS.evalState 0 $ Trav.sequenceA $ pure $ MS.state (\k -> (k,k+1))
instance Vector Float where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Double where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Word where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Word8 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Word16 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Word32 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Word64 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Int where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Int8 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Int16 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Int32 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Int64 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Bool where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance Vector Bool8 where
assembleVector LP.Proxy = assemblePrimitive
disassembleVector LP.Proxy = disassemblePrimitive
instance
(Tuple tuple, TupleVector tuple) =>
Vector (StoreTuple.Tuple tuple) where
assembleVector = deinterleave . fmap StoreTuple.getTuple
disassembleVector = interleave . fmap StoreTuple.getTuple
class TupleVector a where
deinterleave ::
(TypeNum.Positive n) =>
LP.Proxy a -> LLVM.Vector n (Tuple.ValueOf a) ->
CodeGenFunction r (Tuple.VectorValueOf n a)
interleave ::
(TypeNum.Positive n) =>
LP.Proxy a -> Tuple.VectorValueOf n a ->
CodeGenFunction r (LLVM.Vector n (Tuple.ValueOf a))
instance (Vector a, Vector b) => TupleVector (a,b) where
deinterleave = FuncHT.uncurry $ \pa pb -> FuncHT.uncurry $ \a b ->
App.lift2 (,) (assembleVector pa a) (assembleVector pb b)
interleave = FuncHT.uncurry $ \pa pb (a,b) ->
App.lift2 (App.lift2 (,))
(disassembleVector pa a) (disassembleVector pb b)
instance (Vector a, Vector b, Vector c) => TupleVector (a,b,c) where
deinterleave = FuncHT.uncurry3 $ \pa pb pc -> FuncHT.uncurry3 $ \a b c ->
App.lift3 (,,)
(assembleVector pa a)
(assembleVector pb b)
(assembleVector pc c)
interleave = FuncHT.uncurry3 $ \pa pb pc (a,b,c) ->
App.lift3 (App.lift3 (,,))
(disassembleVector pa a)
(disassembleVector pb b)
(disassembleVector pc c)
instance C () where
load _ptr = return ()
store () _ptr = return ()
loadNewtype ::
(C a, Tuple.ValueOf a ~ al) =>
(a -> wrapped) ->
(al -> wrappedl) ->
Value (Ptr wrapped) -> CodeGenFunction r wrappedl
loadNewtype wrap wrapl =
fmap wrapl . load <=< rmapPtr wrap
storeNewtype ::
(C a, Tuple.ValueOf a ~ al) =>
(a -> wrapped) ->
(wrappedl -> al) ->
wrappedl -> Value (Ptr wrapped) -> CodeGenFunction r ()
storeNewtype wrap unwrapl y =
store (unwrapl y) <=< rmapPtr wrap
rmapPtr :: (a -> b) -> Value (Ptr b) -> CodeGenFunction r (Value (Ptr a))
rmapPtr _f = LLVM.bitcast
loadTraversable ::
(NonEmptyC.Repeat f, Trav.Traversable f, C a, Tuple.ValueOf a ~ al) =>
Value (Ptr (f a)) -> CodeGenFunction r (f al)
loadTraversable =
(MS.evalStateT $ Trav.sequence $ NonEmptyC.repeat $ loadState)
<=< castElementPtr
loadApplicative ::
(Applicative f, Trav.Traversable f, C a, Tuple.ValueOf a ~ al) =>
Value (Ptr (f a)) -> CodeGenFunction r (f al)
loadApplicative =
(MS.evalStateT $ Trav.sequence $ pure loadState) <=< castElementPtr
loadState ::
(C a, Tuple.ValueOf a ~ al) =>
MS.StateT (Value (Ptr a)) (CodeGenFunction r) al
loadState = MT.lift . load =<< advancePtrState
storeFoldable ::
(Fold.Foldable f, C a, Tuple.ValueOf a ~ al) =>
f al -> Value (Ptr (f a)) -> CodeGenFunction r ()
storeFoldable xs = MS.evalStateT (Fold.mapM_ storeState xs) <=< castElementPtr
storeState ::
(C a, Tuple.ValueOf a ~ al) =>
al -> MS.StateT (Value (Ptr a)) (CodeGenFunction r) ()
storeState a = MT.lift . store a =<< advancePtrState
update :: (Monad m) => (a -> m a) -> MS.StateT a m a
update f = MS.StateT $ \a0 -> do a1 <- f a0; return (a0,a1)
advancePtrState ::
(C a, Tuple.ValueOf a ~ al, Value (Ptr a) ~ ptr) =>
MS.StateT ptr (CodeGenFunction r) ptr
advancePtrState = update $ advancePtrStatic 1
advancePtr ::
(Store.Storable a, Value (Ptr a) ~ ptr) =>
Value Int -> ptr -> CodeGenFunction r ptr
advancePtr n ptr = do
size <- A.mul n $ LLVM.valueOf $ Store.sizeOf (elementFromPtr ptr)
addPointer size ptr
advancePtrStatic ::
(Store.Storable a, Value (Ptr a) ~ ptr) =>
Int -> ptr -> CodeGenFunction r ptr
advancePtrStatic n ptr =
addPointer (LLVM.valueOf (Store.sizeOf (elementFromPtr ptr) * n)) ptr
incrementPtr ::
(Store.Storable a, Value (Ptr a) ~ ptr) =>
ptr -> CodeGenFunction r ptr
incrementPtr = advancePtrStatic 1
decrementPtr ::
(Store.Storable a, Value (Ptr a) ~ ptr) =>
ptr -> CodeGenFunction r ptr
decrementPtr = advancePtrStatic (1)
addPointer :: Value Int -> Value (Ptr a) -> CodeGenFunction r (Value (Ptr a))
addPointer k ptr = do
bytePtr <- castToBytePtr ptr
castFromBytePtr =<< LLVM.getElementPtr bytePtr (k, ())
type BytePtr = Value (LLVM.Ptr Word8)
castToBytePtr :: Value (Ptr a) -> CodeGenFunction r BytePtr
castToBytePtr = LLVM.bitcast
castFromBytePtr :: BytePtr -> CodeGenFunction r (Value (Ptr a))
castFromBytePtr = LLVM.bitcast
castElementPtr :: Value (Ptr (f a)) -> CodeGenFunction r (Value (Ptr a))
castElementPtr = LLVM.bitcast
sizeOf :: (Store.Storable a) => LP.Proxy a -> Int
sizeOf = Store.sizeOf . elementFromProxy
elementFromPtr :: LLVM.Value (Ptr a) -> a
elementFromPtr _ = error "elementFromProxy"
elementFromProxy :: LP.Proxy a -> a
elementFromProxy LP.Proxy = error "elementFromProxy"
proxyFromElement2 :: f (g a) -> LP.Proxy a
proxyFromElement2 _ = LP.Proxy
proxyFromElement3 :: f (g (h a)) -> LP.Proxy a
proxyFromElement3 _ = LP.Proxy