{-# LANGUAGE CPP #-}
#ifdef __GLASGOW_HASKELL__
{-# LANGUAGE TypeFamilies #-}
#endif

{- |
= Finite vectors

The @'Vector' a@ type represents a finite vector (or dynamic array) of elements of type @a@.
A 'Vector' is strict in its spine.

The class instances are based on those for lists.

This module should be imported qualified, to avoid name clashes with the 'Prelude'.

> import qualified Data.AMT as Vector

== Performance

The worst case running time complexities are given, with /n/ referring the the number of elements in the vector.
A 'Vector' is particularly efficient for applications that require a lot of indexing and updates.
All logarithms are base 16, which means that /O(log n)/ behaves like /O(1)/ in practice.

== Warning

The length of a 'Vector' must not exceed @'maxBound' :: 'Int'@.
Violation of this condition is not detected and if the length limit is exceeded, the behaviour of the vector is undefined.

== Implementation

The implementation of 'Vector' uses array mapped tries.
-}

module Data.AMT
    ( Vector
    -- * Construction
    , empty, singleton, fromList
    , fromFunction
    , replicate, replicateA
    , unfoldr, unfoldl, iterateN
    , (<|), (|>), (><)
    -- * Deconstruction/Subranges
    , viewl
    , viewr
    , last
    , take
    -- * Indexing
    , lookup, index
    , (!?), (!)
    , update
    , adjust
    -- * Transformations
    , map, mapWithIndex
    , traverseWithIndex
    , indexed
    -- * Folds
    , foldMapWithIndex
    , foldlWithIndex, foldrWithIndex
    , foldlWithIndex', foldrWithIndex'
    -- * Zipping/Unzipping
    , zip, zipWith
    , zip3, zipWith3
    , unzip, unzip3
    -- * To Lists
    , toIndexedList
    ) where

import Control.Applicative (Alternative)
import qualified Control.Applicative as Applicative
import Control.Monad (MonadPlus(..))
#if !(MIN_VERSION_base(4,13,0))
import Control.Monad.Fail (MonadFail(..))
#endif
import Control.Monad.Zip (MonadZip(..))

import Data.Bits
import Data.Foldable (foldl', toList)
import Data.Functor.Classes
import Data.Functor.Compose
import Data.Functor.Identity
import Data.List.NonEmpty (NonEmpty(..), (!!))
import qualified Data.List.NonEmpty as L
import Data.Maybe (fromMaybe)
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup (Semigroup((<>)))
#endif
#ifdef __GLASGOW_HASKELL__
import Data.String (IsString)
#endif
import Data.Traversable (mapAccumL)
#ifdef __GLASGOW_HASKELL__
import GHC.Exts (IsList)
import qualified GHC.Exts as Exts
#endif
import Prelude hiding ((!!), last, lookup, map, replicate, tail, take, unzip, unzip3, zip, zipWith, zip3, zipWith3)
import qualified Prelude as P
import Text.Read (Lexeme(Ident), lexP, parens, prec, readPrec)

import Control.Monad.Trans.State.Strict (state, evalState)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as M

infixr 5 ><
infixr 5 <|
infixl 5 |>

data Tree a
    = Internal !(V.Vector (Tree a))
    | Leaf !(V.Vector a)

-- | An array mapped trie.
data Vector a
    = Empty
    | Root
        {-# UNPACK #-} !Int  -- size
        {-# UNPACK #-} !Int  -- offset (number of elements in the tree)
        {-# UNPACK #-} !Int  -- height (of the tree)
        !(Tree a)  -- tree
        !(NonEmpty a)  -- tail (reversed)

errorNegativeLength :: String -> a
errorNegativeLength s = error $ "AMT." ++ s ++ ": expected a nonnegative length"

-- The number of bits used per level.
bits :: Int
bits = 4
{-# INLINE bits #-}

-- The maximum size of the tail.
tailSize :: Int
tailSize = 1 `shiftL` bits

-- The mask used to extract the index into the array.
mask :: Int
mask = tailSize - 1

instance Show1 Vector where
    liftShowsPrec sp sl p v = showsUnaryWith (liftShowsPrec sp sl) "fromList" p (toList v)

instance Show a => Show (Vector a) where
    showsPrec = showsPrec1
    {-# INLINE showsPrec #-}

instance Read1 Vector where
    liftReadsPrec rp rl = readsData $ readsUnaryWith (liftReadsPrec rp rl) "fromList" fromList

instance Read a => Read (Vector a) where
#ifdef __GLASGOW_HASKELL__
    readPrec = parens $ prec 10 $ do
        Ident "fromList" <- lexP
        xs <- readPrec
        pure (fromList xs)
#else
    readsPrec = readsPrec1
    {-# INLINE readsPrec #-}
#endif

instance Eq1 Vector where
    liftEq f v1 v2 = length v1 == length v2 && liftEq f (toList v1) (toList v2)

instance Eq a => Eq (Vector a) where
    (==) = eq1
    {-# INLINE (==) #-}

instance Ord1 Vector where
    liftCompare f v1 v2 = liftCompare f (toList v1) (toList v2)

instance Ord a => Ord (Vector a) where
    compare = compare1
    {-# INLINE compare #-}

instance Semigroup (Vector a) where
    (<>) = (><)
    {-# INLINE (<>) #-}

instance Monoid (Vector a) where
    mempty = empty
    {-# INLINE mempty #-}

    mappend = (<>)
    {-# INLINE mappend #-}

instance Foldable Vector where
    foldr _ acc Empty = acc
    foldr f acc (Root _ _ _ tree tail) = foldrTree tree (foldr f acc (L.reverse tail))
      where
        foldrTree (Internal v) acc' = foldr foldrTree acc' v
        foldrTree (Leaf v) acc' = foldr f acc' v

    null Empty = True
    null Root{} = False
    {-# INLINE null #-}

    length Empty = 0
    length (Root s _ _ _ _) = s
    {-# INLINE length #-}

instance Functor Vector where
    fmap = map
    {-# INLINE fmap #-}

instance Traversable Vector where
    traverse _ Empty = pure Empty
    traverse f (Root s offset h tree tail) =
        Root s offset h <$> traverseTree tree <*> (L.reverse <$> traverse f (L.reverse tail))
      where
        traverseTree (Internal v) = Internal <$> traverse traverseTree v
        traverseTree (Leaf v) = Leaf <$> traverse f v

#ifdef __GLASGOW_HASKELL__
instance IsList (Vector a) where
    type Item (Vector a) = a

    fromList = fromList
    {-# INLINE fromList #-}

    toList = toList
    {-# INLINE toList #-}

instance a ~ Char => IsString (Vector a) where
    fromString = fromList
    {-# INLINE fromString #-}
#endif

instance Applicative Vector where
    pure = singleton
    {-# INLINE pure #-}

    fs <*> xs = foldl' (\acc f -> acc >< map f xs) empty fs

instance Monad Vector where
    xs >>= f = foldl' (\acc x -> acc >< f x) empty xs

instance Alternative Vector where
    empty = empty
    {-# INLINE empty #-}

    (<|>) = (><)
    {-# INLINE (<|>) #-}

instance MonadPlus Vector

instance MonadFail Vector where
    fail _ = empty
    {-# INLINE fail #-}

instance MonadZip Vector where
    mzip = zip
    {-# INLINE mzip #-}

    mzipWith = zipWith
    {-# INLINE mzipWith #-}

    munzip = unzip
    {-# INLINE munzip #-}


-- | /O(1)/. The empty vector.
--
-- > empty = fromList []
empty :: Vector a
empty = Empty
{-# INLINE empty #-}

-- | /O(1)/. A vector with a single element.
--
-- > singleton x = fromList [x]
singleton :: a -> Vector a
singleton x = Root 1 0 0 (Leaf V.empty) (x :| [])
{-# INLINE singleton #-}

-- | /O(n * log n)/. Create a new vector from a list.
fromList :: [a] -> Vector a
fromList = foldl' (|>) empty
{-# INLINE fromList #-}

-- | Create a new vector of the given length from a function.
fromFunction :: Int -> (Int -> a) -> Vector a
fromFunction n f = if n < 0 then errorNegativeLength "fromFunction" else go 0 empty
  where
    go i acc
        | i < n = go (i + 1) (acc |> f i)
        | otherwise = acc
{-# INLINE fromFunction #-}

-- | /O(n * log n)/. @replicate n x@ is a vector consisting of n copies of x.
replicate :: Int -> a -> Vector a
replicate n = if n < 0 then errorNegativeLength "replicate" else runIdentity . replicateA n . Identity
{-# INLINE replicate #-}

-- | @replicateA@ is an 'Applicative' version of 'replicate'.
replicateA :: Applicative f => Int -> f a -> f (Vector a)
replicateA n x = if n < 0 then errorNegativeLength "replicateA" else go 0 (pure empty)
  where
    go i acc
        | i < n = go (i + 1) ((|>) <$> acc <*> x)
        | otherwise = acc
{-# INLINE replicateA #-}

-- | /O(n * log n)/. Build a vector from left to right by repeatedly applying a function to a seed value.
unfoldr :: (b -> Maybe (a, b)) -> b -> Vector a
unfoldr f = go empty
  where
    go v acc = case f acc of
        Nothing -> v
        Just (x, acc') -> go (v |> x) acc'
{-# INLINE unfoldr #-}

-- | /O(n * log n)/. Build a vector from right to left by repeatedly applying a function to a seed value.
unfoldl :: (b -> Maybe (b, a)) -> b -> Vector a
unfoldl f = go
  where
    go acc = case f acc of
        Nothing -> empty
        Just (acc', x) -> go acc' |> x
{-# INLINE unfoldl #-}

-- | Constructs a vector by repeatedly applying a function to a seed value.
iterateN :: Int -> (a -> a) -> a -> Vector a
iterateN n f x = if n < 0 then errorNegativeLength "iterateN" else replicateA n (state (\y -> (y, f y))) `evalState` x
{-# INLINE iterateN #-}

-- | /O(n * log n)/. Add an element to the left end of the vector.
(<|) :: a -> Vector a -> Vector a
x <| v = fromList $ x : toList v

-- | /O(n * log n)/. The first element and the vector without the first element or 'Nothing' if the vector is empty.
viewl :: Vector a -> Maybe (a, Vector a)
viewl Empty = Nothing
viewl v@Root{} =
    let ls = toList v
    in Just (head ls, fromList $ P.tail ls)

-- | /O(log n)/. Add an element to the right end of the vector.
(|>) :: Vector a -> a -> Vector a
Empty |> x = singleton x
Root s offset h tree tail |> x
    | s .&. mask /= 0 = Root (s + 1) offset h tree (x L.<| tail)
    | offset == 0 = Root (s + 1) s (h + 1) (Leaf $ V.fromList (toList $ L.reverse tail)) (x :| [])
    | offset == 1 `shiftL` (bits * h) = Root (s + 1) s (h + 1) (Internal $ V.fromList [tree, newPath h]) (x :| [])
    | otherwise = Root (s + 1) s h (insertTail (bits * (h - 1)) tree) (x :| [])
  where
    -- create a new path from the old tail
    newPath 1 = Leaf $ V.fromList (toList $ L.reverse tail)
    newPath h = Internal $ V.singleton (newPath (h - 1))

    insertTail sh (Internal v)
        | index < V.length v = Internal $ V.modify (\v -> M.modify v (insertTail (sh - bits)) index) v
        | otherwise = Internal $ V.snoc v (newPath (sh `div` bits))
      where
        index = offset `shiftR` sh .&. mask
    insertTail _ (Leaf _) = Leaf $ V.fromList (toList $ L.reverse tail)

-- | /O(log n)/. The vector without the last element and the last element or 'Nothing' if the vector is empty.
viewr :: Vector a -> Maybe (Vector a, a)
viewr Empty = Nothing
viewr (Root s offset h tree (x :| tail))
    | not (null tail) = Just (Root (s - 1) offset h tree (L.fromList tail), x)
    | s == 1 = Just (Empty, x)
    | s == tailSize + 1 = Just (Root (s - 1) 0 0 (Leaf V.empty) (getTail tree), x)
    | otherwise =
        let sh = bits * (h - 1)
        in Just (normalize $ Root (s - 1) (offset - tailSize) h (unsnocTree sh tree) (getTail tree), x)
  where
    index' = offset - tailSize - 1

    unsnocTree sh (Internal v) =
        let subIndex = index' `shiftR` sh .&. mask
            new = V.take (subIndex + 1) v
        in Internal $ V.modify (\v -> M.modify v (unsnocTree (sh - bits)) subIndex) new
    unsnocTree _ (Leaf v) = Leaf v

    getTail (Internal v) = getTail (V.last v)
    getTail (Leaf v) = L.fromList . reverse $ toList v

    normalize (Root s offset h (Internal v) tail)
        | length v == 1 = Root s offset (h - 1) (v V.! 0) tail
    normalize v = v

-- | /O(1)/. The last element in the vector or 'Nothing' if the vector is empty.
last :: Vector a -> Maybe a
last Empty = Nothing
last (Root _ _ _ _ (x :| _)) = Just x
{-# INLINE last #-}

-- | /O(log n)/. Take the first n elements of the vector or the vector if n is larger than the length of the vector.
-- Returns the empty vector if n is negative.
take :: Int -> Vector a -> Vector a
take _ Empty = Empty
take n root@(Root s offset h tree tail)
    | n <= 0 = Empty
    | n >= s = root
    | n > offset = Root n offset h tree (L.fromList $ L.drop (s - n) tail)
    | n <= tailSize = Root n 0 0 (Leaf V.empty) (getTail (bits * (h - 1)) tree)
    | otherwise =
        let sh = bits * (h - 1)
        in normalize $ Root n ((n - 1) .&. complement mask) h (takeTree sh tree) (getTail sh tree)  -- n - 1 because if 'n .&. mask == 0', we need to subtract tailSize
  where
    -- index of the last element in the new vector
    index = n - 1

    index' = index - tailSize

    takeTree sh (Internal v) =
        let subIndex = index' `shiftR` sh .&. mask
            new = V.take (subIndex + 1) v
        in Internal $ V.modify (\v -> M.modify v (takeTree (sh - bits)) subIndex) new
    takeTree _ (Leaf v) = Leaf v

    getTail sh (Internal v) = getTail (sh - bits) (v V.! (index `shiftR` sh .&. mask))
    getTail _ (Leaf v) = L.fromList . reverse . P.take (index .&. mask + 1) $ toList v

    normalize (Root s offset h (Internal v) tail)
        | length v == 1 = normalize $ Root s offset (h - 1) (v V.! 0) tail
    normalize v = v

-- | /O(log n)/. The element at the index or 'Nothing' if the index is out of range.
lookup :: Int -> Vector a -> Maybe a
lookup _ Empty = Nothing
lookup i (Root s offset h tree tail)
    | i < 0 || i >= s = Nothing
    | i < offset = Just $ lookupTree (bits * (h - 1)) tree
    | otherwise = Just $ tail !! (s - i - 1)
  where
    lookupTree sh (Internal v) = lookupTree (sh - bits) (v V.! (i `shiftR` sh .&. mask))
    lookupTree _ (Leaf v) = v V.! (i .&. mask)

-- | /O(log n)/. The element at the index. Calls 'error' if the index is out of range.
index :: Int -> Vector a -> a
index i = fromMaybe (error "AMT.index: index out of range") . lookup i

-- | /O(log n)/. Flipped version of 'lookup'.
(!?) :: Vector a -> Int -> Maybe a
(!?) = flip lookup
{-# INLINE (!?) #-}

-- | /O(log n)/. Flipped version of 'lookup'.
(!) :: Vector a -> Int -> a
(!) = flip index
{-# INLINE (!) #-}

-- | /O(log n)/. Update the element at the index with a new element.
-- Returns the original vector if the index is out of range.
update :: Int -> a -> Vector a -> Vector a
update i x = adjust i (const x)
{-# INLINE update #-}

-- | /O(log n)/. Adjust the element at the index by applying the function to it.
-- Returns the original vector if the index is out of range.
adjust :: Int -> (a -> a) -> Vector a -> Vector a
adjust _ _ Empty = Empty
adjust i f root@(Root s offset h tree tail)
    | i < 0 || i >= s = root
    | i < offset = Root s offset h (adjustTree (bits * (h - 1)) tree) tail
    | otherwise = let (l, x : r) = L.splitAt (s - i - 1) tail in Root s offset h tree (L.fromList $ l ++ (f x : r))
  where
    adjustTree sh (Internal v) =
        let index = i `shiftR` sh .&. mask
        in Internal $ V.modify (\v -> M.modify v (adjustTree (sh - bits)) index) v
    adjustTree _ (Leaf v) =
        let index = i .&. mask
        in Leaf $ V.modify (\v -> M.modify v f index) v

-- | /O(m * log n)/. Concatenate two vectors.
(><) :: Vector a -> Vector a -> Vector a
Empty >< v = v
v >< Empty = v
v1 >< v2 = foldl' (|>) v1 v2
{-# INLINE (><) #-}

-- | /O(n)/. Map a function over the vector.
map :: (a -> b) -> Vector a -> Vector b
map _ Empty = Empty
map f (Root s offset h tree tail) = Root s offset h (mapTree tree) (fmap f tail)
  where
    mapTree (Internal v) = Internal (fmap mapTree v)
    mapTree (Leaf v) = Leaf (fmap f v)

-- | /O(n)/. Map a function that has access to the index of an element over the vector.
mapWithIndex :: (Int -> a -> b) -> Vector a -> Vector b
mapWithIndex f = snd . mapAccumL (\i x -> i `seq` (i + 1, f i x)) 0

-- | /O(n)/. Fold the values in the vector, using the given monoid.
foldMapWithIndex :: Monoid m => (Int -> a -> m) -> Vector a -> m
foldMapWithIndex f = foldrWithIndex (\i -> mappend . f i) mempty

-- | /O(n)/. Fold using the given left-associative function that has access to the index of an element.
foldlWithIndex :: (b -> Int -> a -> b) -> b -> Vector a -> b
foldlWithIndex f acc v = foldl (\g x i -> i `seq` f (g (i - 1)) i x) (const acc) v (length v - 1)

-- | /O(n)/. Fold using the given right-associative function that has access to the index of an element.
foldrWithIndex :: (Int -> a -> b -> b) -> b -> Vector a -> b
foldrWithIndex f acc v = foldr (\x g i -> i `seq` f i x (g (i + 1))) (const acc) v 0

-- | /O(n)/. A strict version of 'foldlWithIndex'.
-- Each application of the function is evaluated before using the result in the next application.
foldlWithIndex' :: (b -> Int -> a -> b) -> b -> Vector a -> b
foldlWithIndex' f acc v = foldrWithIndex f' id v acc
  where
    f' i x k z = k $! f z i x
{-# INLINE foldlWithIndex' #-}

-- | /O(n)/. A strict version of 'foldrWithIndex'.
-- Each application of the function is evaluated before using the result in the next application.
foldrWithIndex' :: (Int -> a -> b -> b) -> b -> Vector a -> b
foldrWithIndex' f acc v = foldlWithIndex f' id v acc
  where
    f' k i x z = k $! f i x z
{-# INLINE foldrWithIndex' #-}

-- | /O(n)/. Traverse the vector with a function that has access to the index of an element.
traverseWithIndex :: Applicative f => (Int -> a -> f b) -> Vector a -> f (Vector b)
traverseWithIndex f v = evalState (getCompose $ traverse (Compose . state . flip f') v) 0
  where
    f' i x = i `seq` (f i x, i + 1)

-- | /O(n)/. Pair each element in the vector with its index.
indexed :: Vector a -> Vector (Int, a)
indexed = mapWithIndex (,)
{-# INLINE indexed #-}

-- | /O(n)/. Takes two vectors and returns a vector of corresponding pairs.
zip :: Vector a -> Vector b -> Vector (a, b)
zip = zipWith (,)
{-# INLINE zip #-}

-- | /O(n)/. A generalized 'zip' zipping with a function.
zipWith :: (a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith f v1 v2
    | length v1 >= length v2 = snd $ mapAccumL f' (toList v1) v2
    | otherwise = zipWith (flip f) v2 v1
  where
    f' [] _ = error "unreachable"
    f' (x : xs) y = (xs, f x y)

-- | /O(n)/. Takes three vectors and returns a vector of corresponding triples.
zip3 :: Vector a -> Vector b -> Vector c -> Vector (a, b, c)
zip3 = zipWith3 (,,)
{-# INLINE zip3 #-}

-- | /O(n)/. A generalized 'zip3' zipping with a function.
zipWith3 :: (a -> b -> c -> d) -> Vector a -> Vector b -> Vector c -> Vector d
zipWith3 f v1 v2 v3 = zipWith ($) (zipWith f v1 v2) v3

-- | /O(n)/. Transforms a vector of pairs into a vector of first components and a vector of second components.
unzip :: Vector (a, b) -> (Vector a, Vector b)
unzip v = (map fst v, map snd v)
{-# INLINE unzip #-}

-- | /O(n)/. Takes a vector of triples and returns three vectors, analogous to 'unzip'.
unzip3 :: Vector (a, b, c) -> (Vector a, Vector b, Vector c)
unzip3 v = (map fst3 v, map snd3 v, map trd3 v)
  where
    fst3 (x, _, _) = x
    snd3 (_, y, _) = y
    trd3 (_, _, z) = z
{-# INLINE unzip3 #-}

-- | /O(n)/. Create a list of index-value pairs from the vector.
toIndexedList :: Vector a -> [(Int, a)]
toIndexedList = foldrWithIndex (curry (:)) []
{-# INLINE toIndexedList #-}