module Data.Vector.Mutable (
  
  MVector(..), IOVector, STVector,
  
  
  length, null,
  
  slice, init, tail, take, drop, splitAt,
  unsafeSlice, unsafeInit, unsafeTail, unsafeTake, unsafeDrop,
  
  overlaps,
  
  
  new, unsafeNew, replicate, replicateM, clone,
  
  grow, unsafeGrow,
  
  clear,
  
  read, write, swap,
  unsafeRead, unsafeWrite, unsafeSwap,
  
  
  set, copy, move, unsafeCopy, unsafeMove
) where
import           Control.Monad (when)
import qualified Data.Vector.Generic.Mutable as G
import           Data.Primitive.Array
import           Control.Monad.Primitive
import Control.DeepSeq ( NFData, rnf )
import Prelude hiding ( length, null, replicate, reverse, map, read,
                        take, drop, splitAt, init, tail )
import Data.Typeable ( Typeable )
#include "vector.h"
data MVector s a = MVector  !Int
                            !Int
                            !(MutableArray s a)
        deriving ( Typeable )
type IOVector = MVector RealWorld
type STVector s = MVector s
instance G.MVector MVector a where
  
  basicLength (MVector _ n _) = n
  
  basicUnsafeSlice j m (MVector i n arr) = MVector (i+j) m arr
  
  basicOverlaps (MVector i m arr1) (MVector j n arr2)
    = sameMutableArray arr1 arr2
      && (between i j (j+n) || between j i (i+m))
    where
      between x y z = x >= y && x < z
  
  basicUnsafeNew n
    = do
        arr <- newArray n uninitialised
        return (MVector 0 n arr)
  
  basicUnsafeReplicate n x
    = do
        arr <- newArray n x
        return (MVector 0 n arr)
  
  basicUnsafeRead (MVector i n arr) j = readArray arr (i+j)
  
  basicUnsafeWrite (MVector i n arr) j x = writeArray arr (i+j) x
  
  basicUnsafeCopy (MVector i n dst) (MVector j _ src)
    = copyMutableArray dst i src j n
  
  basicUnsafeMove dst@(MVector iDst n arrDst) src@(MVector iSrc _ arrSrc)
    = case n of
        0 -> return ()
        1 -> readArray arrSrc iSrc >>= writeArray arrDst iDst
        2 -> do
               x <- readArray arrSrc iSrc
               y <- readArray arrSrc (iSrc + 1)
               writeArray arrDst iDst x
               writeArray arrDst (iDst + 1) y
        _
          | overlaps dst src
             -> case compare iDst iSrc of
                  LT -> moveBackwards arrDst iDst iSrc n
                  EQ -> return ()
                  GT | (iDst  iSrc) * 2 < n
                        -> moveForwardsLargeOverlap arrDst iDst iSrc n
                     | otherwise
                        -> moveForwardsSmallOverlap arrDst iDst iSrc n
          | otherwise -> G.basicUnsafeCopy dst src
  
  basicClear v = G.set v uninitialised
moveBackwards :: PrimMonad m => MutableArray (PrimState m) a -> Int -> Int -> Int -> m ()
moveBackwards !arr !dstOff !srcOff !len =
  INTERNAL_CHECK(check) "moveBackwards" "not a backwards move" (dstOff < srcOff)
  $ loopM len $ \ i -> readArray arr (srcOff + i) >>= writeArray arr (dstOff + i)
moveForwardsSmallOverlap :: PrimMonad m => MutableArray (PrimState m) a -> Int -> Int -> Int -> m ()
moveForwardsSmallOverlap !arr !dstOff !srcOff !len =
  INTERNAL_CHECK(check) "moveForwardsSmallOverlap" "not a forward move" (dstOff > srcOff)
  $ do
      tmp <- newArray overlap uninitialised
      loopM overlap $ \ i -> readArray arr (dstOff + i) >>= writeArray tmp i
      loopM nonOverlap $ \ i -> readArray arr (srcOff + i) >>= writeArray arr (dstOff + i)
      loopM overlap $ \ i -> readArray tmp i >>= writeArray arr (dstOff + nonOverlap + i)
  where nonOverlap = dstOff  srcOff; overlap = len  nonOverlap
moveForwardsLargeOverlap :: PrimMonad m => MutableArray (PrimState m) a -> Int -> Int -> Int -> m ()
moveForwardsLargeOverlap !arr !dstOff !srcOff !len =
  INTERNAL_CHECK(check) "moveForwardsLargeOverlap" "not a forward move" (dstOff > srcOff)
  $ do
      queue <- newArray nonOverlap uninitialised
      loopM nonOverlap $ \ i -> readArray arr (srcOff + i) >>= writeArray queue i
      let mov !i !qTop = when (i < dstOff + len) $ do
            x <- readArray arr i
            y <- readArray queue qTop
            writeArray arr i y
            writeArray queue qTop x
            mov (i+1) (if qTop + 1 >= nonOverlap then 0 else qTop + 1)
      mov dstOff 0
  where nonOverlap = dstOff  srcOff
loopM :: Monad m => Int -> (Int -> m a) -> m ()
loopM !n k = let
  go i = when (i < n) (k i >> go (i+1))
  in go 0
uninitialised :: a
uninitialised = error "Data.Vector.Mutable: uninitialised element"
length :: MVector s a -> Int
length = G.length
null :: MVector s a -> Bool
null = G.null
slice :: Int -> Int -> MVector s a -> MVector s a
slice = G.slice
take :: Int -> MVector s a -> MVector s a
take = G.take
drop :: Int -> MVector s a -> MVector s a
drop = G.drop
splitAt :: Int -> MVector s a -> (MVector s a, MVector s a)
splitAt = G.splitAt
init :: MVector s a -> MVector s a
init = G.init
tail :: MVector s a -> MVector s a
tail = G.tail
unsafeSlice :: Int  
            -> Int  
            -> MVector s a
            -> MVector s a
unsafeSlice = G.unsafeSlice
unsafeTake :: Int -> MVector s a -> MVector s a
unsafeTake = G.unsafeTake
unsafeDrop :: Int -> MVector s a -> MVector s a
unsafeDrop = G.unsafeDrop
unsafeInit :: MVector s a -> MVector s a
unsafeInit = G.unsafeInit
unsafeTail :: MVector s a -> MVector s a
unsafeTail = G.unsafeTail
overlaps :: MVector s a -> MVector s a -> Bool
overlaps = G.overlaps
new :: PrimMonad m => Int -> m (MVector (PrimState m) a)
new = G.new
unsafeNew :: PrimMonad m => Int -> m (MVector (PrimState m) a)
unsafeNew = G.unsafeNew
replicate :: PrimMonad m => Int -> a -> m (MVector (PrimState m) a)
replicate = G.replicate
replicateM :: PrimMonad m => Int -> m a -> m (MVector (PrimState m) a)
replicateM = G.replicateM
clone :: PrimMonad m => MVector (PrimState m) a -> m (MVector (PrimState m) a)
clone = G.clone
grow :: PrimMonad m
              => MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
grow = G.grow
unsafeGrow :: PrimMonad m
               => MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
unsafeGrow = G.unsafeGrow
clear :: PrimMonad m => MVector (PrimState m) a -> m ()
clear = G.clear
read :: PrimMonad m => MVector (PrimState m) a -> Int -> m a
read = G.read
write :: PrimMonad m => MVector (PrimState m) a -> Int -> a -> m ()
write = G.write
swap :: PrimMonad m => MVector (PrimState m) a -> Int -> Int -> m ()
swap = G.swap
unsafeRead :: PrimMonad m => MVector (PrimState m) a -> Int -> m a
unsafeRead = G.unsafeRead
unsafeWrite :: PrimMonad m => MVector (PrimState m) a -> Int -> a -> m ()
unsafeWrite = G.unsafeWrite
unsafeSwap :: PrimMonad m => MVector (PrimState m) a -> Int -> Int -> m ()
unsafeSwap = G.unsafeSwap
set :: PrimMonad m => MVector (PrimState m) a -> a -> m ()
set = G.set
copy :: PrimMonad m
                 => MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
copy = G.copy
unsafeCopy :: PrimMonad m => MVector (PrimState m) a   
                          -> MVector (PrimState m) a   
                          -> m ()
unsafeCopy = G.unsafeCopy
move :: PrimMonad m
                 => MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
move = G.move
unsafeMove :: PrimMonad m => MVector (PrimState m) a   
                          -> MVector (PrimState m) a   
                          -> m ()
unsafeMove = G.unsafeMove