{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE UnliftedDatatypes #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-|
   This module defines 'IntToIntMap', a variant of 'Data.IntMap' in which the
   values are fixed to 'Int'.

   We make use of this structure in 'Data.Equality.Graph.ReprUnionFind' to
   improve performance by a constant factor
 -}
module Data.Equality.Utils.IntToIntMap
  ( IntToIntMap(Nil)
  , Key, Val
  , find, insert, (!)
  , unliftedFoldr
  ) where

import GHC.Exts
import Data.Bits

-- | A map of integers to integers
type IntToIntMap :: TYPE ('BoxedRep 'Unlifted)
data IntToIntMap = Bin Prefix Mask IntToIntMap IntToIntMap
                 | Tip InternalKey Val
                 | Nil -- ^ An empty 'IntToIntMap'. Ideally this would be defined as a function instead of an exported constructor, but it's currently not possible to have top-level bindings for unlifted datatypes

type Prefix      = Word#
type Mask        = Word#
type InternalKey = Word#

-- | Key type synonym in an 'IntToIntMap'
type Key         = Int#
-- | Value type synonym in an 'IntToIntMap'
type Val         = Int#

-- | \(O(\min(n,W))\). Find the value at a key.
-- Calls 'error' when the element can not be found.
(!) :: IntToIntMap -> Key -> Val
! :: IntToIntMap -> Key -> Key
(!) IntToIntMap
m Key
k = Key -> IntToIntMap -> Key
find Key
k IntToIntMap
m
{-# INLINE (!) #-}

-- | Find the 'Val' for a 'Key' in an 'IntToIntMap'
find :: Key -> IntToIntMap -> Val
find :: Key -> IntToIntMap -> Key
find (Key -> Word#
int2Word# -> Word#
k) = Word# -> IntToIntMap -> Key
find' Word#
k
{-# INLINE find #-}

-- | Insert a 'Val' at a 'Key' in an 'IntToIntMap'
insert :: Key -> Val -> IntToIntMap -> IntToIntMap
insert :: Key -> Key -> IntToIntMap -> IntToIntMap
insert Key
k = Word# -> Key -> IntToIntMap -> IntToIntMap
insert' (Key -> Word#
int2Word# Key
k)
{-# INLINE insert #-}

insert' :: InternalKey -> Val -> IntToIntMap -> IntToIntMap
insert' :: Word# -> Key -> IntToIntMap -> IntToIntMap
insert' Word#
k Key
x t :: IntToIntMap
t@(Bin Word#
p Word#
m IntToIntMap
l IntToIntMap
r)
  | Word# -> Word# -> Word# -> Bool
nomatch Word#
k Word#
p Word#
m = Word# -> IntToIntMap -> Word# -> IntToIntMap -> IntToIntMap
link Word#
k (Word# -> Key -> IntToIntMap
Tip Word#
k Key
x) Word#
p IntToIntMap
t
  | Word# -> Word# -> Bool
zero Word#
k Word#
m      = Word# -> Word# -> IntToIntMap -> IntToIntMap -> IntToIntMap
Bin Word#
p Word#
m (Word# -> Key -> IntToIntMap -> IntToIntMap
insert' Word#
k Key
x IntToIntMap
l) IntToIntMap
r
  | Bool
otherwise     = Word# -> Word# -> IntToIntMap -> IntToIntMap -> IntToIntMap
Bin Word#
p Word#
m IntToIntMap
l (Word# -> Key -> IntToIntMap -> IntToIntMap
insert' Word#
k Key
x IntToIntMap
r)
insert' Word#
k Key
x t :: IntToIntMap
t@(Tip Word#
ky Key
_)
  | Key -> Bool
isTrue# (Word#
k Word# -> Word# -> Key
`eqWord#` Word#
ky) = Word# -> Key -> IntToIntMap
Tip Word#
ky Key
x
  | Bool
otherwise                = Word# -> IntToIntMap -> Word# -> IntToIntMap -> IntToIntMap
link Word#
k (Word# -> Key -> IntToIntMap
Tip Word#
k Key
x) Word#
ky IntToIntMap
t
insert' Word#
k Key
x IntToIntMap
Nil = Word# -> Key -> IntToIntMap
Tip Word#
k Key
x

-- DANGEROUS NOTE:
-- Since this is the function that currently takes 10% of runtime, we want to
-- improve constant factors: we'll remove the comparison that checks that the
-- tip we found is the tip we are looking for. This is a very custom map,
-- we will assume the tip we find is ALWAYS the one we are looking for. This,
-- of course, will return wrong results instead of blow up if we use it
-- unexpectedly. Hopefully the testsuite will serve to warn us of this
--
-- Update: The speedup is not noticeable, so we don't do it, but I'll leave the comment here for now
find' :: InternalKey -> IntToIntMap -> Val
find' :: Word# -> IntToIntMap -> Key
find' Word#
k (Bin Word#
_p Word#
m IntToIntMap
l IntToIntMap
r)
  | Word# -> Word# -> Bool
zero Word#
k Word#
m  = Word# -> IntToIntMap -> Key
find' Word#
k IntToIntMap
l
  | Bool
otherwise = Word# -> IntToIntMap -> Key
find' Word#
k IntToIntMap
r
find' Word#
k (Tip Word#
kx Key
x) | Key -> Bool
isTrue# (Word#
k Word# -> Word# -> Key
`eqWord#` Word#
kx) = Key
x
find' Word#
_ IntToIntMap
_ = [Char] -> Key
forall a. HasCallStack => [Char] -> a
error ([Char]
"IntMap.!: key ___ is not an element of the map")

-- * Other stuff taken from IntMap

link :: Prefix -> IntToIntMap -> Prefix -> IntToIntMap -> IntToIntMap
link :: Word# -> IntToIntMap -> Word# -> IntToIntMap -> IntToIntMap
link Word#
p1 IntToIntMap
t1 Word#
p2 IntToIntMap
t2 = Word# -> Word# -> IntToIntMap -> IntToIntMap -> IntToIntMap
linkWithMask (Word# -> Word#
highestBitMask (Word#
p1 Word# -> Word# -> Word#
`xor#` Word#
p2)) Word#
p1 IntToIntMap
t1 {-p2-} IntToIntMap
t2
{-# INLINE link #-}

-- `linkWithMask` is useful when the `branchMask` has already been computed
linkWithMask :: Mask -> Prefix -> IntToIntMap -> IntToIntMap -> IntToIntMap
linkWithMask :: Word# -> Word# -> IntToIntMap -> IntToIntMap -> IntToIntMap
linkWithMask Word#
m Word#
p1 IntToIntMap
t1 IntToIntMap
t2
  | Word# -> Word# -> Bool
zero Word#
p1 Word#
m = Word# -> Word# -> IntToIntMap -> IntToIntMap -> IntToIntMap
Bin Word#
p Word#
m IntToIntMap
t1 IntToIntMap
t2
  | Bool
otherwise = Word# -> Word# -> IntToIntMap -> IntToIntMap -> IntToIntMap
Bin Word#
p Word#
m IntToIntMap
t2 IntToIntMap
t1
  where
    p :: Word#
p = Word# -> Word# -> Word#
maskW Word#
p1 Word#
m
{-# INLINE linkWithMask #-}


-- The highestBitMask implementation is based on
-- http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
-- which has been put in the public domain.

-- | Return a word where only the highest bit is set.
highestBitMask :: Word# -> Word#
highestBitMask :: Word# -> Word#
highestBitMask Word#
w =
  case Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word) of
    I# Key
wordSize -> Word# -> Key -> Word#
shiftL# (Key -> Word#
int2Word# Key
1#) (Key
wordSize Key -> Key -> Key
-# Key
1# Key -> Key -> Key
-# (Word# -> Key
word2Int# (Word# -> Word#
clz# Word#
w)))
{-# INLINE highestBitMask #-}

nomatch :: InternalKey -> Prefix -> Mask -> Bool
nomatch :: Word# -> Word# -> Word# -> Bool
nomatch Word#
i Word#
p Word#
m
  = Key -> Bool
isTrue# ((Word# -> Word# -> Word#
maskW Word#
i Word#
m) Word# -> Word# -> Key
`neWord#` Word#
p)
{-# INLINE nomatch #-}

-- | The prefix of key @i@ up to (but not including) the switching
-- bit @m@.
maskW :: Word# -> Word# -> Prefix
maskW :: Word# -> Word# -> Word#
maskW Word#
i Word#
m
  = (Word#
i Word# -> Word# -> Word#
`and#` ((Key -> Word#
int2Word# (Key -> Key
negateInt# (Word# -> Key
word2Int# Word#
m))) Word# -> Word# -> Word#
`xor#` Word#
m))
{-# INLINE maskW #-}

zero :: InternalKey -> Mask -> Bool
zero :: Word# -> Word# -> Bool
zero Word#
i Word#
m
  = Key -> Bool
isTrue# ((Word#
i Word# -> Word# -> Word#
`and#` Word#
m) Word# -> Word# -> Key
`eqWord#` (Key -> Word#
int2Word# Key
0#))
{-# INLINE zero #-}

-- | A 'foldr' in which the accumulator is unlifted
unliftedFoldr :: forall a {b :: TYPE ('BoxedRep 'Unlifted)} . (a -> b -> b) -> b -> [a] -> b 
unliftedFoldr :: forall a {b :: UnliftedType}. (a -> b -> b) -> b -> [a] -> b
unliftedFoldr a -> b -> b
k b
z = [a] -> b
go
  where
    go :: [a] -> b
go []     = b
z
    go (a
y:[a]
ys) = a
y a -> b -> b
`k` [a] -> b
go [a]
ys