{- | Module : Data.BinaryIndexedTree Description : Binary Indexed Trees (a.k.a. Fenwick Trees) Copyright : (c) 2012 Maxwell Sayles. License : LGPL Maintainer : maxwellsayles@gmail.com Stability : stable Portability : portable Implements persistent binary indexed trees (or Fenwick Trees) in /O(logn)/ for increment and lookup and /O(n)/ for creation. Index /i/ in the tree represents the sum of all values of indexes /j<=i/ for some array. The idea is that for /k/ bits, we parse the index /i/ from msb to lsb and move left\/right on the tree for 0\/1. For a read, we accumulate the values in the tree where the binary representation of the index contains a 1. (The technique is similar to binary exponentiation.) For an increment, we should increment parent nodes in the tree whose corresponding binary index representation is />=/ than the index /i/. /Note: I was unable to find the algorithm used here in the literature./ -} module Data.BinaryIndexedTree (BinaryIndexedTree, new, (!), increment) where import Data.Bits {-| A Binary indexed tree. -} data BinaryIndexedTree a = BinaryIndexedTree Int (Tree a) data Tree a = Empty | Node a (Tree a) (Tree a) {-| Construct a binary indexed tree on k bits. Takes O(n). -} new :: Num a => Int -> BinaryIndexedTree a new k = BinaryIndexedTree k $ f k where f 0 = Empty f k = Node 0 (f (k - 1)) (f (k - 1)) {-| Lookup the sum of all values from index 1 to index i. Takes O(logn). -} (!) :: Num a => BinaryIndexedTree a -> Int -> a (!) (BinaryIndexedTree k root) i = f root (k - 1) 0 where f Empty _ acc = acc f (Node x l r) j acc | i `testBit` j = acc' `seq` f l j' acc' | otherwise = f r j' acc where j' = j - 1 acc' = acc + x {-| Increment the value at index i by amount x. Takes O(logn). -} increment :: Num a => Int -> a -> BinaryIndexedTree a -> BinaryIndexedTree a increment i x (BinaryIndexedTree k root) = BinaryIndexedTree k $ f root (k - 1) 0 where f (Node y l r) j acc | i `testBit` j = if acc' == i then y' `seq` Node y' l r else acc' `seq` Node y (f l j' acc') r | otherwise = y' `seq` Node y' l (f r j' acc) where y' = x + y j' = j - 1 acc' = acc `setBit` j