module Data.QuadTree.Internal where
import Control.Lens.Type (Lens')
import Control.Lens.Setter (over, set)
import Control.Lens.Getter (view)
import Data.List (find, sortBy)
import Data.Function (on)
import Data.Composition ((.:))
type Location = (Int, Int)
data QuadTree a = Wrapper { wrappedTree :: Quadrant a
, treeLength :: Int
, treeWidth :: Int
, treeDepth :: Int }
deriving (Show, Read, Eq)
instance Functor QuadTree where
fmap fn = onQuads $ fmap fn
instance Foldable QuadTree where
foldr = foldTree
data Quadrant a = Leaf a
| Node (Quadrant a)
(Quadrant a)
(Quadrant a)
(Quadrant a)
deriving (Show, Read, Eq)
instance Functor Quadrant where
fmap fn (Leaf x) = Leaf (fn x)
fmap fn (Node a b c d) = Node (fmap fn a)
(fmap fn b)
(fmap fn c)
(fmap fn d)
_a :: forall a. Eq a => Lens' (Quadrant a) (Quadrant a)
_a f (Node a b c d) = fmap (\x -> fuse $ Node x b c d) (f a)
_a f leaf = fmap embed (f leaf)
where embed :: Quadrant a -> Quadrant a
embed x | x == leaf = leaf
| otherwise = Node x leaf leaf leaf
_b :: forall a. Eq a => Lens' (Quadrant a) (Quadrant a)
_b f (Node a b c d) = fmap (\x -> fuse $ Node a x c d) (f b)
_b f leaf = fmap embed (f leaf)
where embed :: Quadrant a -> Quadrant a
embed x | x == leaf = leaf
| otherwise = Node leaf x leaf leaf
_c :: forall a. Eq a => Lens' (Quadrant a) (Quadrant a)
_c f (Node a b c d) = fmap (\x -> fuse $ Node a b x d) (f c)
_c f leaf = fmap embed (f leaf)
where embed :: Quadrant a -> Quadrant a
embed x | x == leaf = leaf
| otherwise = Node leaf leaf x leaf
_d :: forall a. Eq a => Lens' (Quadrant a) (Quadrant a)
_d f (Node a b c d) = fmap (fuse . Node a b c) (f d)
_d f leaf = fmap embed (f leaf)
where embed :: Quadrant a -> Quadrant a
embed x | x == leaf = leaf
| otherwise = Node leaf leaf leaf x
_leaf :: Lens' (Quadrant a) a
_leaf f (Leaf leaf) = Leaf <$> f leaf
_leaf _ _ = error "Wrapped tree is deeper than cached tree depth."
_wrappedTree :: Lens' (QuadTree a) (Quadrant a)
_wrappedTree f qt = (\x -> qt {wrappedTree = x}) <$> f (wrappedTree qt)
verifyLocation :: Location -> Lens' (QuadTree a) (QuadTree a)
verifyLocation index f qt
| index `outOfBounds` qt = error "Location index out of QuadTree bounds."
| otherwise = f qt
atLocation :: forall a. Eq a => Location -> Lens' (QuadTree a) a
atLocation index fn qt = (verifyLocation index . _wrappedTree .
go (offsetIndex qt index) (treeDepth qt)) fn qt
where
go :: Eq a => Location -> Int -> Lens' (Quadrant a) a
go _ 0 = _leaf
go (x,y) n | y < mid = if x < mid then _a . recurse
else _b . recurse
| otherwise = if x < mid then _c . recurse
else _d . recurse
where recurse = go (x `mod` mid, y `mod` mid) (n 1)
mid = 2 ^ (n 1)
getLocation :: Eq a => Location -> QuadTree a -> a
getLocation = view . atLocation
setLocation :: Eq a => Location -> a -> QuadTree a -> QuadTree a
setLocation = set . atLocation
mapLocation :: Eq a => Location -> (a -> a) -> QuadTree a -> QuadTree a
mapLocation = over . atLocation
outOfBounds :: Location -> QuadTree a -> Bool
outOfBounds (x,y) tree = x < 0 || y < 0
|| x >= treeLength tree
|| y >= treeWidth tree
treeDimensions :: QuadTree a
-> (Int, Int)
treeDimensions tree = (treeLength tree, treeWidth tree)
offsetIndex :: QuadTree a -> Location -> Location
offsetIndex tree (x,y) = (x + xOffset, y + yOffset)
where (xOffset, yOffset) = offsets tree
offsets :: QuadTree a -> (Int, Int)
offsets tree = (xOffset, yOffset)
where xOffset = (dimension treeLength tree) `div` 2
yOffset = (dimension treeWidth tree) `div` 2
dimension = 2 ^ treeDepth tree
fuse :: Eq a => Quadrant a -> Quadrant a
fuse (Node (Leaf a) (Leaf b) (Leaf c) (Leaf d))
| allEqual [a,b,c,d] = Leaf a
fuse oldNode = oldNode
allEqual :: Eq a => [a] -> Bool
allEqual = and . (zipWith (==) <*> tail)
onQuads :: (Quadrant a -> Quadrant b) -> QuadTree a -> QuadTree b
onQuads fn tree = tree {wrappedTree = fn (wrappedTree tree)}
fuseTree :: Eq a => QuadTree a -> QuadTree a
fuseTree = onQuads fuseQuads
where fuseQuads :: Eq a => Quadrant a -> Quadrant a
fuseQuads (Node a b c d) = fuse $ Node (fuseQuads a)
(fuseQuads b)
(fuseQuads c)
(fuseQuads d)
fuseQuads leaf = leaf
tmap :: Eq b => (a -> b) -> QuadTree a -> QuadTree b
tmap = fuseTree .: fmap
type Region = (Int, Int, Int, Int)
type Tile a = (a, Region)
foldTree :: (a -> b -> b) -> b -> QuadTree a -> b
foldTree fn z = foldr fn z . expand . tile
expand :: [Tile a] -> [a]
expand = concatMap decompose
where decompose :: Tile a -> [a]
decompose (a, r) = replicate (regionArea r) a
tile :: QuadTree a -> [Tile a]
tile = foldTiles (:) []
foldTiles :: forall a b. (Tile a -> b -> b) -> b -> QuadTree a -> b
foldTiles fn z tree = go (treeRegion tree) (wrappedTree tree) z
where go :: Region -> Quadrant a -> b -> b
go r (Leaf a) = fn (a, normalizedIntersection)
where normalizedIntersection =
(interXl xOffset, interYt yOffset,
interXr xOffset, interYb yOffset)
(interXl, interYt, interXr, interYb) =
treeIntersection r
go (xl, yt, xr, yb) (Node a b c d) =
go (xl, yt, midx, midy) a .
go (midx + 1, yt, xr, midy) b .
go (xl, midy + 1, midx, yb) c .
go (midx + 1, midy + 1, xr, yb) d
where midx = (xr + xl) `div` 2
midy = (yt + yb) `div` 2
(xOffset, yOffset) = offsets tree
treeIntersection = regionIntersection $ boundaries tree
treeRegion :: QuadTree a -> Region
treeRegion tree = (0, 0, limit, limit)
where limit = (2 ^ treeDepth tree) 1
boundaries :: QuadTree a -> Region
boundaries tree = (left, top, right, bottom)
where (left, top) = offsetIndex tree (0,0)
(right, bottom) = offsetIndex tree (treeLength tree 1,
treeWidth tree 1)
regionIntersection :: Region -> Region -> Region
regionIntersection (xl , yt , xr , yb )
(xl', yt', xr', yb') =
(max xl xl', max yt yt',
min xr xr', min yb yb')
regionArea :: Region -> Int
regionArea (xl,yt,xr,yb) = (xr + 1 xl) * (yb + 1 yt)
inRegion :: Location -> Region -> Bool
inRegion (x,y) (xl,yt,xr,yb) = xl <= x && x <= xr &&
yt <= y && y <= yb
filterTree :: (a -> Bool) -> QuadTree a -> [a]
filterTree fn = expand . filterTiles fn . tile
sortTreeBy :: (a -> a -> Ordering) -> QuadTree a -> [a]
sortTreeBy fn = expand . sortTilesBy fn . tile
filterTiles :: (a -> Bool) -> [Tile a] -> [Tile a]
filterTiles _ [] = []
filterTiles fn ((a,r) : rs)
| fn a = (a,r) : filterTiles fn rs
| otherwise = filterTiles fn rs
sortTilesBy :: (a -> a -> Ordering) -> [Tile a] -> [Tile a]
sortTilesBy fn = sortBy (fn `on` fst)
makeTree :: (Int, Int)
-> a
-> QuadTree a
makeTree (x,y) a
| x <= 0 || y <= 0 = error "Invalid dimensions for tree."
| otherwise = Wrapper { wrappedTree = Leaf a
, treeLength = x
, treeWidth = y
, treeDepth = smallestDepth (x,y) }
smallestDepth :: (Int, Int) -> Int
smallestDepth (x,y) = depth
where (depth, _) = smallestPower
Just smallestPower = find bigEnough powersZip
bigEnough (_, e) = e >= max x y
powersZip = zip [0..] $ iterate (* 2) 1
showTree :: Eq a => (a -> Char)
-> QuadTree a -> String
showTree printer tree = breakString (treeLength tree) string
where string = map printer grid
grid = [getLocation (x,y) tree |
y <- [0 .. treeWidth tree 1],
x <- [0 .. treeLength tree 1]]
breakString :: Int -> String -> String
breakString _ [] = []
breakString n xs = a ++ "\n" ++ breakString n b
where (a,b) = splitAt n xs
printTree :: Eq a => (a -> Char)
-> QuadTree a -> IO ()
printTree = putStr .: showTree