{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language DeriveGeneric #-}
{-# language LambdaCase #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiWayIf #-}
{-# options_ghc -Wno-unused-imports #-}
{-# options_ghc -Wno-unused-top-binds #-}
module Data.RPTree (
forest
, knn
, serialiseRPForest
, deserialiseRPForest
, recallWith
, levels, points, candidates
, Embed(..)
, RPTree, RPForest
, SVector, fromListSv, fromVectorSv
, DVector, fromListDv, fromVectorDv
, Inner(..), Scale(..)
, innerSS, innerSD, innerDD
, metricSSL2, metricSDL2
, scaleS, scaleD
, draw
, writeCsv
, randSeed, BenchConfig(..), normalSparse2
, liftC
, dataSource
, datS, datD
, sparse, dense
, normal2
) where
import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO(..))
import Data.Foldable (Foldable(..), maximumBy, minimumBy)
import Data.Functor.Identity (Identity(..))
import Data.List (partition, sortBy)
import Data.Monoid (Sum(..))
import Data.Ord (comparing)
import GHC.Generics (Generic)
import GHC.Word (Word64)
import Data.Sequence (Seq, (|>))
import qualified Data.Map as M (Map, fromList, toList, foldrWithKey, insert, insertWith)
import qualified Data.Set as S (Set, fromList, intersection, insert)
import Control.DeepSeq (NFData(..))
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)
import Control.Monad.Trans.Class (MonadTrans(..))
import qualified Data.Vector as V (Vector, replicateM, fromList)
import qualified Data.Vector.Generic as VG (Vector(..), unfoldrM, length, replicateM, (!), map, freeze, thaw, take, drop, unzip)
import qualified Data.Vector.Unboxed as VU (Vector, Unbox, fromList)
import qualified Data.Vector.Storable as VS (Vector)
import qualified Data.Vector.Algorithms.Merge as V (sortBy)
import Data.RPTree.Conduit (tree, forest, ForestParams, defaultParams, dataSource, liftC)
import Data.RPTree.Gen (sparse, dense, normal2, normalSparse2)
import Data.RPTree.Internal (RPTree(..), RPForest, RPT(..), Embed(..), levels, points, Inner(..), Scale(..), scaleS, scaleD, (/.), innerDD, innerSD, innerSS, metricSSL2, metricSDL2, SVector(..), fromListSv, fromVectorSv, DVector(..), fromListDv, fromVectorDv, partitionAtMedian, Margin, getMargin, sortByVG, serialiseRPForest, deserialiseRPForest)
import Data.RPTree.Internal.Testing (BenchConfig(..), randSeed, datS, datD)
import Data.RPTree.Draw (draw, writeCsv)
knn :: (Ord p, Inner SVector v, VU.Unbox d, Real d) =>
(u d -> v d -> p)
-> Int
-> RPForest d (V.Vector (Embed u d x))
-> v d
-> V.Vector (p, Embed u d x)
knn :: (u d -> v d -> p)
-> Int
-> RPForest d (Vector (Embed u d x))
-> v d
-> Vector (p, Embed u d x)
knn u d -> v d -> p
distf Int
k RPForest d (Vector (Embed u d x))
tts v d
q = ((p, Embed u d x) -> p)
-> Vector (p, Embed u d x) -> Vector (p, Embed u d x)
forall (v :: * -> *) a b.
(Vector v a, Ord b) =>
(a -> b) -> v a -> v a
sortByVG (p, Embed u d x) -> p
forall a b. (a, b) -> a
fst Vector (p, Embed u d x)
cs
where
cs :: Vector (p, Embed u d x)
cs = (Embed u d x -> (p, Embed u d x))
-> Vector (Embed u d x) -> Vector (p, Embed u d x)
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (\Embed u d x
xe -> (Embed u d x -> u d
forall (v :: * -> *) e a. Embed v e a -> v e
eEmbed Embed u d x
xe u d -> v d -> p
`distf` v d
q, Embed u d x
xe)) (Vector (Embed u d x) -> Vector (p, Embed u d x))
-> Vector (Embed u d x) -> Vector (p, Embed u d x)
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Embed u d x) -> Vector (Embed u d x)
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
VG.take Int
k (Vector (Embed u d x) -> Vector (Embed u d x))
-> Vector (Embed u d x) -> Vector (Embed u d x)
forall a b. (a -> b) -> a -> b
$ IntMap (Vector (Embed u d x)) -> Vector (Embed u d x)
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold (IntMap (Vector (Embed u d x)) -> Vector (Embed u d x))
-> IntMap (Vector (Embed u d x)) -> Vector (Embed u d x)
forall a b. (a -> b) -> a -> b
$ (RPTree d (Vector (Embed u d x)) -> v d -> Vector (Embed u d x)
forall (v :: * -> *) d xs.
(Inner SVector v, Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs -> v d -> xs
`candidates` v d
q) (RPTree d (Vector (Embed u d x)) -> Vector (Embed u d x))
-> RPForest d (Vector (Embed u d x))
-> IntMap (Vector (Embed u d x))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RPForest d (Vector (Embed u d x))
tts
recallWith :: (Inner SVector v, VU.Unbox d, Fractional a1, Ord d, Ord a2, Ord x, Ord (u d), Num d) =>
(u d -> v d -> a2)
-> RPForest d (V.Vector (Embed u d x))
-> Int
-> v d
-> a1
recallWith :: (u d -> v d -> a2)
-> RPForest d (Vector (Embed u d x)) -> Int -> v d -> a1
recallWith u d -> v d -> a2
distf RPForest d (Vector (Embed u d x))
tt Int
k v d
q = IntMap a1 -> a1
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum IntMap a1
rs a1 -> a1 -> a1
forall a. Fractional a => a -> a -> a
/ Int -> a1
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
where
rs :: IntMap a1
rs = (RPTree d (Vector (Embed u d x)) -> a1)
-> RPForest d (Vector (Embed u d x)) -> IntMap a1
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\RPTree d (Vector (Embed u d x))
t -> (u d -> v d -> a2)
-> RPTree d (Vector (Embed u d x)) -> Int -> v d -> a1
forall (v :: * -> *) d p a x (u :: * -> *).
(Inner SVector v, Ord d, Unbox d, Fractional p, Ord a, Ord x,
Ord (u d), Num d) =>
(u d -> v d -> a)
-> RPTree d (Vector (Embed u d x)) -> Int -> v d -> p
recallWith1 u d -> v d -> a2
distf RPTree d (Vector (Embed u d x))
t Int
k v d
q) RPForest d (Vector (Embed u d x))
tt
n :: Int
n = RPForest d (Vector (Embed u d x)) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length RPForest d (Vector (Embed u d x))
tt
recallWith1 :: (Inner SVector v, Ord d, VU.Unbox d, Fractional p, Ord a, Ord x, Ord (u d), Num d) =>
(u d -> v d -> a)
-> RPTree d (V.Vector (Embed u d x))
-> Int
-> v d
-> p
recallWith1 :: (u d -> v d -> a)
-> RPTree d (Vector (Embed u d x)) -> Int -> v d -> p
recallWith1 u d -> v d -> a
distf RPTree d (Vector (Embed u d x))
tt Int
k v d
q = Int -> p
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Set (Embed u d x) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Set (Embed u d x)
aintk) p -> p -> p
forall a. Fractional a => a -> a -> a
/ Int -> p
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
where
xs :: Vector (Embed u d x)
xs = RPTree d (Vector (Embed u d x)) -> Vector (Embed u d x)
forall m d. Monoid m => RPTree d m -> m
points RPTree d (Vector (Embed u d x))
tt
dists :: [(Embed u d x, a)]
dists = ((Embed u d x, a) -> (Embed u d x, a) -> Ordering)
-> [(Embed u d x, a)] -> [(Embed u d x, a)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Embed u d x, a) -> a)
-> (Embed u d x, a) -> (Embed u d x, a) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Embed u d x, a) -> a
forall a b. (a, b) -> b
snd) ([(Embed u d x, a)] -> [(Embed u d x, a)])
-> [(Embed u d x, a)] -> [(Embed u d x, a)]
forall a b. (a -> b) -> a -> b
$ Vector (Embed u d x, a) -> [(Embed u d x, a)]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Vector (Embed u d x, a) -> [(Embed u d x, a)])
-> Vector (Embed u d x, a) -> [(Embed u d x, a)]
forall a b. (a -> b) -> a -> b
$ (Embed u d x -> (Embed u d x, a))
-> Vector (Embed u d x) -> Vector (Embed u d x, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Embed u d x
x -> (Embed u d x
x, Embed u d x -> u d
forall (v :: * -> *) e a. Embed v e a -> v e
eEmbed Embed u d x
x u d -> v d -> a
`distf` v d
q)) Vector (Embed u d x)
xs
kk :: Set (Embed u d x)
kk = [Embed u d x] -> Set (Embed u d x)
forall a. Ord a => [a] -> Set a
S.fromList ([Embed u d x] -> Set (Embed u d x))
-> [Embed u d x] -> Set (Embed u d x)
forall a b. (a -> b) -> a -> b
$ ((Embed u d x, a) -> Embed u d x)
-> [(Embed u d x, a)] -> [Embed u d x]
forall a b. (a -> b) -> [a] -> [b]
map (Embed u d x, a) -> Embed u d x
forall a b. (a, b) -> a
fst ([(Embed u d x, a)] -> [Embed u d x])
-> [(Embed u d x, a)] -> [Embed u d x]
forall a b. (a -> b) -> a -> b
$ Int -> [(Embed u d x, a)] -> [(Embed u d x, a)]
forall a. Int -> [a] -> [a]
take Int
k [(Embed u d x, a)]
dists
aa :: Set (Embed u d x)
aa = Vector (Embed u d x) -> Set (Embed u d x)
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> Set a
set (Vector (Embed u d x) -> Set (Embed u d x))
-> Vector (Embed u d x) -> Set (Embed u d x)
forall a b. (a -> b) -> a -> b
$ RPTree d (Vector (Embed u d x)) -> v d -> Vector (Embed u d x)
forall (v :: * -> *) d xs.
(Inner SVector v, Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs -> v d -> xs
candidates RPTree d (Vector (Embed u d x))
tt v d
q
aintk :: Set (Embed u d x)
aintk = Set (Embed u d x)
aa Set (Embed u d x) -> Set (Embed u d x) -> Set (Embed u d x)
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set (Embed u d x)
kk
set :: (Foldable t, Ord a) => t a -> S.Set a
set :: t a -> Set a
set = (Set a -> a -> Set a) -> Set a -> t a -> Set a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((a -> Set a -> Set a) -> Set a -> a -> Set a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert) Set a
forall a. Monoid a => a
mempty
{-# SCC candidates #-}
candidates :: (Inner SVector v, VU.Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs
-> v d
-> xs
candidates :: RPTree d xs -> v d -> xs
candidates (RPTree Vector (SVector d)
rvs RPT d xs
tt) v d
x = Int -> RPT d xs -> xs
forall a. Semigroup a => Int -> RPT d a -> a
go Int
0 RPT d xs
tt
where
go :: Int -> RPT d a -> a
go Int
_ (Tip a
xs) = a
xs
go Int
ixLev (Bin d
thr Margin d
margin RPT d a
ltree RPT d a
rtree) = do
let
(d
mglo, d
mghi) = Margin d -> (d, d)
forall a. Margin a -> (a, a)
getMargin Margin d
margin
r :: SVector d
r = Vector (SVector d)
rvs Vector (SVector d) -> Int -> SVector d
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.! Int
ixLev
proj :: d
proj = SVector d
r SVector d -> v d -> d
forall (u :: * -> *) (v :: * -> *) a.
(Inner u v, Unbox a, Num a) =>
u a -> v a -> a
`inner` v d
x
i' :: Int
i' = Int -> Int
forall a. Enum a => a -> a
succ Int
ixLev
dl :: d
dl = d -> d
forall a. Num a => a -> a
abs (d
mglo d -> d -> d
forall a. Num a => a -> a -> a
- d
proj)
dr :: d
dr = d -> d
forall a. Num a => a -> a
abs (d
mghi d -> d -> d
forall a. Num a => a -> a -> a
- d
proj)
if | d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
thr Bool -> Bool -> Bool
&&
d
dl d -> d -> Bool
forall a. Ord a => a -> a -> Bool
> d
dr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree a -> a -> a
forall a. Semigroup a => a -> a -> a
<> Int -> RPT d a -> a
go Int
i' RPT d a
rtree
| d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
thr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree
| d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
> d
thr Bool -> Bool -> Bool
&&
d
dl d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
dr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree a -> a -> a
forall a. Semigroup a => a -> a -> a
<> Int -> RPT d a -> a
go Int
i' RPT d a
rtree
| Bool
otherwise -> Int -> RPT d a -> a
go Int
i' RPT d a
rtree