module Data.Set.BKTree
(
BKTree
,Metric(..)
,null,empty
,fromList,singleton
,insert
,member,memberDistance
,delete
,union,unions
,elems,elemsDistance
,closest
#ifdef DEBUG
,runTests
#endif
)where
import qualified Data.IntMap as M
import qualified Data.List as L hiding (null)
import Prelude hiding (null)
#ifdef DEBUG
import qualified Prelude
import Test.QuickCheck
import Text.Printf
#endif
data BKTree a = Node a (M.IntMap (BKTree a))
| Empty
#ifdef DEBUG
deriving Show
#endif
class Eq a => Metric a where
distance :: a -> a -> Int
instance Metric Int where
distance i j = abs (i j)
instance Metric Integer where
distance i j = fromInteger (abs (i j))
instance Metric Char where
distance i j = abs (fromEnum i fromEnum j)
null :: BKTree a -> Bool
null (Empty) = True
null (Node _ _) = False
empty :: BKTree a
empty = Empty
singleton :: a -> BKTree a
singleton a = Node a M.empty
insert :: Metric a => a -> BKTree a -> BKTree a
insert a Empty = Node a M.empty
insert a (Node b map) = Node b map'
where map' = M.insertWith recurse d (Node a M.empty) map
d = distance a b
recurse _ tree = insert a tree
member :: Metric a => a -> BKTree a -> Bool
member a Empty = False
member a (Node b map)
| d == 0 = True
| otherwise = case M.lookup d map of
Nothing -> False
Just tree -> member a tree
where d = distance a b
memberDistance :: Metric a => Int -> a -> BKTree a -> Bool
memberDistance n a Empty = False
memberDistance n a (Node b map)
| d <= n = True
| otherwise = any (memberDistance n a) (M.elems subMap)
where d = distance a b
subMap = case M.split (dn1) map of
(_,mapRight) ->
case M.split (d+n+1) mapRight of
(mapCenter,_) -> mapCenter
delete :: Metric a => a -> BKTree a -> BKTree a
delete a Empty = Empty
delete a t@(Node b map)
| d == 0 = unions (M.elems map)
| otherwise = Node b (M.update (Just . delete a) d map)
where d = distance a b
elems :: BKTree a -> [a]
elems Empty = []
elems (Node a imap) = a : concatMap elems (M.elems imap)
elemsDistance :: Metric a => Int -> a -> BKTree a -> [a]
elemsDistance n a Empty = []
elemsDistance n a (Node b imap)
= (if d <= n then (b :) else id) $
concatMap (elemsDistance n a) (M.elems subMap)
where d = distance a b
subMap = case M.split (dn1) imap of
(_,mapRight) ->
case M.split (d+n+1) mapRight of
(mapCenter,_) -> mapCenter
fromList :: Metric a => [a] -> BKTree a
fromList [] = Empty
fromList (a:as) = Node a $
M.fromAscList $
map recurse $
L.groupBy ((==) `on` fst) $
L.sortBy (compare `on` fst) $
map mkDistance $
as
where mkDistance b = (distance a b,b)
recurse bs@((k,_):_) = (k,fromList (map snd bs))
unions :: Metric a => [BKTree a] -> BKTree a
unions [] = Empty
unions (Empty:ts) = unions ts
unions (Node piv pmap:ts) = Node piv $
M.fromAscList $
map recurse $
L.groupBy ((==) `on` fst) $
L.sortBy (compare `on` fst) $
(M.toList pmap ++) $
concatMap mkDistance $
ts
where mkDistance n@(Node a _) = [(distance piv a,n)]
mkDistance _ = []
recurse bs@((k,_):_) = (k,unions (map snd bs))
union :: Metric a => BKTree a -> BKTree a -> BKTree a
union t1 t2 = unions [t1,t2]
closest :: Metric a => a -> BKTree a -> Maybe (a,Int)
closest a Empty = Nothing
closest a tree@(Node b _) = Just (closeLoop a (b,distance a b) tree)
closeLoop a candidate Empty = candidate
closeLoop a candidate@(b,d) (Node x imap)
= L.foldl' (closeLoop a) newCand (M.elems subMap)
where newCand = if j >= d
then candidate
else (x,j)
j = distance a x
subMap = case M.split (dj1) imap of
(_,mapRight) ->
case M.split (d+j+1) mapRight of
(mapCenter,_) -> mapCenter
on rel f x y = rel (f x) (f y)
#ifdef DEBUG
sem tree = L.sort (elems tree)
trans f xs = sem (f (fromList xs))
prop_empty n = not (member (n::Int) empty)
prop_null xs = null (fromList xs) == Prelude.null (xs :: [Int])
prop_singleton n = elems (fromList [n]) == [n :: Int]
prop_insert n xs =
trans (insert (n::Int)) xs == L.sort (n:xs)
prop_member n xs = member n (fromList xs) == L.elem (n::Int) xs
prop_memberDistance dist n xs =
let d = dist `mod` 5
ref = L.any (\e -> distance n e <= d) xs
in collect ref $
memberDistance d n (fromList xs) ==
L.any (\e -> distance n e <= d) (xs :: [Int])
prop_delete n xs =
trans (delete n) xs ==
L.sort (removeFirst (xs :: [Int]))
where removeFirst [] = []
removeFirst (a:as) | a == n = as
| otherwise = a : removeFirst as
prop_elems xs = L.sort (elems (fromList xs)) == L.sort (xs::[Int])
prop_elemsDistance dist n xs =
let d = dist `mod` 5 in
L.sort (elemsDistance d n (fromList xs)) ==
L.sort (filter (\e -> distance n e <= d) (xs::[Int]))
prop_unions xss =
sem (unions (map fromList xss)) ==
L.sort (concat (xss::[[Int]]))
prop_union xs ys =
sem (union (fromList xs) (fromList ys)) ==
L.sort (xs ++ (ys::[Int]))
prop_closest n xs =
case (closest n (fromList xs),xs) of
(Nothing,[]) -> True
(Just (_,d),ys) -> d == minimum (map (distance n) (ys::[Int]))
_ -> False
prop_insertDelete n xs =
trans (delete n . insert n) xs == L.sort (xs::[Int])
tests = [("empty", quickCheck prop_empty)
,("null", quickCheck prop_null)
,("singleton", quickCheck prop_singleton)
,("insert", quickCheck prop_insert)
,("member", quickCheck prop_member)
,("memberDistance", quickCheck prop_memberDistance)
,("delete", quickCheck prop_delete)
,("elems", quickCheck prop_elems)
,("elemsDistance", quickCheck prop_elemsDistance)
,("unions", quickCheck prop_unions)
,("union", quickCheck prop_union)
,("closest", quickCheck prop_closest)
,("insert/delete", quickCheck prop_insertDelete)
]
runTests = mapM_ (\ (s,a) -> printf "%-25s :" s >> a) tests
#endif