module Numeric.LAPACK.Format where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Internal as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)

import Foreign.Storable (Storable)

import Text.Printf (PrintfArg, printf)

import qualified Data.List.HT as ListHT
import Data.Complex (Complex((:+)))


infix 0 ##

(##) :: (Format a) => a -> String -> IO ()
a ## fmt = putStr $ unlines $ format fmt a


class Format a where
   format :: String -> a -> [String]

instance Format Int where
   format _fmt a = [show a]

instance Format Float where
   format fmt a = [printf fmt a]

instance Format Double where
   format fmt a = [printf fmt a]

instance (PrintfArg a) => Format (Complex a) where
   format fmt a = [printfComplex fmt a]

instance (Format a, Format b) => Format (a,b) where
   format fmt (a,b) = format fmt a ++ [""] ++ format fmt b

instance (Format a, Format b, Format c) => Format (a,b,c) where
   format fmt (a,b,c) =
      format fmt a ++ [""] ++ format fmt b ++ [""] ++ format fmt c

instance
   (FormatArray sh, Class.Floating a, Storable a) =>
      Format (Array sh a) where
   format = formatArray


class (Shape.C sh) => FormatArray sh where
   formatArray ::
      (Storable a, Class.Floating a) => String -> Array sh a -> [String]

instance (Integral i) => FormatArray (Shape.ZeroBased i) where
   formatArray fmt m = [unwords $ map (printfFloating fmt) $ Array.toList m]

instance (Integral i) => FormatArray (Shape.OneBased i) where
   formatArray fmt m = [unwords $ map (printfFloating fmt) $ Array.toList m]

instance
   (Shape.C height, Shape.C width) =>
      FormatArray (MatrixShape.General height width) where
   formatArray = formatGeneral

formatGeneral ::
   (Shape.C height, Shape.C width, Storable a, Class.Floating a) =>
   String -> Array (MatrixShape.General height width) a -> [String]
formatGeneral fmt m =
   let MatrixShape.General order height width = Array.shape m
       xss = formatRows fmt order (height,width) $ Array.toList m
       strWidths = columnWidths xss
   in  map (unwords . zipWith (ListHT.padLeft ' ') strWidths) xss

instance
   (Shape.C height, Shape.C width) =>
      FormatArray (MatrixShape.Householder height width) where
   formatArray = formatHouseholder

formatHouseholder ::
   (Shape.C height, Shape.C width, Storable a, Class.Floating a) =>
   String -> Array (MatrixShape.Householder height width) a -> [String]
formatHouseholder fmt m =
   let MatrixShape.Householder order height width = Array.shape m
       xss = formatRows fmt order (height,width) $ Array.toList m
       strWidths = columnWidths xss
   in  zipWith
         (\row xs ->
            concat $
            zipWith (\col cell -> (if row==col then '|' else ' '):cell) [0..] $
            zipWith (ListHT.padLeft ' ') strWidths xs)
         [(0::Int)..] xss

formatRows ::
   (Class.Floating a, Shape.C height, Shape.C width) =>
   String -> Order -> (height, width) -> [a] -> [[String]]
formatRows fmt order (height,width) =
   (case order of
      RowMajor -> ListHT.sliceVertical (Shape.size width)
      ColumnMajor -> ListHT.sliceHorizontal (Shape.size height)) .
   map (printfFloating fmt)

columnWidths :: [[[a]]] -> [Int]
columnWidths xss =
   case map (map length) xss of
      [] -> []
      w:ws -> foldl (zipWith max) w ws


newtype Printf a = Printf {runPrintf :: String -> a -> String}

printfFloating :: (Class.Floating a) => String -> a -> String
printfFloating =
   runPrintf $
   Class.switchFloating
      (Printf printf)
      (Printf printf)
      (Printf printfComplex)
      (Printf printfComplex)

printfComplex :: (PrintfArg a) => String -> Complex a -> String
printfComplex fmt (r:+i) = printf (fmt ++ "+i" ++ fmt) r i