-- | Counting partitions of integers.

{-# LANGUAGE CPP, BangPatterns, ScopedTypeVariables #-}
module Math.Combinat.Partitions.Integer.Count where

--------------------------------------------------------------------------------

import Data.List
import Control.Monad ( liftM , replicateM )

-- import Data.Map (Map)
-- import qualified Data.Map as Map

import Math.Combinat.Numbers ( factorial , binomial , multinomial )
import Math.Combinat.Numbers.Integers -- Primes
import Math.Combinat.Helper

import Data.Array
import System.Random

--------------------------------------------------------------------------------
-- * Infinite tables of integers

-- | A data structure which is essentially an infinite list of @Integer@-s,
-- but fast lookup (for reasonable small inputs)
newtype TableOfIntegers = TableOfIntegers [Array Int Integer]

lookupInteger :: TableOfIntegers -> Int -> Integer
lookupInteger (TableOfIntegers table) !n
  | n >= 0  = (table !! k) ! r
  | n <  0  = 0
  where
    (k,r) = divMod n 1024

makeTableOfIntegers
  :: ((Int -> Integer) -> (Int -> Integer))
  -> TableOfIntegers
makeTableOfIntegers user = table where
  calc  = user lkp
  lkp   = lookupInteger table
  table = TableOfIntegers
    [ listArray (0,1023) (map calc [a..b])
    | k<-[0..]
    , let a = 1024*k
    , let b = 1024*(k+1) - 1
    ]

--------------------------------------------------------------------------------
-- * Counting partitions

-- | Number of partitions of @n@ (looking up a table built using Euler's algorithm)
countPartitions :: Int -> Integer
countPartitions = lookupInteger partitionCountTable

-- | This uses the power series expansion of the infinite product. It is slower than the above.
countPartitionsInfiniteProduct :: Int -> Integer
countPartitionsInfiniteProduct k = partitionCountListInfiniteProduct !! k

-- | This uses 'countPartitions'', and is (very) slow
countPartitionsNaive :: Int -> Integer
countPartitionsNaive d = countPartitions' (d,d) d

--------------------------------------------------------------------------------

-- | This uses Euler's algorithm to compute p(n)
--
-- See eg.:
-- NEIL CALKIN, JIMENA DAVIS, KEVIN JAMES, ELIZABETH PEREZ, AND CHARLES SWANNACK
-- COMPUTING THE INTEGER PARTITION FUNCTION
-- <http://www.math.clemson.edu/~kevja/PAPERS/ComputingPartitions-MathComp.pdf>
--
partitionCountTable :: TableOfIntegers
partitionCountTable = table where

  table = makeTableOfIntegers fun

  fun lkp !n
    | n >  1 = foldl' (+) 0
             [ (if even k then negate else id)
                 ( lkp (n - div (k*(3*k+1)) 2)
                 + lkp (n - div (k*(3*k-1)) 2)
                 )
             | k <- [1..limit n]
             ]
    | n <  0 = 0
    | n == 0 = 1
    | n == 1 = 1

  limit :: Int -> Int
  limit !n = fromInteger $ ceilingSquareRoot (1 + div (nn+nn+1) 3) where
    nn = fromIntegral n :: Integer

-- | An infinite list containing all @p(n)@, starting from @p(0)@.
partitionCountList :: [Integer]
partitionCountList = map countPartitions [0..]

--------------------------------------------------------------------------------

-- | Infinite list of number of partitions of @0,1,2,...@
--
-- This uses the infinite product formula the generating function of partitions, 
-- recursively expanding it; it is reasonably fast for small numbers.
--
-- > partitionCountListInfiniteProduct == map countPartitions [0..]
--
partitionCountListInfiniteProduct :: [Integer]
partitionCountListInfiniteProduct = final where

  final = go 1 (1:repeat 0)

  go !k (x:xs) = x : go (k+1) ys where
    ys = zipWith (+) xs (take k final ++ ys)
    -- explanation:
    --   xs == drop k $ f (k-1)
    --   ys == drop k $ f (k  )  

{-

Full explanation of 'partitionCountList':
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

let f k = productPSeries $ map (:[]) [1..k]

f 0 = [1,0,0,0,0,0,0,0...]
f 1 = [1,1,1,1,1,1,1,1...]
f 2 = [1,1,2,2,3,3,4,4...]
f 3 = [1,1,2,3,4,5,7,8...]

observe: 

* take (k+1) (f k) == take (k+1) partitionCountList
* f (k+1) == zipWith (+) (f k) (replicate (k+1) 0 ++ f (k+1))

now apply (drop (k+1)) to the second one : 

* drop (k+1) (f (k+1)) == zipWith (+) (drop (k+1) $ f k) (f (k+1))
* f (k+1) = take (k+1) final ++ drop (k+1) (f (k+1))

-}

--------------------------------------------------------------------------------

-- | Naive infinite list of number of partitions of @0,1,2,...@
--
-- > partitionCountListNaive == map countPartitionsNaive [0..]
--
-- This is very slow.
--
partitionCountListNaive :: [Integer]
partitionCountListNaive = map countPartitionsNaive [0..]

--------------------------------------------------------------------------------
-- * Counting all partitions

countAllPartitions :: Int -> Integer
countAllPartitions d = sum' [ countPartitions i | i <- [0..d] ]

-- | Count all partitions fitting into a rectangle.
-- # = \\binom { h+w } { h }
countAllPartitions' :: (Int,Int) -> Integer
countAllPartitions' (h,w) =
  binomial (h+w) (min h w)
  --sum [ countPartitions' (h,w) i | i <- [0..d] ] where d = h*w

--------------------------------------------------------------------------------
-- * Counting fitting into a rectangle

-- | Number of of d, fitting into a given rectangle. Naive recursive algorithm.
countPartitions' :: (Int,Int) -> Int -> Integer
countPartitions' _ 0 = 1
countPartitions' (0,_) d = if d==0 then 1 else 0
countPartitions' (_,0) d = if d==0 then 1 else 0
countPartitions' (h,w) d = sum
  [ countPartitions' (i,w-1) (d-i) | i <- [1..min d h] ]

--------------------------------------------------------------------------------
-- * Partitions with given number of parts

-- | Count partitions of @n@ into @k@ parts.
--
-- Naive recursive algorithm.
--
countPartitionsWithKParts
  :: Int    -- ^ @k@ = number of parts
  -> Int    -- ^ @n@ = the integer we partition
  -> Integer
countPartitionsWithKParts k n = go n k n where
  go !h !k !n
    | k <  0     = 0
    | k == 0     = if h>=0 && n==0 then 1 else 0
    | k == 1     = if h>=n && n>=1 then 1 else 0
    | otherwise  = sum' [ go a (k-1) (n-a) | a<-[1..(min h (n-k+1))] ]

--------------------------------------------------------------------------------
-- Partitions with only odd\/distinct parts

{-
-- | Partitions of @n@ with only odd parts
partitionsWithOddParts :: Int -> [Partition]
partitionsWithOddParts d = map Partition (go d d) where
  go _  0  = [[]]
  go !h !n = [ a:as | a<-[1,3..min n h], as <- go a (n-a) ]
-}

{-
-- > length (partitionsWithDistinctParts d) == length (partitionsWithOddParts d)
--
partitionsWithDistinctParts :: Int -> [Partition]
partitionsWithDistinctParts d = map Partition (go d d) where
  go _  0  = [[]]
  go !h !n = [ a:as | a<-[1..min n h], as <- go (a-1) (n-a) ]
-}

--------------------------------------------------------------------------------