{-# LANGUAGE FlexibleContexts #-}
{-# options_ghc -Wno-unused-imports #-}
module Data.RPTree.Batch (
treeBatch, forestBatch,
dataBatch
) where
import Control.Monad (replicateM)
import GHC.Word (Word64)
import qualified Data.IntMap.Strict as IM (IntMap, fromList, insert, lookup, map, mapWithKey, traverseWithKey, foldlWithKey, foldrWithKey, intersectionWith)
import System.Random.SplitMix.Distributions (Gen, sample, GenT, sampleT, stdNormal)
import qualified Data.Vector as V (Vector, replicateM, fromList)
import qualified Data.Vector.Generic as VG (Vector(..), unfoldrM, length, replicateM, (!), map, freeze, thaw, take, drop, unzip)
import qualified Data.Vector.Unboxed as VU (Vector, Unbox, fromList)
import Data.RPTree.Gen (sparse, dense)
import Data.RPTree.Internal (RPTree(..), RPForest, RPT(..), create, createMulti, SVector, Inner(..), Embed(..))
treeBatch :: Inner SVector v =>
Word64
-> Int
-> Int
-> Double
-> Int
-> V.Vector (Embed v Double x)
-> RPTree Double () (V.Vector (Embed v Double x))
treeBatch :: Word64
-> Int
-> Int
-> Double
-> Int
-> Vector (Embed v Double x)
-> RPTree Double () (Vector (Embed v Double x))
treeBatch Word64
seed Int
maxDepth Int
minLeaf Double
pnz Int
dim Vector (Embed v Double x)
src =
let
rvs :: Vector (SVector Double)
rvs = Word64 -> Gen (Vector (SVector Double)) -> Vector (SVector Double)
forall a. Word64 -> Gen a -> a
sample Word64
seed (Gen (Vector (SVector Double)) -> Vector (SVector Double))
-> Gen (Vector (SVector Double)) -> Vector (SVector Double)
forall a b. (a -> b) -> a -> b
$ Int
-> GenT Identity (SVector Double) -> Gen (Vector (SVector Double))
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
maxDepth (Double
-> Int -> GenT Identity Double -> GenT Identity (SVector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Double -> Int -> GenT m a -> GenT m (SVector a)
sparse Double
pnz Int
dim GenT Identity Double
forall (m :: * -> *). Monad m => GenT m Double
stdNormal)
t :: RPT Double () (Vector (Embed v Double x))
t = Int
-> Int
-> Vector (SVector Double)
-> Vector (Embed v Double x)
-> RPT Double () (Vector (Embed v Double x))
forall d (u :: * -> *) (v :: * -> *) (v1 :: * -> *) x.
(Ord d, Inner u v, Unbox d, Fractional d, Vector v1 (u d)) =>
Int
-> Int
-> v1 (u d)
-> Vector (Embed v d x)
-> RPT d () (Vector (Embed v d x))
create Int
maxDepth Int
minLeaf Vector (SVector Double)
rvs Vector (Embed v Double x)
src
in Vector (SVector Double)
-> RPT Double () (Vector (Embed v Double x))
-> RPTree Double () (Vector (Embed v Double x))
forall d l a. Vector (SVector d) -> RPT d l a -> RPTree d l a
RPTree Vector (SVector Double)
rvs RPT Double () (Vector (Embed v Double x))
t
forestBatch :: (Inner SVector v) =>
Word64
-> Int
-> Int
-> Int
-> Double
-> Int
-> V.Vector (Embed v Double x)
-> RPForest Double (V.Vector (Embed v Double x))
forestBatch :: Word64
-> Int
-> Int
-> Int
-> Double
-> Int
-> Vector (Embed v Double x)
-> RPForest Double (Vector (Embed v Double x))
forestBatch Word64
seed Int
maxd Int
minl Int
ntrees Double
pnz Int
dim Vector (Embed v Double x)
src =
let
rvss :: IntMap (Vector (SVector Double))
rvss = Word64
-> Gen (IntMap (Vector (SVector Double)))
-> IntMap (Vector (SVector Double))
forall a. Word64 -> Gen a -> a
sample Word64
seed (Gen (IntMap (Vector (SVector Double)))
-> IntMap (Vector (SVector Double)))
-> Gen (IntMap (Vector (SVector Double)))
-> IntMap (Vector (SVector Double))
forall a b. (a -> b) -> a -> b
$ do
[Vector (SVector Double)]
rvs <- Int
-> Gen (Vector (SVector Double))
-> GenT Identity [Vector (SVector Double)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
ntrees (Gen (Vector (SVector Double))
-> GenT Identity [Vector (SVector Double)])
-> Gen (Vector (SVector Double))
-> GenT Identity [Vector (SVector Double)]
forall a b. (a -> b) -> a -> b
$ Int
-> GenT Identity (SVector Double) -> Gen (Vector (SVector Double))
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
maxd (Double
-> Int -> GenT Identity Double -> GenT Identity (SVector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Double -> Int -> GenT m a -> GenT m (SVector a)
sparse Double
pnz Int
dim GenT Identity Double
forall (m :: * -> *). Monad m => GenT m Double
stdNormal)
IntMap (Vector (SVector Double))
-> Gen (IntMap (Vector (SVector Double)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IntMap (Vector (SVector Double))
-> Gen (IntMap (Vector (SVector Double))))
-> IntMap (Vector (SVector Double))
-> Gen (IntMap (Vector (SVector Double)))
forall a b. (a -> b) -> a -> b
$ [(Int, Vector (SVector Double))]
-> IntMap (Vector (SVector Double))
forall a. [(Int, a)] -> IntMap a
IM.fromList ([(Int, Vector (SVector Double))]
-> IntMap (Vector (SVector Double)))
-> [(Int, Vector (SVector Double))]
-> IntMap (Vector (SVector Double))
forall a b. (a -> b) -> a -> b
$ [Int]
-> [Vector (SVector Double)] -> [(Int, Vector (SVector Double))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. ] [Vector (SVector Double)]
rvs
ts :: IntMap (RPT Double () (Vector (Embed v Double x)))
ts = Int
-> Int
-> IntMap (Vector (SVector Double))
-> Vector (Embed v Double x)
-> IntMap (RPT Double () (Vector (Embed v Double x)))
forall d (u :: * -> *) (v :: * -> *) (v1 :: * -> *) x.
(Ord d, Inner u v, Unbox d, Fractional d, Vector v1 (u d)) =>
Int
-> Int
-> IntMap (v1 (u d))
-> Vector (Embed v d x)
-> IntMap (RPT d () (Vector (Embed v d x)))
createMulti Int
maxd Int
minl IntMap (Vector (SVector Double))
rvss Vector (Embed v Double x)
src
in (Vector (SVector Double)
-> RPT Double () (Vector (Embed v Double x))
-> RPTree Double () (Vector (Embed v Double x)))
-> IntMap (Vector (SVector Double))
-> IntMap (RPT Double () (Vector (Embed v Double x)))
-> RPForest Double (Vector (Embed v Double x))
forall a b c. (a -> b -> c) -> IntMap a -> IntMap b -> IntMap c
IM.intersectionWith Vector (SVector Double)
-> RPT Double () (Vector (Embed v Double x))
-> RPTree Double () (Vector (Embed v Double x))
forall d l a. Vector (SVector d) -> RPT d l a -> RPTree d l a
RPTree IntMap (Vector (SVector Double))
rvss IntMap (RPT Double () (Vector (Embed v Double x)))
ts
dataBatch :: (Monad m, VG.Vector v a) =>
Int
-> GenT m a
-> GenT m (v a)
dataBatch :: Int -> GenT m a -> GenT m (v a)
dataBatch Int
n GenT m a
gg = ((Int -> GenT m (Maybe (a, Int))) -> Int -> GenT m (v a))
-> Int -> (Int -> GenT m (Maybe (a, Int))) -> GenT m (v a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> GenT m (Maybe (a, Int))) -> Int -> GenT m (v a)
forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
(b -> m (Maybe (a, b))) -> b -> m (v a)
VG.unfoldrM Int
0 ((Int -> GenT m (Maybe (a, Int))) -> GenT m (v a))
-> (Int -> GenT m (Maybe (a, Int))) -> GenT m (v a)
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
then Maybe (a, Int) -> GenT m (Maybe (a, Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (a, Int)
forall a. Maybe a
Nothing
else do
a
x <- GenT m a
gg
Maybe (a, Int) -> GenT m (Maybe (a, Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (a, Int) -> GenT m (Maybe (a, Int)))
-> Maybe (a, Int) -> GenT m (Maybe (a, Int))
forall a b. (a -> b) -> a -> b
$ (a, Int) -> Maybe (a, Int)
forall a. a -> Maybe a
Just (a
x, Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)