{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Spectral.SMVM -- Copyright : [2009..2020] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Test.NoFib.Spectral.SMVM ( test_smvm, ) where import Prelude as P import Data.Array.Accelerate as A import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Test.Tasty import Test.Tasty.Hedgehog test_smvm :: RunN -> TestTree test_smvm runN = testGroup "smvm" [ at @TestHalf $ testElt f16 , at @TestFloat $ testElt f32 , at @TestDouble $ testElt f64 ] where testElt :: forall a. (P.Num a, A.Num a, Similar a, Show a) => Gen a -> TestTree testElt e = testProperty (show (eltR @a)) $ test_smvm' runN e test_smvm' :: (A.Num e, P.Num e, Similar e, Show e) => RunN -> Gen e -> Property test_smvm' runN e = property $ do (smat, cols) <- forAll (sparseMatrix e) vec <- forAll (array (Z:.cols) e) -- let !go = runN smvm in go smat vec ~~~ smvmRef smat vec sparseMatrix :: Elt e => Gen e -> Gen (SparseMatrix e, Int) sparseMatrix e = do rows <- Gen.int (Range.linear 1 256) cols <- Gen.int (Range.linear 1 256) seg <- array (Z:.rows) (Gen.int (Range.linear 0 cols)) let nnz = P.sum (toList seg) smat <- array (Z:.nnz) ((,) <$> Gen.int (Range.linear 0 (cols-1)) <*> e) return ((seg,smat), cols) type SparseVector e = Vector (Int, e) type SparseMatrix e = (Segments Int, SparseVector e) smvm :: A.Num a => Acc (SparseMatrix a) -> Acc (Vector a) -> Acc (Vector a) smvm smat vec = let (segd, svec) = unlift smat (inds, vals) = A.unzip svec vecVals = A.gather inds vec products = A.zipWith (*) vecVals vals in foldSeg (+) 0 products segd smvmRef :: (Elt a, P.Num a) => SparseMatrix a -> Vector a -> Vector a smvmRef (segd, smat) vec = fromList (arrayShape segd) [ P.sum [ val * indexArray vec (Z :. i) | (i,val) <- row ] | row <- splitPlaces (toList segd) (toList smat) ]