{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE InstanceSigs           #-}
{-# LANGUAGE MagicHash              #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE UnboxedTuples          #-}
{-# LANGUAGE UndecidableInstances   #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.DataFrame.Contraction
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
-- Maintainer  :  chirkin@arch.ethz.ch
--
-- This modules provides generalization of a matrix product:
--  tensor-like contraction.
-- For matrices and vectors this is a normal matrix*matrix or vector*matrix or matrix*vector product,
-- for larger dimensions it calculates the scalar product of "adjacent" dimesnions of a tensor.
--
-----------------------------------------------------------------------------

module Numeric.DataFrame.Contraction
  ( Contraction (..), (%*)
  ) where

import GHC.Base

import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.Type
import Numeric.Dimensions



class ConcatList as bs asbs
      => Contraction (t :: Type) (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat])
                             | asbs as -> bs, asbs bs -> as, as bs -> asbs where
    -- | Generalization of a matrix product: take scalar product over one dimension
    --   and, thus, concatenate other dimesnions
    contract :: ( KnownDim m
                , PrimArray t (DataFrame t (as +: m))
                , PrimArray t (DataFrame t (m :+ bs))
                , PrimArray t (DataFrame t asbs)
                )
             => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs

-- | Tensor contraction.
--   In particular:
--     1. matrix-matrix product
--     2. matrix-vector or vector-matrix product
--     3. dot product of two vectors.
(%*) :: ( Contraction t as bs asbs
        , KnownDim m
        , PrimArray t (DataFrame t (as +: m))
        , PrimArray t (DataFrame t (m :+ bs))
        , PrimArray t (DataFrame t asbs)
        )  => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
(%*) = contract
{-# INLINE (%*) #-}
infixl 7 %*



instance ( ConcatList as bs asbs
         , Dimensions as
         , Dimensions bs
         , Num t
         ) => Contraction t as bs asbs where

    contract :: forall m .
                ( KnownDim m
                , PrimArray t (DataFrame t (as +: m))
                , PrimArray t (DataFrame t (m :+ bs))
                , PrimArray t (DataFrame t asbs)
                )
             => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
    contract x y = case (# uniqueOrCumulDims x, uniqueOrCumulDims y #) of
      (# Left x0, Left y0 #) -> broadcast (x0 * y0)
      (# ux, uy #)
        | dm <- dim @m
        , (ixX, xs) <- getStepsAndIx (Snoc dims dm) x ux
        , (ixY, ys) <- getStepsAndIx (Cons dm dims) y uy
        , (# n, m, k, steps #) <- conSteps xs ys ->
          let loop i j l r | isTrue# (l ==# m) = r
                           | otherwise = loop i j (l +# 1#)
                              (r + ixX (i *# m +# l) * ixY (l *# k +# j))

              loop2 (T# i j) | isTrue# (i ==# n) = (# T# i j, 0 #)
                             | isTrue# (j ==# k) = loop2 (T# (i +# 1#) 0#)
                             | otherwise = (# T# i (j +# 1#), loop i j 0# 0 #)
          in case gen# steps loop2 (T# 0# 0#) of
              (# _, r #) -> r
      where
        getStepsAndIx :: forall (ns :: [Nat])
                       . PrimArray t (DataFrame t ns)
                      => Dims ns
                      -> DataFrame t ns
                      -> Either t CumulDims
                      -> (Int# -> t, CumulDims)
        getStepsAndIx _  df (Right cds) = ((`ix#` df), cds)
        getStepsAndIx ds _  (Left  e)   = (\_ -> e, cumulDims ds)
        conSteps (CumulDims xs) (CumulDims ys) = case conSteps' xs ys of
          (W# n, W# m, W# k, zs)
            -> (# word2Int# n, word2Int# m, word2Int# k, CumulDims zs #)
        conSteps' :: [Word] -> [Word] -> (Word, Word, Word, [Word])
        conSteps' [m, _] (_:ys@(k:_)) = (1, m, k, ys)
        conSteps' (nm:ns) cys
          | (_, m, k, ys) <- conSteps' ns cys
          , n <- nm `quot` m
            = (n, m, k, n*k : ys )
        conSteps' _ _ = error "Numeric.DataFrame.Contraction: impossible match"

data T# = T# Int# Int#