{-# LANGUAGE FlexibleContexts #-}
{-# options_ghc -Wno-unused-imports #-}
module Data.RPTree.Batch (
  treeBatch, forestBatch,
  -- * utils
  dataBatch
  ) where

import Control.Monad (replicateM)
import GHC.Word (Word64)

-- containers
import qualified Data.IntMap.Strict as IM (IntMap, fromList, insert, lookup, map, mapWithKey, traverseWithKey, foldlWithKey, foldrWithKey, intersectionWith)
-- splitmix-distributions
import System.Random.SplitMix.Distributions (Gen, sample, GenT, sampleT, stdNormal)
-- vector
import qualified Data.Vector as V (Vector, replicateM, fromList)
import qualified Data.Vector.Generic as VG (Vector(..), unfoldrM, length, replicateM, (!), map, freeze, thaw, take, drop, unzip)
import qualified Data.Vector.Unboxed as VU (Vector, Unbox, fromList)

import Data.RPTree.Gen (sparse, dense)
import Data.RPTree.Internal (RPTree(..), RPForest, RPT(..), create, createMulti, SVector, Inner(..), Embed(..))

-- | Populate a tree from a dataset
--
-- Assumptions on the data source:
--
-- * non-empty : contains at least one value
treeBatch :: Inner SVector v =>
             Word64 -- ^ random seed
          -> Int -- ^ max tree depth
          -> Int -- ^ min leaf size
          -> Double -- ^ nonzero density of projection vectors
          -> Int -- ^ dimension of projection vectors
          -> V.Vector (Embed v Double x) -- ^ dataset
          -> RPTree Double () (V.Vector (Embed v Double x))
treeBatch :: Word64
-> Int
-> Int
-> Double
-> Int
-> Vector (Embed v Double x)
-> RPTree Double () (Vector (Embed v Double x))
treeBatch Word64
seed Int
maxDepth Int
minLeaf Double
pnz Int
dim Vector (Embed v Double x)
src =
  let
    rvs :: Vector (SVector Double)
rvs = Word64 -> Gen (Vector (SVector Double)) -> Vector (SVector Double)
forall a. Word64 -> Gen a -> a
sample Word64
seed (Gen (Vector (SVector Double)) -> Vector (SVector Double))
-> Gen (Vector (SVector Double)) -> Vector (SVector Double)
forall a b. (a -> b) -> a -> b
$ Int
-> GenT Identity (SVector Double) -> Gen (Vector (SVector Double))
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
maxDepth (Double
-> Int -> GenT Identity Double -> GenT Identity (SVector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Double -> Int -> GenT m a -> GenT m (SVector a)
sparse Double
pnz Int
dim GenT Identity Double
forall (m :: * -> *). Monad m => GenT m Double
stdNormal)
    t :: RPT Double () (Vector (Embed v Double x))
t = Int
-> Int
-> Vector (SVector Double)
-> Vector (Embed v Double x)
-> RPT Double () (Vector (Embed v Double x))
forall d (u :: * -> *) (v :: * -> *) (v1 :: * -> *) x.
(Ord d, Inner u v, Unbox d, Fractional d, Vector v1 (u d)) =>
Int
-> Int
-> v1 (u d)
-> Vector (Embed v d x)
-> RPT d () (Vector (Embed v d x))
create Int
maxDepth Int
minLeaf Vector (SVector Double)
rvs Vector (Embed v Double x)
src
  in Vector (SVector Double)
-> RPT Double () (Vector (Embed v Double x))
-> RPTree Double () (Vector (Embed v Double x))
forall d l a. Vector (SVector d) -> RPT d l a -> RPTree d l a
RPTree Vector (SVector Double)
rvs RPT Double () (Vector (Embed v Double x))
t

-- | Populate a forest from a data stream
--
-- Assumptions on the data source:
--
-- * non-empty : contains at least one value
forestBatch :: (Inner SVector v) =>
               Word64  -- ^ random seed
            -> Int -- ^ max tree depth, \(l > 1\) 
            -> Int -- ^ min leaf size, \(m_{leaf} > 1\)
            -> Int -- ^ number of trees, \(n_t > 1\)
            -> Double -- ^ nonzero density of projection vectors, \(p_{nz} \in (0, 1)\)
            -> Int -- ^ dimension of projection vectors, \(d > 1\)
            -> V.Vector (Embed v Double x) -- ^ dataset
            -> RPForest Double (V.Vector (Embed v Double x))
forestBatch :: Word64
-> Int
-> Int
-> Int
-> Double
-> Int
-> Vector (Embed v Double x)
-> RPForest Double (Vector (Embed v Double x))
forestBatch Word64
seed Int
maxd Int
minl Int
ntrees Double
pnz Int
dim Vector (Embed v Double x)
src =
  let
    rvss :: IntMap (Vector (SVector Double))
rvss = Word64
-> Gen (IntMap (Vector (SVector Double)))
-> IntMap (Vector (SVector Double))
forall a. Word64 -> Gen a -> a
sample Word64
seed (Gen (IntMap (Vector (SVector Double)))
 -> IntMap (Vector (SVector Double)))
-> Gen (IntMap (Vector (SVector Double)))
-> IntMap (Vector (SVector Double))
forall a b. (a -> b) -> a -> b
$ do
      [Vector (SVector Double)]
rvs <- Int
-> Gen (Vector (SVector Double))
-> GenT Identity [Vector (SVector Double)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
ntrees (Gen (Vector (SVector Double))
 -> GenT Identity [Vector (SVector Double)])
-> Gen (Vector (SVector Double))
-> GenT Identity [Vector (SVector Double)]
forall a b. (a -> b) -> a -> b
$ Int
-> GenT Identity (SVector Double) -> Gen (Vector (SVector Double))
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
maxd (Double
-> Int -> GenT Identity Double -> GenT Identity (SVector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Double -> Int -> GenT m a -> GenT m (SVector a)
sparse Double
pnz Int
dim GenT Identity Double
forall (m :: * -> *). Monad m => GenT m Double
stdNormal)
      IntMap (Vector (SVector Double))
-> Gen (IntMap (Vector (SVector Double)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IntMap (Vector (SVector Double))
 -> Gen (IntMap (Vector (SVector Double))))
-> IntMap (Vector (SVector Double))
-> Gen (IntMap (Vector (SVector Double)))
forall a b. (a -> b) -> a -> b
$ [(Int, Vector (SVector Double))]
-> IntMap (Vector (SVector Double))
forall a. [(Int, a)] -> IntMap a
IM.fromList ([(Int, Vector (SVector Double))]
 -> IntMap (Vector (SVector Double)))
-> [(Int, Vector (SVector Double))]
-> IntMap (Vector (SVector Double))
forall a b. (a -> b) -> a -> b
$ [Int]
-> [Vector (SVector Double)] -> [(Int, Vector (SVector Double))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. ] [Vector (SVector Double)]
rvs
    ts :: IntMap (RPT Double () (Vector (Embed v Double x)))
ts = Int
-> Int
-> IntMap (Vector (SVector Double))
-> Vector (Embed v Double x)
-> IntMap (RPT Double () (Vector (Embed v Double x)))
forall d (u :: * -> *) (v :: * -> *) (v1 :: * -> *) x.
(Ord d, Inner u v, Unbox d, Fractional d, Vector v1 (u d)) =>
Int
-> Int
-> IntMap (v1 (u d))
-> Vector (Embed v d x)
-> IntMap (RPT d () (Vector (Embed v d x)))
createMulti Int
maxd Int
minl IntMap (Vector (SVector Double))
rvss Vector (Embed v Double x)
src
  in (Vector (SVector Double)
 -> RPT Double () (Vector (Embed v Double x))
 -> RPTree Double () (Vector (Embed v Double x)))
-> IntMap (Vector (SVector Double))
-> IntMap (RPT Double () (Vector (Embed v Double x)))
-> RPForest Double (Vector (Embed v Double x))
forall a b c. (a -> b -> c) -> IntMap a -> IntMap b -> IntMap c
IM.intersectionWith Vector (SVector Double)
-> RPT Double () (Vector (Embed v Double x))
-> RPTree Double () (Vector (Embed v Double x))
forall d l a. Vector (SVector d) -> RPT d l a -> RPTree d l a
RPTree IntMap (Vector (SVector Double))
rvss IntMap (RPT Double () (Vector (Embed v Double x)))
ts

-- | Batch random data points
dataBatch :: (Monad m, VG.Vector v a) =>
             Int -- ^ number of points to generate
          -> GenT m a -- ^ random point generator
          -> GenT m (v a)
dataBatch :: Int -> GenT m a -> GenT m (v a)
dataBatch Int
n GenT m a
gg = ((Int -> GenT m (Maybe (a, Int))) -> Int -> GenT m (v a))
-> Int -> (Int -> GenT m (Maybe (a, Int))) -> GenT m (v a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> GenT m (Maybe (a, Int))) -> Int -> GenT m (v a)
forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
(b -> m (Maybe (a, b))) -> b -> m (v a)
VG.unfoldrM Int
0 ((Int -> GenT m (Maybe (a, Int))) -> GenT m (v a))
-> (Int -> GenT m (Maybe (a, Int))) -> GenT m (v a)
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
  if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
    then Maybe (a, Int) -> GenT m (Maybe (a, Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (a, Int)
forall a. Maybe a
Nothing
    else do
    a
x <- GenT m a
gg
    Maybe (a, Int) -> GenT m (Maybe (a, Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (a, Int) -> GenT m (Maybe (a, Int)))
-> Maybe (a, Int) -> GenT m (Maybe (a, Int))
forall a b. (a -> b) -> a -> b
$ (a, Int) -> Maybe (a, Int)
forall a. a -> Maybe a
Just (a
x, Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)