module Numeric.LAPACK.Matrix.Lazy.UpperTriangular where

import qualified Numeric.LAPACK.Matrix.Array.Mosaic as Mosaic
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Quadratic as Quad
import Numeric.LAPACK.Matrix.Array.Indexed ((#!))
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as StArray
import qualified Data.Array.Comfort.Boxed.Unchecked as Array
import qualified Data.Array.Comfort.Boxed as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Boxed ((!))

import Prelude hiding (sqrt)



type Upper sh = Array.Array (Shape.Triangular Shape.Upper (Shape.Deferred sh))
type Vector sh = Array.Array (Shape.Deferred sh)


sample :: Shape.Indexed sh => sh -> (Shape.Index sh -> b) -> Array.Array sh b
sample shape f = fmap f $ CheckedArray.indices shape


fromStorable ::
   (Shape.C sh, Class.Floating a) =>
   Mosaic.FlexUpperP pack diag sh a -> Upper sh a
fromStorable a0 =
   let a = Quad.mapSize Shape.Deferred a0
   in sample (Shape.Triangular Shape.Upper (Quad.size a)) (a#!)

toStorable ::
   (Shape.C sh, Class.Floating a) => Upper sh a -> Mosaic.Upper sh a
toStorable =
   Quad.mapSize (\(Shape.Deferred sh) -> sh) .
   Triangular.fromUpperRowMajor . StArray.fromBoxed


scaleColumns :: (Shape.C sh, Num a) => Vector sh a -> Upper sh a -> Upper sh a
scaleColumns d a = sample (Array.shape a) $ \(i,j) -> a!(i,j) * d!j

scaleRows :: (Shape.C sh, Num a) => Vector sh a -> Upper sh a -> Upper sh a
scaleRows d a = sample (Array.shape a) $ \(i,j) -> a!(i,j) * d!i


{- |
It is an unchecked error if the shapes mismatch.
-}
multiply :: (Shape.C sh, Num a) => Upper sh a -> Upper sh a -> Upper sh a
multiply a b =
   sample (Array.shape a) $
      \(i@(Shape.DeferredIndex di), j@(Shape.DeferredIndex dj)) ->
         sum $ map (\k -> a!(i,k)*b!(k,j)) $ map Shape.DeferredIndex [di..dj]

{- |
@multiplyStrictPart a b@
is almost @multiply (clearDiagonal a) (clearDiagonal b)@,
but it is more lazy
since it does not access the elements that are multiplied with zero.
-}
multiplyStrictPart ::
   (Shape.C sh, Num a) => Upper sh a -> Upper sh a -> Upper sh a
multiplyStrictPart a b =
   sample (Array.shape a) $
      \(i@(Shape.DeferredIndex di), j@(Shape.DeferredIndex dj)) ->
         sum $ map (\k -> a!(i,k)*b!(k,j)) $
            map Shape.DeferredIndex [di+1..dj-1]

takeDiagonal :: (Shape.C sh, Num a) => Upper sh a -> Vector sh a
takeDiagonal a =
   sample (Shape.triangularSize $ Array.shape a) $ \i -> a!(i,i)

replaceDiagonal ::
   (Shape.C sh, Num a) => Vector sh a -> Upper sh a -> Upper sh a
replaceDiagonal d a =
   sample (Array.shape a) $ \(i,j) -> if i==j then d!i else a!(i,j)

rank2Part :: (Shape.C sh, Num a) => Vector sh a -> Upper sh a
rank2Part d =
   sample (Shape.Triangular Shape.Upper $ Array.shape d) $ \(i,j) -> d!i + d!j

rank2DiffPart :: (Shape.C sh, Num a) => Vector sh a -> Upper sh a
rank2DiffPart d =
   sample (Shape.Triangular Shape.Upper $ Array.shape d) $ \(i,j) -> d!i - d!j


{- |
Lazy implicit solver
-}
{-
A = (D+U)*(D+U) = D*D + D*U + U*D + U*U

let U = divide (A-D*D - U*U) (toColumn(D)*1^T+1*toRow(D))
-}
sqrt :: (Shape.C sh, Fractional a) => (a -> a) -> Upper sh a -> Upper sh a
sqrt sqrtF = applyUnchecked $ \a ->
   let d = fmap sqrtF $ takeDiagonal a
       u =
         Array.reshape (Array.shape a) $
         Array.zipWith (/)
            (Array.zipWith (-) a (multiplyStrictPart u u))
            (rank2Part d)
   in replaceDiagonal d u


{- |
Parlett recursion for lifting a scalar function to an upper triangular matrix.

Given A and the diagonal of f(A) it solves A*f(A) = f(A)*A.

Requires distinct values on the diagonal,
where even almost close values can produce dramatic errors.
But it admits for a nice lazy implicit implementation.
-}
{-
/g h i\   /a b c\   /  x y\   /  x y\   /a b c\
|  j k| = |  d e| * |    z| - |    z| * |  d e|
\    l/   \    f/   \     /   \     /   \    f/

h = a*x - x*d = (a-d)*x
k = d*z - z*f = (d-f)*z
i = a*y+b*z - (x*e+y*f) = b*z-x*e + (a-f)*y


0 = A*f(A) - f(A)*A

A = E+V
f(A) = D+U

0 = A*(D+U) - (D+U)*A
D*A-A*D = (E+V)*U - U*(E+V)
D*A-A*D + U*V-V*U = E*U-U*E

U = divide (D*A-A*D + U*V-V*U) (toColumn(E)*1^T - 1*toRow(E))
-}
parlett :: (Shape.C sh, Fractional a) => (a -> a) -> Upper sh a -> Upper sh a
parlett f = applyUnchecked $ \a ->
   let e = takeDiagonal a
       d = fmap f e
       u =
         Array.reshape (Array.shape a) $
         Array.zipWith (/)
            (Array.zipWith (-)
               (Array.zipWith (+) (scaleRows    d a) (multiplyStrictPart u a))
               (Array.zipWith (+) (scaleColumns d a) (multiplyStrictPart a u)))
            (rank2DiffPart e)
   in replaceDiagonal d u


applyUnchecked ::
   (Shape.C sh) =>
   (Upper (Unchecked sh) a -> Upper (Unchecked sh) a) ->
   Upper sh a -> Upper sh a
applyUnchecked f =
   Array.mapShape
      (\(Shape.Triangular uplo (Shape.Deferred (Unchecked sh))) ->
         Shape.Triangular uplo (Shape.Deferred sh)) .
   f .
   Array.mapShape
      (\(Shape.Triangular uplo (Shape.Deferred sh)) ->
         Shape.Triangular uplo (Shape.Deferred (Unchecked sh)))