{-# LANGUAGE GADTs #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Array.Destination
(
DArray
, alloc
, size
, replicate
, split
, mirror
, fromFunction
, fill
, dropEmpty
)
where
import Data.Vector (Vector, (!))
import qualified Data.Vector as Vector
import Data.Vector.Mutable (MVector)
import qualified Data.Vector.Mutable as MVector
import GHC.Exts (RealWorld)
import qualified Prelude as Prelude
import System.IO.Unsafe (unsafeDupablePerformIO)
import GHC.Stack
import Data.Unrestricted.Linear
import Prelude.Linear hiding (replicate)
import qualified Unsafe.Linear as Unsafe
data DArray a where
DArray :: MVector RealWorld a -> DArray a
alloc :: Int -> (DArray a %1-> ()) %1-> Vector a
alloc :: forall a. Int -> (DArray a %1 -> ()) %1 -> Vector a
alloc Int
n DArray a %1 -> ()
writer = (\(Ur MVector RealWorld a
dest, Vector a
vec) -> DArray a %1 -> ()
writer (MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
dest) () %1 -> Vector a %1 -> Vector a
forall a b. Consumable a => a %1 -> b %1 -> b
`lseq` Vector a
vec) ((Ur (MVector RealWorld a), Vector a) %1 -> Vector a)
%1 -> (Ur (MVector RealWorld a), Vector a) %1 -> Vector a
forall a b. (a %1 -> b) %1 -> a %1 -> b
$
IO (Ur (MVector RealWorld a), Vector a)
-> (Ur (MVector RealWorld a), Vector a)
forall a. IO a -> a
unsafeDupablePerformIO (IO (Ur (MVector RealWorld a), Vector a)
-> (Ur (MVector RealWorld a), Vector a))
-> IO (Ur (MVector RealWorld a), Vector a)
-> (Ur (MVector RealWorld a), Vector a)
forall a b. (a -> b) -> a -> b
Prelude.$ do
MVector RealWorld a
destArray <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MVector.unsafeNew Int
n
Vector a
vec <- MVector (PrimState IO) a -> IO (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
Vector.unsafeFreeze MVector RealWorld a
MVector (PrimState IO) a
destArray
(Ur (MVector RealWorld a), Vector a)
-> IO (Ur (MVector RealWorld a), Vector a)
forall (m :: * -> *) a. Monad m => a -> m a
Prelude.return (MVector RealWorld a -> Ur (MVector RealWorld a)
forall a. a -> Ur a
Ur MVector RealWorld a
destArray, Vector a
vec)
size :: DArray a %1-> (Ur Int, DArray a)
size :: forall a. DArray a %1 -> (Ur Int, DArray a)
size (DArray MVector RealWorld a
mvec) = (Int -> Ur Int
forall a. a -> Ur a
Ur (MVector RealWorld a -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec), MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
mvec)
replicate :: a -> DArray a %1-> ()
replicate :: forall a. a -> DArray a %1 -> ()
replicate a
a = (Int -> a) -> DArray a %1 -> ()
forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction (a %1 -> Int -> a
forall a b. a %1 -> b -> a
const a
a)
fill :: HasCallStack => a %1-> DArray a %1-> ()
fill :: forall a. HasCallStack => a %1 -> DArray a %1 -> ()
fill a
a (DArray MVector RealWorld a
mvec) =
if MVector RealWorld a -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec Int %1 -> Int %1 -> Bool
forall a. Eq a => a %1 -> a %1 -> Bool
/= Int
1
then [Char] -> a %1 -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.fill: requires a destination of size 1" (a %1 -> ()) %1 -> a %1 -> ()
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ a
a
else a
a a %1 -> (a %1 -> ()) %1 -> ()
forall a b. a %1 -> (a %1 -> b) %1 -> b
&
(a -> ()) %1 -> a %1 -> ()
forall a b (p :: Multiplicity). (a %p -> b) %1 -> a %1 -> b
Unsafe.toLinear (\a
x -> IO () -> ()
forall a. IO a -> a
unsafeDupablePerformIO (MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.write MVector RealWorld a
MVector (PrimState IO) a
mvec Int
0 a
x))
dropEmpty :: HasCallStack => DArray a %1-> ()
dropEmpty :: forall a. HasCallStack => DArray a %1 -> ()
dropEmpty (DArray MVector RealWorld a
mvec)
| MVector RealWorld a -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
> Int
0 = [Char] -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.dropEmpty on non-empty array."
| Bool
otherwise = MVector RealWorld a
mvec MVector RealWorld a -> () %1 -> ()
forall a b. a -> b %1 -> b
`seq` ()
split :: Int -> DArray a %1-> (DArray a, DArray a)
split :: forall a. Int -> DArray a %1 -> (DArray a, DArray a)
split Int
n (DArray MVector RealWorld a
mvec) | (MVector RealWorld a
ml, MVector RealWorld a
mr) <- Int
-> MVector RealWorld a
-> (MVector RealWorld a, MVector RealWorld a)
forall s a. Int -> MVector s a -> (MVector s a, MVector s a)
MVector.splitAt Int
n MVector RealWorld a
mvec =
(MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
ml, MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
mr)
mirror :: HasCallStack => Vector a -> (a %1-> b) -> DArray b %1-> ()
mirror :: forall a b.
HasCallStack =>
Vector a -> (a %1 -> b) -> DArray b %1 -> ()
mirror Vector a
v a %1 -> b
f DArray b
arr =
DArray b %1 -> (Ur Int, DArray b)
forall a. DArray a %1 -> (Ur Int, DArray a)
size DArray b
arr (Ur Int, DArray b) %1 -> ((Ur Int, DArray b) %1 -> ()) %1 -> ()
forall a b. a %1 -> (a %1 -> b) %1 -> b
& \(Ur Int
sz, DArray b
arr') ->
if Vector a -> Int
forall a. Vector a -> Int
Vector.length Vector a
v Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
< Int
sz
then [Char] -> DArray b %1 -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.mirror: argument smaller than DArray" (DArray b %1 -> ()) %1 -> DArray b %1 -> ()
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ DArray b
arr'
else (Int -> b) -> DArray b %1 -> ()
forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction (\Int
t -> a %1 -> b
f (Vector a
v Vector a -> Int -> a
forall a. Vector a -> Int -> a
! Int
t)) DArray b
arr'
fromFunction :: (Int -> b) -> DArray b %1-> ()
fromFunction :: forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction Int -> b
f (DArray MVector RealWorld b
mvec) = IO () -> ()
forall a. IO a -> a
unsafeDupablePerformIO (IO () -> ()) -> IO () -> ()
forall a b. (a -> b) -> a -> b
Prelude.$ do
let n :: Int
n = MVector RealWorld b -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld b
mvec
[IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
Prelude.sequence_ [MVector (PrimState IO) b -> Int -> b -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.unsafeWrite MVector RealWorld b
MVector (PrimState IO) b
mvec Int
m (Int -> b
f Int
m) | Int
m <- [Int
0..Int
nInt %1 -> Int %1 -> Int
forall a. AdditiveGroup a => a %1 -> a %1 -> a
-Int
1]]