{-# LANGUAGE GADTs #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_HADDOCK hide #-}

module Data.Array.Destination.Internal where

import Data.Unrestricted.Linear
import Data.Vector (Vector, (!))
import qualified Data.Vector as Vector
import Data.Vector.Mutable (MVector)
import qualified Data.Vector.Mutable as MVector
import GHC.Exts (RealWorld)
import GHC.Stack
import Prelude.Linear hiding (replicate)
import System.IO.Unsafe (unsafeDupablePerformIO)
import qualified Unsafe.Linear as Unsafe
import qualified Prelude as Prelude

-- | A destination array, or @DArray@, is a write-only array that is filled
-- by some computation which ultimately returns an array.
data DArray a where
  DArray :: MVector RealWorld a -> DArray a

-- XXX: use of Vector in types is temporary. I will probably move away from
-- vectors and implement most stuff in terms of Array# and MutableArray#
-- eventually, anyway. This would allow to move the MutableArray logic to
-- linear IO, possibly, and segregate the unsafe casts to the Linear IO
-- module.  @`alloc` n k@ must be called with a non-negative value of @n@.
alloc :: Int -> (DArray a %1 -> ()) %1 -> Vector a
alloc :: forall a. Int -> (DArray a %1 -> ()) %1 -> Vector a
alloc Int
n DArray a %1 -> ()
writer = (\(Ur MVector RealWorld a
dest, Vector a
vec) -> DArray a %1 -> ()
writer (forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
dest) forall a b. Consumable a => a %1 -> b %1 -> b
`lseq` Vector a
vec) forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$
  forall a. IO a -> a
unsafeDupablePerformIO forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ do
    MVector RealWorld a
destArray <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MVector.unsafeNew Int
n
    Vector a
vec <- forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
Vector.unsafeFreeze MVector RealWorld a
destArray
    forall (m :: * -> *) a. Monad m => a -> m a
Prelude.return (forall a. a -> Ur a
Ur MVector RealWorld a
destArray, Vector a
vec)

-- | Get the size of a destination array.
size :: DArray a %1 -> (Ur Int, DArray a)
size :: forall a. DArray a %1 -> (Ur Int, DArray a)
size (DArray MVector RealWorld a
mvec) = (forall a. a -> Ur a
Ur (forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec), forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
mvec)

-- | Fill a destination array with a constant
replicate :: a -> DArray a %1 -> ()
replicate :: forall a. a -> DArray a %1 -> ()
replicate a
a = forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction (forall a b (q :: Multiplicity). a %q -> b -> a
const a
a)

-- | @fill a dest@ fills a singleton destination array.
-- Caution, @'fill' a dest@ will fail is @dest@ isn't of length exactly one.
fill :: HasCallStack => a %1 -> DArray a %1 -> ()
fill :: forall a. HasCallStack => a %1 -> DArray a %1 -> ()
fill a
a (DArray MVector RealWorld a
mvec) =
  if forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec forall a. Eq a => a %1 -> a %1 -> Bool
/= Int
1
    then forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.fill: requires a destination of size 1" forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ a
a
    else
      a
a
        forall a b (p :: Multiplicity) (q :: Multiplicity).
a %p -> (a %p -> b) %q -> b
& forall a b (p :: Multiplicity) (x :: Multiplicity).
(a %p -> b) %1 -> a %x -> b
Unsafe.toLinear (\a
x -> forall a. IO a -> a
unsafeDupablePerformIO (forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.write MVector RealWorld a
mvec Int
0 a
x))

-- | @dropEmpty dest@ consumes and empty array and fails otherwise.
dropEmpty :: HasCallStack => DArray a %1 -> ()
dropEmpty :: forall a. HasCallStack => DArray a %1 -> ()
dropEmpty (DArray MVector RealWorld a
mvec)
  | forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec forall a. Ord a => a %1 -> a %1 -> Bool
> Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.dropEmpty on non-empty array."
  | Bool
otherwise = MVector RealWorld a
mvec forall a b (q :: Multiplicity). a -> b %q -> b
`seq` ()

-- | @'split' n dest = (destl, destr)@ such as @destl@ has length @n@.
--
-- 'split' is total: if @n@ is larger than the length of @dest@, then
-- @destr@ is empty.
split :: Int -> DArray a %1 -> (DArray a, DArray a)
split :: forall a. Int -> DArray a %1 -> (DArray a, DArray a)
split Int
n (DArray MVector RealWorld a
mvec)
  | (MVector RealWorld a
ml, MVector RealWorld a
mr) <- forall s a. Int -> MVector s a -> (MVector s a, MVector s a)
MVector.splitAt Int
n MVector RealWorld a
mvec =
      (forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
ml, forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
mr)

-- | Fills the destination array with the contents of given vector.
--
-- Errors if the given vector is smaller than the destination array.
mirror :: HasCallStack => Vector a -> (a %1 -> b) -> DArray b %1 -> ()
mirror :: forall a b.
HasCallStack =>
Vector a -> (a %1 -> b) -> DArray b %1 -> ()
mirror Vector a
v a %1 -> b
f DArray b
arr =
  forall a. DArray a %1 -> (Ur Int, DArray a)
size DArray b
arr forall a b (p :: Multiplicity) (q :: Multiplicity).
a %p -> (a %p -> b) %q -> b
& \(Ur Int
sz, DArray b
arr') ->
    if forall a. Vector a -> Int
Vector.length Vector a
v forall a. Ord a => a %1 -> a %1 -> Bool
< Int
sz
      then forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.mirror: argument smaller than DArray" forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ DArray b
arr'
      else forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction (\Int
t -> a %1 -> b
f (Vector a
v forall a. Vector a -> Int -> a
! Int
t)) DArray b
arr'

-- | Fill a destination array using the given index-to-value function.
fromFunction :: (Int -> b) -> DArray b %1 -> ()
fromFunction :: forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction Int -> b
f (DArray MVector RealWorld b
mvec) = forall a. IO a -> a
unsafeDupablePerformIO forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ do
  let n :: Int
n = forall s a. MVector s a -> Int
MVector.length MVector RealWorld b
mvec
  forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
Prelude.sequence_ [forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.unsafeWrite MVector RealWorld b
mvec Int
m (Int -> b
f Int
m) | Int
m <- [Int
0 .. Int
n forall a. AdditiveGroup a => a %1 -> a %1 -> a
- Int
1]]

-- The use of the mutable array is linear, since getting the length does not
-- touch any elements, and each write fills in exactly one slot, so
-- each slot of the destination array is filled.