-- | Vector partitions. See:
--
--  * Donald E. Knuth: The Art of Computer Programming, vol 4, pre-fascicle 3B.
--

{-# LANGUAGE BangPatterns #-}
module Math.Combinat.Partitions.Vector where

--------------------------------------------------------------------------------

import Data.Array.Unboxed
import Data.List

--------------------------------------------------------------------------------

-- | Integer vectors. The indexing starts from 1.
type IntVector = UArray Int Int

-- | Vector partitions. Basically a synonym for 'fasc3B_algorithm_M'.
vectorPartitions :: IntVector -> [[IntVector]]
vectorPartitions :: IntVector -> [[IntVector]]
vectorPartitions = [Int] -> [[IntVector]]
fasc3B_algorithm_M ([Int] -> [[IntVector]])
-> (IntVector -> [Int]) -> IntVector -> [[IntVector]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntVector -> [Int]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems

_vectorPartitions :: [Int] -> [[[Int]]]
_vectorPartitions :: [Int] -> [[[Int]]]
_vectorPartitions = ([IntVector] -> [[Int]]) -> [[IntVector]] -> [[[Int]]]
forall a b. (a -> b) -> [a] -> [b]
map ((IntVector -> [Int]) -> [IntVector] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map IntVector -> [Int]
forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems) ([[IntVector]] -> [[[Int]]])
-> ([Int] -> [[IntVector]]) -> [Int] -> [[[Int]]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [[IntVector]]
fasc3B_algorithm_M

-- | Generates all vector partitions 
--   (\"algorithm M\" in Knuth). 
--   The order is decreasing lexicographic.  
fasc3B_algorithm_M :: [Int] -> [[IntVector]] 
{- note to self: Knuth's descriptions of algorithms are still totally unreadable -}
fasc3B_algorithm_M :: [Int] -> [[IntVector]]
fasc3B_algorithm_M [Int]
xs = [[(Int, Int, Int)]] -> [[IntVector]]
forall a (a :: * -> * -> *).
(Ord a, IArray a a, Num a) =>
[[(Int, a, a)]] -> [[a Int a]]
worker [[(Int, Int, Int)]
start] where

  -- n = sum xs
  m :: Int
m = [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
xs

  start :: [(Int, Int, Int)]
start = [ (Int
j,Int
x,Int
x) | (Int
j,Int
x) <- [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] [Int]
xs ]  
  
  worker :: [[(Int, a, a)]] -> [[a Int a]]
worker stack :: [[(Int, a, a)]]
stack@([(Int, a, a)]
last:[[(Int, a, a)]]
_) = 
    case [[(Int, a, a)]] -> Maybe [[(Int, a, a)]]
forall a a. (Eq a, Num a) => [[(a, a, a)]] -> Maybe [[(a, a, a)]]
decrease [[(Int, a, a)]]
stack' of
      Maybe [[(Int, a, a)]]
Nothing -> [[a Int a]
visited]
      Just [[(Int, a, a)]]
stack'' -> [a Int a]
visited [a Int a] -> [[a Int a]] -> [[a Int a]]
forall a. a -> [a] -> [a]
: [[(Int, a, a)]] -> [[a Int a]]
worker [[(Int, a, a)]]
stack''
    where
      stack' :: [[(Int, a, a)]]
stack'  = [[(Int, a, a)]] -> [[(Int, a, a)]]
forall c a. (Ord c, Num c) => [[(a, c, c)]] -> [[(a, c, c)]]
subtract_rec [[(Int, a, a)]]
stack
      visited :: [a Int a]
visited = ([(Int, a, a)] -> a Int a) -> [[(Int, a, a)]] -> [a Int a]
forall a b. (a -> b) -> [a] -> [b]
map [(Int, a, a)] -> a Int a
forall (a :: * -> * -> *) b b.
(IArray a b, Num b) =>
[(Int, b, b)] -> a Int b
to_vector [[(Int, a, a)]]
stack'
      
  decrease :: [[(a, a, a)]] -> Maybe [[(a, a, a)]]
decrease ([(a, a, a)]
last:[[(a, a, a)]]
rest) = 
    case ((a, a, a) -> Bool) -> [(a, a, a)] -> ([(a, a, a)], [(a, a, a)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (\(a
_,a
_,a
v) -> a
va -> a -> Bool
forall a. Eq a => a -> a -> Bool
==a
0) ([(a, a, a)] -> [(a, a, a)]
forall a. [a] -> [a]
reverse [(a, a, a)]
last) of
      ( [(a, a, a)]
_ , [(a
_,a
_,a
1)] ) -> case [[(a, a, a)]]
rest of
        [] -> Maybe [[(a, a, a)]]
forall a. Maybe a
Nothing
        [[(a, a, a)]]
_  -> [[(a, a, a)]] -> Maybe [[(a, a, a)]]
decrease [[(a, a, a)]]
rest
      ( [(a, a, a)]
second , (a
c,a
u,a
v):[(a, a, a)]
first ) -> [[(a, a, a)]] -> Maybe [[(a, a, a)]]
forall a. a -> Maybe a
Just ([(a, a, a)]
modified[(a, a, a)] -> [[(a, a, a)]] -> [[(a, a, a)]]
forall a. a -> [a] -> [a]
:[[(a, a, a)]]
rest) where 
        modified :: [(a, a, a)]
modified =   
          [(a, a, a)] -> [(a, a, a)]
forall a. [a] -> [a]
reverse [(a, a, a)]
first [(a, a, a)] -> [(a, a, a)] -> [(a, a, a)]
forall a. [a] -> [a] -> [a]
++ 
          (a
c,a
u,a
va -> a -> a
forall a. Num a => a -> a -> a
-a
1) (a, a, a) -> [(a, a, a)] -> [(a, a, a)]
forall a. a -> [a] -> [a]
:  
          [ (a
c,a
u,a
u) | (a
c,a
u,a
_) <- [(a, a, a)] -> [(a, a, a)]
forall a. [a] -> [a]
reverse [(a, a, a)]
second ] 
      ([(a, a, a)], [(a, a, a)])
_ -> [Char] -> Maybe [[(a, a, a)]]
forall a. HasCallStack => [Char] -> a
error [Char]
"fasc3B_algorithm_M: should not happen"
        
  to_vector :: [(Int, b, b)] -> a Int b
to_vector [(Int, b, b)]
cuvs = 
    (b -> b -> b) -> b -> (Int, Int) -> [(Int, b)] -> a Int b
forall (a :: * -> * -> *) e i e'.
(IArray a e, Ix i) =>
(e -> e' -> e) -> e -> (i, i) -> [(i, e')] -> a i e
accumArray ((b -> b -> b) -> b -> b -> b
forall a b c. (a -> b -> c) -> b -> a -> c
flip b -> b -> b
forall a b. a -> b -> a
const) b
0 (Int
1,Int
m)
      [ (Int
c,b
v) | (Int
c,b
_,b
v) <- [(Int, b, b)]
cuvs ] 

  subtract_rec :: [[(a, c, c)]] -> [[(a, c, c)]]
subtract_rec all :: [[(a, c, c)]]
all@([(a, c, c)]
last:[[(a, c, c)]]
_) = 
    case [(a, c, c)] -> [(a, c, c)]
forall c a. (Ord c, Num c) => [(a, c, c)] -> [(a, c, c)]
subtract [(a, c, c)]
last of 
      []  -> [[(a, c, c)]]
all
      [(a, c, c)]
new -> [[(a, c, c)]] -> [[(a, c, c)]]
subtract_rec ([(a, c, c)]
new[(a, c, c)] -> [[(a, c, c)]] -> [[(a, c, c)]]
forall a. a -> [a] -> [a]
:[[(a, c, c)]]
all) 

  subtract :: [(a, c, c)] -> [(a, c, c)]
subtract [] = []
  subtract full :: [(a, c, c)]
full@((a
c,c
u,c
v):[(a, c, c)]
rest) = 
    if c
w c -> c -> Bool
forall a. Ord a => a -> a -> Bool
>= c
v 
      then (a
c,c
w,c
v) (a, c, c) -> [(a, c, c)] -> [(a, c, c)]
forall a. a -> [a] -> [a]
: [(a, c, c)] -> [(a, c, c)]
subtract   [(a, c, c)]
rest
      else           [(a, c, c)] -> [(a, c, c)]
forall c a. (Eq c, Num c) => [(a, c, c)] -> [(a, c, c)]
subtract_b [(a, c, c)]
full
    where w :: c
w = c
u c -> c -> c
forall a. Num a => a -> a -> a
- c
v
    
  subtract_b :: [(a, c, c)] -> [(a, c, c)]
subtract_b [] = []
  subtract_b ((a
c,c
u,c
v):[(a, c, c)]
rest) = 
    if c
w c -> c -> Bool
forall a. Eq a => a -> a -> Bool
/= c
0 
      then (a
c,c
w,c
w) (a, c, c) -> [(a, c, c)] -> [(a, c, c)]
forall a. a -> [a] -> [a]
: [(a, c, c)] -> [(a, c, c)]
subtract_b [(a, c, c)]
rest
      else           [(a, c, c)] -> [(a, c, c)]
subtract_b [(a, c, c)]
rest
    where w :: c
w = c
u c -> c -> c
forall a. Num a => a -> a -> a
- c
v

--------------------------------------------------------------------------------