{- |
Count and create combinatorial objects.
Also see 'combinat' package.
-}
module Combinatorics (
   permute,
   permuteFast,
   permuteShare,
   permuteMSL,
   runPermuteRep,
   permuteRep,
   permuteRepM,
   choose,
   chooseMSL,
   variateRep,
   variateRepMSL,
   variate,
   variateMSL,
   tuples,
   tuplesMSL,
   tuplesRec,
   partitions,
   rectifications,
   setPartitions,
   chooseFromIndex,
   chooseFromIndexList,
   chooseFromIndexMaybe,
   chooseToIndex,
   factorial,
   binomial,
   binomialSeq,
   binomialGen,
   binomialSeqGen,
   multinomial,
   factorials,
   binomials,
   catalanNumber,
   catalanNumbers,
   derangementNumber,
   derangementNumbers,
   derangementNumbersAlt,
   derangementNumbersInclExcl,
   setPartitionNumbers,
   surjectiveMappingNumber,
   surjectiveMappingNumbers,
   surjectiveMappingNumbersStirling,
   fibonacciNumber,
   fibonacciNumbers,
   ) where

import qualified PowerSeries
import Combinatorics.Utility (scalarProduct, )

import Data.Function.HT (nest, )
import Data.Maybe.HT (toMaybe, )
import Data.Maybe (mapMaybe, catMaybes, )
import Data.Tuple.HT (mapFst, )
import qualified Data.List.Match as Match
import Data.List.HT (tails, partition, mapAdjacent, removeEach, splitEverywhere, viewL, )
import Data.List (mapAccumL, intersperse, genericIndex, genericReplicate, genericTake, )

import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.State as MS
import Control.Monad (liftM, liftM2, replicateM, forM, guard, )


{-* Generate compositions from a list of elements. -}

-- several functions for permutation
-- cf. Equation.hs

{- |
Generate list of all permutations of the input list.
The list is sorted lexicographically.
-}
permute :: [a] -> [[a]]
permute [] = [[]]
permute x =
   concatMap (\(y, ys) -> map (y:) (permute ys))
             (removeEach x)

{- |
Generate list of all permutations of the input list.
It is not lexicographically sorted.
It is slightly faster and consumes less memory
than the lexicographical ordering 'permute'.
-}
permuteFast :: [a] -> [[a]]
permuteFast x = permuteFastStep [] x []

{- |
Each element of (allcycles x) has a different element at the front.
Iterate cycling on the tail elements of each element list of (allcycles x).
-}
permuteFastStep :: [a] -> [a] -> [[a]] -> [[a]]
permuteFastStep suffix [] tl = suffix:tl
permuteFastStep suffix x  tl =
   foldr (\c -> permuteFastStep (head c : suffix) (tail c)) tl (allCycles x)

{- |
All permutations share as much suffixes as possible.
The reversed permutations are sorted lexicographically.
-}
permuteShare :: [a] -> [[a]]
permuteShare x =
   map fst $
--   map (\(y,[]) -> y) $  -- safer but inefficient
   nest (length x) (concatMap permuteShareStep) [([], x)]

permuteShareStep :: ([a], [a]) -> [([a], [a])]
permuteShareStep (perm,todo) =
   map
      (mapFst (:perm))
      (removeEach todo)

permuteMSL :: [a] -> [[a]]
permuteMSL xs =
   flip MS.evalStateT xs $ replicateM (length xs) $
   MS.StateT removeEach




runPermuteRep :: ([(a,Int)] -> [[a]]) -> [(a,Int)] -> [[a]]
runPermuteRep f xs =
   let (ps,ns) = partition ((>0) . snd) xs
   in  if any ((<0) . snd) ns
         then []
         else f ps

permuteRep :: [(a,Int)] -> [[a]]
permuteRep = runPermuteRep permuteRepAux

permuteRepAux :: [(a,Int)] -> [[a]]
permuteRepAux [] = [[]]
permuteRepAux xs =
   concatMap (\(ys,(a,n),zs) ->
      let m = pred n
      in  map (a:) (permuteRepAux (ys ++ (m>0, (a, m)) ?: zs))) $
   filter (\(_,(_,n),_) -> n>0) $
   splitEverywhere xs

permuteRepM :: [(a,Int)] -> [[a]]
permuteRepM = runPermuteRep permuteRepMAux

permuteRepMAux :: [(a,Int)] -> [[a]]
permuteRepMAux [] = [[]]
permuteRepMAux xs =
   do (ys,(a,n),zs) <- splitEverywhere xs
      let m = pred n
      liftM (a:)
         (permuteRepMAux (ys ++ (m>0, (a, m)) ?: zs))


infixr 5 ?:

(?:) :: (Bool, a) -> [a] -> [a]
(True,a)  ?: xs = a:xs
(False,_) ?: xs = xs


choose :: Int -> Int -> [[Bool]]
choose n k =
   if k<0 || k>n
     then []
     else
       if n==0
         then [[]]
         else
           map (False:) (choose (pred n) k) ++
           map (True:)  (choose (pred n) (pred k))

chooseMSL :: Int -> Int -> [[Bool]]
chooseMSL n0 k0 =
   flip MS.evalStateT k0 $ fmap catMaybes $ sequence $
   intersperse (MS.StateT $ \k -> [(Just False, k), (Just True, pred k)]) $
   flip map [n0,n0-1..0] $ \n ->
   MS.gets (\k -> 0<=k && k<=n) >>= guard >> return Nothing

_chooseMSL :: Int -> Int -> [[Bool]]
_chooseMSL n0 k0 =
   flip MS.evalStateT k0 $ do
   count <-
      forM [n0,n0-1..1] $ \n ->
      MS.StateT $ \k ->
      guard (0<=k && k<=n) >> [(False, k), (True, pred k)]
   MS.gets (0==) >>= guard
   return count


{- |
Generate all choices of n elements out of the list x with repetitions.
\"variation\" seems to be used historically,
but I like it more than \"k-permutation\".
-}
variateRep :: Int -> [a] -> [[a]]
variateRep n x = nest n (\y -> concatMap (\z -> map (z:) y) x) [[]]

variateRepMSL :: Int -> [a] -> [[a]]
variateRepMSL = replicateM


{- |
Generate all choices of n elements out of the list x without repetitions.
It holds
   @ variate (length xs) xs == permute xs @
-}
variate :: Int -> [a] -> [[a]]
variate 0 _ = [[]]
variate n x =
   concatMap (\(y, ys) -> map (y:) (variate (n-1) ys))
             (removeEach x)

variateMSL :: Int -> [a] -> [[a]]
variateMSL n xs =
   flip MS.evalStateT xs $ replicateM n $
   MS.StateT removeEach


{- |
Generate all choices of n elements out of the list x
respecting the order in x and without repetitions.
-}
tuples :: Int -> [a] -> [[a]]
tuples 0 _  = [[]]
tuples r xs =
   concatMap (\(y:ys) -> map (y:) (tuples (r-1) ys))
             (init (tails xs))

tuplesMSL :: Int -> [a] -> [[a]]
tuplesMSL n xs =
   flip MS.evalStateT xs $ replicateM n $
   MS.StateT $ mapMaybe viewL . tails

_tuplesMSL :: Int -> [a] -> [[a]]
_tuplesMSL n xs =
   flip MS.evalStateT xs $
   replicateM n $ do
      yl <- MS.get
      (y:ys) <- MT.lift $ tails yl
      MS.put ys
      return y

tuplesRec :: Int -> [a] -> [[a]]
tuplesRec k xt =
   if k<0
     then []
     else
       case xt of
          [] -> guard (k==0) >> [[]]
          x:xs ->
             tuplesRec k xs ++
             map (x:) (tuplesRec (pred k) xs)


partitions :: [a] -> [([a],[a])]
partitions =
   foldr
      (\x -> concatMap (\(lxs,rxs) -> [(x:lxs,rxs), (lxs,x:rxs)]))
      [([],[])]

{- |
Number of possibilities arising in rectification of a predicate
in deductive database theory.
Stefan Brass, \"Logische Programmierung und deduktive Datenbanken\", 2007,
page 7-60
This is isomorphic to the partition of @n@-element sets
into @k@ non-empty subsets.
<http://oeis.org/A048993>

> *Combinatorics> map (length . uncurry rectifications) $ do x<-[0..10]; y<-[0..x]; return (x,[1..y::Int])
> [1,0,1,0,1,1,0,1,3,1,0,1,7,6,1,0,1,15,25,10,1,0,1,31,90,65,15,1,0,1,63,301,350,140,21,1,0,1,127,966,1701,1050,266,28,1,0,1,255,3025,7770,6951,2646,462,36,1,0,1,511,9330,34105,42525,22827,5880,750,45,1]
-}
rectifications :: Int -> [a] -> [[a]]
rectifications =
   let recourse _ 0 xt =
          if null xt
            then [[]]
            else []
       recourse ys n xt =
          let n1 = pred n
          in  liftM2 (:) ys (recourse ys n1 xt) ++
              case xt of
                 [] -> []
                 (x:xs) -> map (x:) (recourse (ys++[x]) n1 xs)
   in  recourse []

{- |
Their number is @k^n@.
-}
{-
setPartitionsEmpty :: Int -> [a] -> [[[a]]]
setPartitionsEmpty k =
   let recourse [] = [replicate k []]
       recourse (x:xs) =
          map (\(ys0,y,ys1) -> ys0 ++ [x:y] ++ ys1) $
          concatMap splitEverywhere (recourse xs)
{-
          do xs1 <- recourse xs
             (ys0,y,ys1) <- splitEverywhere xs1
             return (ys0 ++ [x:y] ++ ys1)
-}
   in  recourse
-}

setPartitions :: Int -> [a] -> [[[a]]]
setPartitions 0 xs =
   if null xs
     then [[]]
     else [  ]
setPartitions _ [] = []
setPartitions 1 xs = [[xs]]  -- unnecessary for correctness, but useful for efficiency
setPartitions k (x:xs) =
   do (rest, choosen) <- partitions xs
      part <- setPartitions (pred k) rest
      return ((x:choosen) : part)


{-* Compute the number of certain compositions from a number of elements. -}

{- |
@chooseFromIndex n k i == choose n k !! i@
-}
chooseFromIndex :: Integral a => a -> a -> a -> [Bool]
chooseFromIndex n 0 _ = genericReplicate n False
chooseFromIndex n k i =
   let n1 = pred n
       p = binomial n1 k
       b = i>=p
   in  b :
       if b
         then chooseFromIndex n1 (pred k) (i-p)
         else chooseFromIndex n1 k i

chooseFromIndexList :: Integral a => a -> a -> a -> [Bool]
chooseFromIndexList n k0 i0 =
--   (\((0,0), xs) -> xs) $
   snd $
   mapAccumL
      (\(k,i) bins ->
          let p = genericIndex (bins++[0]) k
              b = i>=p
          in  (if b
                 then (pred k, i-p)
                 else (k, i),
               b))
      (k0,i0) $
   reverse $
   genericTake n binomials


chooseFromIndexMaybe :: Int -> Int -> Int -> Maybe [Bool]
chooseFromIndexMaybe n k i =
   toMaybe
      (0 <= i && i < binomial n k)
      (chooseFromIndex n k i)
-- error ("chooseFromIndex: out of range " ++ show (n, k, i))


chooseToIndex :: Integral a => [Bool] -> (a, a, a)
chooseToIndex =
   foldl
      (\(n,k0,i0) (bins,b) ->
        let (k1,i1) = if b then (succ k0, i0 + genericIndex (bins++[0]) k1) else (k0,i0)
        in  (succ n, k1, i1))
      (0,0,0) .
   zip binomials .
   reverse


{-* Generate complete lists of combinatorial numbers. -}

factorial :: Integral a => a -> a
factorial n = product [1..n]

{-| Pascal's triangle containing the binomial coefficients. -}
binomial :: Integral a => a -> a -> a
binomial n k =
   let bino n' k' =
         if k'<0
           then 0
           else genericIndex (binomialSeq n') k'
   in  if n<2*k
         then bino n (n-k)
         else bino n k

binomialSeq :: Integral a => a -> [a]
binomialSeq n =
   {- this does not work because the corresponding numbers are not always divisible
    product (zipWith div [n', pred n' ..] [1..k'])
   -}
   scanl (\acc (num,den) -> div (acc*num) den) 1
         (zip [n, pred n ..] [1..n])


binomialGen :: (Integral a, Fractional b) => b -> a -> b
binomialGen n k = genericIndex (binomialSeqGen n) k

binomialSeqGen :: (Fractional b) => b -> [b]
binomialSeqGen n =
   scanl (\acc (num,den) -> acc*num / den) 1
         (zip (iterate (subtract 1) n) (iterate (1+) 1))


multinomial :: Integral a => [a] -> a
multinomial =
   product . mapAdjacent binomial . scanr1 (+)


{-* Generate complete lists of factorial numbers. -}

factorials :: Num a => [a]
factorials = scanl (*) 1 (iterate (+1) 1)

{-|
Pascal's triangle containing the binomial coefficients.
Only efficient if a prefix of all rows is required.
It is not efficient for picking particular rows
or even particular elements.
-}
binomials :: Num a => [[a]]
binomials =
   let conv11 x = zipWith (+) ([0]++x) (x++[0])
   in  iterate conv11 [1]


{- |
@catalanNumber n@ computes the number of binary trees with @n@ nodes.
-}
catalanNumber :: Integer -> Integer
catalanNumber n =
   let (c,r) = divMod (binomial (2*n) n) (n+1)
   in  if r==0
         then c
         else error "catalanNumber: Integer implementation broken"

{- |
Compute the sequence of Catalan numbers by recurrence identity.
It is @catalanNumbers !! n == catalanNumber n@
-}
catalanNumbers :: Num a => [a]
catalanNumbers =
   let xs = 1 : PowerSeries.mul xs xs
   in  xs



derangementNumber :: Integer -> Integer
derangementNumber n =
   sum (scanl (*) ((-1) ^ mod n 2) [-n,1-n..(-1)])

{- |
Number of fix-point-free permutations with @n@ elements.

<http://oeis.org/A000166>
-}
derangementNumbers :: Num a => [a]
derangementNumbers =
   -- OEIS-A166: a(n) = n·a(n-1)+(-1)^n
   -- y(x) = 1/(1+x) + x · (t -> y(t)·t)'(x)
   let xs = PowerSeries.add
               (cycle [1,-1])
               (0 : PowerSeries.differentiate (0 : xs))
   in  xs

derangementNumbersAlt :: Num a => [a]
derangementNumbersAlt =
   -- OEIS-A166: a(n) = (n-1)·(a(n-1)+a(n-2))
   -- y(x) = 1 + x^2 · (t -> y(t)·(1+t))'(x)
   let xs =
         1 : 0 :
             PowerSeries.differentiate
                (PowerSeries.add xs (0 : xs))
   in  xs

derangementNumbersInclExcl :: Num a => [a]
derangementNumbersInclExcl =
   let xs = zipWith (-) factorials (map (scalarProduct xs . init) binomials)
   in  xs


-- generation of all possibilities and computation of their number should be in different modules

{- |
Number of partitions of an @n@ element set into @k@ non-empty subsets.
Known as Stirling numbers <http://oeis.org/A048993>.
-}
setPartitionNumbers :: Num a => [[a]]
setPartitionNumbers =
   -- s_{n+1,k} = s_{n,k-1} + k·s_{n,k}
   iterate (\x -> 0 : PowerSeries.add x (PowerSeries.differentiate x)) [1]


{- |
@surjectiveMappingNumber n k@ computes the number of surjective mappings
from a @n@ element set to a @k@ element set.

<http://oeis.org/A019538>
-}
surjectiveMappingNumber :: Integer -> Integer -> Integer
surjectiveMappingNumber n k =
   foldl subtract 0 $
   zipWith (*)
      (map (^n) [0..])
      (binomialSeq k)

surjectiveMappingNumbers :: Num a => [[a]]
surjectiveMappingNumbers =
   iterate
      (\x -> 0 : PowerSeries.differentiate
                (PowerSeries.add x (0 : x))) [1]

surjectiveMappingNumbersStirling :: Num a => [[a]]
surjectiveMappingNumbersStirling =
   map (zipWith (*) factorials) setPartitionNumbers


{- |
Multiply two Fibonacci matrices, that is matrices of the form

> /F[n-1] F[n]  \
> \F[n]   F[n+1]/
-}
fiboMul ::
   (Integer,Integer,Integer) ->
   (Integer,Integer,Integer) ->
   (Integer,Integer,Integer)
fiboMul (f0,f1,f2) (g0,g1,g2) =
   let h0 = f0*g0 + f1*g1
       h1 = f0*g1 + f1*g2
--     h1 = f1*g0 + f2*g1
       h2 = f1*g1 + f2*g2
   in  (h0,h1,h2)


{-
Fast computation using matrix power of

> /0 1\
> \1 1/

Hard-coded fast power with integer exponent.
Better use a generic algorithm.
-}
fibonacciNumber :: Integer -> Integer
fibonacciNumber x =
   let aux   0  = (1,0,1)
       aux (-1) = (-1,1,0)
       aux n =
          let (m,r) = divMod n 2
              f = aux m
              f2 = fiboMul f f
          in  if r==0
                then f2
                else fiboMul (0,1,1) f2
       (_,y,_) = aux x
   in  y


{- |
Number of possibilities to compose a 2 x n rectangle of n bricks.

>  |||   |--   --|
>  |||   |--   --|
-}
fibonacciNumbers :: [Integer]
fibonacciNumbers =
   let xs = 0 : ys
       ys = 1 : zipWith (+) xs ys
   in  xs



{- * Auxiliary functions -}

{- candidates for Useful -}

{- | Create a list of all possible rotations of the input list. -}
allCycles :: [a] -> [[a]]
allCycles x =
   Match.take x (map (Match.take x) (iterate tail (cycle x)))