{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.LinearAlgebra.Private where

import qualified Data.Array.Accelerate.Utility.Loop as Loop
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate
          (Acc, Array, Exp, Any(Any), All(All), Z(Z), (:.)((:.)))



type Scalar ix a = Acc (Array ix a)
type Vector ix a = Acc (Array (ix :. Int) a)
type Matrix ix a = Acc (Array (ix :. Int :. Int) a)

transpose ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Matrix ix a -> Matrix ix a
transpose m =
   A.backpermute
      (A.lift $ swapIndex $ matrixShape m)
      (A.lift . swapIndex . A.unlift)
      m

swapIndex ::
   Exp ix :. Exp Int :. Exp Int ->
   Exp ix :. Exp Int :. Exp Int
swapIndex (ix :. r :. c) = (ix :. c :. r)


numElems :: (A.Shape ix, A.Slice ix, A.Elt a) => Vector ix a -> Exp Int
numElems m = case vectorShape m of _ix :. n -> n

numRows :: (A.Shape ix, A.Slice ix, A.Elt a) => Matrix ix a -> Exp Int
numRows m = case matrixShape m of _ix :. rows :. _cols -> rows

numCols :: (A.Shape ix, A.Slice ix, A.Elt a) => Matrix ix a -> Exp Int
numCols m = case matrixShape m of _ix :. _rows :. cols -> cols

vectorShape ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Vector ix a -> Exp ix :. Exp Int
vectorShape m = A.unlift $ A.shape m

matrixShape ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Matrix ix a -> Exp ix :. Exp Int :. Exp Int
matrixShape m = A.unlift $ A.shape m

withVectorIndex ::
   (A.Shape ix, A.Slice ix, A.Lift Exp a) =>
   (Exp ix :. Exp Int -> a) ->
   (Exp (ix :. Int) -> Exp (A.Plain a))
withVectorIndex f = A.lift . f . A.unlift

withMatrixIndex ::
   (A.Shape ix, A.Slice ix, A.Lift Exp a) =>
   (Exp ix :. Exp Int :. Exp Int -> a) ->
   (Exp (ix :. Int :. Int) -> Exp (A.Plain a))
withMatrixIndex f = A.lift . f . A.unlift


outer ::
   (A.Shape ix, A.Slice ix, A.Num a) =>
   Vector ix a -> Vector ix a -> Matrix ix a
outer x y =
   A.zipWith (*)
      (A.replicate (A.lift $ Any :. All :. numElems y) x)
      (A.replicate (A.lift $ Any :. numElems x :. All) y)

multiplyMatrixVector ::
   (A.Shape ix, A.Slice ix, A.Num a) =>
   Matrix ix a ->
   Vector ix a ->
   Vector ix a
multiplyMatrixVector m v =
   case matrixShape m of
      (_ix :. rows :. _cols) ->
         A.fold1 (+) $
         A.zipWith (*) m
            (A.replicate (A.lift $ Any :. rows :. All) v)

multiplyMatrixMatrix ::
   (A.Shape ix, A.Slice ix, A.Num a) =>
   Matrix ix a ->
   Matrix ix a ->
   Matrix ix a
multiplyMatrixMatrix x y =
   case (matrixShape x, matrixShape y) of
      (_ :. rows :. _cols, _ :. _rows :. cols) ->
         A.fold1 (+) $ transpose $
         A.zipWith (*)
            (A.replicate (A.lift $ Any :. All :. All :. cols) x)
            (A.replicate (A.lift $ Any :. rows :. All :. All) y)

newtonInverseStep ::
   (A.Shape ix, A.Slice ix, A.Num a) =>
   Matrix ix a ->
   Matrix ix a ->
   Matrix ix a
newtonInverseStep a x =
   A.zipWith (-) (A.map (2*) x) $
   multiplyMatrixMatrix x $ multiplyMatrixMatrix a x

identity ::
   (A.Shape ix, A.Slice ix, A.Elt a, A.FromIntegral Int a) =>
   Exp (ix :. Int :. Int) -> Matrix ix a
identity sh =
   A.generate sh
      (withMatrixIndex $
       \(_ :. r :. c) -> A.fromIntegral $ A.boolToInt (r A.== c))

newtonInverse ::
   (A.Shape ix, A.Slice ix, A.Num a) =>
   Exp Int ->
   Matrix ix a ->
   Matrix ix a ->
   Matrix ix a
newtonInverse n seed a =
   Loop.nest n (newtonInverseStep a) seed



scaleRows ::
   (A.Slice ix, A.Shape ix, A.Num a) =>
   Vector ix a -> Matrix ix a -> Matrix ix a
scaleRows s x =
   zipScalarVectorWith (*) s x



zipScalarVectorWith ::
   (A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
   (Exp a -> Exp b -> Exp c) ->
   Scalar ix a -> Vector ix b -> Vector ix c
zipScalarVectorWith f x ys =
   case vectorShape ys of
      _ix :. dim ->
         A.zipWith f (A.replicate (A.lift (Any :. dim)) x) ys

zipScalarMatrixWith ::
   (A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
   (Exp a -> Exp b -> Exp c) ->
   Scalar ix a -> Matrix ix b -> Matrix ix c
zipScalarMatrixWith f x ys =
   case matrixShape ys of
      _ix :. rows :. cols ->
         A.zipWith f
            (A.replicate (A.lift (Any :. rows :. cols)) x) ys



columnFromVector ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Vector ix a -> Matrix ix a
columnFromVector a = A.reshape (Exp.indexCons (A.shape a) 1) a

{- |
input must be a matrix with exactly one column
-}
vectorFromColumn ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Matrix ix a -> Vector ix a
vectorFromColumn a = A.reshape (A.indexTail $ A.shape a) a



flattenMatrix, flattenMatrixReshape, flattenMatrixBackPermute ::
   (A.Slice ix, A.Shape ix, A.Elt a) =>
   Matrix ix a -> Vector ix a
flattenMatrix = flattenMatrixBackPermute

flattenMatrixReshape m =
   case matrixShape m of
      ix :. rows :. cols ->
         A.reshape (A.lift $ ix :. rows*cols) m

accDivMod :: Integral a => a -> a -> (a, a)
accDivMod x y = (div x y, mod x y)

flattenMatrixBackPermute m =
   case matrixShape m of
      ix :. rows :. cols ->
         A.backpermute
            (A.lift $ ix :. rows*cols)
            (withVectorIndex $
             \(vix :. n) -> case accDivMod n cols of (r,c) -> vix :. r :. c)
            m


restoreMatrix, restoreMatrixReshape, restoreMatrixBackPermute ::
   (A.Slice ix, A.Shape ix, A.Elt a) =>
   Exp Int -> Vector ix a -> Matrix ix a
restoreMatrix = restoreMatrixBackPermute

restoreMatrixReshape cols v =
   case vectorShape v of
      ix :. n ->
         A.reshape (A.lift $ ix :. div n cols :. cols) v

restoreMatrixBackPermute cols v =
   case vectorShape v of
      ix :. n ->
         A.backpermute
            (A.lift $ ix :. div n cols :. cols)
            (withMatrixIndex $ \(vix :. k :. j) -> vix :. k*cols+j)
            v



extrudeVector ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Exp ix -> Vector Z a -> Vector ix a
extrudeVector shape y =
   -- A.replicate (A.lift $ shape :. All) y
   A.backpermute
      (A.lift $ shape :. numElems y)
      (A.index1 . A.indexHead)
      y

extrudeMatrix ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Exp ix -> Matrix Z a -> Matrix ix a
extrudeMatrix shape y =
   A.backpermute
      (A.lift $ shape :. numRows y :. numCols y)
      (withMatrixIndex $ \(_:.r:.c) -> Z:.r:.c)
      y

zipExtrudedVectorWith ::
   (A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
   (Exp a -> Exp b -> Exp c) ->
   Vector Z a ->
   Vector ix b ->
   Vector ix c
zipExtrudedVectorWith f x y =
   A.zipWith f (extrudeVector (A.indexTail $ A.shape y) x) y

zipExtrudedMatrixWith ::
   (A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
   (Exp a -> Exp b -> Exp c) ->
   Matrix Z a ->
   Matrix ix b ->
   Matrix ix c
zipExtrudedMatrixWith f x y =
   A.zipWith f (extrudeMatrix (A.indexTail $ A.indexTail $ A.shape y) x) y

gatherFromVector ::
   (A.Shape ix, A.Elt a) =>
   Scalar ix Int -> Vector Z a -> Scalar ix a
gatherFromVector indices =
   Arrange.gather (A.map A.index1 indices)