{-# LANGUAGE TypeFamilies #-} module Data.Array.Knead.Arithmetic.LinearAlgebra where import qualified Data.Array.Knead.Parameterized.Physical as Phys import qualified Data.Array.Knead.Parameterized.Symbolic as SymP import qualified Data.Array.Knead.Simple.Symbolic as Sym import qualified Data.Array.Knead.Simple.ShapeDependent as ShapeDep import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import LLVM.Extra.Multi.Value (atom) import qualified LLVM.Core as LLVM import Foreign.Storable (Storable) import Control.Arrow (arr) import Control.Monad.HT (chain) import qualified Data.List as List type Scalar p coll a = SymP.Array p coll a type Vector p coll dim a = SymP.Array p (coll, dim) a type Matrix p coll rows cols a = SymP.Array p (coll, (rows, cols)) a type PhysScalar coll a = Phys.Array coll a type PhysVector coll dim a = Phys.Array (coll, dim) a type PhysMatrix coll rows cols a = Phys.Array (coll, (rows, cols)) a type IOScalar p coll a = IO (p -> IO (PhysScalar coll a)) type IOVector p coll dim a = IO (p -> IO (PhysVector coll dim a)) type IOMatrix p coll rows cols a = IO (p -> IO (PhysMatrix coll rows cols a)) dotProduct :: (Shape.C coll, Shape.C dim, MultiValue.PseudoRing a) => Vector p coll dim a -> Vector p coll dim a -> Scalar p coll a dotProduct a b = Sym.fold1 Expr.add $ Sym.zipWith Expr.mul a b outer :: (Shape.C coll, Shape.C rows, Shape.C cols, MultiValue.PseudoRing a) => Vector p coll rows a -> Vector p coll cols a -> Matrix p coll rows cols a outer = ShapeDep.backpermute2 (Expr.modify2 (atom,atom) (atom,atom) $ \(colla,rows) (collb,cols) -> (Shape.intersect colla collb, (rows, cols))) (Expr.mapSnd Expr.fst) (Expr.mapSnd Expr.snd) Expr.mul multiplyMatrixVector :: (Shape.C coll, Shape.C rows, Shape.C cols, MultiValue.PseudoRing a) => Matrix p coll rows cols a -> Vector p coll cols a -> Vector p coll rows a multiplyMatrixVector m v = Sym.fold1 Expr.add $ ShapeDep.backpermute2 (Expr.modify2 (atom, (atom,atom)) (atom,atom) $ \(collM, (rowsM, colsM)) (collV, colsV) -> ((Shape.intersect collM collV, rowsM), Shape.intersect colsM colsV)) balanceRight (Expr.mapFst Expr.fst) Expr.mul m v multiplyMatrixMatrix :: (Shape.C coll, Shape.C rows, Shape.C glue, Shape.C cols, MultiValue.PseudoRing a) => Matrix p coll rows glue a -> Matrix p coll glue cols a -> Matrix p coll rows cols a multiplyMatrixMatrix a b = Sym.fold1 Expr.add $ ShapeDep.backpermute2 (Expr.modify2 (atom, (atom,atom)) (atom, (atom,atom)) $ \(collA, (rows, glueA)) (collB, (glueB, cols)) -> ((Shape.intersect collA collB, (rows, cols)), Shape.intersect glueA glueB)) (Expr.modify ((atom, (atom,atom)), atom) $ \((coll, (rows, _cols)), glue) -> (coll, (rows, glue))) (Expr.modify ((atom, (atom,atom)), atom) $ \((coll, (_rows, cols)), glue) -> (coll, (glue, cols))) Expr.mul a b {- transpose $ ShapeDep.backpermute balanceRight balanceLeft $ multiplyMatrixVector a $ ShapeDep.backpermute balanceLeft balanceRight $ transpose b -} {- For efficient computation of x*a*x we must cache (a*x) or (x*a). Is there an efficient joint multiplication? xa_i_k = sum_j x_i_j * a_j_k xax_i_l = sum_k (sum_j x_i_j * a_j_k) * x_k_l = sum_j sum_k x_i_j * a_j_k * x_k_l -} matrixInverseNewtonStepNaive :: (Shape.C coll, Shape.C rows, Shape.C cols, MultiValue.PseudoRing a) => Matrix p coll rows cols a -> Matrix p coll cols rows a -> Matrix p coll cols rows a matrixInverseNewtonStepNaive a x = Sym.zipWith Expr.sub (Sym.map double x) $ multiplyMatrixMatrix x $ multiplyMatrixMatrix a x matrixInverseNewtonStep :: (Shape.C coll, Shape.C rows, Shape.C cols, MultiValue.PseudoRing a, MultiValueMemory.C a, Storable a, MultiValueMemory.C rows, Storable rows, MultiValueMemory.C cols, Storable cols, MultiValueMemory.C coll, Storable coll, MultiValueMemory.Struct rows ~ rowsstruct, LLVM.IsSized rowsstruct, MultiValueMemory.Struct cols ~ colsstruct, LLVM.IsSized colsstruct, MultiValueMemory.Struct coll ~ collstruct, LLVM.IsSized collstruct) => Matrix p coll rows cols a -> Matrix p coll cols rows a -> IOMatrix p coll cols rows a matrixInverseNewtonStep a x = do ax <- Phys.render $ multiplyMatrixMatrix a x result <- Phys.render $ let xe = SymP.extendParameter fst x in Sym.zipWith Expr.sub (Sym.map double xe) $ multiplyMatrixMatrix xe $ Phys.feed $ arr snd return $ \p -> curry result p =<< ax p nest :: (Integral i, Monad m) => i -> (a -> m a) -> a -> m a nest i f = chain $ List.genericReplicate i f matrixInverseNewton :: (Shape.C coll, Shape.C rows, Shape.C cols, MultiValue.PseudoRing a, MultiValueMemory.C a, Storable a, MultiValueMemory.C rows, Storable rows, MultiValueMemory.C cols, Storable cols, MultiValueMemory.C coll, Storable coll, MultiValueMemory.Struct rows ~ rowsstruct, LLVM.IsSized rowsstruct, MultiValueMemory.Struct cols ~ colsstruct, LLVM.IsSized colsstruct, MultiValueMemory.Struct coll ~ collstruct, LLVM.IsSized collstruct) => Int -> Matrix p coll rows cols a -> Matrix p coll cols rows a -> IOMatrix p coll cols rows a matrixInverseNewton n a x = do physx <- Phys.render x step <- matrixInverseNewtonStep (SymP.extendParameter fst a) (Phys.feed $ arr snd) return $ \p -> nest n (curry step p) =<< physx p double :: (MultiValue.Additive a) => Exp a -> Exp a double = Expr.liftM $ \x -> MultiValue.add x x transpose :: (Shape.C coll, Shape.C rows, Shape.C cols) => Matrix p coll rows cols a -> Matrix p coll cols rows a transpose = ShapeDep.backpermute (Expr.mapSnd Expr.swap) (Expr.mapSnd Expr.swap) scaleRows :: (Shape.C coll, Shape.C rows, Shape.C cols, MultiValue.PseudoRing a) => Vector p coll rows a -> Matrix p coll rows cols a -> Matrix p coll rows cols a scaleRows = ShapeDep.backpermute2 (Expr.modify2 (atom, atom) (atom, (atom,atom)) $ \(collV, rowsV) (collM, (rowsM, colsM)) -> (Shape.intersect collV collM, (Shape.intersect rowsV rowsM, colsM))) (Expr.mapSnd Expr.fst) id Expr.mul balanceLeft :: (Expr.Value val) => val (a,(b,c)) -> val ((a,b),c) balanceLeft = Expr.lift1 $ MultiValue.modify (atom,(atom,atom)) $ \(a,(b,c)) -> ((a,b),c) balanceRight :: (Expr.Value val) => val ((a,b),c) -> val (a,(b,c)) balanceRight = Expr.lift1 $ MultiValue.modify ((atom,atom),atom) $ \((a,b),c) -> (a,(b,c))