-- | Helpers for two-dimensional arrays.
module Rainbox.Array2d
  (
  -- * Tables
    Table
  , lCols
  , lRows
  , cells
  , table
  , labelCols
  , labelRows
  , mapTable
  , mapColLabels
  , mapRowLabels

  -- * Two-dimensional arrays
  , cols
  , rows
  , arrayByRows
  , arrayByCols
  ) where

import Data.Array

-- * Tables

-- | A Table is a two-dimensional array with two associated
-- one-dimensional arrays: an array of labels for each column, and
-- an array of labels for each row.
data Table lCol lRow col row a = Table
  { lCols :: Array col lCol
  -- ^ One label for each column
  , lRows :: Array row lRow
  -- ^ One label for each row
  , cells :: Array (col, row) a
  -- ^ Two-dimensional array of cells
  } deriving (Eq, Show)

instance (Ix col, Ix row) => Functor (Table lCol lRow col row) where
  fmap f t =  t { cells = fmap f . cells $ t }

-- | Make a new Table.
table
  :: (Ix col, Ix row)
  => (col -> [(row, a)] -> lCol)
  -- ^ Function to generate the column labels.  It is applied to the
  -- column index and the full contents of the column.
 
  -> (row -> [(col, a)] -> lRow)
  -- ^ Function to generate the row labels.  It is applied to the
  -- row index and the full contents of the row.

  -> Array (col, row) a
  -- ^ Cells of the table

  -> Table lCol lRow col row a
table fCol fRow ay = Table ayc ayr ay
  where
    ayc = labelCols fCol ay
    ayr = labelRows fRow ay

-- | Given a two-dimensional array and a function that generates
-- labels, return an array of column labels.
labelCols
  :: (Ix col, Ix row)
  => (col -> [(row, a)] -> lCol)
  -- ^ Function to generate the column labels.  It is applied to the
  -- column index and the full contents of the column.
  -> Array (col, row) a
  -> Array col lCol
labelCols f a = listArray (minCol, maxCol) es
  where
    ((minCol, minRow), (maxCol, maxRow)) = bounds a
    es = zipWith f ixsCols . map mkRow $ ixsCols
      where
        ixsCols = range (minCol, maxCol)
        mkRow col = zip ixsRows (map (\rw -> a ! (col, rw)) ixsRows)
          where
            ixsRows = range (minRow, maxRow)

-- | Given a two-dimensional array and a function that generates
-- labels, return an array of row labels.
labelRows
  :: (Ix col, Ix row)
  => (row -> [(col, a)] -> lRow)
  -- ^ Function to generate the row labels.  It is applied to the
  -- row index and the full contents of the row.
  -> Array (col, row) a
  -> Array row lRow
labelRows f a = listArray (minRow, maxRow) es
  where
    ((minCol, minRow), (maxCol, maxRow)) = bounds a
    es = zipWith f ixsRows . map mkCol $ ixsRows
      where
        ixsRows = range (minRow, maxRow)
        mkCol row = zip ixsCols (map (\cl -> a ! (cl, row)) ixsCols)
          where
            ixsCols = range (minCol, maxCol)

-- | Transform the cells of the table.  Similar to the Functor
-- instance, but the mapping function has access to the label and
-- index of each cell in the 'Table'.
mapTable
  :: (Ix col, Ix row)
  => (lCol -> lRow -> col -> row -> a -> b)
  -- ^ Function is passed the label for the column, the label for
  -- the row, the column index, the row index, and the contents of
  -- the cell.  It returns a new cell.
  -> Table lCol lRow col row a
  -> Table lCol lRow col row b
mapTable f (Table cs rs ls) = Table cs rs ls'
  where
    ls' = listArray (bounds ls) . map g . assocs $ ls
      where
        g ((col, row), e) = f (cs ! col) (rs ! row) col row e

-- | Transform the column labels.
mapColLabels
  :: (Ix col, Ix row)
  => (lCol -> col -> [(lRow, row, a)] -> lCol')
  -- ^ The function is passed the column label, column index, and
  -- the full contents of the column.
  -> Table lCol lRow col row a
  -> Table lCol' lRow col row a
mapColLabels f (Table cs rs ls) = Table cs' rs ls
  where
    ((colMin, rowMin), (colMax, rowMax)) = bounds ls
    cs' = listArray (colMin, colMax) es
      where
        es = zipWith3 f (elems cs) (indices cs) rws
          where
            rws = map mkRow . indices $ cs
              where
                mkRow idx = zipWith3 (,,) (elems rs)
                  (indices rs)
                  (map (ls !) (range ((idx, rowMin), (idx, rowMax))))

-- | Transform the row labels.
mapRowLabels
  :: (Ix col, Ix row)
  => (lRow -> row -> [(lCol, col, a)] -> lRow')
  -- ^ The function is passed the row label, the row index, and the
  -- full contents of the row.
  -> Table lCol lRow col row a
  -> Table lCol lRow' col row a
mapRowLabels f (Table cs rs ls) = Table cs rs' ls
  where
    ((colMin, rowMin), (colMax, rowMax)) = bounds ls
    rs' = listArray (rowMin, rowMax) es
      where
        es = zipWith3 f (elems rs) (indices rs) cls
          where
            cls = map mkCol . indices $ rs
              where
                mkCol idx = zipWith3 (,,) (elems cs)
                  (indices cs)
                  (map (ls !) (range ((colMin, idx), (colMax, idx))))

-- * Two-dimensional arrays

-- | Given a two-dimensional array, return a list of columns in
-- order.
cols
  :: (Ix col, Ix row)
  => Array (col, row) a
  -> [[a]]
cols ay = map getCol $ range (minCol, maxCol)
  where
    ((minCol, minRow), (maxCol, maxRow)) = bounds ay
    ixsRows = range (minRow, maxRow)
    getCol ixCol = map (\rw -> ay ! (ixCol, rw)) ixsRows

-- | Given a two-dimensional array, return a list of rows in order.
rows
  :: (Ix col, Ix row)
  => Array (col, row) a
  -> [[a]]
rows ay = map getRow $ range (minRow, maxRow)
  where
    ((minCol, minRow), (maxCol, maxRow)) = bounds ay
    ixsCols = range (minCol, maxCol)
    getRow ixRow = map (\cl -> ay ! (cl, ixRow)) ixsCols

-- | Generate a two-dimensional array from a list of rows.  Every
-- row's length will be equal to the length of the first row; any rows
-- after the first row that are shorter than the first row will have
-- extra columns appended to the end.  Therefore, the resulting
-- 'Array' will have no undefined values.
arrayByRows
  :: a
  -- ^ Append this empty value to rows that are too short.
  -> [[a]]
  -- ^ One list per row
  -> Array (Int, Int) a
arrayByRows empty ls
  = array ((0,0), (colMax, rowMax))
  . indexRows
  . padder empty
  $ ls
  where
    rowMax = length ls - 1
    colMax = case ls of
      [] -> -1
      x:_ -> length x - 1

-- | Returns a list where every row is the same length as the first
-- row.  Subsequent rows are padded on the end or have elements
-- removed from the end, as needed.
padder
  :: a
  -- ^ Empty element
  -> [[a]]
  -> [[a]]
padder emp input = case input of
  [] -> []
  x:xs -> x : map adjust xs
    where
      len = length x
      adjust ls = take len $ ls ++ repeat emp

indexRows :: [[a]] -> [((Int, Int),a)]
indexRows = concat . map f . zip [0 ..]
  where
    f (rw, ls) = map g $ zip [0 ..] ls
      where
        g (cl, a) = ((cl, rw), a)

-- | Generate a two-dimensional array from a list of columns.  Every
-- column will be the same height as the first column; subsequent
-- colums will be padded or truncated on the bottom, as needed.
-- Therefore the resulting 'Array' will have no undefined elements.
arrayByCols
  :: a
  -- ^ Append this value to columns that are too short.
  -> [[a]]
  -- ^ One list per column; the head of each list is the top of the
  -- column.
  -> Array (Int, Int) a
arrayByCols empty ls
  = listArray ((0,0), (colMax, rowMax))
  . concat
  . padder empty
  $ ls
  where
    colMax = length ls - 1
    rowMax = case ls of
      [] -> -1
      x:_ -> length x - 1