{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
module NumHask.Array.Fixed
(
Array (..),
with,
shape,
toDynamic,
reshape,
transpose,
diag,
selects,
selectsExcept,
folds,
extracts,
joins,
maps,
concatenate,
insert,
append,
reorder,
expand,
contract,
dot,
slice,
squeeze,
ident,
singleton,
Scalar,
fromScalar,
toScalar,
Vector,
Matrix,
col,
row,
safeCol,
safeRow,
mmult,
)
where
import Data.Distributive (Distributive (..))
import Data.Functor.Rep
import Data.List ((!!))
import qualified Data.Vector as V
import GHC.Exts (IsList (..))
import GHC.Show (Show (..))
import GHC.TypeLits
import qualified NumHask.Array.Dynamic as D
import NumHask.Array.Shape
import NumHask.Prelude as P hiding (identity, outer, transpose)
newtype Array s a = Array {unArray :: V.Vector a} deriving (Eq, Ord, NFData, Functor, Foldable, Generic, Traversable)
instance (HasShape s, Show a) => Show (Array s a) where
show a = GHC.Show.show (toDynamic a)
instance
( HasShape s
) =>
Data.Distributive.Distributive (Array s)
where
distribute = distributeRep
{-# INLINE distribute #-}
instance
forall s.
( HasShape s
) =>
Representable (Array s)
where
type Rep (Array s) = [Int]
tabulate f =
Array . V.generate (size s) $ (f . shapen s)
where
s = shapeVal $ toShape @s
{-# INLINE tabulate #-}
index (Array v) i = V.unsafeIndex v (flatten s i)
where
s = shapeVal (toShape @s)
{-# INLINE index #-}
instance
( Additive a,
HasShape s
) =>
Additive (Array s a)
where
(+) = liftR2 (+)
zero = pureRep zero
instance
( Subtractive a,
HasShape s
) =>
Subtractive (Array s a)
where
negate = fmapRep negate
type instance Actor (Array s a) = a
instance
( Multiplicative a,
HasShape s
) =>
HadamardMultiplication (Array s) a
where
(.*.) = liftR2 (*)
instance
( Divisive a,
HasShape s
) =>
HadamardDivision (Array s) a
where
(./.) = liftR2 (/)
instance
(HasShape s, Multiplicative a) =>
MultiplicativeAction (Array s a)
where
(.*) r s = fmap (* s) r
{-# INLINE (.*) #-}
(*.) s = fmap (s *)
{-# INLINE (*.) #-}
instance (HasShape s, JoinSemiLattice a) => JoinSemiLattice (Array s a) where
(\/) = liftR2 (\/)
instance (HasShape s, MeetSemiLattice a) => MeetSemiLattice (Array s a) where
(/\) = liftR2 (/\)
instance (HasShape s, Subtractive a, Epsilon a) => Epsilon (Array s a) where
epsilon = singleton epsilon
nearZero (Array a) = all nearZero a
instance
( HasShape s
) =>
IsList (Array s a)
where
type Item (Array s a) = a
fromList l =
bool
(throw (NumHaskException "shape mismatch"))
(Array $ V.fromList l)
((length l == 1 && null ds) || (length l == size ds))
where
ds = shapeVal (toShape @s)
toList (Array v) = V.toList v
shape :: forall a s. (HasShape s) => Array s a -> [Int]
shape _ = shapeVal $ toShape @s
{-# INLINE shape #-}
toDynamic :: (HasShape s) => Array s a -> D.Array a
toDynamic a = D.fromFlatList (shape a) (P.toList a)
with ::
forall a r s.
(HasShape s) =>
D.Array a ->
(Array s a -> r) ->
r
with (D.Array _ v) f = f (Array v)
reshape ::
forall a s s'.
( Size s ~ Size s',
HasShape s,
HasShape s'
) =>
Array s a ->
Array s' a
reshape a = tabulate (index a . shapen s . flatten s')
where
s = shapeVal (toShape @s)
s' = shapeVal (toShape @s')
transpose :: forall a s. (HasShape s, HasShape (Reverse s)) => Array s a -> Array (Reverse s) a
transpose a = tabulate (index a . reverse)
ident :: forall a s. (HasShape s, Additive a, Multiplicative a) => Array s a
ident = tabulate (bool zero one . isDiag)
where
isDiag [] = True
isDiag [_] = True
isDiag [x, y] = x == y
isDiag (x : y : xs) = x == y && isDiag (y : xs)
diag ::
forall a s.
( HasShape s,
HasShape '[Minimum s]
) =>
Array s a ->
Array '[Minimum s] a
diag a = tabulate go
where
go [] = throw (NumHaskException "Rank Underflow")
go (s' : _) = index a (replicate (length ds) s')
ds = shapeVal (toShape @s)
singleton :: (HasShape s) => a -> Array s a
singleton a = tabulate (const a)
selects ::
forall ds s s' a.
( HasShape s,
HasShape ds,
HasShape s',
s' ~ DropIndexes s ds
) =>
Proxy ds ->
[Int] ->
Array s a ->
Array s' a
selects _ i a = tabulate go
where
go s = index a (addIndexes s ds i)
ds = shapeVal (toShape @ds)
selectsExcept ::
forall ds s s' a.
( HasShape s,
HasShape ds,
HasShape s',
s' ~ TakeIndexes s ds
) =>
Proxy ds ->
[Int] ->
Array s a ->
Array s' a
selectsExcept _ i a = tabulate go
where
go s = index a (addIndexes i ds s)
ds = shapeVal (toShape @ds)
folds ::
forall ds st si so a b.
( HasShape st,
HasShape ds,
HasShape si,
HasShape so,
si ~ DropIndexes st ds,
so ~ TakeIndexes st ds
) =>
(Array si a -> b) ->
Proxy ds ->
Array st a ->
Array so b
folds f d a = tabulate go
where
go s = f (selects d s a)
extracts ::
forall ds st si so a.
( HasShape st,
HasShape ds,
HasShape si,
HasShape so,
si ~ DropIndexes st ds,
so ~ TakeIndexes st ds
) =>
Proxy ds ->
Array st a ->
Array so (Array si a)
extracts d a = tabulate go
where
go s = selects d s a
extractsExcept ::
forall ds st si so a.
( HasShape st,
HasShape ds,
HasShape si,
HasShape so,
so ~ DropIndexes st ds,
si ~ TakeIndexes st ds
) =>
Proxy ds ->
Array st a ->
Array so (Array si a)
extractsExcept d a = tabulate go
where
go s = selectsExcept d s a
joins ::
forall ds si st so a.
( HasShape st,
HasShape ds,
st ~ AddIndexes si ds so,
HasShape si,
HasShape so
) =>
Proxy ds ->
Array so (Array si a) ->
Array st a
joins _ a = tabulate go
where
go s = index (index a (takeIndexes s ds)) (dropIndexes s ds)
ds = shapeVal (toShape @ds)
maps ::
forall ds st st' si si' so a b.
( HasShape st,
HasShape st',
HasShape ds,
HasShape si,
HasShape si',
HasShape so,
si ~ DropIndexes st ds,
so ~ TakeIndexes st ds,
st' ~ AddIndexes si' ds so,
st ~ AddIndexes si ds so
) =>
(Array si a -> Array si' b) ->
Proxy ds ->
Array st a ->
Array st' b
maps f d a = joins d (fmapRep f (extracts d a))
concatenate ::
forall a s0 s1 d s.
( CheckConcatenate d s0 s1 s,
Concatenate d s0 s1 ~ s,
HasShape s0,
HasShape s1,
HasShape s,
KnownNat d
) =>
Proxy d ->
Array s0 a ->
Array s1 a ->
Array s a
concatenate _ s0 s1 = tabulate go
where
go s =
bool
(index s0 s)
( index
s1
( addIndex
(dropIndex s d)
d
((s !! d) - (ds0 !! d))
)
)
((s !! d) >= (ds0 !! d))
ds0 = shapeVal (toShape @s0)
d = fromIntegral $ natVal @d Proxy
insert ::
forall a s s' d i.
( DropIndex s d ~ s',
CheckInsert d i s,
KnownNat i,
KnownNat d,
HasShape s,
HasShape s',
HasShape (Insert d s)
) =>
Proxy d ->
Proxy i ->
Array s a ->
Array s' a ->
Array (Insert d s) a
insert _ _ a b = tabulate go
where
go s
| s !! d == i = index b (dropIndex s d)
| s !! d < i = index a s
| otherwise = index a (decAt d s)
d = fromIntegral $ natVal @d Proxy
i = fromIntegral $ natVal @i Proxy
append ::
forall a d s s'.
( DropIndex s d ~ s',
CheckInsert d (Dimension s d - 1) s,
KnownNat (Dimension s d - 1),
KnownNat d,
HasShape s,
HasShape s',
HasShape (Insert d s)
) =>
Proxy d ->
Array s a ->
Array s' a ->
Array (Insert d s) a
append d = insert d (Proxy :: Proxy (Dimension s d - 1))
reorder ::
forall a ds s.
( HasShape ds,
HasShape s,
HasShape (Reorder s ds),
CheckReorder ds s
) =>
Proxy ds ->
Array s a ->
Array (Reorder s ds) a
reorder _ a = tabulate go
where
go s = index a (addIndexes [] ds s)
ds = shapeVal (toShape @ds)
expand ::
forall s s' a b c.
( HasShape s,
HasShape s',
HasShape ((++) s s')
) =>
(a -> b -> c) ->
Array s a ->
Array s' b ->
Array ((++) s s') c
expand f a b = tabulate (\i -> f (index a (take r i)) (index b (drop r i)))
where
r = rank (shape a)
contract ::
forall a b s ss s' ds.
( KnownNat (Minimum (TakeIndexes s ds)),
HasShape (TakeIndexes s ds),
HasShape s,
HasShape ds,
HasShape ss,
HasShape s',
s' ~ DropIndexes s ds,
ss ~ '[Minimum (TakeIndexes s ds)]
) =>
(Array ss a -> b) ->
Proxy ds ->
Array s a ->
Array s' b
contract f xs a = f . diag <$> extractsExcept xs a
dot ::
forall a b c d sa sb s' ss se.
( HasShape sa,
HasShape sb,
HasShape (sa ++ sb),
se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
HasShape se,
KnownNat (Minimum se),
KnownNat (Rank sa - 1),
KnownNat (Rank sa),
ss ~ '[Minimum se],
HasShape ss,
s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
HasShape s'
) =>
(Array ss c -> d) ->
(a -> b -> c) ->
Array sa a ->
Array sb b ->
Array s' d
dot f g a b = contract f (Proxy :: Proxy '[Rank sa - 1, Rank sa]) (expand g a b)
slice ::
forall (pss :: [[Nat]]) s s' a.
( HasShape s,
HasShape s',
KnownNatss pss,
KnownNat (Rank pss),
s' ~ Ranks pss
) =>
Proxy pss ->
Array s a ->
Array s' a
slice pss a = tabulate go
where
go s = index a (zipWith (!!) pss' s)
pss' = natValss pss
squeeze ::
forall s t a.
(t ~ Squeeze s) =>
Array s a ->
Array t a
squeeze (Array x) = Array x
type Scalar a = Array ('[] :: [Nat]) a
fromScalar :: (HasShape ('[] :: [Nat])) => Array ('[] :: [Nat]) a -> a
fromScalar a = index a ([] :: [Int])
toScalar :: (HasShape ('[] :: [Nat])) => a -> Array ('[] :: [Nat]) a
toScalar a = fromList [a]
type Vector s a = Array '[s] a
type Matrix m n a = Array '[m, n] a
instance
( Multiplicative a,
P.Distributive a,
Subtractive a,
KnownNat m,
HasShape '[m, m]
) =>
Multiplicative (Matrix m m a)
where
(*) = mmult
one = ident
row :: forall m n a. (KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
row i (Array a) = Array $ V.slice (i * n) n a
where
n = fromIntegral $ natVal @n Proxy
safeRow :: forall m n a j. ('True ~ CheckIndex j m, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeRow _j (Array a) = Array $ V.slice (j * n) n a
where
n = fromIntegral $ natVal @n Proxy
j = fromIntegral $ natVal @j Proxy
col :: forall m n a. (KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
col i (Array a) = Array $ V.generate m (\x -> V.unsafeIndex a (i + x * n))
where
m = fromIntegral $ natVal @m Proxy
n = fromIntegral $ natVal @n Proxy
safeCol :: forall m n a j. ('True ~ CheckIndex j n, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeCol _j (Array a) = Array $ V.generate m (\x -> V.unsafeIndex a (j + x * n))
where
m = fromIntegral $ natVal @m Proxy
n = fromIntegral $ natVal @n Proxy
j = fromIntegral $ natVal @j Proxy
mmult ::
forall m n k a.
( KnownNat k,
KnownNat m,
KnownNat n,
HasShape [m, n],
Ring a
) =>
Array [m, k] a ->
Array [k, n] a ->
Array [m, n] a
mmult (Array x) (Array y) = tabulate go
where
go [] = throw (NumHaskException "Needs two dimensions")
go [_] = throw (NumHaskException "Needs two dimensions")
go (i : j : _) = sum $ V.zipWith (*) (V.slice (fromIntegral i * k) k x) (V.generate k (\x' -> y V.! (fromIntegral j + x' * n)))
n = fromIntegral $ natVal @n Proxy
k = fromIntegral $ natVal @k Proxy
{-# INLINE mmult #-}