module Data.Array.Accelerate.LinearAlgebra.Private where
import qualified Data.Array.Accelerate.Utility.Loop as Loop
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate
(Acc, Array, Exp, Any(Any), All(All), Z(Z), (:.)((:.)))
type Scalar ix a = Acc (Array ix a)
type Vector ix a = Acc (Array (ix :. Int) a)
type Matrix ix a = Acc (Array (ix :. Int :. Int) a)
transpose ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Matrix ix a -> Matrix ix a
transpose m =
A.backpermute
(A.lift $ swapIndex $ matrixShape m)
(A.lift . swapIndex . A.unlift)
m
swapIndex ::
Exp ix :. Exp Int :. Exp Int ->
Exp ix :. Exp Int :. Exp Int
swapIndex (ix :. r :. c) = (ix :. c :. r)
numElems :: (A.Shape ix, A.Slice ix, A.Elt a) => Vector ix a -> Exp Int
numElems m = case vectorShape m of _ix :. n -> n
numRows :: (A.Shape ix, A.Slice ix, A.Elt a) => Matrix ix a -> Exp Int
numRows m = case matrixShape m of _ix :. rows :. _cols -> rows
numCols :: (A.Shape ix, A.Slice ix, A.Elt a) => Matrix ix a -> Exp Int
numCols m = case matrixShape m of _ix :. _rows :. cols -> cols
vectorShape ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Vector ix a -> Exp ix :. Exp Int
vectorShape m = A.unlift $ A.shape m
matrixShape ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Matrix ix a -> Exp ix :. Exp Int :. Exp Int
matrixShape m = A.unlift $ A.shape m
withVectorIndex ::
(A.Shape ix, A.Slice ix, A.Lift Exp a) =>
(Exp ix :. Exp Int -> a) ->
(Exp (ix :. Int) -> Exp (A.Plain a))
withVectorIndex f = A.lift . f . A.unlift
withMatrixIndex ::
(A.Shape ix, A.Slice ix, A.Lift Exp a) =>
(Exp ix :. Exp Int :. Exp Int -> a) ->
(Exp (ix :. Int :. Int) -> Exp (A.Plain a))
withMatrixIndex f = A.lift . f . A.unlift
outer ::
(A.Shape ix, A.Slice ix, A.Num a) =>
Vector ix a -> Vector ix a -> Matrix ix a
outer x y =
A.zipWith (*)
(A.replicate (A.lift $ Any :. All :. numElems y) x)
(A.replicate (A.lift $ Any :. numElems x :. All) y)
multiplyMatrixVector ::
(A.Shape ix, A.Slice ix, A.Num a) =>
Matrix ix a ->
Vector ix a ->
Vector ix a
multiplyMatrixVector m v =
case matrixShape m of
(_ix :. rows :. _cols) ->
A.fold1 (+) $
A.zipWith (*) m
(A.replicate (A.lift $ Any :. rows :. All) v)
multiplyMatrixMatrix ::
(A.Shape ix, A.Slice ix, A.Num a) =>
Matrix ix a ->
Matrix ix a ->
Matrix ix a
multiplyMatrixMatrix x y =
case (matrixShape x, matrixShape y) of
(_ :. rows :. _cols, _ :. _rows :. cols) ->
A.fold1 (+) $ transpose $
A.zipWith (*)
(A.replicate (A.lift $ Any :. All :. All :. cols) x)
(A.replicate (A.lift $ Any :. rows :. All :. All) y)
newtonInverseStep ::
(A.Shape ix, A.Slice ix, A.Num a) =>
Matrix ix a ->
Matrix ix a ->
Matrix ix a
newtonInverseStep a x =
A.zipWith () (A.map (2*) x) $
multiplyMatrixMatrix x $ multiplyMatrixMatrix a x
identity ::
(A.Shape ix, A.Slice ix, A.Elt a, A.FromIntegral Int a) =>
Exp (ix :. Int :. Int) -> Matrix ix a
identity sh =
A.generate sh
(withMatrixIndex $
\(_ :. r :. c) -> A.fromIntegral $ A.boolToInt (r A.== c))
newtonInverse ::
(A.Shape ix, A.Slice ix, A.Num a) =>
Exp Int ->
Matrix ix a ->
Matrix ix a ->
Matrix ix a
newtonInverse n seed a =
Loop.nest n (newtonInverseStep a) seed
scaleRows ::
(A.Slice ix, A.Shape ix, A.Num a) =>
Vector ix a -> Matrix ix a -> Matrix ix a
scaleRows s x =
zipScalarVectorWith (*) s x
zipScalarVectorWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Scalar ix a -> Vector ix b -> Vector ix c
zipScalarVectorWith f x ys =
case vectorShape ys of
_ix :. dim ->
A.zipWith f (A.replicate (A.lift (Any :. dim)) x) ys
zipScalarMatrixWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Scalar ix a -> Matrix ix b -> Matrix ix c
zipScalarMatrixWith f x ys =
case matrixShape ys of
_ix :. rows :. cols ->
A.zipWith f
(A.replicate (A.lift (Any :. rows :. cols)) x) ys
columnFromVector ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Vector ix a -> Matrix ix a
columnFromVector a = A.reshape (Exp.indexCons (A.shape a) 1) a
vectorFromColumn ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Matrix ix a -> Vector ix a
vectorFromColumn a = A.reshape (A.indexTail $ A.shape a) a
flattenMatrix, flattenMatrixReshape, flattenMatrixBackPermute ::
(A.Slice ix, A.Shape ix, A.Elt a) =>
Matrix ix a -> Vector ix a
flattenMatrix = flattenMatrixBackPermute
flattenMatrixReshape m =
case matrixShape m of
ix :. rows :. cols ->
A.reshape (A.lift $ ix :. rows*cols) m
accDivMod :: Integral a => a -> a -> (a, a)
accDivMod x y = (div x y, mod x y)
flattenMatrixBackPermute m =
case matrixShape m of
ix :. rows :. cols ->
A.backpermute
(A.lift $ ix :. rows*cols)
(withVectorIndex $
\(vix :. n) -> case accDivMod n cols of (r,c) -> vix :. r :. c)
m
restoreMatrix, restoreMatrixReshape, restoreMatrixBackPermute ::
(A.Slice ix, A.Shape ix, A.Elt a) =>
Exp Int -> Vector ix a -> Matrix ix a
restoreMatrix = restoreMatrixBackPermute
restoreMatrixReshape cols v =
case vectorShape v of
ix :. n ->
A.reshape (A.lift $ ix :. div n cols :. cols) v
restoreMatrixBackPermute cols v =
case vectorShape v of
ix :. n ->
A.backpermute
(A.lift $ ix :. div n cols :. cols)
(withMatrixIndex $ \(vix :. k :. j) -> vix :. k*cols+j)
v
extrudeVector ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Exp ix -> Vector Z a -> Vector ix a
extrudeVector shape y =
A.backpermute
(A.lift $ shape :. numElems y)
(A.index1 . A.indexHead)
y
extrudeMatrix ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Exp ix -> Matrix Z a -> Matrix ix a
extrudeMatrix shape y =
A.backpermute
(A.lift $ shape :. numRows y :. numCols y)
(withMatrixIndex $ \(_:.r:.c) -> Z:.r:.c)
y
zipExtrudedVectorWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Vector Z a ->
Vector ix b ->
Vector ix c
zipExtrudedVectorWith f x y =
A.zipWith f (extrudeVector (A.indexTail $ A.shape y) x) y
zipExtrudedMatrixWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Matrix Z a ->
Matrix ix b ->
Matrix ix c
zipExtrudedMatrixWith f x y =
A.zipWith f (extrudeMatrix (A.indexTail $ A.indexTail $ A.shape y) x) y
gatherFromVector ::
(A.Shape ix, A.Elt a) =>
Scalar ix Int -> Vector Z a -> Scalar ix a
gatherFromVector indices =
Arrange.gather (A.map A.index1 indices)