module Data.Packed.ST (
STVector, newVector, thawVector, freezeVector, runSTVector,
readVector, writeVector, modifyVector, liftSTVector,
STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
newUndefinedVector,
unsafeReadVector, unsafeWriteVector,
unsafeThawVector, unsafeFreezeVector,
newUndefinedMatrix,
unsafeReadMatrix, unsafeWriteMatrix,
unsafeThawMatrix, unsafeFreezeMatrix
) where
import Data.Packed.Internal
import Control.Monad.ST(ST, runST)
import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
#if MIN_VERSION_base(4,4,0)
import Control.Monad.ST.Unsafe(unsafeIOToST)
#else
import Control.Monad.ST(unsafeIOToST)
#endif
ioReadV :: Storable t => Vector t -> Int -> IO t
ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
newtype STVector s t = STVector (Vector t)
thawVector :: Storable t => Vector t -> ST s (STVector s t)
thawVector = unsafeIOToST . fmap STVector . cloneVector
unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
unsafeThawVector = unsafeIOToST . return . STVector
runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
runSTVector st = runST (st >>= unsafeFreezeVector)
unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
freezeVector v = liftSTVector id v
unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
safeIndexV f (STVector v) k
| k < 0 || k>= dim v = error $ "out of range error in vector (dim="
++show (dim v)++", pos="++show k++")"
| otherwise = f (STVector v) k
readVector :: Storable t => STVector s t -> Int -> ST s t
readVector = safeIndexV unsafeReadVector
writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
writeVector = safeIndexV unsafeWriteVector
newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
newUndefinedVector = unsafeIOToST . fmap STVector . createVector
newVector :: Storable t => t -> Int -> ST s (STVector s t)
newVector x n = do
v <- newUndefinedVector n
let go (1) = return v
go !k = unsafeWriteVector v k x >> go (k1 :: Int)
go (n1)
ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
newtype STMatrix s t = STMatrix (Matrix t)
thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix = unsafeIOToST . return . STMatrix
runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
freezeMatrix m = liftSTMatrix id m
cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
safeIndexM f (STMatrix m) r c
| r<0 || r>=rows m ||
c<0 || c>=cols m = error $ "out of range error in matrix (size="
++show (rows m,cols m)++", pos="++show (r,c)++")"
| otherwise = f (STMatrix m) r c
readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix = safeIndexM unsafeReadMatrix
writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix = safeIndexM unsafeWriteMatrix
newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)