{-#LANGUAGE BangPatterns, ParallelListComp#-} -- | This module presents a basic version of the meanshift algorithm for -- feature-space analysis. Mean shifting is an iterative process with -- fixed points that correspond to -- modes of kernel density estimate performed -- with the same bandwidth (first parameter). This -- can be used to, for example, to partition the data by -- determining which fixed point each of the samples belongs to. -- -- Usage example: -- > fixedPointE 0.001 (meanShift 0.1 points) (V.fromList [1,1,1]) -- -- More examples can be found in the Examples directory of this package. module Math.Meanshift ( -- * Basic Meanshift routines meanShift,meanShiftWindow -- * Auxiliary functions for iterating the meanshift steps. ,fixedPoint, fixedPointE -- * Types ,Window,Support -- * (multidimensional) Kernel Density Estimates ,kde ) where import qualified Data.Vector.Unboxed as V import Data.List hiding (sum) import qualified Data.List as L import Prelude hiding (sum) -- The project cabal file -- | Euclidian norm norm² :: Vector -> Double norm² = V.sum . V.map (**2) -- | One dimensional normal kernel and its derivative. normalKernel,normalKernel' :: Double -> Double normalKernel x = exp(-0.5 * x) normalKernel' x = -2 * exp(-0.5 * x) -- | Kernel density estimate of given points. Uses a normal kernel. kde :: Double -> [Vector] -> (Vector -> Double) kde h vs x = (1 / (n*((2*π)**(d/2))*(h**d))) * (L.sum $ map (\xi -> normalKernel (norm² ((x ^- xi) ./ h))) vs) where n = fi . length $ vs d = fi . V.length . head $ vs -- | Calculate the Mean shift for a point in a dataset. This is -- efficient only when we cannot make an a priori estimate on which -- points contribute to the mean shift at given location. -- meanShift :: Double -> [Vector] -> (Vector -> Vector) meanShift h vs x = sumW d vs dists (1/V.sum dists) where d = V.length (head vs) dists = V.fromList $ map (\xi -> normalKernel' $ distPerH x xi) vs distPerH :: Vector -> Vector -> Double distPerH !a !b = V.sum (V.zipWith (\u v -> ((u-v) / h)^(2::Int)) a b) type Window = Support -> [Vector] type Support = (Vector,Double) -- | Mean shift with a windowing function. Performing mean shift is more -- efficient if we can index and calculate only those points that are in -- the support of our kernel. {-#INLINEABLE meanShiftWindow#-} meanShiftWindow :: Int -> Window -> Double -> (Vector -> Vector) meanShiftWindow d window h x = sumW d w dists (1/V.sum dists) where dists = V.fromList $ map (\xi -> normalKernel' $ distPerH x xi) w w = window (x,h*2) -- TODO: Think this through distPerH :: Vector -> Vector -> Double distPerH !a !b = V.sum (V.zipWith (\u v -> ((u-v) / h)^(2::Int)) a b) -- | Find a path to the fixed point of a function. {-#INLINEABLE fixedPoint#-} fixedPoint :: Eq a => (a -> a) -> a -> [a] fixedPoint f x = x:let x' = f x in if x'/=x then fixedPoint f x' else [x'] fixedPointE :: Double -> (Vector -> Vector) -> Vector -> [Vector] fixedPointE e f x = x:let x' = f x in if V.sum (V.map abs $ x' ^- x) > e then fixedPointE e f x' else [x'] -- * Auxiliary functions, and shorthands type Vector = V.Vector Double v :: [Double] -> Vector v = V.fromList {-#INLINE (^+)#-} {-#INLINE (^-)#-} {-#INLINE (^/)#-} (^+),(^-),(^/) :: Vector -> Vector -> Vector (^-) = V.zipWith (-) (^+) = V.zipWith (+) (^/) = V.zipWith (/) a .+ b = V.map (+b) a (.+),(./),(.*) :: Vector -> Double -> Vector a ./ b = V.map (/b) a a .* b = V.map (*b) a infixl 7 ^/ infixl 6 ^+ , ^- box :: x -> [x] box x = [x] box2 :: x -> x -> [x] box2 x y = [x,y] sv :: Double -> Vector sv = V.singleton fs :: Vector -> Double fs = V.head {-#INLINEABLE sumD#-} sumD :: Int -> [Vector] -> Vector sumD d xs = V.generate d (\i -> L.sum (map (`V.unsafeIndex` i) xs) ) {-#INLINEABLE sumW#-} sumW :: Int -> [Vector] -> Vector -> Double -> Vector sumW d es ws n = V.generate d (\i -> go i 0 es 0) where go i j (x:xs) acc = go i (j+1) xs $ acc + n*(x V.! i)*(ws V.! j) go _ _ [] acc = acc π :: Double π = pi fi :: Int -> Double fi = fromIntegral