module Data.Array.Accelerate.TypeLits.Internal where
import GHC.TypeLits ( Nat, KnownNat, natVal)
import Control.Monad (replicateM)
import qualified Data.Array.Accelerate as A
import qualified Data.Array.Accelerate.Interpreter as I
import Data.Proxy (Proxy(..))
import Data.Array.Accelerate ( (:.)((:.)), Array
, Exp
, DIM0, DIM1, DIM2, Z(Z)
, Elt, Acc
)
import Test.SmallCheck.Series
import Test.QuickCheck.Arbitrary
newtype AccScalar a = AccScalar { unScalar :: Acc (Array DIM0 a)}
deriving (Show)
instance forall a. (Eq a, Elt a) => Eq (AccScalar a) where
s == t = let s' = I.run $ unScalar s
t' = I.run $ unScalar t
in A.toList s' == A.toList t'
newtype AccVector (dim :: Nat) a = AccVector { unVector :: Acc (Array DIM1 a)}
deriving (Show)
instance forall n a. (KnownNat n, Eq a, Elt a) => Eq (AccVector n a) where
v == w = let v' = I.run $ unVector v
w' = I.run $ unVector w
in A.toList v' == A.toList w'
instance forall mm n a. (Serial mm a, KnownNat n, Eq a, Elt a)
=> Serial mm (AccVector n a) where
series = AccVector . A.use . A.fromList (Z:.n') <$> cons1 (replicate n')
where n' = fromIntegral $ natVal (Proxy :: Proxy n)
instance forall n a. (KnownNat n, Arbitrary a, Eq a, Elt a)
=> Arbitrary (AccVector n a) where
arbitrary = AccVector . A.use . A.fromList (Z:.n') <$> replicateM n' arbitrary
where n' = fromIntegral $ natVal (Proxy :: Proxy n)
newtype AccMatrix (rows :: Nat) (cols :: Nat) a = AccMatrix {unMatrix :: Acc (Array DIM2 a)}
deriving (Show)
instance forall m n a. (KnownNat m, KnownNat n, Eq a, Elt a) => Eq (AccMatrix m n a) where
v == w = let v' = I.run $ unMatrix v
w' = I.run $ unMatrix w
in A.toList v' == A.toList w'
instance forall mm m n a. (Serial mm a, KnownNat m, KnownNat n, Eq a, Elt a)
=> Serial mm (AccMatrix m n a) where
series = AccMatrix . A.use . A.fromList (Z:.m':.n') <$> cons1 (replicate $ m'*n')
where m' = fromIntegral $ natVal (Proxy :: Proxy m)
n' = fromIntegral $ natVal (Proxy :: Proxy n)
instance forall m n a. (KnownNat m, KnownNat n, Arbitrary a, Eq a, Elt a)
=> Arbitrary (AccMatrix m n a) where
arbitrary = AccMatrix . A.use . A.fromList (Z:.m':.n') <$> replicateM (m'*n') arbitrary
where m' = fromIntegral $ natVal (Proxy :: Proxy m)
n' = fromIntegral $ natVal (Proxy :: Proxy n)
class AccFunctor f where
afmap :: forall a b. (Elt a, Elt b) => (Exp a -> Exp b) -> f a -> f b
instance AccFunctor AccScalar where
afmap f (AccScalar a) = AccScalar (A.map f a)
instance forall n. (KnownNat n) => AccFunctor (AccVector n) where
afmap f (AccVector a) = AccVector (A.map f a)
instance forall m n. (KnownNat m, KnownNat n) => AccFunctor (AccMatrix m n) where
afmap f (AccMatrix a) = AccMatrix (A.map f a)
mkVector :: forall n a. (KnownNat n, Elt a) => [a] -> Maybe (AccVector n a)
mkVector as = if length as == n'
then Just $ unsafeMkVector as
else Nothing
where n' = fromIntegral $ natVal (Proxy :: Proxy n)
unsafeMkVector :: forall n a. (KnownNat n, Elt a) => [a] -> AccVector n a
unsafeMkVector as = AccVector (A.use $ A.fromList (Z:.n') as)
where n' = fromIntegral $ natVal (Proxy :: Proxy n)
mkMatrix :: forall m n a. (KnownNat m, KnownNat n, Elt a)
=> [a] -> Maybe (AccMatrix m n a)
mkMatrix as = if length as == m'*n'
then Just $ unsafeMkMatrix as
else Nothing
where m' = fromIntegral $ natVal (Proxy :: Proxy m)
n' = fromIntegral $ natVal (Proxy :: Proxy n)
unsafeMkMatrix :: forall m n a. (KnownNat m, KnownNat n, Elt a)
=> [a] -> AccMatrix m n a
unsafeMkMatrix as = AccMatrix (A.use $ A.fromList (Z:. m':.n') as)
where m' = fromIntegral $ natVal (Proxy :: Proxy m)
n' = fromIntegral $ natVal (Proxy :: Proxy n)
mkScalar :: forall a. Elt a => Exp a -> AccScalar a
mkScalar = AccScalar . A.unit