-- | Counting partitions of integers.

{-# LANGUAGE CPP, BangPatterns, ScopedTypeVariables #-}
module Math.Combinat.Partitions.Integer.Count where

--------------------------------------------------------------------------------

import Data.List
import Control.Monad ( liftM , replicateM )

-- import Data.Map (Map)
-- import qualified Data.Map as Map

import Math.Combinat.Numbers ( factorial , binomial , multinomial )
import Math.Combinat.Numbers.Integers -- Primes
import Math.Combinat.Helper

import Data.Array
import System.Random

--------------------------------------------------------------------------------
-- * Infinite tables of integers

-- | A data structure which is essentially an infinite list of @Integer@-s,
-- but fast lookup (for reasonable small inputs)
newtype TableOfIntegers = TableOfIntegers [Array Int Integer]

lookupInteger :: TableOfIntegers -> Int -> Integer
lookupInteger :: TableOfIntegers -> Int -> Integer
lookupInteger (TableOfIntegers [Array Int Integer]
table) !Int
n 
  | Int
n forall a. Ord a => a -> a -> Bool
>= Int
0  = ([Array Int Integer]
table forall a. [a] -> Int -> a
!! Int
k) forall i e. Ix i => Array i e -> i -> e
! Int
r
  | Int
n forall a. Ord a => a -> a -> Bool
<  Int
0  = Integer
0
  where
    (Int
k,Int
r) = forall a. Integral a => a -> a -> (a, a)
divMod Int
n Int
1024

makeTableOfIntegers
  :: ((Int -> Integer) -> (Int -> Integer))
  -> TableOfIntegers
makeTableOfIntegers :: ((Int -> Integer) -> Int -> Integer) -> TableOfIntegers
makeTableOfIntegers (Int -> Integer) -> Int -> Integer
user = TableOfIntegers
table where
  calc :: Int -> Integer
calc  = (Int -> Integer) -> Int -> Integer
user Int -> Integer
lkp
  lkp :: Int -> Integer
lkp   = TableOfIntegers -> Int -> Integer
lookupInteger TableOfIntegers
table
  table :: TableOfIntegers
table = [Array Int Integer] -> TableOfIntegers
TableOfIntegers
    [ forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Int
0,Int
1023) (forall a b. (a -> b) -> [a] -> [b]
map Int -> Integer
calc [Int
a..Int
b]) 
    | Int
k<-[Int
0..] 
    , let a :: Int
a = Int
1024forall a. Num a => a -> a -> a
*Int
k 
    , let b :: Int
b = Int
1024forall a. Num a => a -> a -> a
*(Int
kforall a. Num a => a -> a -> a
+Int
1) forall a. Num a => a -> a -> a
- Int
1 
    ]

--------------------------------------------------------------------------------
-- * Counting partitions

-- | Number of partitions of @n@ (looking up a table built using Euler's algorithm)
countPartitions :: Int -> Integer
countPartitions :: Int -> Integer
countPartitions = TableOfIntegers -> Int -> Integer
lookupInteger TableOfIntegers
partitionCountTable 

-- | This uses the power series expansion of the infinite product. It is slower than the above.
countPartitionsInfiniteProduct :: Int -> Integer
countPartitionsInfiniteProduct :: Int -> Integer
countPartitionsInfiniteProduct Int
k = [Integer]
partitionCountListInfiniteProduct forall a. [a] -> Int -> a
!! Int
k

-- | This uses 'countPartitions'', and is (very) slow
countPartitionsNaive :: Int -> Integer
countPartitionsNaive :: Int -> Integer
countPartitionsNaive Int
d = (Int, Int) -> Int -> Integer
countPartitions' (Int
d,Int
d) Int
d

--------------------------------------------------------------------------------

-- | This uses Euler's algorithm to compute p(n)
--
-- See eg.:
-- NEIL CALKIN, JIMENA DAVIS, KEVIN JAMES, ELIZABETH PEREZ, AND CHARLES SWANNACK
-- COMPUTING THE INTEGER PARTITION FUNCTION
-- <http://www.math.clemson.edu/~kevja/PAPERS/ComputingPartitions-MathComp.pdf>
--
partitionCountTable :: TableOfIntegers
partitionCountTable :: TableOfIntegers
partitionCountTable = TableOfIntegers
table where

  table :: TableOfIntegers
table = ((Int -> Integer) -> Int -> Integer) -> TableOfIntegers
makeTableOfIntegers forall {a}. Num a => (Int -> a) -> Int -> a
fun

  fun :: (Int -> a) -> Int -> a
fun Int -> a
lkp !Int
n 
    | Int
n forall a. Ord a => a -> a -> Bool
>  Int
1 = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Num a => a -> a -> a
(+) a
0 
             [ (if forall a. Integral a => a -> Bool
even Int
k then forall a. Num a => a -> a
negate else forall a. a -> a
id) 
                 ( Int -> a
lkp (Int
n forall a. Num a => a -> a -> a
- forall a. Integral a => a -> a -> a
div (Int
kforall a. Num a => a -> a -> a
*(Int
3forall a. Num a => a -> a -> a
*Int
kforall a. Num a => a -> a -> a
+Int
1)) Int
2)
                 forall a. Num a => a -> a -> a
+ Int -> a
lkp (Int
n forall a. Num a => a -> a -> a
- forall a. Integral a => a -> a -> a
div (Int
kforall a. Num a => a -> a -> a
*(Int
3forall a. Num a => a -> a -> a
*Int
kforall a. Num a => a -> a -> a
-Int
1)) Int
2)
                 )
             | Int
k <- [Int
1..Int -> Int
limit Int
n]
             ]
    | Int
n forall a. Ord a => a -> a -> Bool
<  Int
0 = a
0
    | Int
n forall a. Eq a => a -> a -> Bool
== Int
0 = a
1
    | Int
n forall a. Eq a => a -> a -> Bool
== Int
1 = a
1

  limit :: Int -> Int
  limit :: Int -> Int
limit !Int
n = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ Integer -> Integer
ceilingSquareRoot (Integer
1 forall a. Num a => a -> a -> a
+ forall a. Integral a => a -> a -> a
div (Integer
nnforall a. Num a => a -> a -> a
+Integer
nnforall a. Num a => a -> a -> a
+Integer
1) Integer
3) where
    nn :: Integer
nn = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n :: Integer

-- | An infinite list containing all @p(n)@, starting from @p(0)@.
partitionCountList :: [Integer]
partitionCountList :: [Integer]
partitionCountList = forall a b. (a -> b) -> [a] -> [b]
map Int -> Integer
countPartitions [Int
0..]

--------------------------------------------------------------------------------

-- | Infinite list of number of partitions of @0,1,2,...@
--
-- This uses the infinite product formula the generating function of partitions, 
-- recursively expanding it; it is reasonably fast for small numbers.
--
-- > partitionCountListInfiniteProduct == map countPartitions [0..]
--
partitionCountListInfiniteProduct :: [Integer]
partitionCountListInfiniteProduct :: [Integer]
partitionCountListInfiniteProduct = [Integer]
final where

  final :: [Integer]
final = Int -> [Integer] -> [Integer]
go Int
1 (Integer
1forall a. a -> [a] -> [a]
:forall a. a -> [a]
repeat Integer
0) 

  go :: Int -> [Integer] -> [Integer]
go !Int
k (Integer
x:[Integer]
xs) = Integer
x forall a. a -> [a] -> [a]
: Int -> [Integer] -> [Integer]
go (Int
kforall a. Num a => a -> a -> a
+Int
1) [Integer]
ys where
    ys :: [Integer]
ys = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(+) [Integer]
xs (forall a. Int -> [a] -> [a]
take Int
k [Integer]
final forall a. [a] -> [a] -> [a]
++ [Integer]
ys)
    -- explanation:
    --   xs == drop k $ f (k-1)
    --   ys == drop k $ f (k  )  

{-

Full explanation of 'partitionCountList':
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

let f k = productPSeries $ map (:[]) [1..k]

f 0 = [1,0,0,0,0,0,0,0...]
f 1 = [1,1,1,1,1,1,1,1...]
f 2 = [1,1,2,2,3,3,4,4...]
f 3 = [1,1,2,3,4,5,7,8...]

observe: 

* take (k+1) (f k) == take (k+1) partitionCountList
* f (k+1) == zipWith (+) (f k) (replicate (k+1) 0 ++ f (k+1))

now apply (drop (k+1)) to the second one : 

* drop (k+1) (f (k+1)) == zipWith (+) (drop (k+1) $ f k) (f (k+1))
* f (k+1) = take (k+1) final ++ drop (k+1) (f (k+1))

-}

--------------------------------------------------------------------------------

-- | Naive infinite list of number of partitions of @0,1,2,...@
--
-- > partitionCountListNaive == map countPartitionsNaive [0..]
--
-- This is very slow.
--
partitionCountListNaive :: [Integer]
partitionCountListNaive :: [Integer]
partitionCountListNaive = forall a b. (a -> b) -> [a] -> [b]
map Int -> Integer
countPartitionsNaive [Int
0..]

--------------------------------------------------------------------------------
-- * Counting all partitions

countAllPartitions :: Int -> Integer
countAllPartitions :: Int -> Integer
countAllPartitions Int
d = forall a. Num a => [a] -> a
sum' [ Int -> Integer
countPartitions Int
i | Int
i <- [Int
0..Int
d] ]

-- | Count all partitions fitting into a rectangle.
-- # = \\binom { h+w } { h }
countAllPartitions' :: (Int,Int) -> Integer
countAllPartitions' :: (Int, Int) -> Integer
countAllPartitions' (Int
h,Int
w) = 
  forall a. Integral a => a -> a -> Integer
binomial (Int
hforall a. Num a => a -> a -> a
+Int
w) (forall a. Ord a => a -> a -> a
min Int
h Int
w)
  --sum [ countPartitions' (h,w) i | i <- [0..d] ] where d = h*w

--------------------------------------------------------------------------------
-- * Counting fitting into a rectangle

-- | Number of of d, fitting into a given rectangle. Naive recursive algorithm.
countPartitions' :: (Int,Int) -> Int -> Integer
countPartitions' :: (Int, Int) -> Int -> Integer
countPartitions' (Int, Int)
_ Int
0 = Integer
1
countPartitions' (Int
0,Int
_) Int
d = if Int
dforall a. Eq a => a -> a -> Bool
==Int
0 then Integer
1 else Integer
0
countPartitions' (Int
_,Int
0) Int
d = if Int
dforall a. Eq a => a -> a -> Bool
==Int
0 then Integer
1 else Integer
0
countPartitions' (Int
h,Int
w) Int
d = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
  [ (Int, Int) -> Int -> Integer
countPartitions' (Int
i,Int
wforall a. Num a => a -> a -> a
-Int
1) (Int
dforall a. Num a => a -> a -> a
-Int
i) | Int
i <- [Int
1..forall a. Ord a => a -> a -> a
min Int
d Int
h] ] 

--------------------------------------------------------------------------------
-- * Partitions with given number of parts

-- | Count partitions of @n@ into @k@ parts.
--
-- Naive recursive algorithm.
--
countPartitionsWithKParts 
  :: Int    -- ^ @k@ = number of parts
  -> Int    -- ^ @n@ = the integer we partition
  -> Integer
countPartitionsWithKParts :: Int -> Int -> Integer
countPartitionsWithKParts Int
k Int
n = forall {t} {a}. (Ord t, Num t, Num a, Enum t) => t -> t -> t -> a
go Int
n Int
k Int
n where
  go :: t -> t -> t -> a
go !t
h !t
k !t
n 
    | t
k forall a. Ord a => a -> a -> Bool
<  t
0     = a
0
    | t
k forall a. Eq a => a -> a -> Bool
== t
0     = if t
hforall a. Ord a => a -> a -> Bool
>=t
0 Bool -> Bool -> Bool
&& t
nforall a. Eq a => a -> a -> Bool
==t
0 then a
1 else a
0
    | t
k forall a. Eq a => a -> a -> Bool
== t
1     = if t
hforall a. Ord a => a -> a -> Bool
>=t
n Bool -> Bool -> Bool
&& t
nforall a. Ord a => a -> a -> Bool
>=t
1 then a
1 else a
0
    | Bool
otherwise  = forall a. Num a => [a] -> a
sum' [ t -> t -> t -> a
go t
a (t
kforall a. Num a => a -> a -> a
-t
1) (t
nforall a. Num a => a -> a -> a
-t
a) | t
a<-[t
1..(forall a. Ord a => a -> a -> a
min t
h (t
nforall a. Num a => a -> a -> a
-t
kforall a. Num a => a -> a -> a
+t
1))] ]

--------------------------------------------------------------------------------
-- Partitions with only odd\/distinct parts

{-
-- | Partitions of @n@ with only odd parts
partitionsWithOddParts :: Int -> [Partition]
partitionsWithOddParts d = map Partition (go d d) where
  go _  0  = [[]]
  go !h !n = [ a:as | a<-[1,3..min n h], as <- go a (n-a) ]
-}

{-
-- > length (partitionsWithDistinctParts d) == length (partitionsWithOddParts d)
--
partitionsWithDistinctParts :: Int -> [Partition]
partitionsWithDistinctParts d = map Partition (go d d) where
  go _  0  = [[]]
  go !h !n = [ a:as | a<-[1..min n h], as <- go (a-1) (n-a) ]
-}

--------------------------------------------------------------------------------