module Data.Set.BKTree
(
BKTree
,Metric(..)
,null,size,empty
,fromList,singleton
,insert
,member,memberDistance
,delete
,union,unions
,elems,elemsDistance
,closest
#ifdef DEBUG
,runTests
#endif
)where
import qualified Data.IntMap as M
import qualified Data.List as L hiding (null)
import Prelude hiding (null)
import Data.Array.IArray (Array,array,listArray,(!),assocs)
import Data.Array.Unboxed (UArray)
#ifdef DEBUG
import qualified Prelude
import Test.QuickCheck
import Text.Printf
import System.Exit
#endif
class Eq a => Metric a where
distance :: a -> a -> Int
instance Metric Int where
distance i j = abs (i j)
instance Metric Integer where
distance i j = fromInteger (abs (i j))
instance Metric Char where
distance i j = abs (fromEnum i fromEnum j)
hirschberg :: Eq a => [a] -> [a] -> Int
hirschberg xs [] = length xs
hirschberg xs ys = let
lxs = length xs
lys = length ys
start_arr :: UArray Int Int
start_arr = listArray (1,lys) [1..lys]
in (L.foldl' (\arr (i,xi) -> let
narr :: UArray Int Int
narr = array (1,lys) (snd $ L.mapAccumL
(\(s,c) ((j,el),yj) -> let
nc = minimum
[s + (if xi==yj then 0 else 1)
,el + 1
,c + 1
]
in ((el,nc),(j,nc)))
(i1,i)
(zip (assocs arr) ys)
)
in narr
) start_arr (zip [1..] xs))!lys
instance Eq a => Metric [a] where
distance = hirschberg
data BKTree a = Node a !Int (M.IntMap (BKTree a))
| Empty
#ifdef DEBUG
deriving Show
#endif
null :: BKTree a -> Bool
null (Empty) = True
null (Node _ _ _) = False
size :: BKTree a -> Int
size (Empty) = 0
size (Node _ s _) = s
empty :: BKTree a
empty = Empty
singleton :: a -> BKTree a
singleton a = Node a 1 M.empty
insert :: Metric a => a -> BKTree a -> BKTree a
insert a Empty = Node a 1 M.empty
insert a (Node b size map) = Node b (size+1) map'
where map' = M.insertWith recurse d (Node a 1 M.empty) map
d = distance a b
recurse _ tree = insert a tree
member :: Metric a => a -> BKTree a -> Bool
member a Empty = False
member a (Node b _ map)
| d == 0 = True
| otherwise = case M.lookup d map of
Nothing -> False
Just tree -> member a tree
where d = distance a b
memberDistance :: Metric a => Int -> a -> BKTree a -> Bool
memberDistance n a Empty = False
memberDistance n a (Node b _ map)
| d <= n = True
| otherwise = any (memberDistance n a) (M.elems subMap)
where d = distance a b
subMap = case M.split (dn1) map of
(_,mapRight) ->
case M.split (d+n+1) mapRight of
(mapCenter,_) -> mapCenter
delete :: Metric a => a -> BKTree a -> BKTree a
delete a Empty = Empty
delete a t@(Node b _ map)
| d == 0 = unions (M.elems map)
| otherwise = Node b sz subtrees
where d = distance a b
subtrees = M.update (Just . delete a) d map
sz = sum (L.map size (M.elems subtrees)) + 1
elems :: BKTree a -> [a]
elems Empty = []
elems (Node a _ imap) = a : concatMap elems (M.elems imap)
elemsDistance :: Metric a => Int -> a -> BKTree a -> [a]
elemsDistance n a Empty = []
elemsDistance n a (Node b _ imap)
= (if d <= n then (b :) else id) $
concatMap (elemsDistance n a) (M.elems subMap)
where d = distance a b
subMap = case M.split (dn1) imap of
(_,mapRight) ->
case M.split (d+n+1) mapRight of
(mapCenter,_) -> mapCenter
fromList :: Metric a => [a] -> BKTree a
fromList xs = constructTree (\a -> Just (a,[])) xs
unions :: Metric a => [BKTree a] -> BKTree a
unions xs = constructTree split xs
where split Empty = Nothing
split (Node a _ imap) = Just (a,M.elems imap)
constructTree extract [] = Empty
constructTree extract (a:as)
= case extract a of
Nothing -> constructTree extract as
Just (piv,rest) ->
(\imap -> Node piv (1 + sum (map size (M.elems imap))) imap) $
M.fromAscList $
map recurse $
L.groupBy ((==) `on` fst) $
L.sortBy (compare `on` fst) $
concatMap (mkDist piv) $
as ++ rest
where mkDist piv m = case extract m of
Just (a,_) -> [(distance piv a,m)]
Nothing -> []
recurse bs@((k,_):_) = (k, constructTree extract (map snd bs))
union :: Metric a => BKTree a -> BKTree a -> BKTree a
union t1 t2 = unions [t1,t2]
closest :: Metric a => a -> BKTree a -> Maybe (a,Int)
closest a Empty = Nothing
closest a tree@(Node b _ _) = Just (closeLoop a (b,distance a b) tree)
closeLoop a candidate Empty = candidate
closeLoop a candidate@(b,d) (Node x _ imap)
= L.foldl' (closeLoop a) newCand (M.elems subMap)
where newCand = if j >= d
then candidate
else (x,j)
j = distance a x
subMap = case M.split (dj1) imap of
(_,mapRight) ->
case M.split (d+j+1) mapRight of
(mapCenter,_) -> mapCenter
on rel f x y = rel (f x) (f y)
#ifdef DEBUG
levenshtein :: Eq a => [a] -> [a] -> Int
levenshtein xs ys = let
lxs = length xs
lys = length ys
d x y cx cy = minimum
[dist!(x1,y1) + (if cx == cy then 0 else 1)
,dist!(x1,y) + 1
,dist!(x,y1) + 1
]
dist :: Array (Int,Int) Int
dist = array ((0,0),(lxs,lys))
( [((0,0),0)]
++ [((x,0),x) | x <- [1..lxs]]
++ [((0,y),y) | y <- [1..lys]]
++ [ ((x,y),d x y cx cy)
| (x,cx) <- zip [1..] xs
, (y,cy) <- zip [1..] ys])
in dist!(lxs,lys)
prop_levenshtein xs ys = distance xs ys == levenshtein xs (ys :: [Int])
prop_levenshteinRepeat (NonZero (NonNegative n)) (NonZero (NonNegative m)) =
distance (replicate n (0::Int)) (replicate m 0) == distance n m
prop_levenshteinLength xs =
forAll (vectorOf (length xs) arbitrary) $ \ys ->
distance xs ys == length xs && allDifferent xs ys
|| distance xs ys < length (xs :: [Int])
where allDifferent xs ys = all (==False) (zipWith (==) xs ys)
sem tree = L.sort (elems tree) :: [Int]
trans f xs = sem (f (fromList xs))
prop_empty n = not (member (n::Int) empty)
prop_null xs = null (fromList xs) == Prelude.null (xs :: [Int])
prop_singleton n = elems (fromList [n]) == [n :: Int]
prop_fromList xs = sem (fromList xs) == L.sort xs
prop_insert n xs =
trans (insert n) xs == L.sort (n:xs)
prop_member n xs = member n (fromList xs) == L.elem (n::Int) xs
prop_memberDistance dist n xs =
let d = dist `mod` 5
ref = L.any (\e -> distance n e <= d) xs
in collect ref $
memberDistance d n (fromList xs) ==
L.any (\e -> distance n e <= d) (xs :: [Int])
prop_delete n xs =
trans (delete n) xs ==
L.sort (removeFirst (xs :: [Int]))
where removeFirst [] = []
removeFirst (a:as) | a == n = as
| otherwise = a : removeFirst as
prop_elems xs = L.sort (elems (fromList xs)) == L.sort (xs::[Int])
prop_elemsDistance dist n xs =
let d = dist `mod` 5 in
L.sort (elemsDistance d n (fromList xs)) ==
L.sort (filter (\e -> distance n e <= d) (xs::[Int]))
prop_unions xss =
sem (unions (map fromList xss)) ==
L.sort (concat (xss::[[Int]]))
prop_union xs ys =
sem (union (fromList xs) (fromList ys)) ==
L.sort (xs ++ (ys::[Int]))
prop_closest n xs =
case (closest n (fromList xs),xs) of
(Nothing,[]) -> True
(Just (_,d),ys) -> d == minimum (map (distance n) (ys::[Int]))
_ -> False
prop_insertDelete n xs =
trans (delete n . insert n) xs == L.sort (xs::[Int])
prop_sizeEmpty = size empty == 0
prop_sizeFromList xs = size (fromList xs) == length (xs :: [Int])
prop_sizeSucc n xs = size (insert (n::Int) tree) == size tree + 1
where tree = fromList xs
prop_sizeDelete n xs
= size (delete (n::Int) tree) ==
size tree (if n `member` tree then 1 else 0)
where tree = fromList xs
prop_sizeUnion xs ys = size (union treeXs treeYs) == size treeXs + size treeYs
where (treeXs,treeYs) = (fromList xs, fromList (ys :: [Int]))
prop_sizeUnions xss = size (unions trees) == sum (map size trees)
where trees = map fromList (xss :: [[Int]])
tests = [("empty", quickCheck' prop_empty)
,("null", quickCheck' prop_null)
,("singleton", quickCheck' prop_singleton)
,("fromList", quickCheck' prop_fromList)
,("insert", quickCheck' prop_insert)
,("member", quickCheck' prop_member)
,("memberDistance", quickCheck' prop_memberDistance)
,("delete", quickCheck' prop_delete)
,("elems", quickCheck' prop_elems)
,("elemsDistance", quickCheck' prop_elemsDistance)
,("unions", quickCheck' prop_unions)
,("union", quickCheck' prop_union)
,("closest", quickCheck' prop_closest)
,("size/empty", quickCheck' prop_sizeEmpty)
,("size/fromList", quickCheck' prop_sizeFromList)
,("size/succ", quickCheck' prop_sizeSucc)
,("size/delete", quickCheck' prop_sizeDelete)
,("size/union", quickCheck' prop_sizeUnion)
,("size/unions", quickCheck' prop_sizeUnions)
,("insert/delete", quickCheck' prop_insertDelete)
,("levenshtein", quickCheck' prop_levenshtein)
,("levenshtein repeat",quickCheck' prop_levenshteinRepeat)
,("levenshtein length",quickCheck' prop_levenshteinLength)
]
runTests = mapM_ runTest tests
where runTest (s,a) = do printf "%-25s :" s
b <- a
if b
then return ()
else exitFailure
#endif