module Data.KdMap.Static
(
PointAsListFn
, SquaredDistanceFn
, KdMap
, buildKdMap
, buildKdMapWithDistFn
, nearestNeighbor
, pointsInRadius
, kNearestNeighbors
, pointsInRange
, assocs
, points
, values
, size
, foldrKdMap
, defaultDistSqrFn
, runTests
) where
import Control.DeepSeq
import Control.DeepSeq.Generics (genericRnf)
import GHC.Generics
import Control.Applicative
import Data.Foldable
import Data.Function
import qualified Data.List as L
import Data.Maybe
import Data.Ord
import qualified Data.PQueue.Prio.Max as Q
import Data.Traversable
import Test.QuickCheck
import Data.Point2d
data TreeNode a p v = TreeNode { _treeLeft :: TreeNode a p v
, _treePoint :: (p, v)
, _axisValue :: a
, _treeRight :: TreeNode a p v
} |
Empty
deriving Generic
instance (NFData a, NFData p, NFData v) => NFData (TreeNode a p v) where rnf = genericRnf
mapTreeNode :: (v1 -> v2) -> TreeNode a p v1 -> TreeNode a p v2
mapTreeNode _ Empty = Empty
mapTreeNode f (TreeNode left (k, v) axisValue right) =
TreeNode (mapTreeNode f left) (k, f v) axisValue (mapTreeNode f right)
type PointAsListFn a p = p -> [a]
type SquaredDistanceFn a p = p -> p -> a
data KdMap a p v = KdMap { _pointAsList :: PointAsListFn a p
, _distSqr :: SquaredDistanceFn a p
, _rootNode :: TreeNode a p v
, _size :: Int
} deriving Generic
instance (NFData a, NFData p, NFData v) => NFData (KdMap a p v) where rnf = genericRnf
instance Functor (KdMap a p) where
fmap f kdMap = kdMap { _rootNode = mapTreeNode f (_rootNode kdMap) }
foldrTreeNode :: ((p, v) -> b -> b) -> b -> TreeNode a p v -> b
foldrTreeNode _ z Empty = z
foldrTreeNode f z (TreeNode left p _ right) =
foldrTreeNode f (f p (foldrTreeNode f z right)) left
foldrKdMap :: ((p, v) -> b -> b) -> b -> KdMap a p v -> b
foldrKdMap f z (KdMap _ _ r _) = foldrTreeNode f z r
instance Foldable (KdMap a p) where
foldr f = foldrKdMap (f . snd)
traverseTreeNode :: Applicative f => (b -> f c) -> TreeNode a p b -> f (TreeNode a p c)
traverseTreeNode _ Empty = pure Empty
traverseTreeNode f (TreeNode l (p, v) axisValue r) =
TreeNode <$>
traverseTreeNode f l <*>
((,) p <$> f v) <*>
pure axisValue <*>
traverseTreeNode f r
instance Traversable (KdMap a p) where
traverse f (KdMap p d r n) =
KdMap <$> pure p <*> pure d <*> traverseTreeNode f r <*> pure n
quickselect :: (b -> b -> Ordering) -> Int -> [b] -> b
quickselect cmp = go
where go _ [] = error "quickselect must be called on a non-empty list."
go k (x:xs) | k < l = go k ys
| k > l = go (k l 1) zs
| otherwise = x
where (ys, zs) = L.partition ((== LT) . (`cmp` x)) xs
l = length ys
buildKdMapWithDistFn :: Real a => PointAsListFn a p ->
SquaredDistanceFn a p ->
[(p, v)] ->
KdMap a p v
buildKdMapWithDistFn _ _ [] = error "KdMap must be built with a non-empty list."
buildKdMapWithDistFn pointAsList distSqr dataPoints =
let axisValsPointsPairs = zip (map (cycle . pointAsList . fst) dataPoints) dataPoints
in KdMap { _pointAsList = pointAsList
, _distSqr = distSqr
, _rootNode = buildTreeInternal axisValsPointsPairs
, _size = length dataPoints
}
where buildTreeInternal [] = Empty
buildTreeInternal ps =
let n = length ps
(medianAxisVal : _, _) =
quickselect (comparing (head . fst)) (n `div` 2) ps
f ([], _) _ = error "buildKdMap.f: no empty lists allowed!"
f (v : vt, p) (lt, maybeMedian, gt)
| v < medianAxisVal = ((vt, p) : lt, maybeMedian, gt)
| v > medianAxisVal = (lt, maybeMedian, (vt, p) : gt)
| otherwise =
case maybeMedian of
Nothing -> (lt, Just p, gt)
Just _ -> ((vt, p) : lt, maybeMedian, gt)
(leftPoints, maybeMedianPt, rightPoints) = L.foldr f ([], Nothing, []) ps
in TreeNode
{ _treeLeft = buildTreeInternal leftPoints
, _treePoint = fromJust maybeMedianPt
, _axisValue = medianAxisVal
, _treeRight = buildTreeInternal rightPoints
}
defaultDistSqrFn :: Num a => PointAsListFn a p -> SquaredDistanceFn a p
defaultDistSqrFn pointAsList k1 k2 =
L.sum $ map (^ (2 :: Int)) $ zipWith () (pointAsList k1) (pointAsList k2)
buildKdMap :: Real a => PointAsListFn a p -> [(p, v)] -> KdMap a p v
buildKdMap pointAsList =
buildKdMapWithDistFn pointAsList $ defaultDistSqrFn pointAsList
assocsInternal :: TreeNode a p v -> [(p, v)]
assocsInternal t = go t []
where go Empty = id
go (TreeNode l p _ r) = go l . (p :) . go r
assocs :: KdMap a p v -> [(p, v)]
assocs (KdMap _ _ t _) = assocsInternal t
points :: KdMap a p v -> [p]
points = map fst . assocs
values :: KdMap a p v -> [v]
values = map snd . assocs
nearestNeighbor :: Real a => KdMap a p v -> p -> (p, v)
nearestNeighbor (KdMap _ _ Empty _) _ =
error "nearestNeighbor: why is there an empty KdMap?"
nearestNeighbor (KdMap pointAsList distSqr t@(TreeNode _ root _ _) _) query =
fst $ go (root, distSqr (fst root) query) (cycle $ pointAsList query) t
where
go _ [] _ = error "nearestNeighbor.go: no empty lists allowed!"
go bestSoFar _ Empty = bestSoFar
go bestSoFar
(queryAxisValue : qvs)
(TreeNode left (nodeK, nodeV) nodeAxisVal right) =
let better x1@(_, dist1) x2@(_, dist2) = if dist1 < dist2
then x1
else x2
currDist = distSqr query nodeK
bestAfterNode = better ((nodeK, nodeV), currDist) bestSoFar
nearestInTree onsideSubtree offsideSubtree =
let bestAfterOnside = go bestAfterNode qvs onsideSubtree
checkOffsideSubtree =
(queryAxisValue nodeAxisVal)^(2 :: Int) < snd bestAfterOnside
in if checkOffsideSubtree
then go bestAfterOnside qvs offsideSubtree
else bestAfterOnside
in if queryAxisValue <= nodeAxisVal
then nearestInTree left right
else nearestInTree right left
pointsInRadius :: Real a => KdMap a p v
-> a
-> p
-> [(p, v)]
pointsInRadius (KdMap pointAsList distSqr t _) radius query =
go (cycle $ pointAsList query) t []
where
go [] _ _ = error "pointsInRadius.go: no empty lists allowed!"
go _ Empty acc = acc
go (queryAxisValue : qvs) (TreeNode left (k, v) nodeAxisVal right) acc =
let onTheLeft = queryAxisValue <= nodeAxisVal
accAfterOnside = if onTheLeft
then go qvs left acc
else go qvs right acc
accAfterOffside = if abs (queryAxisValue nodeAxisVal) < radius
then if onTheLeft
then go qvs right accAfterOnside
else go qvs left accAfterOnside
else accAfterOnside
accAfterCurrent = if distSqr k query <= radius * radius
then (k, v) : accAfterOffside
else accAfterOffside
in accAfterCurrent
kNearestNeighbors :: Real a => KdMap a p v -> Int -> p -> [(p, v)]
kNearestNeighbors (KdMap pointAsList distSqr t _) numNeighbors query =
reverse $ map snd $ Q.toList $ go (cycle $ pointAsList query) Q.empty t
where
go [] _ _ = error "kNearestNeighbors.go: no empty lists allowed!"
go _ q Empty = q
go (queryAxisValue : qvs) q (TreeNode left (k, v) nodeAxisVal right) =
let insertBounded queue dist x
| Q.size queue < numNeighbors = Q.insert dist x queue
| otherwise = if dist < fst (Q.findMax queue)
then Q.insert dist x $ Q.deleteMax queue
else queue
q' = insertBounded q (distSqr k query) (k, v)
kNearest queue onsideSubtree offsideSubtree =
let queue' = go qvs queue onsideSubtree
checkOffsideTree =
Q.size queue' < numNeighbors ||
(queryAxisValue nodeAxisVal)^(2 :: Int) < fst (Q.findMax queue')
in if checkOffsideTree
then go qvs queue' offsideSubtree
else queue'
in if queryAxisValue <= nodeAxisVal
then kNearest q' left right
else kNearest q' right left
pointsInRange :: Real a => KdMap a p v
-> p
-> p
-> [(p, v)]
pointsInRange (KdMap pointAsList _ t _) lowers uppers =
go (cycle (pointAsList lowers) `zip` cycle (pointAsList uppers)) t []
where
go [] _ _ = error "neighborsInRange.go: no empty lists allowed!"
go _ Empty acc = acc
go ((lower, upper) : nextBounds) (TreeNode left p nodeAxisVal right) acc =
let accAfterLeft = if lower <= nodeAxisVal
then go nextBounds left acc
else acc
accAfterRight = if upper > nodeAxisVal
then go nextBounds right accAfterLeft
else accAfterLeft
valInRange l x u = l <= x && x <= u
currentInRange =
L.and $ zipWith3 valInRange
(pointAsList lowers) (pointAsList $ fst p) (pointAsList uppers)
accAfterCurrent = if currentInRange
then p : accAfterRight
else accAfterRight
in accAfterCurrent
size :: KdMap a p v -> Int
size (KdMap _ _ _ n) = n
testElements :: [p] -> [(p, Int)]
testElements ps = zip ps [0 ..]
isTreeValid :: Real a => PointAsListFn a p -> Int -> TreeNode a p v -> Bool
isTreeValid _ _ Empty = True
isTreeValid pointAsList axis (TreeNode l (k, _) nodeAxisVal r) =
let childrenAxisValues = map ((!! axis) . pointAsList . fst) . assocsInternal
leftSubtreeLess = L.all (<= nodeAxisVal) $ childrenAxisValues l
rightSubtreeGreater = L.all (> nodeAxisVal) $ childrenAxisValues r
nextAxis = (axis + 1) `mod` length (pointAsList k)
in leftSubtreeLess && rightSubtreeGreater &&
isTreeValid pointAsList nextAxis l && isTreeValid pointAsList nextAxis r
checkValidTree :: Real a => PointAsListFn a p -> [p] -> Bool
checkValidTree pointAsList ps =
let (KdMap _ _ r _) = buildKdMap pointAsList $ testElements ps
in isTreeValid pointAsList 0 r
prop_validTree :: Property
prop_validTree = forAll (listOf1 arbitrary) $ checkValidTree pointAsList2d
checkElements :: (Ord p, Real a) => PointAsListFn a p -> [p] -> Bool
checkElements pointAsList ps =
let kdt = buildKdMap pointAsList $ testElements ps
in L.sort (assocs kdt) == L.sort (testElements ps)
prop_sameElements :: Property
prop_sameElements = forAll (listOf1 arbitrary) $ checkElements pointAsList2d
checkNumElements :: Real a => PointAsListFn a p -> [p] -> Bool
checkNumElements pointAsList ps =
let (KdMap _ _ _ n) = buildKdMap pointAsList $ testElements ps
in n == length ps
prop_validNumElements :: Property
prop_validNumElements = forAll (listOf1 arbitrary) $ checkNumElements pointAsList2d
nearestNeighborLinear :: Real a => PointAsListFn a p -> [(p, v)] -> p -> (p, v)
nearestNeighborLinear pointAsList xs query =
L.minimumBy (compare `on` (defaultDistSqrFn pointAsList query . fst)) xs
checkNearestEqualToLinear :: (Eq p, Real a) => PointAsListFn a p -> ([p], p) -> Bool
checkNearestEqualToLinear pointAsList (ps, query) =
let kdt = buildKdMap pointAsList $ testElements ps
in nearestNeighbor kdt query == nearestNeighborLinear pointAsList (testElements ps) query
prop_nearestEqualToLinear :: Point2d -> Property
prop_nearestEqualToLinear query =
forAll (listOf1 arbitrary) $ \xs ->
checkNearestEqualToLinear pointAsList2d (xs, query)
pointsInRadiusLinear :: Real a => PointAsListFn a p -> [(p, v)] -> p -> a -> [(p, v)]
pointsInRadiusLinear pointAsList xs query radius =
filter ((<= radius * radius) . defaultDistSqrFn pointAsList query . fst) xs
checkInRadiusEqualToLinear :: (Ord p, Real a) => PointAsListFn a p -> a -> ([p], p) -> Bool
checkInRadiusEqualToLinear pointAsList radius (ps, query) =
let kdt = buildKdMap pointAsList $ testElements ps
kdtNear = pointsInRadius kdt radius query
linearNear = pointsInRadiusLinear pointAsList (testElements ps) query radius
in L.sort kdtNear == L.sort linearNear
prop_inRadiusEqualToLinear :: Point2d -> Property
prop_inRadiusEqualToLinear query =
forAll (listOf1 arbitrary) $ \xs ->
forAll (choose (0.0, 1000.0)) $ \radius ->
checkInRadiusEqualToLinear pointAsList2d radius (xs, query)
kNearestNeighborsLinear :: Real a => PointAsListFn a p -> [(p, v)] -> p -> Int -> [(p, v)]
kNearestNeighborsLinear pointAsList xs query k =
take k $ L.sortBy (compare `on` (defaultDistSqrFn pointAsList query . fst)) xs
checkKNearestEqualToLinear :: (Ord p, Real a) => PointAsListFn a p -> Int -> ([p], p) -> Bool
checkKNearestEqualToLinear pointAsList k (xs, query) =
let kdt = buildKdMap pointAsList $ testElements xs
kdtKNear = kNearestNeighbors kdt k query
linearKNear = kNearestNeighborsLinear pointAsList (testElements xs) query k
in kdtKNear == linearKNear
prop_kNearestEqualToLinear :: Point2d -> Property
prop_kNearestEqualToLinear query =
forAll (listOf1 arbitrary) $ \xs ->
forAll (choose (1, length xs)) $ \k ->
checkKNearestEqualToLinear pointAsList2d k (xs, query)
checkKNearestSorted :: (Eq p, Real a) => PointAsListFn a p -> ([p], p) -> Bool
checkKNearestSorted _ ([], _) = True
checkKNearestSorted pointAsList (ps, query) =
let kdt = buildKdMap pointAsList $ testElements ps
kNearestDists =
map (defaultDistSqrFn pointAsList query . fst) $ kNearestNeighbors kdt (length ps) query
in kNearestDists == L.sort kNearestDists
prop_kNearestSorted :: Point2d -> Property
prop_kNearestSorted query =
forAll (listOf1 arbitrary) $ \xs ->
checkKNearestSorted pointAsList2d (xs, query)
rangeLinear :: Real a => PointAsListFn a p -> [(p, v)] -> p -> p -> [(p, v)]
rangeLinear pointAsList xs lowers uppers =
let valInRange a lower upper = lower <= a && a <= upper
lowersAsList = pointAsList lowers
uppersAsList = pointAsList uppers
pointInRange (p, _) =
L.and $ zipWith3 valInRange (pointAsList p) lowersAsList uppersAsList
in filter pointInRange xs
prop_rangeEqualToLinear :: ([Point2d], Point2d, Point2d) -> Bool
prop_rangeEqualToLinear (xs, lowers, uppers)
| null xs = True
| L.and $ zipWith (<) (pointAsList2d lowers) (pointAsList2d uppers) =
let linear = rangeLinear pointAsList2d (testElements xs) lowers uppers
kdt = buildKdMap pointAsList2d $ testElements xs
kdtPoints = pointsInRange kdt lowers uppers
in L.sort linear == L.sort kdtPoints
| otherwise = True
prop_equalAxisValueSameElems :: Property
prop_equalAxisValueSameElems =
forAll (listOf1 arbitrary) $ \xs@(Point2d x y : _) ->
checkElements pointAsList2d $ Point2d x (y + 1) : xs
prop_equalAxisValueEqualToLinear :: Point2d -> Property
prop_equalAxisValueEqualToLinear query =
forAll (listOf1 arbitrary) $ \xs@(Point2d x y : _) ->
checkNearestEqualToLinear pointAsList2d (Point2d x (y + 1) : xs, query)
return []
runTests :: IO Bool
runTests = $quickCheckAll