module Utility.HashJoin (
    nubById
  , nubByIdSinglePass
  , hashClusterIdNub
  , clusterByHash
  , hashJoin
  ) where

import Control.Monad ( forM_, void )
import Control.Monad.ST ( ST, runST )
import Data.Foldable ( foldrM )

import qualified Data.HashTable.ST.Cuckoo as HT

-------------------------------------
--- Hash join / clustering / nub
-------------------------------------


-- | PRECONDITION: (h x == h y) => x == y
nubById :: (a -> Int) -> [a] -> [a]
nubById :: (a -> Int) -> [a] -> [a]
nubById a -> Int
_ [a
x] = [a
x]
nubById a -> Int
h [a]
ls = (forall s. ST s [a]) -> [a]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [a]) -> [a]) -> (forall s. ST s [a]) -> [a]
forall a b. (a -> b) -> a -> b
$ do
    HashTable s Int a
ht <- Int -> ST s (HashTable s Int a)
forall s k v. Int -> ST s (HashTable s k v)
HT.newSized Int
101
    (a -> ST s ()) -> [a] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\a
x -> HashTable s Int a -> Int -> a -> ST s ()
forall k s v.
(Eq k, Hashable k) =>
HashTable s k v -> k -> v -> ST s ()
HT.insert HashTable s Int a
ht (a -> Int
h a
x) a
x) [a]
ls
    ([a] -> (Int, a) -> ST s [a])
-> [a] -> HashTable s Int a -> ST s [a]
forall a k v s.
(a -> (k, v) -> ST s a) -> a -> HashTable s k v -> ST s a
HT.foldM (\[a]
res (Int
_, a
v) -> [a] -> ST s [a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> ST s [a]) -> [a] -> ST s [a]
forall a b. (a -> b) -> a -> b
$ a
v a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
res) [] HashTable s Int a
ht

nubByIdSinglePass :: forall a. (a -> Int) -> [a] -> [a]
nubByIdSinglePass :: (a -> Int) -> [a] -> [a]
nubByIdSinglePass a -> Int
_ [a
x] = [a
x]
nubByIdSinglePass a -> Int
h [a]
ls = (forall s. ST s [a]) -> [a]
forall a. (forall s. ST s a) -> a
runST ([a] -> [a] -> HashTable s Int Bool -> ST s [a]
forall s. [a] -> [a] -> HashTable s Int Bool -> ST s [a]
go [a]
ls [] (HashTable s Int Bool -> ST s [a])
-> ST s (HashTable s Int Bool) -> ST s [a]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ST s (HashTable s Int Bool)
forall s k v. ST s (HashTable s k v)
HT.new)
  where
    go :: [a] -> [a] -> HT.HashTable s Int Bool -> ST s [a]
    go :: [a] -> [a] -> HashTable s Int Bool -> ST s [a]
go []     [a]
acc    HashTable s Int Bool
_  = [a] -> ST s [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a]
acc
    go (a
x:[a]
xs) [a]
acc HashTable s Int Bool
ht = do Bool
alreadyPresent <- HashTable s Int Bool
-> Int -> (Maybe Bool -> (Maybe Bool, Bool)) -> ST s Bool
forall k s v a.
(Eq k, Hashable k) =>
HashTable s k v -> k -> (Maybe v -> (Maybe v, a)) -> ST s a
HT.mutate HashTable s Int Bool
ht
                                                      (a -> Int
h a
x)
                                                      (\case Maybe Bool
Nothing -> (Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True, Bool
False)
                                                             Just Bool
_  -> (Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True, Bool
True))
                          if Bool
alreadyPresent then
                            [a] -> [a] -> HashTable s Int Bool -> ST s [a]
forall s. [a] -> [a] -> HashTable s Int Bool -> ST s [a]
go [a]
xs [a]
acc HashTable s Int Bool
ht
                          else
                            [a] -> [a] -> HashTable s Int Bool -> ST s [a]
forall s. [a] -> [a] -> HashTable s Int Bool -> ST s [a]
go [a]
xs (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc) HashTable s Int Bool
ht


maybeAddToHt :: v -> Maybe [v] -> (Maybe [v], ())
maybeAddToHt :: v -> Maybe [v] -> (Maybe [v], ())
maybeAddToHt v
v = \case Maybe [v]
Nothing -> ([v] -> Maybe [v]
forall a. a -> Maybe a
Just [v
v], ())
                       Just [v]
vs -> ([v] -> Maybe [v]
forall a. a -> Maybe a
Just (v
v v -> [v] -> [v]
forall a. a -> [a] -> [a]
: [v]
vs), ())

-- This is testing slower than running clusterByHash and nubByIdSinglePass separately. How?
hashClusterIdNub :: (a -> Int) -> (a -> Int) -> [a] -> [[a]]
hashClusterIdNub :: (a -> Int) -> (a -> Int) -> [a] -> [[a]]
hashClusterIdNub a -> Int
_ a -> Int
_ [a
x] = [[a
x]]
hashClusterIdNub a -> Int
hCluster a -> Int
hNub [a]
ls = (forall s. ST s [[a]]) -> [[a]]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [[a]]) -> [[a]])
-> (forall s. ST s [[a]]) -> [[a]]
forall a b. (a -> b) -> a -> b
$ do
    HashTable s Int [a]
clusters <- ST s (HashTable s Int [a])
forall s k v. ST s (HashTable s k v)
HT.new
    HashTable s Int Bool
seen <- ST s (HashTable s Int Bool)
forall s k v. ST s (HashTable s k v)
HT.new

    [a] -> (a -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [a]
ls ((a -> ST s ()) -> ST s ()) -> (a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \a
x -> do
      Bool
alreadyPresent <- HashTable s Int Bool
-> Int -> (Maybe Bool -> (Maybe Bool, Bool)) -> ST s Bool
forall k s v a.
(Eq k, Hashable k) =>
HashTable s k v -> k -> (Maybe v -> (Maybe v, a)) -> ST s a
HT.mutate HashTable s Int Bool
seen
                                  (a -> Int
hNub a
x)
                                  (\case Maybe Bool
Nothing -> (Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True, Bool
False)
                                         Just Bool
_  -> (Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True, Bool
True))
      if Bool
alreadyPresent then
        () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
       else do
        ST s () -> ST s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ HashTable s Int [a]
-> Int -> (Maybe [a] -> (Maybe [a], ())) -> ST s ()
forall k s v a.
(Eq k, Hashable k) =>
HashTable s k v -> k -> (Maybe v -> (Maybe v, a)) -> ST s a
HT.mutate HashTable s Int [a]
clusters (a -> Int
hCluster a
x) (a -> Maybe [a] -> (Maybe [a], ())
forall v. v -> Maybe [v] -> (Maybe [v], ())
maybeAddToHt a
x)

    ([[a]] -> (Int, [a]) -> ST s [[a]])
-> [[a]] -> HashTable s Int [a] -> ST s [[a]]
forall a k v s.
(a -> (k, v) -> ST s a) -> a -> HashTable s k v -> ST s a
HT.foldM (\[[a]]
res (Int
_, [a]
vs) -> [[a]] -> ST s [[a]]
forall (m :: * -> *) a. Monad m => a -> m a
return ([[a]] -> ST s [[a]]) -> [[a]] -> ST s [[a]]
forall a b. (a -> b) -> a -> b
$ [a]
vs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [[a]]
res) [] HashTable s Int [a]
clusters

clusterByHash :: (a -> Int) -> [a] -> [[a]]
clusterByHash :: (a -> Int) -> [a] -> [[a]]
clusterByHash a -> Int
h [a]
ls = (forall s. ST s [[a]]) -> [[a]]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [[a]]) -> [[a]])
-> (forall s. ST s [[a]]) -> [[a]]
forall a b. (a -> b) -> a -> b
$ do
    HashTable s Int [a]
ht <- ST s (HashTable s Int [a])
forall s k v. ST s (HashTable s k v)
HT.new
    (a -> ST s ()) -> [a] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\a
x -> HashTable s Int [a]
-> Int -> (Maybe [a] -> (Maybe [a], ())) -> ST s ()
forall k s v a.
(Eq k, Hashable k) =>
HashTable s k v -> k -> (Maybe v -> (Maybe v, a)) -> ST s a
HT.mutate HashTable s Int [a]
ht (a -> Int
h a
x) (a -> Maybe [a] -> (Maybe [a], ())
forall v. v -> Maybe [v] -> (Maybe [v], ())
maybeAddToHt a
x)) [a]
ls
    ([[a]] -> (Int, [a]) -> ST s [[a]])
-> [[a]] -> HashTable s Int [a] -> ST s [[a]]
forall a k v s.
(a -> (k, v) -> ST s a) -> a -> HashTable s k v -> ST s a
HT.foldM (\[[a]]
res (Int
_, [a]
vs) -> [[a]] -> ST s [[a]]
forall (m :: * -> *) a. Monad m => a -> m a
return ([[a]] -> ST s [[a]]) -> [[a]] -> ST s [[a]]
forall a b. (a -> b) -> a -> b
$ [a]
vs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [[a]]
res) [] HashTable s Int [a]
ht

hashJoin :: (a -> Int) -> (a -> a -> b) -> [a] -> [a] -> [b]
hashJoin :: (a -> Int) -> (a -> a -> b) -> [a] -> [a] -> [b]
hashJoin a -> Int
h a -> a -> b
j [a]
l1 [a]
l2 = (forall s. ST s [b]) -> [b]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [b]) -> [b]) -> (forall s. ST s [b]) -> [b]
forall a b. (a -> b) -> a -> b
$ do
    HashTable s Int [a]
ht2 <- ST s (HashTable s Int [a])
forall s k v. ST s (HashTable s k v)
HT.new
    (a -> ST s ()) -> [a] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\a
x -> HashTable s Int [a]
-> Int -> (Maybe [a] -> (Maybe [a], ())) -> ST s ()
forall k s v a.
(Eq k, Hashable k) =>
HashTable s k v -> k -> (Maybe v -> (Maybe v, a)) -> ST s a
HT.mutate HashTable s Int [a]
ht2 (a -> Int
h a
x) (a -> Maybe [a] -> (Maybe [a], ())
forall v. v -> Maybe [v] -> (Maybe [v], ())
maybeAddToHt a
x)) [a]
l2
    (a -> [b] -> ST s [b]) -> [b] -> [a] -> ST s [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> b -> m b) -> b -> t a -> m b
foldrM (\a
x [b]
res -> do Maybe [a]
maybeCluster <- HashTable s Int [a] -> Int -> ST s (Maybe [a])
forall k s v.
(Eq k, Hashable k) =>
HashTable s k v -> k -> ST s (Maybe v)
HT.lookup HashTable s Int [a]
ht2 (a -> Int
h a
x)
                         case Maybe [a]
maybeCluster of
                           Maybe [a]
Nothing  -> [b] -> ST s [b]
forall (m :: * -> *) a. Monad m => a -> m a
return [b]
res
                           Just [a]
vs2 -> [b] -> ST s [b]
forall (m :: * -> *) a. Monad m => a -> m a
return ([b] -> ST s [b]) -> [b] -> ST s [b]
forall a b. (a -> b) -> a -> b
$ [a -> a -> b
j a
x a
v2 | a
v2 <- [a]
vs2] [b] -> [b] -> [b]
forall a. [a] -> [a] -> [a]
++ [b]
res )
           []
           [a]
l1