-- |
-- Module      : ConClusion.Numeric.Statistics
-- Description : Statistical Functions
-- Copyright   : Phillip Seeber, 2021
-- License     : AGPL-3
-- Maintainer  : phillip.seeber@googlemail.com
-- Stability   : experimental
-- Portability : POSIX, Windows
module ConClusion.Numeric.Statistics
  ( -- * PCA
    PCA (..),
    pca,

    -- * Variance
    normalise,
    meanDeviation,
    covariance,

    -- * Distance Metrics
    DistFn,
    lpNorm,
    manhattan,
    euclidean,
    mahalanobis,

    -- * Cluster Algorithms
    Clusters,

    -- ** DBScan
    DistanceInvalidException (..),
    dbscan,

    -- ** Hierarchical Cluster Analysis
    Dendrogram,
    JoinStrat (..),
    hca,
    cutDendroAt,
  )
where

import ConClusion.Numeric.Data hiding (normalise)
import Data.Aeson hiding (Array)
import Data.Complex
import qualified Data.IntSet as IntSet
import Data.Massiv.Array as Massiv
import Data.Massiv.Array.Unsafe as Massiv
import qualified Data.PSQueue as PQ
import qualified Numeric.LinearAlgebra as LA
import RIO hiding (Vector)
import System.IO.Unsafe (unsafePerformIO)

----------------------------------------------------------------------------------------------------
-- Others/Helpers

-- | Solves eigenvalue problem of a square matrix and obtains its eigenvalues and eigenvectors.
{-# SCC eig #-}
eig ::
  ( Mutable r1 Ix1 (Complex Double),
    Mutable r2 Ix1 (Complex Double),
    LA.Field e,
    Manifest r3 Ix1 e,
    Resize r3 Ix2,
    Load r3 Ix2 e,
    MonadThrow m
  ) =>
  Matrix r3 e ->
  m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
eig :: Matrix r3 e
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
eig Matrix r3 e
covM
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n = IndexException
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException
 -> m (Vector r1 (Complex Double), Matrix r2 (Complex Double)))
-> IndexException
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"eigenvalue problems can only be solved for square matrix"
  | Bool
otherwise = (Vector r1 (Complex Double), Matrix r2 (Complex Double))
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Vector r1 (Complex Double), Matrix r2 (Complex Double))
 -> m (Vector r1 (Complex Double), Matrix r2 (Complex Double)))
-> (Matrix e
    -> (Vector r1 (Complex Double), Matrix r2 (Complex Double)))
-> Matrix e
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector (Complex Double) -> Vector r1 (Complex Double))
-> (Matrix (Complex Double) -> Matrix r2 (Complex Double))
-> (Vector (Complex Double), Matrix (Complex Double))
-> (Vector r1 (Complex Double), Matrix r2 (Complex Double))
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Vector (Complex Double) -> Vector r1 (Complex Double)
forall e r. (Element e, Mutable r Int e) => Vector e -> Vector r e
vecH2M Matrix (Complex Double) -> Matrix r2 (Complex Double)
forall r e. (Mutable r Int e, Element e) => Matrix e -> Matrix r e
matH2M ((Vector (Complex Double), Matrix (Complex Double))
 -> (Vector r1 (Complex Double), Matrix r2 (Complex Double)))
-> (Matrix e -> (Vector (Complex Double), Matrix (Complex Double)))
-> Matrix e
-> (Vector r1 (Complex Double), Matrix r2 (Complex Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix e -> (Vector (Complex Double), Matrix (Complex Double))
forall t.
Field t =>
Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
LA.eig (Matrix e
 -> m (Vector r1 (Complex Double), Matrix r2 (Complex Double)))
-> Matrix e
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
forall a b. (a -> b) -> a -> b
$ Matrix e
cov
  where
    Sz (Int
m :. Int
n) = Matrix r3 e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r3 e
covM
    cov :: Matrix e
cov = Matrix r3 e -> Matrix e
forall r e.
(Manifest r Int e, Element e, Resize r Ix2, Load r Ix2 e) =>
Matrix r e -> Matrix e
matM2H Matrix r3 e
covM

-- | Sort eigenvalues and eigenvectors by magnitude of the eigenvalues in descending order (largest
-- eigenvalues first). Eigenvectors are the columns of the input matrix.
{-# SCC eigSort #-}
eigSort ::
  ( Load r2 Ix2 e,
    MonadThrow m,
    Source r1 Ix1 e,
    Source r2 Ix2 e,
    Mutable r1 Ix1 e,
    Mutable r2 Ix2 e,
    Unbox e,
    Ord e
  ) =>
  (Vector r1 e, Matrix r2 e) ->
  m (Vector r1 e, Matrix r2 e)
eigSort :: (Vector r1 e, Matrix r2 e) -> m (Vector r1 e, Matrix r2 e)
eigSort (Vector r1 e
vec, Matrix r2 e
mat)
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n = IndexException -> m (Vector r1 e, Matrix r2 e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Vector r1 e, Matrix r2 e))
-> IndexException -> m (Vector r1 e, Matrix r2 e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"matrix of the eigenvectors is not a square matrix"
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n' = IndexException -> m (Vector r1 e, Matrix r2 e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Vector r1 e, Matrix r2 e))
-> IndexException -> m (Vector r1 e, Matrix r2 e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"different number of eigenvalues and eigenvectors"
  | Bool
otherwise = do
    let ixedEigenvalues :: Array D Int (e, Int)
ixedEigenvalues = Vector r1 e -> Array D Int Int -> Array D Int (e, Int)
forall r1 ix e1 r2 e2.
(Source r1 ix e1, Source r2 ix e2) =>
Array r1 ix e1 -> Array r2 ix e2 -> Array D ix (e1, e2)
Massiv.zip Vector r1 e
vec Array D Int Int
ixVec
        (Array U Int e
eigenValSortAsc, Array U Int Int
ixSort) = (\Array U Int (e, Int)
a -> (((e, Int) -> e) -> Array U Int (e, Int) -> Array U Int e
forall e ix r e'.
(Unbox e, Source r ix e') =>
(e' -> e) -> Array r ix e' -> Array U ix e
get (e, Int) -> e
forall a b. (a, b) -> a
fst Array U Int (e, Int)
a, ((e, Int) -> Int) -> Array U Int (e, Int) -> Array U Int Int
forall e ix r e'.
(Unbox e, Source r ix e') =>
(e' -> e) -> Array r ix e' -> Array U ix e
get (e, Int) -> Int
forall a b. (a, b) -> b
snd Array U Int (e, Int)
a)) (Array U Int (e, Int) -> (Array U Int e, Array U Int Int))
-> (Array D Int (e, Int) -> Array U Int (e, Int))
-> Array D Int (e, Int)
-> (Array U Int e, Array U Int Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array U Int (e, Int) -> Array U Int (e, Int)
forall r e.
(Mutable r Int e, Ord e) =>
Array r Int e -> Array r Int e
quicksort (Array U Int (e, Int) -> Array U Int (e, Int))
-> (Array D Int (e, Int) -> Array U Int (e, Int))
-> Array D Int (e, Int)
-> Array U Int (e, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array D Int (e, Int) -> (Array U Int e, Array U Int Int))
-> Array D Int (e, Int) -> (Array U Int e, Array U Int Int)
forall a b. (a -> b) -> a -> b
$ Array D Int (e, Int)
ixedEigenvalues
        eigenVecSortAsc :: Array D Ix2 e
eigenVecSortAsc = Sz Ix2 -> (Ix2 -> Ix2) -> Matrix r2 e -> Array D Ix2 e
forall r' ix' e ix.
(Source r' ix' e, Index ix) =>
Sz ix -> (ix -> ix') -> Array r' ix' e -> Array D ix e
backpermute' (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Ix2 -> Sz Ix2) -> Ix2 -> Sz Ix2
forall a b. (a -> b) -> a -> b
$ Int
m Int -> Int -> Ix2
:. Int
n) (\(Int
r :. Int
c) -> Int
r Int -> Int -> Ix2
:. (Array U Int Int
ixSort Array U Int Int -> Int -> Int
forall r ix e. Manifest r ix e => Array r ix e -> ix -> e
! Int
c)) Matrix r2 e
mat
        eigenValSort :: Array D Int e
eigenValSort = Dim -> Array U Int e -> Array D Int e
forall r ix e. Source r ix e => Dim -> Array r ix e -> Array D ix e
reverse' (Int -> Dim
Dim Int
1) Array U Int e
eigenValSortAsc
        eigenVecSort :: Array D Ix2 e
eigenVecSort = Dim -> Array D Ix2 e -> Array D Ix2 e
forall r ix e. Source r ix e => Dim -> Array r ix e -> Array D ix e
reverse' (Int -> Dim
Dim Int
1) Array D Ix2 e
eigenVecSortAsc
    (Vector r1 e, Matrix r2 e) -> m (Vector r1 e, Matrix r2 e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Array D Int e -> Vector r1 e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Int e
eigenValSort, Array D Ix2 e -> Matrix r2 e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Ix2 e
eigenVecSort)
  where
    Sz (Int
m :. Int
n) = Matrix r2 e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r2 e
mat
    Sz Int
n' = Vector r1 e -> Sz Int
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Vector r1 e
vec
    ixVec :: Array D Int Int
ixVec = Comp -> Sz Int -> (Int -> Int) -> Array D Int Int
forall r ix e.
Construct r ix e =>
Comp -> Sz ix -> (Int -> e) -> Array r ix e
makeArrayLinear @D Comp
Seq (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
n') Int -> Int
forall a. a -> a
id
    get :: (e' -> e) -> Array r ix e' -> Array U ix e
get e' -> e
acc = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array D ix e -> Array U ix e)
-> (Array r ix e' -> Array D ix e) -> Array r ix e' -> Array U ix e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e' -> e) -> Array r ix e' -> Array D ix e
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map e' -> e
acc

----------------------------------------------------------------------------------------------------
-- Principal Component Analysis

data PCA = PCA
  { -- | Original feature matrix.
    PCA -> Matrix U Double
x :: Matrix U Double,
    -- | Feature matrix in mean deviation form.
    PCA -> Matrix U Double
x' :: Matrix U Double,
    -- | Transformed data.
    PCA -> Matrix U Double
y :: Matrix U Double,
    -- | Transformation matrix to transform feature matrix into PCA result matrix.
    PCA -> Matrix U Double
a :: Matrix U Double,
    -- | Mean squared error introduced by PCA.
    PCA -> Double
mse :: Double,
    -- | Percentage of the behaviour captured in the remaining dimensions.
    PCA -> Double
remaining :: Double,
    -- | All eigenvalues from the diagonalisation of the covariance matrix.
    PCA -> Vector U Double
allEigenValues :: Vector U Double,
    -- | Eigenvalues that were kept for PCA.
    PCA -> Vector U Double
pcaEigenValues :: Vector U Double,
    -- | All eigenvectors from the diagonalisation of the covariance matrix.
    PCA -> Matrix U Double
allEigenVecs :: Matrix U Double,
    -- | Eigenvectors that were kept for PCA.
    PCA -> Matrix U Double
pcaEigenVecs :: Matrix U Double
  }

-- | Transform the input values with a transformation matrix \(\mathbf{A}\), where \(\mathbf{A}\) is
-- constructed from the eigenvectors associated to the largest eigenvalues.
{-# SCC transformToPCABasis #-}
transformToPCABasis ::
  ( Source (R r) Ix2 e,
    Extract r Ix2 e,
    Mutable r Ix2 e,
    Numeric r e,
    MonadThrow m
  ) =>
  -- | Number of dimensions to keep from PCA.
  Int ->
  -- | Matrix of the eigenvectors, sorted descendingly by eigenvalues, where the eigenvectors are
  -- the columns of the matrix.
  Matrix r e ->
  -- | Feature matrix in mean deviation form.
  Matrix r e ->
  -- | Input data transformed by PCA to lower dimensions, and the transformation matrix
  -- \(\mathbf{A}\).
  m (Matrix r e, Matrix r e)
transformToPCABasis :: Int -> Matrix r e -> Matrix r e -> m (Matrix r e, Matrix r e)
transformToPCABasis Int
nDim Matrix r e
eigenVecMat Matrix r e
featureMat
  | Int
mE Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nE = IndexException -> m (Matrix r e, Matrix r e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Matrix r e, Matrix r e))
-> IndexException -> m (Matrix r e, Matrix r e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"the matrix of the eigenvectors must be a quadratic matrix"
  | Int
nDim Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = IndexException -> m (Matrix r e, Matrix r e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Matrix r e, Matrix r e))
-> IndexException -> m (Matrix r e, Matrix r e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"the number of dimensions of the PCA is smaller than or zero"
  | Int
nDim Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nE = IndexException -> m (Matrix r e, Matrix r e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Matrix r e, Matrix r e))
-> IndexException -> m (Matrix r e, Matrix r e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"more than the possible amount of dimensions has been selected"
  | Int
mE Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
mF = IndexException -> m (Matrix r e, Matrix r e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Matrix r e, Matrix r e))
-> IndexException -> m (Matrix r e, Matrix r e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"eigenvector matrix and feature matrix have mismatching dimensions"
  | Bool
otherwise = do
    Matrix r e
matA <- Array D Ix2 e -> Matrix r e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Ix2 e -> Matrix r e)
-> (Array (R r) Ix2 e -> Array D Ix2 e)
-> Array (R r) Ix2 e
-> Matrix r e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array (R r) Ix2 e -> Array D Ix2 e
forall r e. Source r Ix2 e => Array r Ix2 e -> Array D Ix2 e
transpose (Array (R r) Ix2 e -> Matrix r e)
-> m (Array (R r) Ix2 e) -> m (Matrix r e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ix2 -> Sz Ix2 -> Matrix r e -> m (Array (R r) Ix2 e)
forall (m :: * -> *) r ix e.
(MonadThrow m, Extract r ix e) =>
ix -> Sz ix -> Array r ix e -> m (Array (R r) ix e)
extractM (Int
0 Int -> Int -> Ix2
:. Int
0) (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Ix2 -> Sz Ix2) -> Ix2 -> Sz Ix2
forall a b. (a -> b) -> a -> b
$ Int
mE Int -> Int -> Ix2
:. Int
nDim) Matrix r e
eigenVecMat
    Matrix r e
pcaData <- Matrix r e
matA Matrix r e -> Matrix r e -> m (Matrix r e)
forall r e (m :: * -> *).
(Numeric r e, Mutable r Ix2 e, MonadThrow m) =>
Matrix r e -> Matrix r e -> m (Matrix r e)
.><. Matrix r e
featureMat
    (Matrix r e, Matrix r e) -> m (Matrix r e, Matrix r e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix r e
pcaData, Matrix r e
matA)
  where
    Sz (Int
mE :. Int
nE) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r e
eigenVecMat
    Sz (Int
mF :. Int
_nF) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r e
featureMat

-- | Performs a PCA on the feature matrix \(\mathbf{X}\) by solving the eigenproblem of the
-- covariance matrix. The function takes the feature matrix directly and perfoms the conversion
-- to mean deviation form, the calculation of the covariance matrix and the eigenvalue problem
-- automatically.
{-# SCC pca #-}
pca ::
  ( Numeric r Double,
    Mutable r Ix2 Double,
    Manifest r Ix1 Double,
    Source (R r) Ix2 Double,
    Extract r Ix2 Double,
    MonadThrow m
  ) =>
  -- | Dimensionalty after PCA transformation.
  Int ->
  -- | \(m \times n\) Feaute matrix \(\mathbf{X}\), with \(m\) different measurements (rows) in
  -- \(n\) different trials (columns).
  Matrix r Double ->
  m PCA
pca :: Int -> Matrix r Double -> m PCA
pca Int
dim Matrix r Double
x = do
  -- Calculate the mean deviation form of the feature matrix and the covariance matrix from it.
  let x' :: Matrix r Double
x' = Matrix r Double -> Matrix r Double
forall e r.
(Ord e, Unbox e, Numeric r e, Fractional e, Source r Ix2 e,
 Mutable r Ix2 e) =>
Array r Ix2 e -> Array r Ix2 e
normalise (Matrix r Double -> Matrix r Double)
-> (Matrix r Double -> Matrix r Double)
-> Matrix r Double
-> Matrix r Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix r Double -> Matrix r Double
forall r e.
(Source r Ix2 e, Fractional e, Unbox e, Numeric r e,
 Mutable r Ix2 e) =>
Matrix r e -> Matrix r e
meanDeviation (Matrix r Double -> Matrix r Double)
-> Matrix r Double -> Matrix r Double
forall a b. (a -> b) -> a -> b
$ Matrix r Double
x
      cov :: Matrix r Double
cov = Matrix r Double -> Matrix r Double
forall r e.
(Numeric r e, Mutable r Ix2 e, Fractional e) =>
Matrix r e -> Matrix r e
covariance Matrix r Double
x'

  -- Obtain eigenvalues and eigenvectors of the covariance matrix and sort them.
  (Vector U (Complex Double)
eigValsC :: Vector U (Complex Double), Matrix U (Complex Double)
eigVecsC :: Matrix U (Complex Double)) <- Matrix r Double
-> m (Vector U (Complex Double), Matrix U (Complex Double))
forall r1 r2 e r3 (m :: * -> *).
(Mutable r1 Int (Complex Double), Mutable r2 Int (Complex Double),
 Field e, Manifest r3 Int e, Resize r3 Ix2, Load r3 Ix2 e,
 MonadThrow m) =>
Matrix r3 e
-> m (Vector r1 (Complex Double), Matrix r2 (Complex Double))
eig Matrix r Double
cov
  let eigValsR :: Vector U Double
eigValsR = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array D Int Double -> Vector U Double)
-> (Vector U (Complex Double) -> Array D Int Double)
-> Vector U (Complex Double)
-> Vector U Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Complex Double -> Double)
-> Vector U (Complex Double) -> Array D Int Double
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map Complex Double -> Double
forall a. Complex a -> a
realPart (Vector U (Complex Double) -> Vector U Double)
-> Vector U (Complex Double) -> Vector U Double
forall a b. (a -> b) -> a -> b
$ Vector U (Complex Double)
eigValsC
      eigVecsR :: Matrix r Double
eigVecsR = Array D Ix2 Double -> Matrix r Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Ix2 Double -> Matrix r Double)
-> (Matrix U (Complex Double) -> Array D Ix2 Double)
-> Matrix U (Complex Double)
-> Matrix r Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Complex Double -> Double)
-> Matrix U (Complex Double) -> Array D Ix2 Double
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map Complex Double -> Double
forall a. Complex a -> a
realPart (Matrix U (Complex Double) -> Matrix r Double)
-> Matrix U (Complex Double) -> Matrix r Double
forall a b. (a -> b) -> a -> b
$ Matrix U (Complex Double)
eigVecsC
  (Vector U Double
eValS, Matrix r Double
eVecS) <- (Vector U Double, Matrix r Double)
-> m (Vector U Double, Matrix r Double)
forall r2 e (m :: * -> *) r1.
(Load r2 Ix2 e, MonadThrow m, Source r1 Int e, Source r2 Ix2 e,
 Mutable r1 Int e, Mutable r2 Ix2 e, Unbox e, Ord e) =>
(Vector r1 e, Matrix r2 e) -> m (Vector r1 e, Matrix r2 e)
eigSort (Vector U Double
eigValsR, Matrix r Double
eigVecsR)

  -- Use the subset of the eigenvectors with the largest eigenvalues to transform the features in
  -- mean deviation form into the result matrix Y.
  (Matrix r Double
pcaData, Matrix r Double
matA) <- Int
-> Matrix r Double
-> Matrix r Double
-> m (Matrix r Double, Matrix r Double)
forall r e (m :: * -> *).
(Source (R r) Ix2 e, Extract r Ix2 e, Mutable r Ix2 e, Numeric r e,
 MonadThrow m) =>
Int -> Matrix r e -> Matrix r e -> m (Matrix r e, Matrix r e)
transformToPCABasis Int
dim Matrix r Double
eVecS Matrix r Double
x'

  -- Reconstuct the original data from lower dimensions and calculate the mean squared deviation.
  Matrix r Double
reconstructX <- (Array D Ix2 Double -> Matrix r Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Ix2 Double -> Matrix r Double)
-> (Matrix r Double -> Array D Ix2 Double)
-> Matrix r Double
-> Matrix r Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix r Double -> Array D Ix2 Double
forall r e. Source r Ix2 e => Array r Ix2 e -> Array D Ix2 e
transpose (Matrix r Double -> Matrix r Double)
-> Matrix r Double -> Matrix r Double
forall a b. (a -> b) -> a -> b
$ Matrix r Double
matA) Matrix r Double -> Matrix r Double -> m (Matrix r Double)
forall r e (m :: * -> *).
(Numeric r e, Mutable r Ix2 e, MonadThrow m) =>
Matrix r e -> Matrix r e -> m (Matrix r e)
.><. Matrix r Double
pcaData
  Double
mse <- (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) (Double -> Double)
-> (Matrix r Double -> Double) -> Matrix r Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Ix2 Double -> Double
forall r ix e. (Source r ix e, Num e) => Array r ix e -> e
Massiv.sum (Array D Ix2 Double -> Double)
-> (Matrix r Double -> Array D Ix2 Double)
-> Matrix r Double
-> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> Matrix r Double -> Array D Ix2 Double
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
2) (Matrix r Double -> Double) -> m (Matrix r Double) -> m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Matrix r Double
x' Matrix r Double -> Matrix r Double -> m (Matrix r Double)
forall r ix e (m :: * -> *).
(Load r ix e, Numeric r e, MonadThrow m) =>
Array r ix e -> Array r ix e -> m (Array r ix e)
.-. Matrix r Double
reconstructX)

  -- For output give the eigenvalues and eigenvectors that were kept.
  Array M Int Double
pcaEigenValues <- Int -> Sz Int -> Vector U Double -> m (Array (R U) Int Double)
forall (m :: * -> *) r ix e.
(MonadThrow m, Extract r ix e) =>
ix -> Sz ix -> Array r ix e -> m (Array (R r) ix e)
extractM Int
0 (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
dim) Vector U Double
eValS
  Array (R r) Ix2 Double
pcaEigenVecs <- Ix2 -> Sz Ix2 -> Matrix r Double -> m (Array (R r) Ix2 Double)
forall (m :: * -> *) r ix e.
(MonadThrow m, Extract r ix e) =>
ix -> Sz ix -> Array r ix e -> m (Array (R r) ix e)
extractM (Int
0 Int -> Int -> Ix2
:. Int
0) (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Ix2 -> Sz Ix2) -> Ix2 -> Sz Ix2
forall a b. (a -> b) -> a -> b
$ Int
m Int -> Int -> Ix2
:. Int
dim) Matrix r Double
eVecS

  -- Calculate the amount of behaviour that could be kept.
  let remaining :: Double
remaining = (Array M Int Double -> Double
forall r ix e. (Source r ix e, Num e) => Array r ix e -> e
Massiv.sum Array M Int Double
pcaEigenValues Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Vector U Double -> Double
forall r ix e. (Source r ix e, Num e) => Array r ix e -> e
Massiv.sum Vector U Double
eValS) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
100

  PCA -> m PCA
forall (m :: * -> *) a. Monad m => a -> m a
return
    PCA :: Matrix U Double
-> Matrix U Double
-> Matrix U Double
-> Matrix U Double
-> Double
-> Double
-> Vector U Double
-> Vector U Double
-> Matrix U Double
-> Matrix U Double
-> PCA
PCA
      { $sel:x:PCA :: Matrix U Double
x = Matrix r Double -> Matrix U Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
x,
        $sel:x':PCA :: Matrix U Double
x' = Matrix r Double -> Matrix U Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
x',
        $sel:y:PCA :: Matrix U Double
y = Matrix r Double -> Matrix U Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
pcaData,
        $sel:a:PCA :: Matrix U Double
a = Matrix r Double -> Matrix U Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
matA,
        $sel:mse:PCA :: Double
mse = Double
mse,
        $sel:remaining:PCA :: Double
remaining = Double
remaining,
        $sel:allEigenValues:PCA :: Vector U Double
allEigenValues = Vector U Double
eValS,
        $sel:pcaEigenValues:PCA :: Vector U Double
pcaEigenValues = Array M Int Double -> Vector U Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array M Int Double
pcaEigenValues,
        $sel:allEigenVecs:PCA :: Matrix U Double
allEigenVecs = Matrix r Double -> Matrix U Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Matrix r Double
eVecS,
        $sel:pcaEigenVecs:PCA :: Matrix U Double
pcaEigenVecs = Array (R r) Ix2 Double -> Matrix U Double
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array (R r) Ix2 Double
pcaEigenVecs
      }
  where
    Sz (Int
m :. Int
n) = Matrix r Double -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r Double
x

----------------------------------------------------------------------------------------------------
-- Variance

-- | Subtract the mean value of all columns from the feature matrix. Brings the feature matrix to
-- mean deviation form.
{-# SCC meanDeviation #-}
meanDeviation ::
  ( Source r Ix2 e,
    Fractional e,
    Unbox e,
    Numeric r e,
    Mutable r Ix2 e
  ) =>
  Matrix r e ->
  Matrix r e
meanDeviation :: Matrix r e -> Matrix r e
meanDeviation Matrix r e
mat = Matrix r e
mat Matrix r e -> Matrix r e -> Matrix r e
forall r ix e.
(Load r ix e, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!-! Array D Ix2 e -> Matrix r e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Array D Ix2 e
meanMat
  where
    Sz (Int
_ :. Int
n) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
Massiv.size Matrix r e
mat
    featueMean :: Array D Int e
featueMean = (e -> e -> e) -> e -> Matrix r e -> Array D (Lower Ix2) e
forall ix r e a.
(Index (Lower ix), Source r ix e) =>
(a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
Massiv.foldlInner e -> e -> e
forall a. Num a => a -> a -> a
(+) e
0 Matrix r e
mat Array D Int e -> e -> Array D Int e
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    meanMat :: Array D Ix2 e
meanMat = Sz Int -> (e -> Int -> e) -> Array U (Lower Ix2) e -> Array D Ix2 e
forall ix r a b.
(Index ix, Manifest r (Lower ix) a) =>
Sz Int -> (a -> Int -> b) -> Array r (Lower ix) a -> Array D ix b
expandInner (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
n) e -> Int -> e
forall a b. a -> b -> a
const (Array U Int e -> Array D Ix2 e)
-> (Array D Int e -> Array U Int e)
-> Array D Int e
-> Array D Ix2 e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array D Int e -> Array D Ix2 e) -> Array D Int e -> Array D Ix2 e
forall a b. (a -> b) -> a -> b
$ Array D Int e
featueMean

-- | Obtains the covariance matrix \(\mathbf{C_X}\) from the feature matrix \(\mathbf{X}\).
-- \[
--   \mathbf{C_X} \equiv \frac{1}{n - 1} \mathbf{X} \mathbf{X}^T
-- \]
-- where \(n\) is the number of columns in the matrix.
--
-- The feature matrix should be in mean deviation form, see 'meanDeviation'.
{-# SCC covariance #-}
covariance :: (Numeric r e, Mutable r Ix2 e, Fractional e) => Matrix r e -> Matrix r e
covariance :: Matrix r e -> Matrix r e
covariance Matrix r e
x = (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ (Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n e -> e -> e
forall a. Num a => a -> a -> a
- e
1)) e -> Matrix r e -> Matrix r e
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. (Matrix r e
x Matrix r e -> Matrix r e -> Matrix r e
forall r e.
(Numeric r e, Mutable r Ix2 e) =>
Matrix r e -> Matrix r e -> Matrix r e
!><! (Array D Ix2 e -> Matrix r e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Ix2 e -> Matrix r e)
-> (Matrix r e -> Array D Ix2 e) -> Matrix r e -> Matrix r e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix r e -> Array D Ix2 e
forall r e. Source r Ix2 e => Array r Ix2 e -> Array D Ix2 e
transpose (Matrix r e -> Matrix r e) -> Matrix r e -> Matrix r e
forall a b. (a -> b) -> a -> b
$ Matrix r e
x))
  where
    Sz (Int
_ :. Int
n) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r e
x

-- | Normalise each value so that the maximum absolute value in each row becomes one.
normalise ::
  ( Ord e,
    Unbox e,
    Numeric r e,
    Fractional e,
    Source r Ix2 e,
    Mutable r Ix2 e
  ) =>
  Array r Ix2 e ->
  Array r Ix2 e
normalise :: Array r Ix2 e -> Array r Ix2 e
normalise Array r Ix2 e
mat =
  let absMat :: Array D Ix2 e
absMat = (e -> e) -> Array r Ix2 e -> Array D Ix2 e
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map e -> e
forall a. Num a => a -> a
abs Array r Ix2 e
mat
      maxPerRow :: Array U Int e
maxPerRow = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array D Int e -> Array U Int e)
-> (Array D Ix2 e -> Array D Int e)
-> Array D Ix2 e
-> Array U Int e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> e -> e) -> e -> Array D Ix2 e -> Array D (Lower Ix2) e
forall ix r e a.
(Index (Lower ix), Source r ix e) =>
(a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
foldlInner e -> e -> e
forall a. Ord a => a -> a -> a
max e
0 (Array D Ix2 e -> Array U Int e) -> Array D Ix2 e -> Array U Int e
forall a b. (a -> b) -> a -> b
$ Array D Ix2 e
absMat
      divMat :: Array r Ix2 e
divMat = Array D Ix2 e -> Array r Ix2 e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Ix2 e -> Array r Ix2 e)
-> (Array U Int e -> Array D Ix2 e)
-> Array U Int e
-> Array r Ix2 e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> e) -> Array D Ix2 e -> Array D Ix2 e
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/) (Array D Ix2 e -> Array D Ix2 e)
-> (Array U Int e -> Array D Ix2 e)
-> Array U Int e
-> Array D Ix2 e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz Int -> (e -> Int -> e) -> Array U (Lower Ix2) e -> Array D Ix2 e
forall ix r a b.
(Index ix, Manifest r (Lower ix) a) =>
Sz Int -> (a -> Int -> b) -> Array r (Lower ix) a -> Array D ix b
expandInner @Ix2 (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
n) e -> Int -> e
forall a b. a -> b -> a
const (Array U Int e -> Array r Ix2 e) -> Array U Int e -> Array r Ix2 e
forall a b. (a -> b) -> a -> b
$ Array U Int e
maxPerRow
   in Array r Ix2 e
divMat Array r Ix2 e -> Array r Ix2 e -> Array r Ix2 e
forall r ix e.
(Load r ix e, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!*! Array r Ix2 e
mat
  where
    Sz (Int
_ :. Int
n) = Array r Ix2 e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array r Ix2 e
mat

----------------------------------------------------------------------------------------------------
-- Distance Measures

-- | Distance matrix generator functions.
type DistFn r e = Matrix r e -> Matrix r e

-- | Builds the distance measures in a permutation matrix/distance matrix.
buildDistMat ::
  (Mutable r Ix2 e) =>
  -- | Zip function to combine the elements of vectors \(\mathbf{a}\) \(\mathbf{b}\). Usually @(-)@.
  -- \( f(\mathbf{a}_i, \mathbf{b}_i) = \mathbf{c} \)
  (e -> e -> a) ->
  -- | Fold the vector \(\mathbf{c}\) elementwise to a distance \(d\).
  (a -> a -> a) ->
  -- | Accumulator of the fold function.
  a ->
  -- | \(m \times n\) matrix, with \(n\) \(m\)-dimensional points (column vectors of the matrix).
  Matrix r e ->
  -- | Resulting distance matrix.
  Matrix D a
buildDistMat :: (e -> e -> a) -> (a -> a -> a) -> a -> Matrix r e -> Matrix D a
buildDistMat e -> e -> a
zipFn a -> a -> a
foldFn a
acc Matrix r e
mat =
  let a :: Array D Ix3 e
a = Array D Ix3 e -> Array D Ix3 e
forall ix r' e.
(Index (Lower ix), Source r' ix e) =>
Array r' ix e -> Array D ix e
transposeOuter (Array D Ix3 e -> Array D Ix3 e)
-> (Matrix r e -> Array D Ix3 e) -> Matrix r e -> Array D Ix3 e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz Int -> (e -> Int -> e) -> Array r (Lower Ix3) e -> Array D Ix3 e
forall ix r a b.
(Index ix, Manifest r (Lower ix) a) =>
Sz Int -> (a -> Int -> b) -> Array r (Lower ix) a -> Array D ix b
expandOuter @Ix3 (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
n) e -> Int -> e
forall a b. a -> b -> a
const (Matrix r e -> Array D Ix3 e) -> Matrix r e -> Array D Ix3 e
forall a b. (a -> b) -> a -> b
$ Matrix r e
mat
      b :: Array D Ix3 e
b = Array D Ix3 e -> Array D Ix3 e
forall ix r' e.
(Index (Lower ix), Source r' ix e) =>
Array r' ix e -> Array D ix e
transposeInner Array D Ix3 e
a
      ab :: Array D Ix3 a
ab = (e -> e -> a) -> Array D Ix3 e -> Array D Ix3 e -> Array D Ix3 a
forall r1 ix e1 r2 e2 e.
(Source r1 ix e1, Source r2 ix e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
Massiv.zipWith e -> e -> a
zipFn Array D Ix3 e
a Array D Ix3 e
b
      d :: Array D (Lower Ix3) a
d = (a -> a -> a) -> a -> Array D Ix3 a -> Array D (Lower Ix3) a
forall ix r e a.
(Index (Lower ix), Source r ix e) =>
(a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
foldlInner a -> a -> a
foldFn a
acc Array D Ix3 a
ab
   in Array D (Lower Ix3) a
Matrix D a
d
  where
    Sz (Int
_ :. Int
n) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r e
mat

-- | The \(L_p\) norm between two vectors. Generalisation of Manhattan and Euclidean distances.
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \left( \sum \limits_{i=1}^n \lvert \mathbf{a}_i - \mathbf{b}_i \rvert ^p \right) ^ \frac{1}{p}
-- \]
{-# SCC lpNorm #-}
lpNorm :: (Mutable r Ix2 e, Floating e) => Int -> DistFn r e
lpNorm :: Int -> DistFn r e
lpNorm Int
p = Array D Ix2 e -> Array r Ix2 e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Ix2 e -> Array r Ix2 e)
-> (Array r Ix2 e -> Array D Ix2 e) -> DistFn r e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> e -> e)
-> (e -> e -> e) -> e -> Array r Ix2 e -> Array D Ix2 e
forall r e a.
Mutable r Ix2 e =>
(e -> e -> a) -> (a -> a -> a) -> a -> Matrix r e -> Matrix D a
buildDistMat e -> e -> e
zipFn e -> e -> e
foldFn e
0
  where
    zipFn :: e -> e -> e
zipFn e
a e
b = (e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
p) (e -> e) -> (e -> e) -> e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> e
forall a. Num a => a -> a
abs (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
a e -> e -> e
forall a. Num a => a -> a -> a
- e
b
    foldFn :: e -> e -> e
foldFn e
a e
b = (e -> e -> e
forall a. Floating a => a -> a -> a
** (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p)) (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
a e -> e -> e
forall a. Num a => a -> a -> a
+ e
b

-- | The Manhattan distance between two vectors. Specialisation of the \(L_p\) norm for \(p = 1\).
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \sum \limits_{i=1}^n \lvert \mathbf{a}_i - \mathbf{b}_i \rvert
-- \]
{-# SCC manhattan #-}
manhattan :: (Mutable r Ix2 e, Floating e) => DistFn r e
manhattan :: DistFn r e
manhattan = Int -> DistFn r e
forall r e. (Mutable r Ix2 e, Floating e) => Int -> DistFn r e
lpNorm Int
1

-- | The Euclidean distance between two vectors. Specialisation of the \(L_p\) norm for \(p = 2\).
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \sqrt{\sum \limits_{i=1}^n (\mathbf{a}_i - \mathbf{b}_i)^2}
-- \]
{-# SCC euclidean #-}
euclidean :: (Mutable r Ix2 e, Floating e) => DistFn r e
euclidean :: DistFn r e
euclidean = Int -> DistFn r e
forall r e. (Mutable r Ix2 e, Floating e) => Int -> DistFn r e
lpNorm Int
2

-- | Mahalanobis distance between points. Suitable for non correlated axes.
-- \[
--   d(\mathbf{a}, \mathbf{b}) = \sqrt{(\mathbf{a} - \mathbf{b})^T \mathbf{S}^{-1} (\mathbf{a} - \mathbf{b})}
-- \]
-- where \(\mathbf{S}\) is the covariance matrix.
{-# SCC mahalanobis #-}
mahalanobis :: (Unbox e, Numeric r e, Mutable r Ix2 e, Mutable r Ix1 e, LA.Field e) => DistFn r e
mahalanobis :: DistFn r e
mahalanobis Matrix r e
points =
  let a :: Array D Ix3 e
a = Array D Ix3 e -> Array D Ix3 e
forall ix r' e.
(Index (Lower ix), Source r' ix e) =>
Array r' ix e -> Array D ix e
transposeOuter (Array D Ix3 e -> Array D Ix3 e)
-> (Matrix r e -> Array D Ix3 e) -> Matrix r e -> Array D Ix3 e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz Int -> (e -> Int -> e) -> Array r (Lower Ix3) e -> Array D Ix3 e
forall ix r a b.
(Index ix, Manifest r (Lower ix) a) =>
Sz Int -> (a -> Int -> b) -> Array r (Lower ix) a -> Array D ix b
expandOuter @Ix3 (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
n) e -> Int -> e
forall a b. a -> b -> a
const (Matrix r e -> Array D Ix3 e) -> Matrix r e -> Array D Ix3 e
forall a b. (a -> b) -> a -> b
$ Matrix r e
points
      b :: Array D Ix3 e
b = Array D Ix3 e -> Array D Ix3 e
forall ix r' e.
(Index (Lower ix), Source r' ix e) =>
Array r' ix e -> Array D ix e
transposeInner Array D Ix3 e
a
      abDiff :: Array U Ix3 e
abDiff = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array D Ix3 e -> Array U Ix3 e) -> Array D Ix3 e -> Array U Ix3 e
forall a b. (a -> b) -> a -> b
$ Array D Ix3 e
a Array D Ix3 e -> Array D Ix3 e -> Array D Ix3 e
forall r ix e.
(Load r ix e, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!-! Array D Ix3 e
b
      ixArray :: Array U Ix2 Ix2
ixArray = Comp -> Sz Ix2 -> (Ix2 -> Ix2) -> Array U Ix2 Ix2
forall r ix e.
Construct r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray @U @Ix2 @Ix2 Comp
Par (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Ix2 -> Sz Ix2) -> Ix2 -> Sz Ix2
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Ix2
:. Int
n) Ix2 -> Ix2
forall a. a -> a
id
      distMat :: Array D Ix2 e
distMat =
        (Ix2 -> e) -> Array U Ix2 Ix2 -> Array D Ix2 e
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map
          ( \(Int
x :. Int
y) ->
              let ab :: Array U Int e
ab = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array M Int e -> Array U Int e) -> Array M Int e -> Array U Int e
forall a b. (a -> b) -> a -> b
$ Array U Ix3 e
abDiff Array U Ix3 e -> Int -> Elt U Ix3 e
forall r ix e.
OuterSlice r ix e =>
Array r ix e -> Int -> Elt r ix e
!> Int
x Array M Ix2 e -> Int -> Elt M Ix2 e
forall r ix e.
OuterSlice r ix e =>
Array r ix e -> Int -> Elt r ix e
!> Int
y
               in Array U Int e
ab Array U Int e -> Matrix U e -> Array U Int e
forall r e.
(Numeric r e, Mutable r Int e, Mutable r Ix2 e) =>
Vector r e -> Matrix r e -> Vector r e
><! Matrix U e
covInv Array U Int e -> Array U Int e -> e
forall r e.
(Numeric r e, Source r Int e) =>
Vector r e -> Vector r e -> e
!.! Array U Int e
ab
          )
          Array U Ix2 Ix2
ixArray
   in Array D Ix2 e -> Matrix r e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Ix2 e -> Matrix r e)
-> (Array D Ix2 e -> Array D Ix2 e) -> Array D Ix2 e -> Matrix r e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> e) -> Array D Ix2 e -> Array D Ix2 e
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map e -> e
forall a. Floating a => a -> a
sqrt (Array D Ix2 e -> Matrix r e) -> Array D Ix2 e -> Matrix r e
forall a b. (a -> b) -> a -> b
$ Array D Ix2 e
distMat
  where
    Sz (Int
_ :. Int
n) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r e
points
    cov :: Matrix r e
cov = DistFn r e
forall r e.
(Numeric r e, Mutable r Ix2 e, Fractional e) =>
Matrix r e -> Matrix r e
covariance DistFn r e -> DistFn r e -> DistFn r e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistFn r e
forall r e.
(Source r Ix2 e, Fractional e, Unbox e, Numeric r e,
 Mutable r Ix2 e) =>
Matrix r e -> Matrix r e
meanDeviation DistFn r e -> DistFn r e
forall a b. (a -> b) -> a -> b
$ Matrix r e
points
    covInv :: Matrix U e
covInv = Matrix e -> Matrix U e
forall r e. (Mutable r Int e, Element e) => Matrix e -> Matrix r e
matH2M (Matrix e -> Matrix U e)
-> (Matrix r e -> Matrix e) -> Matrix r e -> Matrix U e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix e -> Matrix e
forall t. Field t => Matrix t -> Matrix t
LA.inv (Matrix e -> Matrix e)
-> (Matrix r e -> Matrix e) -> Matrix r e -> Matrix e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix r e -> Matrix e
forall r e.
(Manifest r Int e, Element e, Resize r Ix2, Load r Ix2 e) =>
Matrix r e -> Matrix e
matM2H (Matrix r e -> Matrix U e) -> Matrix r e -> Matrix U e
forall a b. (a -> b) -> a -> b
$ Matrix r e
cov

----------------------------------------------------------------------------------------------------
-- DBScan

-- | Exception for invalid search distances.
newtype DistanceInvalidException e = DistanceInvalidException e deriving (Int -> DistanceInvalidException e -> ShowS
[DistanceInvalidException e] -> ShowS
DistanceInvalidException e -> String
(Int -> DistanceInvalidException e -> ShowS)
-> (DistanceInvalidException e -> String)
-> ([DistanceInvalidException e] -> ShowS)
-> Show (DistanceInvalidException e)
forall e. Show e => Int -> DistanceInvalidException e -> ShowS
forall e. Show e => [DistanceInvalidException e] -> ShowS
forall e. Show e => DistanceInvalidException e -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DistanceInvalidException e] -> ShowS
$cshowList :: forall e. Show e => [DistanceInvalidException e] -> ShowS
show :: DistanceInvalidException e -> String
$cshow :: forall e. Show e => DistanceInvalidException e -> String
showsPrec :: Int -> DistanceInvalidException e -> ShowS
$cshowsPrec :: forall e. Show e => Int -> DistanceInvalidException e -> ShowS
Show, DistanceInvalidException e -> DistanceInvalidException e -> Bool
(DistanceInvalidException e -> DistanceInvalidException e -> Bool)
-> (DistanceInvalidException e
    -> DistanceInvalidException e -> Bool)
-> Eq (DistanceInvalidException e)
forall e.
Eq e =>
DistanceInvalidException e -> DistanceInvalidException e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DistanceInvalidException e -> DistanceInvalidException e -> Bool
$c/= :: forall e.
Eq e =>
DistanceInvalidException e -> DistanceInvalidException e -> Bool
== :: DistanceInvalidException e -> DistanceInvalidException e -> Bool
$c== :: forall e.
Eq e =>
DistanceInvalidException e -> DistanceInvalidException e -> Bool
Eq)

instance (Typeable e, Show e) => Exception (DistanceInvalidException e)

-- | Representation of clusters.
type Clusters = Vector B IntSet

-- | DBScan algorithm.
{-# SCC dbscan #-}
dbscan ::
  ( MonadThrow m,
    Ord e,
    Num e,
    Typeable e,
    Show e,
    Source r Ix2 e
  ) =>
  -- | Distance measure to build the distance matrix of all points.
  DistFn r e ->
  -- | Minimal number of members in a cluster.
  Int ->
  -- | Search radius \(\epsilon\)
  e ->
  -- | \(n\) \(m\)-dimensional data points as column vectors of a \(m \times n\) matrix.
  Matrix r e ->
  -- | Resulting clusters.
  m Clusters
dbscan :: DistFn r e -> Int -> e -> Matrix r e -> m Clusters
dbscan DistFn r e
distFn Int
nPoints e
epsilon Matrix r e
points
  | Matrix r e -> Bool
forall r ix e. Load r ix e => Array r ix e -> Bool
isEmpty Matrix r e
points = SizeException -> m Clusters
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m Clusters) -> SizeException -> m Clusters
forall a b. (a -> b) -> a -> b
$ Sz Int -> SizeException
forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
0 :: Sz1)
  | Int
nPoints Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = SizeException -> m Clusters
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m Clusters) -> SizeException -> m Clusters
forall a b. (a -> b) -> a -> b
$ Sz Int -> SizeException
forall ix. Index ix => Sz ix -> SizeException
SizeNegativeException (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
nPoints)
  | e
epsilon e -> e -> Bool
forall a. Ord a => a -> a -> Bool
<= e
0 = DistanceInvalidException e -> m Clusters
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (DistanceInvalidException e -> m Clusters)
-> DistanceInvalidException e -> m Clusters
forall a b. (a -> b) -> a -> b
$ e -> DistanceInvalidException e
forall e. e -> DistanceInvalidException e
DistanceInvalidException e
epsilon
  | Bool
otherwise =
    let pointNeighbours :: Array D (Lower Ix2) IntSet
pointNeighbours = (Ix2 -> IntSet -> e -> IntSet)
-> IntSet -> Matrix r e -> Array D (Lower Ix2) IntSet
forall ix r e a.
(Index (Lower ix), Source r ix e) =>
(ix -> a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
ifoldlInner Ix2 -> IntSet -> e -> IntSet
collectNeighbours IntSet
forall a. Monoid a => a
mempty Matrix r e
distMat
        allClusters :: Clusters
allClusters = Clusters -> Clusters
joinOverlapping (Clusters -> Clusters)
-> (Array D Int IntSet -> Clusters)
-> Array D Int IntSet
-> Clusters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix e r'.
(Mutable B ix e, Load r' ix e) =>
Array r' ix e -> Array B ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B (Array D Int IntSet -> Clusters) -> Array D Int IntSet -> Clusters
forall a b. (a -> b) -> a -> b
$ Array D Int IntSet
Array D (Lower Ix2) IntSet
pointNeighbours
        largeClusters :: Vector DS IntSet
largeClusters = (IntSet -> Bool) -> Clusters -> Vector DS IntSet
forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
sfilter (\IntSet
s -> IntSet -> Int
IntSet.size IntSet
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nPoints) Clusters
allClusters
     in Clusters -> m Clusters
forall (m :: * -> *) a. Monad m => a -> m a
return (Clusters -> m Clusters) -> Clusters -> m Clusters
forall a b. (a -> b) -> a -> b
$ Vector DS IntSet -> Clusters
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute Vector DS IntSet
largeClusters
  where
    -- The distance matrix in the measure chosen by the distance function.
    distMat :: Matrix r e
distMat = DistFn r e
distFn Matrix r e
points

    -- Function to collect the neighbours of a point within the search radius epsilon.
    {-# SCC collectNeighbours #-}
    collectNeighbours :: Ix2 -> IntSet -> e -> IntSet
collectNeighbours (Int
_ :. Int
n) IntSet
acc e
d = if e
d e -> e -> Bool
forall a. Ord a => a -> a -> Bool
<= e
epsilon then Int -> IntSet -> IntSet
IntSet.insert Int
n IntSet
acc else IntSet
acc

    -- Construct the overlap matrix of all clusters.
    compareSets :: (IntSet -> IntSet -> Bool) -> Vector B IntSet -> Matrix D Bool
    compareSets :: (IntSet -> IntSet -> Bool) -> Clusters -> Matrix D Bool
compareSets IntSet -> IntSet -> Bool
fn Clusters
clVec =
      let a :: Array D Ix2 IntSet
a = Sz Int
-> (IntSet -> Int -> IntSet)
-> Array B (Lower Ix2) IntSet
-> Array D Ix2 IntSet
forall ix r a b.
(Index ix, Manifest r (Lower ix) a) =>
Sz Int -> (a -> Int -> b) -> Array r (Lower ix) a -> Array D ix b
expandOuter @Ix2 Sz Int
sz IntSet -> Int -> IntSet
forall a b. a -> b -> a
const Clusters
Array B (Lower Ix2) IntSet
clVec
          b :: Array D Ix2 IntSet
b = Array D Ix2 IntSet -> Array D Ix2 IntSet
forall r e. Source r Ix2 e => Array r Ix2 e -> Array D Ix2 e
transpose Array D Ix2 IntSet
a
          compareMat :: Matrix D Bool
compareMat = (IntSet -> IntSet -> Bool)
-> Array D Ix2 IntSet -> Array D Ix2 IntSet -> Matrix D Bool
forall r1 ix e1 r2 e2 e.
(Source r1 ix e1, Source r2 ix e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
Massiv.zipWith IntSet -> IntSet -> Bool
fn Array D Ix2 IntSet
a Array D Ix2 IntSet
b
       in Matrix D Bool
compareMat
      where
        sz :: Sz Int
sz = Clusters -> Sz Int
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Clusters
clVec

    -- Overlap matrix. Checks if two sets have any overlap. Sets do overlap with themself.
    overlap :: Vector B IntSet -> Matrix D Bool
    overlap :: Clusters -> Matrix D Bool
overlap = (IntSet -> IntSet -> Bool) -> Clusters -> Matrix D Bool
compareSets (\IntSet
a IntSet
b -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ IntSet -> IntSet -> Bool
IntSet.disjoint IntSet
a IntSet
b)

    -- Check if any set overlaps wiht **any** other set.
    anyOtherOverlap :: Vector B IntSet -> Bool
    anyOtherOverlap :: Clusters -> Bool
anyOtherOverlap = (Bool -> Bool) -> Matrix D Bool -> Bool
forall r ix e. Source r ix e => (e -> Bool) -> Array r ix e -> Bool
Massiv.any (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
True) (Matrix D Bool -> Bool)
-> (Clusters -> Matrix D Bool) -> Clusters -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ix2 -> Bool -> Bool) -> Matrix D Bool -> Matrix D Bool
forall r ix e' e.
Source r ix e' =>
(ix -> e' -> e) -> Array r ix e' -> Array D ix e
imap (\(Int
m :. Int
n) Bool
v -> if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n then Bool
False else Bool
v) (Matrix D Bool -> Matrix D Bool)
-> (Clusters -> Matrix D Bool) -> Clusters -> Matrix D Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clusters -> Matrix D Bool
overlap

    -- Check if two sets are identical. Sets are identical to themself.
    same :: Vector B IntSet -> Matrix D Bool
    same :: Clusters -> Matrix D Bool
same = (IntSet -> IntSet -> Bool) -> Clusters -> Matrix D Bool
compareSets IntSet -> IntSet -> Bool
forall a. Eq a => a -> a -> Bool
(==)

    -- Join all overlapping clusters recursively.
    {-# SCC joinOverlapping #-}
    joinOverlapping :: Vector B IntSet -> Vector B IntSet
    joinOverlapping :: Clusters -> Clusters
joinOverlapping Clusters
clVec =
      let -- The overlap matrix of the clusters.
          ovlpMat :: Array U Ix2 Bool
ovlpMat = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Matrix D Bool -> Array U Ix2 Bool)
-> (Clusters -> Matrix D Bool) -> Clusters -> Array U Ix2 Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clusters -> Matrix D Bool
overlap (Clusters -> Array U Ix2 Bool) -> Clusters -> Array U Ix2 Bool
forall a b. (a -> b) -> a -> b
$ Clusters
clVec
          anyOvlp :: Bool
anyOvlp = Clusters -> Bool
anyOtherOverlap Clusters
clVec

          -- Join all sets that have overlap but keep them redundantly (no reduction of the amount
          -- of clusters).
          joined :: Array D (Lower Ix2) IntSet
joined =
            (Ix2 -> IntSet -> Bool -> IntSet)
-> IntSet -> Array U Ix2 Bool -> Array D (Lower Ix2) IntSet
forall ix r e a.
(Index (Lower ix), Source r ix e) =>
(ix -> a -> e -> a) -> a -> Array r ix e -> Array D (Lower ix) a
ifoldlInner
              (\(Int
_ :. Int
n) IntSet
acc Bool
ovlp -> if Bool
ovlp then (Clusters
clVec Clusters -> Int -> IntSet
forall r ix e. Manifest r ix e => Array r ix e -> ix -> e
! Int
n) IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
acc else IntSet
acc)
              IntSet
forall a. Monoid a => a
mempty
              Array U Ix2 Bool
ovlpMat

          -- Find all sets at different indices that are the same. This is an upper triangular
          -- matrix with the main diagonal being False.
          sameMat :: Array U Ix2 Bool
sameMat =
            forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U
              (Matrix D Bool -> Array U Ix2 Bool)
-> (Array D Int IntSet -> Matrix D Bool)
-> Array D Int IntSet
-> Array U Ix2 Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ix2 -> Bool -> Bool) -> Matrix D Bool -> Matrix D Bool
forall r ix e' e.
Source r ix e' =>
(ix -> e' -> e) -> Array r ix e' -> Array D ix e
imap (\(Int
m :. Int
n) Bool
v -> if Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n then Bool
False else Bool
v)
              (Matrix D Bool -> Matrix D Bool)
-> (Array D Int IntSet -> Matrix D Bool)
-> Array D Int IntSet
-> Matrix D Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clusters -> Matrix D Bool
same
              (Clusters -> Matrix D Bool)
-> (Array D Int IntSet -> Clusters)
-> Array D Int IntSet
-> Matrix D Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix e r'.
(Mutable B ix e, Load r' ix e) =>
Array r' ix e -> Array B ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B
              (Array D Int IntSet -> Array U Ix2 Bool)
-> Array D Int IntSet -> Array U Ix2 Bool
forall a b. (a -> b) -> a -> b
$ Array D Int IntSet
Array D (Lower Ix2) IntSet
joined

          -- Remove all sets that are redundant. Redundancy is checked by two criteria:
          --   1. Is this cluster the same set of points as any other cluster? If yes, it is
          --      redundant.
          --   2. Is this cluster isolated it is not redundant.
          nonRed :: Clusters
nonRed =
            forall ix e r'.
(Mutable B ix e, Load r' ix e) =>
Array r' ix e -> Array B ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B
              (Vector DS IntSet -> Clusters)
-> (Array D Int IntSet -> Vector DS IntSet)
-> Array D Int IntSet
-> Clusters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> IntSet -> Bool) -> Array D Int IntSet -> Vector DS IntSet
forall r ix a.
Stream r ix a =>
(ix -> a -> Bool) -> Array r ix a -> Vector DS a
sifilter
                ( \Int
ix IntSet
_ ->
                    let sameAsAnyOther :: Bool
sameAsAnyOther = (Bool -> Bool) -> Array M Int Bool -> Bool
forall r ix e. Source r ix e => (e -> Bool) -> Array r ix e -> Bool
Massiv.any (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
True) (Array M Int Bool -> Bool) -> Array M Int Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Array U Ix2 Bool
sameMat Array U Ix2 Bool -> Int -> Elt U Ix2 Bool
forall r ix e.
OuterSlice r ix e =>
Array r ix e -> Int -> Elt r ix e
!> Int
ix
                     in Bool -> Bool
not Bool
sameAsAnyOther
                )
              (Array D Int IntSet -> Clusters) -> Array D Int IntSet -> Clusters
forall a b. (a -> b) -> a -> b
$ Array D Int IntSet
Array D (Lower Ix2) IntSet
joined
       in if Bool
anyOvlp then Clusters -> Clusters
joinOverlapping Clusters
nonRed else Clusters
clVec

----------------------------------------------------------------------------------------------------
-- Hierarchical Cluster Analysis

-- | Nodes of a dendrogram.
data DendroNode e = DendroNode
  { DendroNode e -> e
distance :: e,
    DendroNode e -> IntSet
cluster :: IntSet
  }
  deriving (DendroNode e -> DendroNode e -> Bool
(DendroNode e -> DendroNode e -> Bool)
-> (DendroNode e -> DendroNode e -> Bool) -> Eq (DendroNode e)
forall e. Eq e => DendroNode e -> DendroNode e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DendroNode e -> DendroNode e -> Bool
$c/= :: forall e. Eq e => DendroNode e -> DendroNode e -> Bool
== :: DendroNode e -> DendroNode e -> Bool
$c== :: forall e. Eq e => DendroNode e -> DendroNode e -> Bool
Eq, Int -> DendroNode e -> ShowS
[DendroNode e] -> ShowS
DendroNode e -> String
(Int -> DendroNode e -> ShowS)
-> (DendroNode e -> String)
-> ([DendroNode e] -> ShowS)
-> Show (DendroNode e)
forall e. Show e => Int -> DendroNode e -> ShowS
forall e. Show e => [DendroNode e] -> ShowS
forall e. Show e => DendroNode e -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DendroNode e] -> ShowS
$cshowList :: forall e. Show e => [DendroNode e] -> ShowS
show :: DendroNode e -> String
$cshow :: forall e. Show e => DendroNode e -> String
showsPrec :: Int -> DendroNode e -> ShowS
$cshowsPrec :: forall e. Show e => Int -> DendroNode e -> ShowS
Show, (forall x. DendroNode e -> Rep (DendroNode e) x)
-> (forall x. Rep (DendroNode e) x -> DendroNode e)
-> Generic (DendroNode e)
forall x. Rep (DendroNode e) x -> DendroNode e
forall x. DendroNode e -> Rep (DendroNode e) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall e x. Rep (DendroNode e) x -> DendroNode e
forall e x. DendroNode e -> Rep (DendroNode e) x
$cto :: forall e x. Rep (DendroNode e) x -> DendroNode e
$cfrom :: forall e x. DendroNode e -> Rep (DendroNode e) x
Generic)

instance (FromJSON e) => FromJSON (DendroNode e)

instance (ToJSON e) => ToJSON (DendroNode e)

-- | A dendrogram as a binary tree.
newtype Dendrogram e = Dendrogram {Dendrogram e -> BinTree (DendroNode e)
unDendro :: BinTree (DendroNode e)}
  deriving (Int -> Dendrogram e -> ShowS
[Dendrogram e] -> ShowS
Dendrogram e -> String
(Int -> Dendrogram e -> ShowS)
-> (Dendrogram e -> String)
-> ([Dendrogram e] -> ShowS)
-> Show (Dendrogram e)
forall e. Show e => Int -> Dendrogram e -> ShowS
forall e. Show e => [Dendrogram e] -> ShowS
forall e. Show e => Dendrogram e -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Dendrogram e] -> ShowS
$cshowList :: forall e. Show e => [Dendrogram e] -> ShowS
show :: Dendrogram e -> String
$cshow :: forall e. Show e => Dendrogram e -> String
showsPrec :: Int -> Dendrogram e -> ShowS
$cshowsPrec :: forall e. Show e => Int -> Dendrogram e -> ShowS
Show, Dendrogram e -> Dendrogram e -> Bool
(Dendrogram e -> Dendrogram e -> Bool)
-> (Dendrogram e -> Dendrogram e -> Bool) -> Eq (Dendrogram e)
forall e. Eq e => Dendrogram e -> Dendrogram e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Dendrogram e -> Dendrogram e -> Bool
$c/= :: forall e. Eq e => Dendrogram e -> Dendrogram e -> Bool
== :: Dendrogram e -> Dendrogram e -> Bool
$c== :: forall e. Eq e => Dendrogram e -> Dendrogram e -> Bool
Eq, (forall x. Dendrogram e -> Rep (Dendrogram e) x)
-> (forall x. Rep (Dendrogram e) x -> Dendrogram e)
-> Generic (Dendrogram e)
forall x. Rep (Dendrogram e) x -> Dendrogram e
forall x. Dendrogram e -> Rep (Dendrogram e) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall e x. Rep (Dendrogram e) x -> Dendrogram e
forall e x. Dendrogram e -> Rep (Dendrogram e) x
$cto :: forall e x. Rep (Dendrogram e) x -> Dendrogram e
$cfrom :: forall e x. Dendrogram e -> Rep (Dendrogram e) x
Generic)

instance ToJSON e => ToJSON (Dendrogram e)

instance FromJSON e => FromJSON (Dendrogram e)

-- | An accumulator to finally build a dendrogram by a bottom-up algorithm. Not to be exposed in the
-- API.
type DendroAcc e = Vector B (Dendrogram e)

-- | Mutable version of the dendrogram accumulator.
type DendroAccM m e = MArray (PrimState m) B Ix1 (Dendrogram e)

-- | Cut a 'Dendrogram' at a given distance and obtain all clusters from it.
cutDendroAt :: Ord e => Dendrogram e -> e -> Clusters
cutDendroAt :: Dendrogram e -> e -> Clusters
cutDendroAt Dendrogram e
dendro e
dist =
  let nodes :: Vector DL (DendroNode e)
nodes = (DendroNode e -> Bool)
-> BinTree (DendroNode e) -> Vector DL (DendroNode e)
forall a. (a -> Bool) -> BinTree a -> Vector DL a
takeLeafyBranchesWhile (\DendroNode {e
distance :: e
$sel:distance:DendroNode :: forall e. DendroNode e -> e
distance} -> e
distance e -> e -> Bool
forall a. Ord a => a -> a -> Bool
>= e
dist) (BinTree (DendroNode e) -> Vector DL (DendroNode e))
-> (Dendrogram e -> BinTree (DendroNode e))
-> Dendrogram e
-> Vector DL (DendroNode e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dendrogram e -> BinTree (DendroNode e)
forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro (Dendrogram e -> Vector DL (DendroNode e))
-> Dendrogram e -> Vector DL (DendroNode e)
forall a b. (a -> b) -> a -> b
$ Dendrogram e
dendro
      clusters :: Clusters
clusters = forall ix e r'.
(Mutable B ix e, Load r' ix e) =>
Array r' ix e -> Array B ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B (Array D Int IntSet -> Clusters)
-> (Vector DL (DendroNode e) -> Array D Int IntSet)
-> Vector DL (DendroNode e)
-> Clusters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DendroNode e -> IntSet)
-> Array B Int (DendroNode e) -> Array D Int IntSet
forall r ix e' e.
Source r ix e' =>
(e' -> e) -> Array r ix e' -> Array D ix e
Massiv.map DendroNode e -> IntSet
forall e. DendroNode e -> IntSet
cluster (Array B Int (DendroNode e) -> Array D Int IntSet)
-> (Vector DL (DendroNode e) -> Array B Int (DendroNode e))
-> Vector DL (DendroNode e)
-> Array D Int IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix e r'.
(Mutable B ix e, Load r' ix e) =>
Array r' ix e -> Array B ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B (Vector DL (DendroNode e) -> Clusters)
-> Vector DL (DendroNode e) -> Clusters
forall a b. (a -> b) -> a -> b
$ Vector DL (DendroNode e)
nodes
   in Clusters
clusters

-- | A strategy/distance measure for clusters.
data JoinStrat e
  = SingleLinkage
  | CompleteLinkage
  | Median
  | UPGMA
  | WPGMA
  | Centroid
  | Ward
  | LWFB e
  | LW e e e e
  deriving (JoinStrat e -> JoinStrat e -> Bool
(JoinStrat e -> JoinStrat e -> Bool)
-> (JoinStrat e -> JoinStrat e -> Bool) -> Eq (JoinStrat e)
forall e. Eq e => JoinStrat e -> JoinStrat e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JoinStrat e -> JoinStrat e -> Bool
$c/= :: forall e. Eq e => JoinStrat e -> JoinStrat e -> Bool
== :: JoinStrat e -> JoinStrat e -> Bool
$c== :: forall e. Eq e => JoinStrat e -> JoinStrat e -> Bool
Eq, Int -> JoinStrat e -> ShowS
[JoinStrat e] -> ShowS
JoinStrat e -> String
(Int -> JoinStrat e -> ShowS)
-> (JoinStrat e -> String)
-> ([JoinStrat e] -> ShowS)
-> Show (JoinStrat e)
forall e. Show e => Int -> JoinStrat e -> ShowS
forall e. Show e => [JoinStrat e] -> ShowS
forall e. Show e => JoinStrat e -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JoinStrat e] -> ShowS
$cshowList :: forall e. Show e => [JoinStrat e] -> ShowS
show :: JoinStrat e -> String
$cshow :: forall e. Show e => JoinStrat e -> String
showsPrec :: Int -> JoinStrat e -> ShowS
$cshowsPrec :: forall e. Show e => Int -> JoinStrat e -> ShowS
Show)

-- | Lance Williams formula to update distances.
{-# SCC lanceWilliams #-}
lanceWilliams ::
  Fractional e =>
  -- | How to calculate distance between clusters of points.
  JoinStrat e ->
  -- | Number of points in cluster \(A\).
  Int ->
  -- | Number of points in cluster \(B\)
  Int ->
  -- | Number of points in cluster \(C\)
  Int ->
  -- | \(d(A, B)\)
  e ->
  -- | \(d(A, C)\)
  e ->
  -- | \(d(B, C)\)
  e ->
  -- | Updated distance \(D \(A \cup B, C\)
  e
lanceWilliams :: JoinStrat e -> Int -> Int -> Int -> e -> e -> e -> e
lanceWilliams JoinStrat e
js Int
nA Int
nB Int
nC e
dAB e
dAC e
dBC = e
alpha1 e -> e -> e
forall a. Num a => a -> a -> a
* e
dAC e -> e -> e
forall a. Num a => a -> a -> a
+ e
alpha2 e -> e -> e
forall a. Num a => a -> a -> a
* e
dBC e -> e -> e
forall a. Num a => a -> a -> a
+ e
beta e -> e -> e
forall a. Num a => a -> a -> a
* e
dAB e -> e -> e
forall a. Num a => a -> a -> a
+ e
gamma e -> e -> e
forall a. Num a => a -> a -> a
* e -> e
forall a. Num a => a -> a
abs (e
dAC e -> e -> e
forall a. Num a => a -> a -> a
- e
dBC)
  where
    nA' :: e
nA' = Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nA
    nB' :: e
nB' = Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nB
    nC' :: e
nC' = Int -> e
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nC
    (e
alpha1, e
alpha2, e
beta, e
gamma) = case JoinStrat e
js of
      JoinStrat e
SingleLinkage -> (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
0, - e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2)
      JoinStrat e
CompleteLinkage -> (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
0, e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2)
      JoinStrat e
Median -> (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, - e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
4, e
0)
      JoinStrat e
UPGMA -> (e
nA' e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB'), e
nB' e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB'), e
0, e
0)
      JoinStrat e
WPGMA -> (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
0, e
0)
      JoinStrat e
Centroid -> (e
nA' e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB'), e
nB' e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB'), - (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
* e
nB') e -> e -> e
forall a. Fractional a => a -> a -> a
/ ((e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB') e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)), e
0)
      JoinStrat e
Ward -> ((e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nC') e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nC'), (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nC') e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nC'), - (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nC') e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
nA' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nB' e -> e -> e
forall a. Num a => a -> a -> a
+ e
nC'), e
0)
      LWFB e
b -> ((e
1 e -> e -> e
forall a. Num a => a -> a -> a
- e
b) e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, (e
1 e -> e -> e
forall a. Num a => a -> a -> a
- e
b) e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
2, e
b, e
0)
      LW e
a e
b e
c e
d -> (e
a, e
b, e
c, e
d)

----------------------------------------------------------------------------------------------------
-- Müllner Generic Hierarchical Clustering

-- | A neighbourlist. At index @i@ of the vector it contains a tuple with the minimal distance of
-- this cluster to any other cluster and the index of the other cluster.
type Neighbourlist r e = Vector r (e, Ix1)

-- | A distance matrix.
type DistanceMatrix r e = Matrix r e

-- | Performance improved hierarchical clustering algorithm. @GENERIC_LINKAGE@ from figure 3,
-- <https://arxiv.org/pdf/1109.2378.pdf>.
{-# SCC hca #-}
hca ::
  ( MonadThrow m,
    Mutable r Ix1 e,
    Mutable r Ix2 e,
    Mutable r Ix1 (e, Ix1),
    Manifest (R r) Ix1 e,
    OuterSlice r Ix2 e,
    Ord e,
    Unbox e,
    Fractional e
  ) =>
  DistFn r e ->
  JoinStrat e ->
  Matrix r e ->
  m (Dendrogram e)
hca :: DistFn r e -> JoinStrat e -> Matrix r e -> m (Dendrogram e)
hca DistFn r e
distFn JoinStrat e
joinStrat Matrix r e
points
  | Matrix r e -> Bool
forall r ix e. Load r ix e => Array r ix e -> Bool
Massiv.isEmpty Matrix r e
points = SizeException -> m (Dendrogram e)
forall a e. Exception e => e -> a
throw (SizeException -> m (Dendrogram e))
-> SizeException -> m (Dendrogram e)
forall a b. (a -> b) -> a -> b
$ Sz Int -> SizeException
forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
nPoints)
  | Bool
otherwise = do
    let -- The distance matrix from the points.
        distMat :: Matrix r e
distMat = DistFn r e
distFn Matrix r e
points

    -- Initial vector of nearest neighbour to each point.
    Vector r (e, Int)
nNghbr <- Matrix r e -> m (Vector r (e, Int))
forall (m :: * -> *) r e.
(MonadThrow m, Mutable r Int e, Mutable r Int (e, Int),
 OuterSlice r Ix2 e, Source (R r) Int e, Ord e, Unbox e) =>
Matrix r e -> m (Vector r (e, Int))
nearestNeighbours Matrix r e
distMat

    let -- Initial priority queue of points. Has the minimum distance of all points.
        pq :: PSQ Int e
pq = [Binding Int e] -> PSQ Int e
forall k p. (Ord k, Ord p) => [Binding k p] -> PSQ k p
PQ.fromList ([Binding Int e] -> PSQ Int e)
-> (Vector r (e, Int) -> [Binding Int e])
-> Vector r (e, Int)
-> PSQ Int e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int (Binding Int e) -> [Binding Int e]
forall r ix e. Source r ix e => Array r ix e -> [e]
Massiv.toList (Array D Int (Binding Int e) -> [Binding Int e])
-> (Vector r (e, Int) -> Array D Int (Binding Int e))
-> Vector r (e, Int)
-> [Binding Int e]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> (e, Int) -> Binding Int e)
-> Vector r (e, Int) -> Array D Int (Binding Int e)
forall r ix e' e.
Source r ix e' =>
(ix -> e' -> e) -> Array r ix e' -> Array D ix e
Massiv.imap (\Int
k (e
d, Int
_) -> Int
k Int -> e -> Binding Int e
forall k p. k -> p -> Binding k p
PQ.:-> e
d) (Vector r (e, Int) -> PSQ Int e) -> Vector r (e, Int) -> PSQ Int e
forall a b. (a -> b) -> a -> b
$ Vector r (e, Int)
nNghbr
        -- Set of points not joined yet. Initially all points.
        s :: IntSet
s = [Int] -> IntSet
IntSet.fromDistinctAscList [Int
0 .. Int
nPoints Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        -- Initial dendrogram accumulator. The vector of all points as their own cluster.
        dendroAcc :: Array B Int (Dendrogram e)
dendroAcc =
          Comp
-> Sz Int -> (Int -> Dendrogram e) -> Array B Int (Dendrogram e)
forall r ix e.
Construct r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray @B @Ix1
            Comp
Par
            (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
nPoints)
            (\Int
p -> BinTree (DendroNode e) -> Dendrogram e
forall e. BinTree (DendroNode e) -> Dendrogram e
Dendrogram (BinTree (DendroNode e) -> Dendrogram e)
-> (DendroNode e -> BinTree (DendroNode e))
-> DendroNode e
-> Dendrogram e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DendroNode e -> BinTree (DendroNode e)
forall e. e -> BinTree e
Leaf (DendroNode e -> Dendrogram e) -> DendroNode e -> Dendrogram e
forall a b. (a -> b) -> a -> b
$ DendroNode :: forall e. e -> IntSet -> DendroNode e
DendroNode {$sel:distance:DendroNode :: e
distance = e
0, $sel:cluster:DendroNode :: IntSet
cluster = Int -> IntSet
IntSet.singleton Int
p})

    MArray RealWorld r Ix2 e
distMatM <- MArray RealWorld r Ix2 e -> m (MArray RealWorld r Ix2 e)
forall (m :: * -> *) a. Monad m => a -> m a
return (MArray RealWorld r Ix2 e -> m (MArray RealWorld r Ix2 e))
-> (Matrix r e -> MArray RealWorld r Ix2 e)
-> Matrix r e
-> m (MArray RealWorld r Ix2 e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (MArray RealWorld r Ix2 e) -> MArray RealWorld r Ix2 e
forall a. IO a -> a
unsafePerformIO (IO (MArray RealWorld r Ix2 e) -> MArray RealWorld r Ix2 e)
-> (Matrix r e -> IO (MArray RealWorld r Ix2 e))
-> Matrix r e
-> MArray RealWorld r Ix2 e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix r e -> IO (MArray RealWorld r Ix2 e)
forall r ix e (m :: * -> *).
(Mutable r ix e, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
thaw (Matrix r e -> m (MArray RealWorld r Ix2 e))
-> Matrix r e -> m (MArray RealWorld r Ix2 e)
forall a b. (a -> b) -> a -> b
$ Matrix r e
distMat
    MArray RealWorld r Int (e, Int)
nNghbrM <- MArray RealWorld r Int (e, Int)
-> m (MArray RealWorld r Int (e, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (MArray RealWorld r Int (e, Int)
 -> m (MArray RealWorld r Int (e, Int)))
-> (Vector r (e, Int) -> MArray RealWorld r Int (e, Int))
-> Vector r (e, Int)
-> m (MArray RealWorld r Int (e, Int))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (MArray RealWorld r Int (e, Int))
-> MArray RealWorld r Int (e, Int)
forall a. IO a -> a
unsafePerformIO (IO (MArray RealWorld r Int (e, Int))
 -> MArray RealWorld r Int (e, Int))
-> (Vector r (e, Int) -> IO (MArray RealWorld r Int (e, Int)))
-> Vector r (e, Int)
-> MArray RealWorld r Int (e, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector r (e, Int) -> IO (MArray RealWorld r Int (e, Int))
forall r ix e (m :: * -> *).
(Mutable r ix e, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
thaw (Vector r (e, Int) -> m (MArray RealWorld r Int (e, Int)))
-> Vector r (e, Int) -> m (MArray RealWorld r Int (e, Int))
forall a b. (a -> b) -> a -> b
$ Vector r (e, Int)
nNghbr
    MArray RealWorld B Int (Dendrogram e)
dendroAccM <- MArray RealWorld B Int (Dendrogram e)
-> m (MArray RealWorld B Int (Dendrogram e))
forall (m :: * -> *) a. Monad m => a -> m a
return (MArray RealWorld B Int (Dendrogram e)
 -> m (MArray RealWorld B Int (Dendrogram e)))
-> (Array B Int (Dendrogram e)
    -> MArray RealWorld B Int (Dendrogram e))
-> Array B Int (Dendrogram e)
-> m (MArray RealWorld B Int (Dendrogram e))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (MArray RealWorld B Int (Dendrogram e))
-> MArray RealWorld B Int (Dendrogram e)
forall a. IO a -> a
unsafePerformIO (IO (MArray RealWorld B Int (Dendrogram e))
 -> MArray RealWorld B Int (Dendrogram e))
-> (Array B Int (Dendrogram e)
    -> IO (MArray RealWorld B Int (Dendrogram e)))
-> Array B Int (Dendrogram e)
-> MArray RealWorld B Int (Dendrogram e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array B Int (Dendrogram e)
-> IO (MArray RealWorld B Int (Dendrogram e))
forall r ix e (m :: * -> *).
(Mutable r ix e, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
thaw (Array B Int (Dendrogram e)
 -> m (MArray RealWorld B Int (Dendrogram e)))
-> Array B Int (Dendrogram e)
-> m (MArray RealWorld B Int (Dendrogram e))
forall a b. (a -> b) -> a -> b
$ Array B Int (Dendrogram e)
dendroAcc

    Dendrogram e -> m (Dendrogram e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Dendrogram e -> m (Dendrogram e))
-> (IO (Dendrogram e) -> Dendrogram e)
-> IO (Dendrogram e)
-> m (Dendrogram e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Dendrogram e) -> Dendrogram e
forall a. IO a -> a
unsafePerformIO (IO (Dendrogram e) -> m (Dendrogram e))
-> IO (Dendrogram e) -> m (Dendrogram e)
forall a b. (a -> b) -> a -> b
$ JoinStrat e
-> MArray (PrimState IO) r Ix2 e
-> MArray (PrimState IO) r Int (e, Int)
-> PSQ Int e
-> IntSet
-> DendroAccM IO e
-> IO (Dendrogram e)
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Mutable r Ix2 e, OuterSlice r Ix2 e,
 Manifest (R r) Int e, Mutable r Int (e, Int), Fractional e,
 Ord e) =>
JoinStrat e
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> IntSet
-> DendroAccM m e
-> m (Dendrogram e)
agglomerate JoinStrat e
joinStrat MArray RealWorld r Ix2 e
MArray (PrimState IO) r Ix2 e
distMatM MArray RealWorld r Int (e, Int)
MArray (PrimState IO) r Int (e, Int)
nNghbrM PSQ Int e
pq IntSet
s MArray RealWorld B Int (Dendrogram e)
DendroAccM IO e
dendroAccM
  where
    Sz (Int
_mFeatures :. Int
nPoints) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r e
points

-- | Agglomerative clustering by the improved generic linkage algorithm. This is the main loop
-- recursion L 10-43.
{-# SCC agglomerate #-}
agglomerate ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    PrimState m ~ RealWorld,
    Mutable r Ix2 e,
    OuterSlice r Ix2 e,
    Manifest (R r) Ix1 e,
    Mutable r Ix1 (e, Ix1),
    Fractional e,
    Ord e
  ) =>
  -- | Join strategy for clusters and therefore how to calculate cluster-cluster distances.
  JoinStrat e ->
  -- | Distance matrix.
  MArray (PrimState m) r Ix2 e ->
  -- | List of nearest neighbours for each point.
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  -- | Priority queue with the distances as priorities and the cluster index as keys.
  PQ.PSQ Ix1 e ->
  -- | A set \(S\), that keeps track which clusters have already been joined.
  IntSet ->
  -- | Accumulator of the dendrogram. Should collapse to a singleton vector.
  DendroAccM m e ->
  -- | The final dendrogram, after all clusters have been joined.
  m (Dendrogram e)
agglomerate :: JoinStrat e
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> IntSet
-> DendroAccM m e
-> m (Dendrogram e)
agglomerate JoinStrat e
joinStrat MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
pq IntSet
s DendroAccM m e
dendroAcc
  | IntSet -> Bool
IntSet.null IntSet
s = IndexException -> m (Dendrogram e)
forall a e. Exception e => e -> a
throw (IndexException -> m (Dendrogram e))
-> IndexException -> m (Dendrogram e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"No clusters left. This must never happen."
  | Bool
otherwise = do
    -- Obtain candidates for the two clusters to join and the minimal distance in the priority queue.
    (Int, Int, e)
candidates <- MArray (PrimState m) r Int (e, Int) -> PSQ Int e -> m (Int, Int, e)
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, Mutable r Int (e, Int), Ord e) =>
MArray (PrimState m) r Int (e, Int) -> PSQ Int e -> m (Int, Int, e)
getJoinCandidates MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
pq

    -- If the distance between a b is not the minimal distance that the priority queue has found, the
    -- neighbour list must be wrong and recalculated.
    (Int
a, Int
b, e
delta, MArray RealWorld r Int (e, Int)
nNghbrU1, PSQ Int e
pqU1) <- (Int, Int, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (Int, Int, e, MArray (PrimState m) r Int (e, Int), PSQ Int e)
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, OuterSlice r Ix2 e, Manifest (R r) Int e,
 Mutable r Int (e, Int), Mutable r Ix2 e, Ord e) =>
(Int, Int, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (Int, Int, e, MArray (PrimState m) r Int (e, Int), PSQ Int e)
recalculateNghbr (Int, Int, e)
candidates IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
pq

    -- Remove the minimal element from the priority queue and join clusters a and b. The cluster
    -- accumulator is reduced in its size: a is removed and b is updated with the joined cluster.
    (IntSet
newS, PSQ Int e
pqU2, MArray RealWorld B Int (Dendrogram e)
newAcc) <- Int
-> Int
-> e
-> IntSet
-> PSQ Int e
-> DendroAccM m e
-> m (IntSet, PSQ Int e, DendroAccM m e)
forall (m :: * -> *) e.
(MonadThrow m, PrimMonad m, Ord e) =>
Int
-> Int
-> e
-> IntSet
-> PSQ Int e
-> DendroAccM m e
-> m (IntSet, PSQ Int e, DendroAccM m e)
joinClusters Int
a Int
b e
delta IntSet
s PSQ Int e
pqU1 DendroAccM m e
dendroAcc

    -- Update the distance matrix in the row and column of b but not at (b,b) and not at (a,b) and
    -- (b,a).
    MArray RealWorld r Ix2 e
newDistMat <- JoinStrat e
-> Int
-> Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> DendroAccM m e
-> m (MArray (PrimState m) r Ix2 e)
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m, Mutable r Ix2 e,
 Fractional e) =>
JoinStrat e
-> Int
-> Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> DendroAccM m e
-> m (MArray (PrimState m) r Ix2 e)
updateDistMat JoinStrat e
joinStrat Int
a Int
b IntSet
newS MArray (PrimState m) r Ix2 e
distMat MArray RealWorld B Int (Dendrogram e)
DendroAccM m e
newAcc

    -- Redirect neighbours to b, if they previously pointed to a.
    MArray RealWorld r Int (e, Int)
nNghbrU2 <- Int
-> Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> m (MArray (PrimState m) r Int (e, Int))
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 Mutable r Int (e, Int), Mutable r Ix2 e) =>
Int
-> Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> m (MArray (PrimState m) r Int (e, Int))
redirectNeighbours Int
a Int
b IntSet
newS MArray RealWorld r Ix2 e
MArray (PrimState m) r Ix2 e
newDistMat MArray RealWorld r Int (e, Int)
MArray (PrimState m) r Int (e, Int)
nNghbrU1

    -- Update the neighbourlist and priority queue with the new distances to b.
    (MArray RealWorld r Int (e, Int)
newNNghbr, PSQ Int e
newPQ) <- Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (MArray (PrimState m) r Int (e, Int), PSQ Int e)
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 Mutable r Int (e, Int), Mutable r Ix2 e, Ord e) =>
Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (MArray (PrimState m) r Int (e, Int), PSQ Int e)
updateBNeighbour Int
b IntSet
s MArray RealWorld r Ix2 e
MArray (PrimState m) r Ix2 e
newDistMat MArray RealWorld r Int (e, Int)
MArray (PrimState m) r Int (e, Int)
nNghbrU2 PSQ Int e
pqU2

    -- If the problem has been reduced to a single cluster the algorithm is done and the final
    -- dendrogram can be obtained from the accumulator at index b. Otherwise join further.
    if IntSet -> Int
IntSet.size IntSet
newS Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
      then MArray RealWorld B Int (Dendrogram e)
DendroAccM m e
newAcc DendroAccM m e -> Int -> m (Dendrogram e)
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Int
b
      else JoinStrat e
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> IntSet
-> DendroAccM m e
-> m (Dendrogram e)
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, Mutable r Ix2 e, OuterSlice r Ix2 e,
 Manifest (R r) Int e, Mutable r Int (e, Int), Fractional e,
 Ord e) =>
JoinStrat e
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> IntSet
-> DendroAccM m e
-> m (Dendrogram e)
agglomerate JoinStrat e
joinStrat MArray RealWorld r Ix2 e
MArray (PrimState m) r Ix2 e
newDistMat MArray RealWorld r Int (e, Int)
MArray (PrimState m) r Int (e, Int)
newNNghbr PSQ Int e
newPQ IntSet
newS MArray RealWorld B Int (Dendrogram e)
DendroAccM m e
newAcc

-- | Obtain candidates for the clusters to join by looking at the minimal distance in the priority
-- queue and the neighbourlist. L 11-13
{-# SCC getJoinCandidates #-}
getJoinCandidates ::
  ( MonadThrow m,
    PrimMonad m,
    Mutable r Ix1 (e, Ix1),
    Ord e
  ) =>
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.PSQ Ix1 e ->
  m (Ix1, Ix1, e)
getJoinCandidates :: MArray (PrimState m) r Int (e, Int) -> PSQ Int e -> m (Int, Int, e)
getJoinCandidates MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
pq = do
  (Int
a PQ.:-> e
d) <- case PSQ Int e -> Maybe (Binding Int e)
forall k p. (Ord k, Ord p) => PSQ k p -> Maybe (Binding k p)
PQ.findMin PSQ Int e
pq of
    Maybe (Binding Int e)
Nothing -> IndexException -> m (Binding Int e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Binding Int e))
-> IndexException -> m (Binding Int e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Empty priority queue"
    Just Binding Int e
v -> Binding Int e -> m (Binding Int e)
forall (m :: * -> *) a. Monad m => a -> m a
return Binding Int e
v
  (e
_, Int
b) <- MArray (PrimState m) r Int (e, Int)
nNghbr MArray (PrimState m) r Int (e, Int) -> Int -> m (e, Int)
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Int
a
  (Int, Int, e) -> m (Int, Int, e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
a, Int
b, e
d)

-- | If the minimal distance @d@ found is not the distance between @a@ and @b@ recalculate the
-- neighbour list, update the priority queue and obtain a new set of a,b and a distance between them.
-- L 14-20.
{-# SCC recalculateNghbr #-}
recalculateNghbr ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    PrimState m ~ RealWorld,
    OuterSlice r Ix2 e,
    Manifest (R r) Ix1 e,
    Mutable r Ix1 (e, Ix1),
    Mutable r Ix2 e,
    Ord e
  ) =>
  (Ix1, Ix1, e) ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.PSQ Ix1 e ->
  m (Ix1, Ix1, e, MArray (PrimState m) r Ix1 (e, Ix1), PQ.PSQ Ix1 e)
recalculateNghbr :: (Int, Int, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (Int, Int, e, MArray (PrimState m) r Int (e, Int), PSQ Int e)
recalculateNghbr (Int
cA, Int
cB, e
d) IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
pq = do
  e
dAB <- MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix2 e -> Ix2 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Int
cA Int -> Int -> Ix2
:. Int
cB)
  if e
d e -> e -> Bool
forall a. Eq a => a -> a -> Bool
== e
dAB
    then (Int, Int, e, MArray RealWorld r Int (e, Int), PSQ Int e)
-> m (Int, Int, e, MArray RealWorld r Int (e, Int), PSQ Int e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
cA, Int
cB, e
d, MArray RealWorld r Int (e, Int)
MArray (PrimState m) r Int (e, Int)
nNghbr, PSQ Int e
pq)
    else do
      -- Recalculate the nearest neighbours just on index cA. Consider only clusters, that were not
      -- merged yet.
      Array r Int (e, Int)
dmRowA <- Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Int (e, Int))
forall (m :: * -> *) r e.
(PrimMonad m, MonadThrow m, MonadUnliftIO m, Mutable r Ix2 e,
 Mutable r Int (e, Int)) =>
Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Int (e, Int))
searchRow Int
cA IntSet
s MArray (PrimState m) r Ix2 e
distMat m (MArray RealWorld r Int (e, Int))
-> (MArray RealWorld r Int (e, Int) -> m (Array r Int (e, Int)))
-> m (Array r Int (e, Int))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp
-> MArray (PrimState m) r Int (e, Int) -> m (Array r Int (e, Int))
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
unsafeFreeze Comp
Par
      newNeighbourA :: (e, Int)
newNeighbourA@(e
minDistA, Int
_) <- Array r Int (e, Int) -> m (e, Int)
forall (m :: * -> *) r ix e.
(MonadThrow m, Source r ix e, Ord e) =>
Array r ix e -> m e
minimumM Array r Int (e, Int)
dmRowA
      MArray (PrimState m) r Int (e, Int) -> Int -> (e, Int) -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
writeM MArray (PrimState m) r Int (e, Int)
nNghbr Int
cA (e, Int)
newNeighbourA

      -- Update the priority queue at key cA with the new distance.
      let newPQ :: PSQ Int e
newPQ = (e -> e) -> Int -> PSQ Int e -> PSQ Int e
forall p k. (Ord p, Ord k) => (p -> p) -> k -> PSQ k p -> PSQ k p
PQ.adjust (e -> e -> e
forall a b. a -> b -> a
const e
minDistA) Int
cA PSQ Int e
pq

      -- Determine new a, b and d from the updated neighbour list and priority queue.
      (Int
a PQ.:-> e
newD) <- case PSQ Int e -> Maybe (Binding Int e)
forall k p. (Ord k, Ord p) => PSQ k p -> Maybe (Binding k p)
PQ.findMin PSQ Int e
newPQ of
        Maybe (Binding Int e)
Nothing -> IndexException -> m (Binding Int e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Binding Int e))
-> IndexException -> m (Binding Int e)
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Empty priority queue"
        Just Binding Int e
v -> Binding Int e -> m (Binding Int e)
forall (m :: * -> *) a. Monad m => a -> m a
return Binding Int e
v
      (e
_, Int
b) <- MArray (PrimState m) r Int (e, Int)
nNghbr MArray (PrimState m) r Int (e, Int) -> Int -> m (e, Int)
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Int
a
      (Int, Int, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (Int, Int, e, MArray (PrimState m) r Int (e, Int), PSQ Int e)
forall (m :: * -> *) r e.
(MonadThrow m, PrimMonad m, MonadUnliftIO m,
 PrimState m ~ RealWorld, OuterSlice r Ix2 e, Manifest (R r) Int e,
 Mutable r Int (e, Int), Mutable r Ix2 e, Ord e) =>
(Int, Int, e)
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (Int, Int, e, MArray (PrimState m) r Int (e, Int), PSQ Int e)
recalculateNghbr (Int
a, Int
b, e
newD) IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
newPQ

-- | Joins the selected clusters \(A\) and \(B\) and updates the dendrogram accumulator at index b.
-- A will not be removed so that the accumulator never shrinks.
-- L 21-24
{-# SCC joinClusters #-}
joinClusters ::
  ( MonadThrow m,
    PrimMonad m,
    Ord e
  ) =>
  Ix1 ->
  Ix1 ->
  e ->
  IntSet ->
  PQ.PSQ Ix1 e ->
  DendroAccM m e ->
  m (IntSet, PQ.PSQ Ix1 e, DendroAccM m e)
joinClusters :: Int
-> Int
-> e
-> IntSet
-> PSQ Int e
-> DendroAccM m e
-> m (IntSet, PSQ Int e, DendroAccM m e)
joinClusters Int
a Int
b e
d IntSet
s PSQ Int e
pq DendroAccM m e
acc = do
  Dendrogram e
clA <- DendroAccM m e
acc DendroAccM m e -> Int -> m (Dendrogram e)
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Int
a
  let newPQ :: PSQ Int e
newPQ = PSQ Int e -> PSQ Int e
forall k p. (Ord k, Ord p) => PSQ k p -> PSQ k p
PQ.deleteMin PSQ Int e
pq
  DendroAccM m e -> (Dendrogram e -> m (Dendrogram e)) -> Int -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_
    DendroAccM m e
acc
    ( \Dendrogram e
clB ->
        Dendrogram e -> m (Dendrogram e)
forall (m :: * -> *) a. Monad m => a -> m a
return
          (Dendrogram e -> m (Dendrogram e))
-> (BinTree (DendroNode e) -> Dendrogram e)
-> BinTree (DendroNode e)
-> m (Dendrogram e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinTree (DendroNode e) -> Dendrogram e
forall e. BinTree (DendroNode e) -> Dendrogram e
Dendrogram
          (BinTree (DendroNode e) -> m (Dendrogram e))
-> BinTree (DendroNode e) -> m (Dendrogram e)
forall a b. (a -> b) -> a -> b
$ DendroNode e
-> BinTree (DendroNode e)
-> BinTree (DendroNode e)
-> BinTree (DendroNode e)
forall e. e -> BinTree e -> BinTree e -> BinTree e
Node
            ( DendroNode :: forall e. e -> IntSet -> DendroNode e
DendroNode
                { $sel:distance:DendroNode :: e
distance = e
d,
                  $sel:cluster:DendroNode :: IntSet
cluster = (DendroNode e -> IntSet
forall e. DendroNode e -> IntSet
cluster (DendroNode e -> IntSet)
-> (Dendrogram e -> DendroNode e) -> Dendrogram e -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinTree (DendroNode e) -> DendroNode e
forall e. BinTree e -> e
root (BinTree (DendroNode e) -> DendroNode e)
-> (Dendrogram e -> BinTree (DendroNode e))
-> Dendrogram e
-> DendroNode e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dendrogram e -> BinTree (DendroNode e)
forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro (Dendrogram e -> IntSet) -> Dendrogram e -> IntSet
forall a b. (a -> b) -> a -> b
$ Dendrogram e
clA) IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> (DendroNode e -> IntSet
forall e. DendroNode e -> IntSet
cluster (DendroNode e -> IntSet)
-> (Dendrogram e -> DendroNode e) -> Dendrogram e -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinTree (DendroNode e) -> DendroNode e
forall e. BinTree e -> e
root (BinTree (DendroNode e) -> DendroNode e)
-> (Dendrogram e -> BinTree (DendroNode e))
-> Dendrogram e
-> DendroNode e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dendrogram e -> BinTree (DendroNode e)
forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro (Dendrogram e -> IntSet) -> Dendrogram e -> IntSet
forall a b. (a -> b) -> a -> b
$ Dendrogram e
clB)
                }
            )
            (Dendrogram e -> BinTree (DendroNode e)
forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro Dendrogram e
clA)
            (Dendrogram e -> BinTree (DendroNode e)
forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro Dendrogram e
clB)
    )
    Int
b
  let newS :: IntSet
newS = Int -> IntSet -> IntSet
IntSet.delete Int
a IntSet
s
  (IntSet, PSQ Int e, DendroAccM m e)
-> m (IntSet, PSQ Int e, DendroAccM m e)
forall (m :: * -> *) a. Monad m => a -> m a
return (IntSet
newS, PSQ Int e
newPQ, DendroAccM m e
acc)

-- | Update the distance matrix with a Lance-Williams update in the rows and columns of cluster b.
-- L 25-27
{-# SCC updateDistMat #-}
updateDistMat ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    Mutable r Ix2 e,
    Fractional e
  ) =>
  JoinStrat e ->
  Ix1 ->
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  DendroAccM m e ->
  m (MArray (PrimState m) r Ix2 e)
updateDistMat :: JoinStrat e
-> Int
-> Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> DendroAccM m e
-> m (MArray (PrimState m) r Ix2 e)
updateDistMat JoinStrat e
js Int
a Int
b IntSet
s MArray (PrimState m) r Ix2 e
distMat DendroAccM m e
dendroAcc
  | Int
nDM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nDM = SizeException -> m (MArray (PrimState m) r Ix2 e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m (MArray (PrimState m) r Ix2 e))
-> SizeException -> m (MArray (PrimState m) r Ix2 e)
forall a b. (a -> b) -> a -> b
$ Sz Int -> Sz Int -> SizeException
forall ix. Index ix => Sz ix -> Sz ix -> SizeException
SizeMismatchException (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
nDM) (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
nCl)
  | Int
mDM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nDM = SizeException -> m (MArray (PrimState m) r Ix2 e)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m (MArray (PrimState m) r Ix2 e))
-> SizeException -> m (MArray (PrimState m) r Ix2 e)
forall a b. (a -> b) -> a -> b
$ Sz Int -> Sz Int -> SizeException
forall ix. Index ix => Sz ix -> Sz ix -> SizeException
SizeMismatchException (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
mDM) (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
nDM)
  | Bool
otherwise = do
    e
dAB <- MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix2 e -> Ix2 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Int
a Int -> Int -> Ix2
:. Int
b)
    Int
nA <- Int -> m Int
clSize Int
a
    Int
nB <- Int -> m Int
clSize Int
b
    Array U Int Int -> (Int -> m ()) -> m ()
forall r ix e (m :: * -> *) a.
(Source r ix e, MonadUnliftIO m) =>
Array r ix e -> (e -> m a) -> m ()
forIO_ Array U Int Int
ixV ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
ix -> do
      e
dAX <- MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix2 e -> Ix2 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Int
a Int -> Int -> Ix2
:. Int
ix)
      Int
nX <- Int -> m Int
clSize Int
ix
      MArray (PrimState m) r Ix2 e -> (e -> m e) -> Ix2 -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_ MArray (PrimState m) r Ix2 e
distMat (\e
dBX -> e -> m e
forall (m :: * -> *) a. Monad m => a -> m a
return (e -> m e) -> e -> m e
forall a b. (a -> b) -> a -> b
$ JoinStrat e -> Int -> Int -> Int -> e -> e -> e -> e
forall e.
Fractional e =>
JoinStrat e -> Int -> Int -> Int -> e -> e -> e -> e
lanceWilliams JoinStrat e
js Int
nA Int
nB Int
nX e
dAB e
dAX e
dBX) (Int
ix Int -> Int -> Ix2
:. Int
b)
      MArray (PrimState m) r Ix2 e -> (e -> m e) -> Ix2 -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_ MArray (PrimState m) r Ix2 e
distMat (\e
dBX -> e -> m e
forall (m :: * -> *) a. Monad m => a -> m a
return (e -> m e) -> e -> m e
forall a b. (a -> b) -> a -> b
$ JoinStrat e -> Int -> Int -> Int -> e -> e -> e -> e
forall e.
Fractional e =>
JoinStrat e -> Int -> Int -> Int -> e -> e -> e -> e
lanceWilliams JoinStrat e
js Int
nA Int
nB Int
nX e
dAB e
dAX e
dBX) (Int
b Int -> Int -> Ix2
:. Int
ix)
    MArray (PrimState m) r Ix2 e -> m (MArray (PrimState m) r Ix2 e)
forall (m :: * -> *) a. Monad m => a -> m a
return MArray (PrimState m) r Ix2 e
distMat
  where
    Sz (Int
mDM :. Int
nDM) = MArray (PrimState m) r Ix2 e -> Sz Ix2
forall r ix e s. Mutable r ix e => MArray s r ix e -> Sz ix
msize MArray (PrimState m) r Ix2 e
distMat
    Sz Int
nCl = DendroAccM m e -> Sz Int
forall r ix e s. Mutable r ix e => MArray s r ix e -> Sz ix
msize DendroAccM m e
dendroAcc
    ixV :: Array U Int Int
ixV = Comp -> [Int] -> Array U Int Int
forall r e. Mutable r Int e => Comp -> [e] -> Array r Int e
Massiv.fromList @U Comp
Par ([Int] -> Array U Int Int)
-> (IntSet -> [Int]) -> IntSet -> Array U Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Int]
IntSet.toAscList (IntSet -> [Int]) -> (IntSet -> IntSet) -> IntSet -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IntSet.delete Int
b (IntSet -> Array U Int Int) -> IntSet -> Array U Int Int
forall a b. (a -> b) -> a -> b
$ IntSet
s
    clSize :: Int -> m Int
clSize Int
i = IntSet -> Int
IntSet.size (IntSet -> Int) -> (Dendrogram e -> IntSet) -> Dendrogram e -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DendroNode e -> IntSet
forall e. DendroNode e -> IntSet
cluster (DendroNode e -> IntSet)
-> (Dendrogram e -> DendroNode e) -> Dendrogram e -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinTree (DendroNode e) -> DendroNode e
forall e. BinTree e -> e
root (BinTree (DendroNode e) -> DendroNode e)
-> (Dendrogram e -> BinTree (DendroNode e))
-> Dendrogram e
-> DendroNode e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dendrogram e -> BinTree (DendroNode e)
forall e. Dendrogram e -> BinTree (DendroNode e)
unDendro (Dendrogram e -> Int) -> m (Dendrogram e) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DendroAccM m e
dendroAcc DendroAccM m e -> Int -> m (Dendrogram e)
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` Int
i

-- | Updates the neighbourlist. All elements with a smaller index than a, that had a as a nearest
-- neighbour are blindly redirected to the union of a and b, now at index b.
-- L 28-32
{-# SCC redirectNeighbours #-}
redirectNeighbours ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    Mutable r Ix1 (e, Ix1),
    Mutable r Ix2 e
  ) =>
  Ix1 ->
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  m (MArray (PrimState m) r Ix1 (e, Ix1))
redirectNeighbours :: Int
-> Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> m (MArray (PrimState m) r Int (e, Int))
redirectNeighbours Int
a Int
b IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Int (e, Int)
nNghbr = do
  Array U Int Int -> (Int -> m ()) -> m ()
forall r ix e (m :: * -> *) a.
(Source r ix e, MonadUnliftIO m) =>
Array r ix e -> (e -> m a) -> m ()
forIO_ Array U Int Int
ixV ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
ix ->
    MArray (PrimState m) r Int (e, Int)
-> ((e, Int) -> m (e, Int)) -> Int -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> (e -> m e) -> ix -> m ()
modifyM_
      MArray (PrimState m) r Int (e, Int)
nNghbr
      ( \old :: (e, Int)
old@(e
_, Int
nghbrX) ->
          if Int
nghbrX Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
a
            then MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix2 e -> Ix2 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Int
ix Int -> Int -> Ix2
:. Int
b) m e -> (e -> m (e, Int)) -> m (e, Int)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \e
dXB -> (e, Int) -> m (e, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (e
dXB, Int
b)
            else (e, Int) -> m (e, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (e, Int)
old
      )
      Int
ix
  MArray (PrimState m) r Int (e, Int)
-> m (MArray (PrimState m) r Int (e, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return MArray (PrimState m) r Int (e, Int)
nNghbr
  where
    ixV :: Array U Int Int
ixV = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array DS Int Int -> Array U Int Int)
-> (IntSet -> Array DS Int Int) -> IntSet -> Array U Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Array U Int Int -> Array DS Int Int
forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
sfilter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
a) (Array U Int Int -> Array DS Int Int)
-> (IntSet -> Array U Int Int) -> IntSet -> Array DS Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> [Int] -> Array U Int Int
forall r e. Mutable r Int e => Comp -> [e] -> Array r Int e
Massiv.fromList @U Comp
Par ([Int] -> Array U Int Int)
-> (IntSet -> [Int]) -> IntSet -> Array U Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Int]
IntSet.toAscList (IntSet -> Array U Int Int) -> IntSet -> Array U Int Int
forall a b. (a -> b) -> a -> b
$ IntSet
s

-- | Updates the list of nearest neighbours for all combinations that might have changed by
-- recalculation with the joined cluster AB at index b.
-- L
{-# SCC updateWithNewBDists #-}
updateWithNewBDists ::
  ( MonadThrow m,
    MonadUnliftIO m,
    PrimMonad m,
    Mutable r Ix2 e,
    Mutable r Ix1 (e, Ix1),
    Ord e
  ) =>
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.PSQ Ix1 e ->
  m (MArray (PrimState m) r Ix1 (e, Ix1), PQ.PSQ Ix1 e)
updateWithNewBDists :: Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (MArray (PrimState m) r Int (e, Int), PSQ Int e)
updateWithNewBDists Int
b IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
pq = do
  TVar (PSQ Int e)
pqT <- PSQ Int e -> m (TVar (PSQ Int e))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO PSQ Int e
pq
  Array U Int Int -> (Int -> m ()) -> m ()
forall r ix e (m :: * -> *) a.
(Source r ix e, MonadUnliftIO m) =>
Array r ix e -> (e -> m a) -> m ()
forIO_ Array U Int Int
ixV ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
ix -> do
    e
dBX <- MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Ix2 e -> Ix2 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Int
ix Int -> Int -> Ix2
:. Int
b)
    PSQ Int e
currentPQ <- TVar (PSQ Int e) -> m (PSQ Int e)
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (PSQ Int e)
pqT
    e
minDistX <- case Int -> PSQ Int e -> Maybe e
forall k p. (Ord k, Ord p) => k -> PSQ k p -> Maybe p
PQ.lookup Int
ix PSQ Int e
currentPQ of
      Maybe e
Nothing -> IndexException -> m e
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m e) -> IndexException -> m e
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Empty priority queue."
      Just e
v -> e -> m e
forall (m :: * -> *) a. Monad m => a -> m a
return e
v
    if e
dBX e -> e -> Bool
forall a. Ord a => a -> a -> Bool
< e
minDistX
      then do
        MArray (PrimState m) r Int (e, Int) -> Int -> (e, Int) -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
writeM MArray (PrimState m) r Int (e, Int)
nNghbr Int
ix (e
dBX, Int
b)
        STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> (PSQ Int e -> STM ()) -> PSQ Int e -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar (PSQ Int e) -> PSQ Int e -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (PSQ Int e)
pqT (PSQ Int e -> STM ())
-> (PSQ Int e -> PSQ Int e) -> PSQ Int e -> STM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> e) -> Int -> PSQ Int e -> PSQ Int e
forall p k. (Ord p, Ord k) => (p -> p) -> k -> PSQ k p -> PSQ k p
PQ.adjust (e -> e -> e
forall a b. a -> b -> a
const e
dBX) Int
ix (PSQ Int e -> m ()) -> PSQ Int e -> m ()
forall a b. (a -> b) -> a -> b
$ PSQ Int e
currentPQ
      else STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> (PSQ Int e -> STM ()) -> PSQ Int e -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar (PSQ Int e) -> PSQ Int e -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (PSQ Int e)
pqT (PSQ Int e -> m ()) -> PSQ Int e -> m ()
forall a b. (a -> b) -> a -> b
$ PSQ Int e
currentPQ

  PSQ Int e
newPQ <- TVar (PSQ Int e) -> m (PSQ Int e)
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (PSQ Int e)
pqT
  (MArray (PrimState m) r Int (e, Int), PSQ Int e)
-> m (MArray (PrimState m) r Int (e, Int), PSQ Int e)
forall (m :: * -> *) a. Monad m => a -> m a
return (MArray (PrimState m) r Int (e, Int)
nNghbr, PSQ Int e
newPQ)
  where
    ixV :: Array U Int Int
ixV = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array DS Int Int -> Array U Int Int)
-> (IntSet -> Array DS Int Int) -> IntSet -> Array U Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Array U Int Int -> Array DS Int Int
forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
Massiv.sfilter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
b) (Array U Int Int -> Array DS Int Int)
-> (IntSet -> Array U Int Int) -> IntSet -> Array DS Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> [Int] -> Array U Int Int
forall r e. Mutable r Int e => Comp -> [e] -> Array r Int e
Massiv.fromList @U Comp
Par ([Int] -> Array U Int Int)
-> (IntSet -> [Int]) -> IntSet -> Array U Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Int]
IntSet.toAscList (IntSet -> Array U Int Int) -> IntSet -> Array U Int Int
forall a b. (a -> b) -> a -> b
$ IntSet
s

--  | Updates the list of nearest neighbours and the priority queue at key b.
-- L 39-40
{-# SCC updateBNeighbour #-}
updateBNeighbour ::
  ( MonadThrow m,
    PrimMonad m,
    MonadUnliftIO m,
    Mutable r Ix1 (e, Ix1),
    Mutable r Ix2 e,
    Ord e
  ) =>
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  MArray (PrimState m) r Ix1 (e, Ix1) ->
  PQ.PSQ Ix1 e ->
  m (MArray (PrimState m) r Ix1 (e, Ix1), PQ.PSQ Ix1 e)
updateBNeighbour :: Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> MArray (PrimState m) r Int (e, Int)
-> PSQ Int e
-> m (MArray (PrimState m) r Int (e, Int), PSQ Int e)
updateBNeighbour Int
b IntSet
s MArray (PrimState m) r Ix2 e
distMat MArray (PrimState m) r Int (e, Int)
nNghbr PSQ Int e
pq =
  if Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nNeighbours
    then (MArray (PrimState m) r Int (e, Int), PSQ Int e)
-> m (MArray (PrimState m) r Int (e, Int), PSQ Int e)
forall (m :: * -> *) a. Monad m => a -> m a
return (MArray (PrimState m) r Int (e, Int)
nNghbr, PSQ Int e
pq)
    else do
      Array r Int (e, Int)
rowAB <- Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Int (e, Int))
forall (m :: * -> *) r e.
(PrimMonad m, MonadThrow m, MonadUnliftIO m, Mutable r Ix2 e,
 Mutable r Int (e, Int)) =>
Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Int (e, Int))
searchRow Int
b IntSet
s MArray (PrimState m) r Ix2 e
distMat m (MArray (PrimState m) r Int (e, Int))
-> (MArray (PrimState m) r Int (e, Int)
    -> m (Array r Int (e, Int)))
-> m (Array r Int (e, Int))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp
-> MArray (PrimState m) r Int (e, Int) -> m (Array r Int (e, Int))
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
unsafeFreeze Comp
Par
      newNeighbourB :: (e, Int)
newNeighbourB@(e
distB, Int
neighbourB) <- Array r Int (e, Int) -> m (e, Int)
forall (m :: * -> *) r ix e.
(MonadThrow m, Source r ix e, Ord e) =>
Array r ix e -> m e
minimumM Array r Int (e, Int)
rowAB
      MArray (PrimState m) r Int (e, Int) -> Int -> (e, Int) -> m ()
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
writeM MArray (PrimState m) r Int (e, Int)
nNghbr Int
b (e, Int)
newNeighbourB
      let newPQ :: PSQ Int e
newPQ = (e -> e) -> Int -> PSQ Int e -> PSQ Int e
forall p k. (Ord p, Ord k) => (p -> p) -> k -> PSQ k p -> PSQ k p
PQ.adjust (e -> e -> e
forall a b. a -> b -> a
const e
distB) Int
neighbourB PSQ Int e
pq
      (MArray (PrimState m) r Int (e, Int), PSQ Int e)
-> m (MArray (PrimState m) r Int (e, Int), PSQ Int e)
forall (m :: * -> *) a. Monad m => a -> m a
return (MArray (PrimState m) r Int (e, Int)
nNghbr, PSQ Int e
newPQ)
  where
    Sz Int
nNeighbours = MArray (PrimState m) r Int (e, Int) -> Sz Int
forall r ix e s. Mutable r ix e => MArray s r ix e -> Sz ix
msize MArray (PrimState m) r Int (e, Int)
nNghbr

-- | Find the nearest neighbour for each point from a distance matrix. For each point it stores the
-- minimum distance and the index of the other point, that is the nearest neighbour but at a higher
-- index.
{-# SCC nearestNeighbours #-}
nearestNeighbours ::
  ( MonadThrow m,
    Mutable r Ix1 e,
    Mutable r Ix1 (e, Ix1),
    OuterSlice r Ix2 e,
    Source (R r) Ix1 e,
    Ord e,
    Unbox e
  ) =>
  Matrix r e ->
  m (Vector r (e, Ix1))
nearestNeighbours :: Matrix r e -> m (Vector r (e, Int))
nearestNeighbours Matrix r e
distMat
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n = IndexException -> m (Vector r (e, Int))
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Vector r (e, Int)))
-> IndexException -> m (Vector r (e, Int))
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Distance matrix is not square"
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = IndexException -> m (Vector r (e, Int))
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (Vector r (e, Int)))
-> IndexException -> m (Vector r (e, Int))
forall a b. (a -> b) -> a -> b
$ String -> IndexException
IndexException String
"Distance matrix is empty"
  | Bool
otherwise =
    let rows :: Array B Int (Array (R r) Int e)
rows = forall ix e r'.
(Mutable B ix e, Load r' ix e) =>
Array r' ix e -> Array B ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @B (Array D Int (Array (R r) Int e)
 -> Array B Int (Array (R r) Int e))
-> (Matrix r e -> Array D Int (Array (R r) Int e))
-> Matrix r e
-> Array B Int (Array (R r) Int e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix r e -> Array D Int (Array (R r) Int e)
forall r ix e.
OuterSlice r ix e =>
Array r ix e -> Array D Int (Elt r ix e)
outerSlices (Matrix r e -> Array B Int (Array (R r) Int e))
-> Matrix r e -> Array B Int (Array (R r) Int e)
forall a b. (a -> b) -> a -> b
$ Matrix r e
distMat
        minDistIx :: Array D Int (e, Int)
minDistIx =
          (Int -> Array (R r) Int e -> (e, Int))
-> Array B Int (Array (R r) Int e) -> Array D Int (e, Int)
forall r ix e' e.
Source r ix e' =>
(ix -> e' -> e) -> Array r ix e' -> Array D ix e
Massiv.imap (\Int
i Array (R r) Int e
v -> IO (e, Int) -> (e, Int)
forall a. IO a -> a
unsafePerformIO (IO (e, Int) -> (e, Int))
-> (Array (R r) Int e -> IO (e, Int))
-> Array (R r) Int e
-> (e, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector U e -> IO (e, Int)
forall r e (m :: * -> *).
(Manifest r Int e, MonadThrow m, Ord e) =>
Int -> Vector r e -> m (e, Int)
minDistAtVec Int
i (Vector U e -> IO (e, Int))
-> (Array (R r) Int e -> Vector U e)
-> Array (R r) Int e
-> IO (e, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array (R r) Int e -> (e, Int)) -> Array (R r) Int e -> (e, Int)
forall a b. (a -> b) -> a -> b
$ Array (R r) Int e
v) (Array B Int (Array (R r) Int e) -> Array D Int (e, Int))
-> (Array B Int (Array (R r) Int e)
    -> Array B Int (Array (R r) Int e))
-> Array B Int (Array (R r) Int e)
-> Array D Int (e, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array B Int (Array (R r) Int e) -> Array B Int (Array (R r) Int e)
forall r e. Source r Int e => Vector r e -> Vector r e
init (Array B Int (Array (R r) Int e) -> Array D Int (e, Int))
-> Array B Int (Array (R r) Int e) -> Array D Int (e, Int)
forall a b. (a -> b) -> a -> b
$ Array B Int (Array (R r) Int e)
rows
     in Vector r (e, Int) -> m (Vector r (e, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector r (e, Int) -> m (Vector r (e, Int)))
-> (Array D Int (e, Int) -> Vector r (e, Int))
-> Array D Int (e, Int)
-> m (Vector r (e, Int))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int (e, Int) -> Vector r (e, Int)
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute (Array D Int (e, Int) -> m (Vector r (e, Int)))
-> Array D Int (e, Int) -> m (Vector r (e, Int))
forall a b. (a -> b) -> a -> b
$ Array D Int (e, Int)
minDistIx
  where
    Sz (Int
m :. Int
n) = Matrix r e -> Sz Ix2
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Matrix r e
distMat

-- Make a search row for distances. Takes row x from a distance matrix and zips them with their
-- column index. Then keeps only the valid elements of the row, that are still part of the available
-- points. A minimum or maximum search can be performed on the resulting vector and a valid pair of
-- distance and index can be obtained.
searchRow ::
  ( PrimMonad m,
    MonadThrow m,
    MonadUnliftIO m,
    Mutable r Ix2 e,
    Mutable r Ix1 (e, Ix1)
  ) =>
  Ix1 ->
  IntSet ->
  MArray (PrimState m) r Ix2 e ->
  m (MArray (PrimState m) r Ix1 (e, Ix1))
searchRow :: Int
-> IntSet
-> MArray (PrimState m) r Ix2 e
-> m (MArray (PrimState m) r Int (e, Int))
searchRow Int
x IntSet
s MArray (PrimState m) r Ix2 e
dm =
  Comp
-> Sz Int
-> (Int -> m (e, Int))
-> m (MArray (PrimState m) r Int (e, Int))
forall r ix e (m :: * -> *).
(PrimMonad m, MonadUnliftIO m, Mutable r ix e) =>
Comp -> Sz ix -> (ix -> m e) -> m (MArray (PrimState m) r ix e)
makeMArray Comp
Par (Array U Int Int -> Sz Int
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array U Int Int
ixV) ((Int -> m (e, Int)) -> m (MArray (PrimState m) r Int (e, Int)))
-> (Int -> m (e, Int)) -> m (MArray (PrimState m) r Int (e, Int))
forall a b. (a -> b) -> a -> b
$ \Int
ix -> do
    Int
dmIx <- Array U Int Int
ixV Array U Int Int -> Int -> m Int
forall r ix e (m :: * -> *).
(Manifest r ix e, MonadThrow m) =>
Array r ix e -> ix -> m e
!? Int
ix
    (MArray (PrimState m) r Ix2 e
dm MArray (PrimState m) r Ix2 e -> Ix2 -> m e
forall r ix e (m :: * -> *).
(Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
`readM` (Int
x Int -> Int -> Ix2
:. Int
dmIx)) m e -> (e -> m (e, Int)) -> m (e, Int)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \e
dist -> (e, Int) -> m (e, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (e
dist, Int
dmIx)
  where
    ixV :: Vector U Ix1
    ixV :: Array U Int Int
ixV = forall ix e r'.
(Mutable U ix e, Load r' ix e) =>
Array r' ix e -> Array U ix e
forall r ix e r'.
(Mutable r ix e, Load r' ix e) =>
Array r' ix e -> Array r ix e
compute @U (Array DS Int Int -> Array U Int Int)
-> (IntSet -> Array DS Int Int) -> IntSet -> Array U Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Array U Int Int -> Array DS Int Int
forall r ix e.
Stream r ix e =>
(e -> Bool) -> Array r ix e -> Vector DS e
sfilter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
x) (Array U Int Int -> Array DS Int Int)
-> (IntSet -> Array U Int Int) -> IntSet -> Array DS Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> [Int] -> Array U Int Int
forall r e. Mutable r Int e => Comp -> [e] -> Array r Int e
Massiv.fromList @U Comp
Par ([Int] -> Array U Int Int)
-> (IntSet -> [Int]) -> IntSet -> Array U Int Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Int]
IntSet.toAscList (IntSet -> Array U Int Int) -> IntSet -> Array U Int Int
forall a b. (a -> b) -> a -> b
$ IntSet
s