{-# LANGUAGE CPP #-}
#ifdef __GLASGOW_HASKELL__
{-# LANGUAGE TypeFamilies #-}
#endif
module Data.AMT
( Vector
, empty, singleton, fromList
, fromFunction
, replicate, replicateA
, unfoldr, unfoldl, iterateN
, (<|), (|>), (><)
, viewl
, viewr
, last
, take
, lookup, index
, (!?), (!)
, update
, adjust
, map, mapWithIndex
, traverseWithIndex
, indexed
, foldMapWithIndex
, foldlWithIndex, foldrWithIndex
, foldlWithIndex', foldrWithIndex'
, zip, zipWith
, zip3, zipWith3
, unzip, unzip3
, toIndexedList
) where
import Control.Applicative (Alternative)
import qualified Control.Applicative as Applicative
import Control.Monad (MonadPlus(..))
#if !(MIN_VERSION_base(4,13,0))
import Control.Monad.Fail (MonadFail(..))
#endif
import Control.Monad.Zip (MonadZip(..))
import Data.Bits
import Data.Foldable (foldl', toList)
import Data.Functor.Classes
import Data.Functor.Compose
import Data.Functor.Identity
import Data.List.NonEmpty (NonEmpty(..), (!!))
import qualified Data.List.NonEmpty as L
import Data.Maybe (fromMaybe)
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup (Semigroup((<>)))
#endif
#ifdef __GLASGOW_HASKELL__
import Data.String (IsString)
#endif
import Data.Traversable (mapAccumL)
#ifdef __GLASGOW_HASKELL__
import GHC.Exts (IsList)
import qualified GHC.Exts as Exts
#endif
import Prelude hiding ((!!), last, lookup, map, replicate, tail, take, unzip, unzip3, zip, zipWith, zip3, zipWith3)
import qualified Prelude as P
import Text.Read (Lexeme(Ident), lexP, parens, prec, readPrec)
import Control.Monad.Trans.State.Strict (state, evalState)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as M
infixr 5 ><
infixr 5 <|
infixl 5 |>
data Tree a
= Internal !(V.Vector (Tree a))
| Leaf !(V.Vector a)
data Vector a
= Empty
| Root
{-# UNPACK #-} !Int
{-# UNPACK #-} !Int
{-# UNPACK #-} !Int
!(Tree a)
!(NonEmpty a)
errorNegativeLength :: String -> a
errorNegativeLength s = error $ "AMT." ++ s ++ ": expected a nonnegative length"
bits :: Int
bits = 4
{-# INLINE bits #-}
tailSize :: Int
tailSize = 1 `shiftL` bits
mask :: Int
mask = tailSize - 1
instance Show1 Vector where
liftShowsPrec sp sl p v = showsUnaryWith (liftShowsPrec sp sl) "fromList" p (toList v)
instance Show a => Show (Vector a) where
showsPrec = showsPrec1
{-# INLINE showsPrec #-}
instance Read1 Vector where
liftReadsPrec rp rl = readsData $ readsUnaryWith (liftReadsPrec rp rl) "fromList" fromList
instance Read a => Read (Vector a) where
#ifdef __GLASGOW_HASKELL__
readPrec = parens $ prec 10 $ do
Ident "fromList" <- lexP
xs <- readPrec
pure (fromList xs)
#else
readsPrec = readsPrec1
{-# INLINE readsPrec #-}
#endif
instance Eq1 Vector where
liftEq f v1 v2 = length v1 == length v2 && liftEq f (toList v1) (toList v2)
instance Eq a => Eq (Vector a) where
(==) = eq1
{-# INLINE (==) #-}
instance Ord1 Vector where
liftCompare f v1 v2 = liftCompare f (toList v1) (toList v2)
instance Ord a => Ord (Vector a) where
compare = compare1
{-# INLINE compare #-}
instance Semigroup (Vector a) where
(<>) = (><)
{-# INLINE (<>) #-}
instance Monoid (Vector a) where
mempty = empty
{-# INLINE mempty #-}
mappend = (<>)
{-# INLINE mappend #-}
instance Foldable Vector where
foldr _ acc Empty = acc
foldr f acc (Root _ _ _ tree tail) = foldrTree tree (foldr f acc (L.reverse tail))
where
foldrTree (Internal v) acc' = foldr foldrTree acc' v
foldrTree (Leaf v) acc' = foldr f acc' v
null Empty = True
null Root{} = False
{-# INLINE null #-}
length Empty = 0
length (Root s _ _ _ _) = s
{-# INLINE length #-}
instance Functor Vector where
fmap = map
{-# INLINE fmap #-}
instance Traversable Vector where
traverse _ Empty = pure Empty
traverse f (Root s offset h tree tail) =
Root s offset h <$> traverseTree tree <*> (L.reverse <$> traverse f (L.reverse tail))
where
traverseTree (Internal v) = Internal <$> traverse traverseTree v
traverseTree (Leaf v) = Leaf <$> traverse f v
#ifdef __GLASGOW_HASKELL__
instance IsList (Vector a) where
type Item (Vector a) = a
fromList = fromList
{-# INLINE fromList #-}
toList = toList
{-# INLINE toList #-}
instance a ~ Char => IsString (Vector a) where
fromString = fromList
{-# INLINE fromString #-}
#endif
instance Applicative Vector where
pure = singleton
{-# INLINE pure #-}
fs <*> xs = foldl' (\acc f -> acc >< map f xs) empty fs
instance Monad Vector where
xs >>= f = foldl' (\acc x -> acc >< f x) empty xs
instance Alternative Vector where
empty = empty
{-# INLINE empty #-}
(<|>) = (><)
{-# INLINE (<|>) #-}
instance MonadPlus Vector
instance MonadFail Vector where
fail _ = empty
{-# INLINE fail #-}
instance MonadZip Vector where
mzip = zip
{-# INLINE mzip #-}
mzipWith = zipWith
{-# INLINE mzipWith #-}
munzip = unzip
{-# INLINE munzip #-}
empty :: Vector a
empty = Empty
{-# INLINE empty #-}
singleton :: a -> Vector a
singleton x = Root 1 0 0 (Leaf V.empty) (x :| [])
{-# INLINE singleton #-}
fromList :: [a] -> Vector a
fromList = foldl' (|>) empty
{-# INLINE fromList #-}
fromFunction :: Int -> (Int -> a) -> Vector a
fromFunction n f = if n < 0 then errorNegativeLength "fromFunction" else go 0 empty
where
go i acc
| i < n = go (i + 1) (acc |> f i)
| otherwise = acc
{-# INLINE fromFunction #-}
replicate :: Int -> a -> Vector a
replicate n = if n < 0 then errorNegativeLength "replicate" else runIdentity . replicateA n . Identity
{-# INLINE replicate #-}
replicateA :: Applicative f => Int -> f a -> f (Vector a)
replicateA n x = if n < 0 then errorNegativeLength "replicateA" else go 0 (pure empty)
where
go i acc
| i < n = go (i + 1) ((|>) <$> acc <*> x)
| otherwise = acc
{-# INLINE replicateA #-}
unfoldr :: (b -> Maybe (a, b)) -> b -> Vector a
unfoldr f = go empty
where
go v acc = case f acc of
Nothing -> v
Just (x, acc') -> go (v |> x) acc'
{-# INLINE unfoldr #-}
unfoldl :: (b -> Maybe (b, a)) -> b -> Vector a
unfoldl f = go
where
go acc = case f acc of
Nothing -> empty
Just (acc', x) -> go acc' |> x
{-# INLINE unfoldl #-}
iterateN :: Int -> (a -> a) -> a -> Vector a
iterateN n f x = if n < 0 then errorNegativeLength "iterateN" else replicateA n (state (\y -> (y, f y))) `evalState` x
{-# INLINE iterateN #-}
(<|) :: a -> Vector a -> Vector a
x <| v = fromList $ x : toList v
viewl :: Vector a -> Maybe (a, Vector a)
viewl Empty = Nothing
viewl v@Root{} =
let ls = toList v
in Just (head ls, fromList $ P.tail ls)
(|>) :: Vector a -> a -> Vector a
Empty |> x = singleton x
Root s offset h tree tail |> x
| s .&. mask /= 0 = Root (s + 1) offset h tree (x L.<| tail)
| offset == 0 = Root (s + 1) s (h + 1) (Leaf $ V.fromList (toList $ L.reverse tail)) (x :| [])
| offset == 1 `shiftL` (bits * h) = Root (s + 1) s (h + 1) (Internal $ V.fromList [tree, newPath h]) (x :| [])
| otherwise = Root (s + 1) s h (insertTail (bits * (h - 1)) tree) (x :| [])
where
newPath 1 = Leaf $ V.fromList (toList $ L.reverse tail)
newPath h = Internal $ V.singleton (newPath (h - 1))
insertTail sh (Internal v)
| index < V.length v = Internal $ V.modify (\v -> M.modify v (insertTail (sh - bits)) index) v
| otherwise = Internal $ V.snoc v (newPath (sh `div` bits))
where
index = offset `shiftR` sh .&. mask
insertTail _ (Leaf _) = Leaf $ V.fromList (toList $ L.reverse tail)
viewr :: Vector a -> Maybe (Vector a, a)
viewr Empty = Nothing
viewr (Root s offset h tree (x :| tail))
| not (null tail) = Just (Root (s - 1) offset h tree (L.fromList tail), x)
| s == 1 = Just (Empty, x)
| s == tailSize + 1 = Just (Root (s - 1) 0 0 (Leaf V.empty) (getTail tree), x)
| otherwise =
let sh = bits * (h - 1)
in Just (normalize $ Root (s - 1) (offset - tailSize) h (unsnocTree sh tree) (getTail tree), x)
where
index' = offset - tailSize - 1
unsnocTree sh (Internal v) =
let subIndex = index' `shiftR` sh .&. mask
new = V.take (subIndex + 1) v
in Internal $ V.modify (\v -> M.modify v (unsnocTree (sh - bits)) subIndex) new
unsnocTree _ (Leaf v) = Leaf v
getTail (Internal v) = getTail (V.last v)
getTail (Leaf v) = L.fromList . reverse $ toList v
normalize (Root s offset h (Internal v) tail)
| length v == 1 = Root s offset (h - 1) (v V.! 0) tail
normalize v = v
last :: Vector a -> Maybe a
last Empty = Nothing
last (Root _ _ _ _ (x :| _)) = Just x
{-# INLINE last #-}
take :: Int -> Vector a -> Vector a
take _ Empty = Empty
take n root@(Root s offset h tree tail)
| n <= 0 = Empty
| n >= s = root
| n > offset = Root n offset h tree (L.fromList $ L.drop (s - n) tail)
| n <= tailSize = Root n 0 0 (Leaf V.empty) (getTail (bits * (h - 1)) tree)
| otherwise =
let sh = bits * (h - 1)
in normalize $ Root n ((n - 1) .&. complement mask) h (takeTree sh tree) (getTail sh tree)
where
index = n - 1
index' = index - tailSize
takeTree sh (Internal v) =
let subIndex = index' `shiftR` sh .&. mask
new = V.take (subIndex + 1) v
in Internal $ V.modify (\v -> M.modify v (takeTree (sh - bits)) subIndex) new
takeTree _ (Leaf v) = Leaf v
getTail sh (Internal v) = getTail (sh - bits) (v V.! (index `shiftR` sh .&. mask))
getTail _ (Leaf v) = L.fromList . reverse . P.take (index .&. mask + 1) $ toList v
normalize (Root s offset h (Internal v) tail)
| length v == 1 = normalize $ Root s offset (h - 1) (v V.! 0) tail
normalize v = v
lookup :: Int -> Vector a -> Maybe a
lookup _ Empty = Nothing
lookup i (Root s offset h tree tail)
| i < 0 || i >= s = Nothing
| i < offset = Just $ lookupTree (bits * (h - 1)) tree
| otherwise = Just $ tail !! (s - i - 1)
where
lookupTree sh (Internal v) = lookupTree (sh - bits) (v V.! (i `shiftR` sh .&. mask))
lookupTree _ (Leaf v) = v V.! (i .&. mask)
index :: Int -> Vector a -> a
index i = fromMaybe (error "AMT.index: index out of range") . lookup i
(!?) :: Vector a -> Int -> Maybe a
(!?) = flip lookup
{-# INLINE (!?) #-}
(!) :: Vector a -> Int -> a
(!) = flip index
{-# INLINE (!) #-}
update :: Int -> a -> Vector a -> Vector a
update i x = adjust i (const x)
{-# INLINE update #-}
adjust :: Int -> (a -> a) -> Vector a -> Vector a
adjust _ _ Empty = Empty
adjust i f root@(Root s offset h tree tail)
| i < 0 || i >= s = root
| i < offset = Root s offset h (adjustTree (bits * (h - 1)) tree) tail
| otherwise = let (l, x : r) = L.splitAt (s - i - 1) tail in Root s offset h tree (L.fromList $ l ++ (f x : r))
where
adjustTree sh (Internal v) =
let index = i `shiftR` sh .&. mask
in Internal $ V.modify (\v -> M.modify v (adjustTree (sh - bits)) index) v
adjustTree _ (Leaf v) =
let index = i .&. mask
in Leaf $ V.modify (\v -> M.modify v f index) v
(><) :: Vector a -> Vector a -> Vector a
Empty >< v = v
v >< Empty = v
v1 >< v2 = foldl' (|>) v1 v2
{-# INLINE (><) #-}
map :: (a -> b) -> Vector a -> Vector b
map _ Empty = Empty
map f (Root s offset h tree tail) = Root s offset h (mapTree tree) (fmap f tail)
where
mapTree (Internal v) = Internal (fmap mapTree v)
mapTree (Leaf v) = Leaf (fmap f v)
mapWithIndex :: (Int -> a -> b) -> Vector a -> Vector b
mapWithIndex f = snd . mapAccumL (\i x -> i `seq` (i + 1, f i x)) 0
foldMapWithIndex :: Monoid m => (Int -> a -> m) -> Vector a -> m
foldMapWithIndex f = foldrWithIndex (\i -> mappend . f i) mempty
foldlWithIndex :: (b -> Int -> a -> b) -> b -> Vector a -> b
foldlWithIndex f acc v = foldl (\g x i -> i `seq` f (g (i - 1)) i x) (const acc) v (length v - 1)
foldrWithIndex :: (Int -> a -> b -> b) -> b -> Vector a -> b
foldrWithIndex f acc v = foldr (\x g i -> i `seq` f i x (g (i + 1))) (const acc) v 0
foldlWithIndex' :: (b -> Int -> a -> b) -> b -> Vector a -> b
foldlWithIndex' f acc v = foldrWithIndex f' id v acc
where
f' i x k z = k $! f z i x
{-# INLINE foldlWithIndex' #-}
foldrWithIndex' :: (Int -> a -> b -> b) -> b -> Vector a -> b
foldrWithIndex' f acc v = foldlWithIndex f' id v acc
where
f' k i x z = k $! f i x z
{-# INLINE foldrWithIndex' #-}
traverseWithIndex :: Applicative f => (Int -> a -> f b) -> Vector a -> f (Vector b)
traverseWithIndex f v = evalState (getCompose $ traverse (Compose . state . flip f') v) 0
where
f' i x = i `seq` (f i x, i + 1)
indexed :: Vector a -> Vector (Int, a)
indexed = mapWithIndex (,)
{-# INLINE indexed #-}
zip :: Vector a -> Vector b -> Vector (a, b)
zip = zipWith (,)
{-# INLINE zip #-}
zipWith :: (a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith f v1 v2
| length v1 >= length v2 = snd $ mapAccumL f' (toList v1) v2
| otherwise = zipWith (flip f) v2 v1
where
f' [] _ = error "unreachable"
f' (x : xs) y = (xs, f x y)
zip3 :: Vector a -> Vector b -> Vector c -> Vector (a, b, c)
zip3 = zipWith3 (,,)
{-# INLINE zip3 #-}
zipWith3 :: (a -> b -> c -> d) -> Vector a -> Vector b -> Vector c -> Vector d
zipWith3 f v1 v2 v3 = zipWith ($) (zipWith f v1 v2) v3
unzip :: Vector (a, b) -> (Vector a, Vector b)
unzip v = (map fst v, map snd v)
{-# INLINE unzip #-}
unzip3 :: Vector (a, b, c) -> (Vector a, Vector b, Vector c)
unzip3 v = (map fst3 v, map snd3 v, map trd3 v)
where
fst3 (x, _, _) = x
snd3 (_, y, _) = y
trd3 (_, _, z) = z
{-# INLINE unzip3 #-}
toIndexedList :: Vector a -> [(Int, a)]
toIndexedList = foldrWithIndex (curry (:)) []
{-# INLINE toIndexedList #-}