module Data.Algorithm.TSNE.InternalsSpec (main, spec) where import Data.Default (def) import Test.Hspec import Data.Algorithm.TSNE.Internals import Data.Algorithm.TSNE.Checks import Data.Algorithm.TSNE.Types import Data.Algorithm.TSNE.Utils -- `main` is here so that this module can be run from GHCi on its own. It is -- not needed for automatic spec discovery. main :: IO () main = hspec spec u = undefined spec :: Spec spec = do let n = inputSize testInput w = inputValueSize testInput describe "testInput" $ do it "is right shape" $ do testInput `shouldSatisfy` has2DShape (64, 20) it "is valid" $ do inputIsValid testInput `shouldBe` Right () it "has right size" $ do inputSize testInput `shouldBe` 20 it "has right value size" $ do inputValueSize testInput `shouldBe` 64 describe "initSolution3D" $ do it "is right shape" $ --initSolution3D 99 >>= (`shouldSatisfy` (\s -> length s == 3 && all (\xs -> length xs == 99) s)) initSolution3D n >>= (`shouldSatisfy` has2DShape (n,3)) describe "initState" $ do it "is valid state" $ initState n >>= (`shouldSatisfy` isRight . (isValidStateForInput testInput)) describe "neighbourProbabilities" $ do it "is right shape" $ do testNeighbourProbs `shouldSatisfy` has2DShape (n,n) describe "qdist" $ do it "is right shape" $ do s <- initSolution3D n qdist s `shouldSatisfy` has2DShape (n,n) describe "qdist'" $ do it "is right shape" $ do s <- initSolution3D n qdist' s `shouldSatisfy` has2DShape (n,n) describe "gradients" $ do it "is right shape" $ do s <- initState n gradients testNeighbourProbs s `shouldSatisfy` has2DShape (n,3) describe "stepTSNE" $ do it "works" $ do s <- initState n stepTSNE def testInput testNeighbourProbs s `shouldSatisfy` isRight . (isValidStateForInput testInput) isRight :: Either a b -> Bool isRight (Left _) = False isRight (Right _) = True -- first 20 digits from the Python sklearn digits dataset testInput :: TSNEInput testInput = [ [0.0,0.0,5.0,13.0,9.0,1.0,0.0,0.0,0.0,0.0,13.0,15.0,10.0,15.0,5.0,0.0,0.0,3.0,15.0,2.0,0.0,11.0,8.0,0.0,0.0,4.0,12.0,0.0,0.0,8.0,8.0,0.0,0.0,5.0,8.0,0.0,0.0,9.0,8.0,0.0,0.0,4.0,11.0,0.0,1.0,12.0,7.0,0.0,0.0,2.0,14.0,5.0,10.0,12.0,0.0,0.0,0.0,0.0,6.0,13.0,10.0,0.0,0.0,0.0], [0.0,0.0,0.0,12.0,13.0,5.0,0.0,0.0,0.0,0.0,0.0,11.0,16.0,9.0,0.0,0.0,0.0,0.0,3.0,15.0,16.0,6.0,0.0,0.0,0.0,7.0,15.0,16.0,16.0,2.0,0.0,0.0,0.0,0.0,1.0,16.0,16.0,3.0,0.0,0.0,0.0,0.0,1.0,16.0,16.0,6.0,0.0,0.0,0.0,0.0,1.0,16.0,16.0,6.0,0.0,0.0,0.0,0.0,0.0,11.0,16.0,10.0,0.0,0.0], [0.0,0.0,0.0,4.0,15.0,12.0,0.0,0.0,0.0,0.0,3.0,16.0,15.0,14.0,0.0,0.0,0.0,0.0,8.0,13.0,8.0,16.0,0.0,0.0,0.0,0.0,1.0,6.0,15.0,11.0,0.0,0.0,0.0,1.0,8.0,13.0,15.0,1.0,0.0,0.0,0.0,9.0,16.0,16.0,5.0,0.0,0.0,0.0,0.0,3.0,13.0,16.0,16.0,11.0,5.0,0.0,0.0,0.0,0.0,3.0,11.0,16.0,9.0,0.0], [0.0,0.0,7.0,15.0,13.0,1.0,0.0,0.0,0.0,8.0,13.0,6.0,15.0,4.0,0.0,0.0,0.0,2.0,1.0,13.0,13.0,0.0,0.0,0.0,0.0,0.0,2.0,15.0,11.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,12.0,12.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,10.0,8.0,0.0,0.0,0.0,8.0,4.0,5.0,14.0,9.0,0.0,0.0,0.0,7.0,13.0,13.0,9.0,0.0,0.0], [0.0,0.0,0.0,1.0,11.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,13.0,6.0,2.0,2.0,0.0,0.0,0.0,7.0,15.0,0.0,9.0,8.0,0.0,0.0,5.0,16.0,10.0,0.0,16.0,6.0,0.0,0.0,4.0,15.0,16.0,13.0,16.0,1.0,0.0,0.0,0.0,0.0,3.0,15.0,10.0,0.0,0.0,0.0,0.0,0.0,2.0,16.0,4.0,0.0,0.0], [0.0,0.0,12.0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,14.0,16.0,16.0,14.0,0.0,0.0,0.0,0.0,13.0,16.0,15.0,10.0,1.0,0.0,0.0,0.0,11.0,16.0,16.0,7.0,0.0,0.0,0.0,0.0,0.0,4.0,7.0,16.0,7.0,0.0,0.0,0.0,0.0,0.0,4.0,16.0,9.0,0.0,0.0,0.0,5.0,4.0,12.0,16.0,4.0,0.0,0.0,0.0,9.0,16.0,16.0,10.0,0.0,0.0], [0.0,0.0,0.0,12.0,13.0,0.0,0.0,0.0,0.0,0.0,5.0,16.0,8.0,0.0,0.0,0.0,0.0,0.0,13.0,16.0,3.0,0.0,0.0,0.0,0.0,0.0,14.0,13.0,0.0,0.0,0.0,0.0,0.0,0.0,15.0,12.0,7.0,2.0,0.0,0.0,0.0,0.0,13.0,16.0,13.0,16.0,3.0,0.0,0.0,0.0,7.0,16.0,11.0,15.0,8.0,0.0,0.0,0.0,1.0,9.0,15.0,11.0,3.0,0.0], [0.0,0.0,7.0,8.0,13.0,16.0,15.0,1.0,0.0,0.0,7.0,7.0,4.0,11.0,12.0,0.0,0.0,0.0,0.0,0.0,8.0,13.0,1.0,0.0,0.0,4.0,8.0,8.0,15.0,15.0,6.0,0.0,0.0,2.0,11.0,15.0,15.0,4.0,0.0,0.0,0.0,0.0,0.0,16.0,5.0,0.0,0.0,0.0,0.0,0.0,9.0,15.0,1.0,0.0,0.0,0.0,0.0,0.0,13.0,5.0,0.0,0.0,0.0,0.0], [0.0,0.0,9.0,14.0,8.0,1.0,0.0,0.0,0.0,0.0,12.0,14.0,14.0,12.0,0.0,0.0,0.0,0.0,9.0,10.0,0.0,15.0,4.0,0.0,0.0,0.0,3.0,16.0,12.0,14.0,2.0,0.0,0.0,0.0,4.0,16.0,16.0,2.0,0.0,0.0,0.0,3.0,16.0,8.0,10.0,13.0,2.0,0.0,0.0,1.0,15.0,1.0,3.0,16.0,8.0,0.0,0.0,0.0,11.0,16.0,15.0,11.0,1.0,0.0], [0.0,0.0,11.0,12.0,0.0,0.0,0.0,0.0,0.0,2.0,16.0,16.0,16.0,13.0,0.0,0.0,0.0,3.0,16.0,12.0,10.0,14.0,0.0,0.0,0.0,1.0,16.0,1.0,12.0,15.0,0.0,0.0,0.0,0.0,13.0,16.0,9.0,15.0,2.0,0.0,0.0,0.0,0.0,3.0,0.0,9.0,11.0,0.0,0.0,0.0,0.0,0.0,9.0,15.0,4.0,0.0,0.0,0.0,9.0,12.0,13.0,3.0,0.0,0.0], [0.0,0.0,1.0,9.0,15.0,11.0,0.0,0.0,0.0,0.0,11.0,16.0,8.0,14.0,6.0,0.0,0.0,2.0,16.0,10.0,0.0,9.0,9.0,0.0,0.0,1.0,16.0,4.0,0.0,8.0,8.0,0.0,0.0,4.0,16.0,4.0,0.0,8.0,8.0,0.0,0.0,1.0,16.0,5.0,1.0,11.0,3.0,0.0,0.0,0.0,12.0,12.0,10.0,10.0,0.0,0.0,0.0,0.0,1.0,10.0,13.0,3.0,0.0,0.0], [0.0,0.0,0.0,0.0,14.0,13.0,1.0,0.0,0.0,0.0,0.0,5.0,16.0,16.0,2.0,0.0,0.0,0.0,0.0,14.0,16.0,12.0,0.0,0.0,0.0,1.0,10.0,16.0,16.0,12.0,0.0,0.0,0.0,3.0,12.0,14.0,16.0,9.0,0.0,0.0,0.0,0.0,0.0,5.0,16.0,15.0,0.0,0.0,0.0,0.0,0.0,4.0,16.0,14.0,0.0,0.0,0.0,0.0,0.0,1.0,13.0,16.0,1.0,0.0], [0.0,0.0,5.0,12.0,1.0,0.0,0.0,0.0,0.0,0.0,15.0,14.0,7.0,0.0,0.0,0.0,0.0,0.0,13.0,1.0,12.0,0.0,0.0,0.0,0.0,2.0,10.0,0.0,14.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,16.0,1.0,0.0,0.0,0.0,0.0,0.0,6.0,15.0,0.0,0.0,0.0,0.0,0.0,9.0,16.0,15.0,9.0,8.0,2.0,0.0,0.0,3.0,11.0,8.0,13.0,12.0,4.0], [0.0,2.0,9.0,15.0,14.0,9.0,3.0,0.0,0.0,4.0,13.0,8.0,9.0,16.0,8.0,0.0,0.0,0.0,0.0,6.0,14.0,15.0,3.0,0.0,0.0,0.0,0.0,11.0,14.0,2.0,0.0,0.0,0.0,0.0,0.0,2.0,15.0,11.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,15.0,4.0,0.0,0.0,1.0,5.0,6.0,13.0,16.0,6.0,0.0,0.0,2.0,12.0,12.0,13.0,11.0,0.0,0.0], [0.0,0.0,0.0,8.0,15.0,1.0,0.0,0.0,0.0,0.0,1.0,14.0,13.0,1.0,1.0,0.0,0.0,0.0,10.0,15.0,3.0,15.0,11.0,0.0,0.0,7.0,16.0,7.0,1.0,16.0,8.0,0.0,0.0,9.0,16.0,13.0,14.0,16.0,5.0,0.0,0.0,1.0,10.0,15.0,16.0,14.0,0.0,0.0,0.0,0.0,0.0,1.0,16.0,10.0,0.0,0.0,0.0,0.0,0.0,10.0,15.0,4.0,0.0,0.0], [0.0,5.0,12.0,13.0,16.0,16.0,2.0,0.0,0.0,11.0,16.0,15.0,8.0,4.0,0.0,0.0,0.0,8.0,14.0,11.0,1.0,0.0,0.0,0.0,0.0,8.0,16.0,16.0,14.0,0.0,0.0,0.0,0.0,1.0,6.0,6.0,16.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,16.0,3.0,0.0,0.0,0.0,1.0,5.0,15.0,13.0,0.0,0.0,0.0,0.0,4.0,15.0,16.0,2.0,0.0,0.0,0.0], [0.0,0.0,0.0,8.0,15.0,1.0,0.0,0.0,0.0,0.0,0.0,12.0,14.0,0.0,0.0,0.0,0.0,0.0,3.0,16.0,7.0,0.0,0.0,0.0,0.0,0.0,6.0,16.0,2.0,0.0,0.0,0.0,0.0,0.0,7.0,16.0,16.0,13.0,5.0,0.0,0.0,0.0,15.0,16.0,9.0,9.0,14.0,0.0,0.0,0.0,3.0,14.0,9.0,2.0,16.0,2.0,0.0,0.0,0.0,7.0,15.0,16.0,11.0,0.0], [0.0,0.0,1.0,8.0,15.0,10.0,0.0,0.0,0.0,3.0,13.0,15.0,14.0,14.0,0.0,0.0,0.0,5.0,10.0,0.0,10.0,12.0,0.0,0.0,0.0,0.0,3.0,5.0,15.0,10.0,2.0,0.0,0.0,0.0,16.0,16.0,16.0,16.0,12.0,0.0,0.0,1.0,8.0,12.0,14.0,8.0,3.0,0.0,0.0,0.0,0.0,10.0,13.0,0.0,0.0,0.0,0.0,0.0,0.0,11.0,9.0,0.0,0.0,0.0], [0.0,0.0,10.0,7.0,13.0,9.0,0.0,0.0,0.0,0.0,9.0,10.0,12.0,15.0,2.0,0.0,0.0,0.0,4.0,11.0,10.0,11.0,0.0,0.0,0.0,0.0,1.0,16.0,10.0,1.0,0.0,0.0,0.0,0.0,12.0,13.0,4.0,0.0,0.0,0.0,0.0,0.0,12.0,1.0,12.0,0.0,0.0,0.0,0.0,1.0,10.0,2.0,14.0,0.0,0.0,0.0,0.0,0.0,11.0,14.0,5.0,0.0,0.0,0.0], [0.0,0.0,6.0,14.0,4.0,0.0,0.0,0.0,0.0,0.0,11.0,16.0,10.0,0.0,0.0,0.0,0.0,0.0,8.0,14.0,16.0,2.0,0.0,0.0,0.0,0.0,1.0,12.0,12.0,11.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,11.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,11.0,0.0,0.0,0.0,1.0,4.0,4.0,7.0,16.0,2.0,0.0,0.0,7.0,16.0,16.0,13.0,11.0,1.0] ] testNeighbourProbs :: [[Probability]] testNeighbourProbs = neighbourProbabilities def testInput