{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}
module Internal.Vector where

import Zhp hiding (empty)

import Data.Vector.Endian.Class (Endian)
import Internal.Endianness      (from, to)
import Text.Read                (Read(..))

import Control.DeepSeq      (NFData)
import Control.Monad.ST     (ST)
import Data.Traversable     (Traversable)
import Data.Vector.Storable (Storable)
import Foreign.ForeignPtr   (ForeignPtr)
import Foreign.Ptr          (Ptr)
import GHC.Exts             (IsList(..))

import qualified Data.Vector.Generic          as GV
import qualified Data.Vector.Generic.Mutable  as GMV
import qualified Data.Vector.Storable         as SV
import qualified Data.Vector.Storable.Mutable as SMV

newtype MVector s a = MVector (SMV.MVector s a)
    deriving(NFData)

newtype Vector a = Vector (SV.Vector a)
    deriving(Eq, NFData, Semigroup, Monoid)

instance (Storable a, Endian a) => IsList (Vector a) where
    type Item (Vector a) = a

    fromList = Vector . fromList . map to
    fromListN len = Vector . fromListN len . map to
    toList (Vector v) = map from $ SV.toList v

instance (Show a, SV.Storable a, Endian a) => Show (Vector a) where
    show (Vector v) = show $ SV.map from v

instance (Read a, SV.Storable a, Endian a) => Read (Vector a) where
    -- Read in the storable vector, then convert all of the elemnents to
    -- the storage format we need.
    readPrec = Vector . SV.map to <$> readPrec

type instance GV.Mutable Vector = MVector

instance (SV.Storable a, Endian a) => GMV.MVector MVector a where
    -- These are all trivial wrappers around the underlying vector.
    basicLength (MVector v) = GMV.basicLength v
    basicUnsafeSlice i len (MVector v) = MVector (GMV.basicUnsafeSlice i len v)
    basicOverlaps (MVector l) (MVector r) = GMV.basicOverlaps l r
    basicUnsafeNew = fmap MVector . GMV.basicUnsafeNew
    basicInitialize (MVector v) = GMV.basicInitialize v
    basicClear (MVector v) = GMV.basicClear v
    basicUnsafeCopy (MVector l) (MVector r) = GMV.basicUnsafeCopy l r
    basicUnsafeMove (MVector l) (MVector r) = GMV.basicUnsafeMove l r
    basicUnsafeGrow (MVector v) i = MVector <$> GMV.basicUnsafeGrow v i

    -- These have to do endianness conversion:
    basicUnsafeReplicate !i !e = GMV.basicUnsafeReplicate i $! to e
    basicUnsafeRead (MVector v) i = from <$> GMV.basicUnsafeRead v i
    basicUnsafeWrite (MVector v) i e = GMV.basicUnsafeWrite v i $! to e
    basicSet (MVector v) e = GMV.basicSet v $! to e

instance (SV.Storable a, Endian a) => GV.Vector Vector a where
    -- boring wrappers:
    basicUnsafeFreeze (MVector v) = Vector <$> GV.basicUnsafeFreeze v
    basicUnsafeThaw (Vector v) = MVector <$> GV.basicUnsafeThaw v
    basicLength (Vector v) = GV.basicLength v
    basicUnsafeSlice i len (Vector v) = Vector (GV.basicUnsafeSlice i len v)
    basicUnsafeCopy (MVector mv) (Vector v) = GV.basicUnsafeCopy mv v
    elemseq (Vector v) x y = GV.elemseq v x y

    -- need to do endianness conversion:
    basicUnsafeIndexM (Vector v) i = from <$!> GV.basicUnsafeIndexM v i

---------------
-- Accessors --
---------------


-- Length information
---------------------

length :: (Storable a, Endian a) => Vector a -> Int
length (Vector v) = SV.length v

null :: (Storable a, Endian a) => Vector a -> Bool
null (Vector v) = SV.null v

-- Indexing
-----------

(!) :: (Storable a, Endian a) => Vector a -> Int -> a
(!) (Vector v) i = from $! v SV.! i

(!?) :: (Storable a, Endian a) => Vector a -> Int -> Maybe a
Vector v !? i = from <$!> v SV.!? i

head :: (Storable a, Endian a) => Vector a -> a
head (Vector v) = from $! SV.head v

last :: (Storable a, Endian a) => Vector a -> a
last (Vector v) = from $! SV.last v

unsafeIndex :: (Storable a, Endian a) => Vector a -> Int -> a
unsafeIndex (Vector v) i = from $! SV.unsafeIndex v i

unsafeHead :: (Storable a, Endian a) => Vector a -> a
unsafeHead (Vector v) = from $! SV.unsafeHead v

unsafeLast :: (Storable a, Endian a) => Vector a -> a
unsafeLast (Vector v) = from $! SV.unsafeLast v

-- Monadic indexing
-------------------

indexM :: (Storable a, Endian a, Monad m) => Vector a -> Int -> m a
indexM (Vector v) i = from <$!> SV.indexM v i

headM :: (Storable a, Endian a, Monad m) => Vector a -> m a
headM (Vector v) = from <$!> SV.headM v

lastM :: (Storable a, Endian a, Monad m) => Vector a -> m a
lastM (Vector v) = from <$!> SV.lastM v

unsafeIndexM :: (Storable a, Endian a, Monad m) => Vector a -> Int -> m a
unsafeIndexM (Vector v) i = from <$!> SV.unsafeIndexM v i

unsafeHeadM :: (Storable a, Endian a, Monad m) => Vector a -> m a
unsafeHeadM (Vector v) = from <$!> SV.unsafeHeadM v

unsafeLastM :: (Storable a, Endian a, Monad m) => Vector a -> m a
unsafeLastM (Vector v) = from <$!> SV.unsafeLastM v

-- Extracting subvectors (slicing)
----------------------------------

slice :: (Storable a, Endian a) => Int -> Int -> Vector a -> Vector a
slice i len (Vector v) = Vector (SV.slice i len v)

init :: (Storable a, Endian a) => Vector a -> Vector a
init (Vector v) = Vector (SV.init v)

tail :: (Storable a, Endian a) => Vector a -> Vector a
tail (Vector v) = Vector (SV.tail v)

take :: (Storable a, Endian a) => Int -> Vector a -> Vector a
take count (Vector v) = Vector (SV.take count v)

drop :: (Storable a, Endian a) => Int -> Vector a -> Vector a
drop count (Vector v) = Vector (SV.drop count v)

splitAt :: (Storable a, Endian a) => Int -> Vector a -> (Vector a, Vector a)
splitAt i (Vector v) =
    let (l, r) = SV.splitAt i v
    in (Vector l, Vector r)

unsafeSlice :: (Storable a, Endian a) => Int -> Int -> Vector a -> Vector a
unsafeSlice i len (Vector v) = Vector (SV.unsafeSlice i len v)

unsafeInit :: (Storable a, Endian a) => Vector a -> Vector a
unsafeInit (Vector v) = Vector (SV.unsafeInit v)

unsafeTail :: (Storable a, Endian a) => Vector a -> Vector a
unsafeTail (Vector v) = Vector (SV.unsafeTail v)

unsafeTake :: (Storable a, Endian a) => Int -> Vector a -> Vector a
unsafeTake count (Vector v) = Vector (SV.unsafeTake count v)

unsafeDrop :: (Storable a, Endian a) => Int -> Vector a -> Vector a
unsafeDrop count (Vector v) = Vector (SV.unsafeDrop count v)

---------------
-- Construction
---------------

-- Initialisation
-----------------

empty :: (Storable a, Endian a) => Vector a
empty = Vector SV.empty

singleton :: (Storable a, Endian a) => a -> Vector a
singleton value = Vector (SV.singleton (to value))

replicate :: (Storable a, Endian a) => Int -> a -> Vector a
replicate count value = Vector (SV.replicate count (to value))

generate :: (Storable a, Endian a) => Int -> (Int -> a) -> Vector a
generate count f = Vector (SV.generate count (to . f))

iterateN :: (Storable a, Endian a) => Int -> (a -> a) -> a -> Vector a
iterateN count f orig = Vector (SV.iterateN count (to . f . from) (to orig))

-- Monadic Initialisation
-------------------------

replicateM :: (Monad m, Storable a, Endian a) => Int -> m a -> m (Vector a)
replicateM count m = Vector <$> SV.replicateM count (to <$> m)

generateM :: (Monad m, Storable a, Endian a) => Int -> (Int -> m a) -> m (Vector a)
generateM count m = Vector <$> SV.generateM count (fmap to . m)

iterateNM :: (Monad m, Storable a, Endian a) => Int -> (a -> m a) -> a -> m (Vector a)
iterateNM count m orig = Vector <$> SV.iterateNM count (fmap to . m . from) orig

create :: (Storable a, Endian a) => (forall s. ST s (MVector s a)) -> Vector a
create = GV.create

createT :: (Traversable f, Storable a, Endian a)
    => (forall s. ST s (f (MVector s a))) -> f (Vector a)
createT = GV.createT

-- Unfolding
------------

unfoldr :: (Storable a, Endian a) => (b -> Maybe (a, b)) -> b -> Vector a
unfoldr = GV.unfoldr

unfoldrN :: (Storable a, Endian a) => Int -> (b -> Maybe (a, b)) -> b -> Vector a
unfoldrN = GV.unfoldrN

unfoldrM :: (Monad m, Storable a, Endian a) => (b -> m (Maybe (a, b))) -> b -> m (Vector a)
unfoldrM = GV.unfoldrM

unfoldrNM :: (Monad m, Storable a, Endian a) => Int -> (b -> m (Maybe (a, b))) -> b -> m (Vector a)
unfoldrNM = GV.unfoldrNM

constructN :: (Storable a, Endian a) => Int -> (Vector a -> a) -> Vector a
constructN = GV.constructN

constructrN :: (Storable a, Endian a) => Int -> (Vector a -> a) -> Vector a
constructrN = GV.constructrN

-- Enumeration
--------------
enumFromN :: (Storable a, Endian a, Num a) => a -> Int -> Vector a
enumFromN = GV.enumFromN

enumFromStepN :: (Storable a, Endian a, Num a) => a -> a -> Int -> Vector a
enumFromStepN = GV.enumFromStepN

enumFromTo :: (Storable a, Endian a, Enum a) => a -> a -> Vector a
enumFromTo = GV.enumFromTo

enumFromThenTo :: (Storable a, Endian a, Enum a) => a -> a -> a -> Vector a
enumFromThenTo = GV.enumFromThenTo

---------------
-- Raw pointers
---------------

unsafeFromForeignPtr :: (SV.Storable a, Endian a)
    => ForeignPtr a -> Int -> Int -> Vector a
unsafeFromForeignPtr p off len = Vector $ SV.unsafeFromForeignPtr p off len

unsafeFromForeignPtr0 :: (SV.Storable a, Endian a)
    => ForeignPtr a -> Int -> Vector a
unsafeFromForeignPtr0 p len = Vector $ SV.unsafeFromForeignPtr0 p len

unsafeToForeignPtr :: (SV.Storable a, Endian a)
    => Vector a -> (ForeignPtr a, Int, Int)
unsafeToForeignPtr (Vector v) = SV.unsafeToForeignPtr v

unsafeToForeignPtr0 :: (SV.Storable a, Endian a)
    => Vector a -> (ForeignPtr a, Int)
unsafeToForeignPtr0 (Vector v) = SV.unsafeToForeignPtr0 v

-- | Like 'SV.unsafeWith', but note well: the pointer will point to the value
-- in its *wire format*, which may not match the host cpu endianness.
unsafeWith :: (SV.Storable a, Endian a) => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith (Vector v) = SV.unsafeWith v