module Game.Mastermind.CodeSet.Union (
   T, member, size,
   fromLists, cube,
   overlappingPairs, overlapping,
   ) where

import qualified Game.Mastermind.CodeSet as CodeSet
import Game.Utility (nonEmptySetToList, )

import qualified Data.NonEmpty.Set as NonEmptySet
import qualified Data.NonEmpty as NonEmpty
import qualified Data.Set as Set
import qualified Data.List.HT as ListHT
import qualified Data.List as List
import Data.Maybe (mapMaybe, )

import Control.Monad (liftM2, guard, )


{- |
@Cons [[a,b,c,d], [e,f,g,h]]@
expresses  a x b x c x d  union  e x f x g x h,
where @x@ denotes the set product.
-}
newtype T a = Cons [[NonEmptySet.T a]]

instance (Ord a, Show a) => Show (T a) where
   showsPrec n cs =
      showParen (n>=10) $
      showString "CodeSet.fromLists " . shows (toLists cs)

instance CodeSet.C T where
   empty = Cons []
   union = union
   intersection = intersection
   unit = Cons [[]]
   leftNonEmptyProduct c (Cons xs) = Cons (map (c:) xs)
   flatten = flatten
   symbols = symbols
   null (Cons xs) = null xs
   size = size
   select = select
   representationSize = representationSize
   compress = id


toLists :: (Ord a) => T a -> [[[a]]]
toLists (Cons xs) = map (map nonEmptySetToList) xs

fromLists :: (Ord a) => [[NonEmpty.T [] a]] -> T a
fromLists = Cons . map (map NonEmptySet.fromList)

flatten :: (Ord a) => T a -> [[a]]
flatten = concatMap sequence . toLists

symbols :: (Ord a) => T a -> Set.Set a
symbols = Set.unions . map Set.unions . flattenFactors

cube :: Int -> NonEmptySet.T a -> T a
cube n alphabet = Cons [replicate n alphabet]


size :: T a -> Integer
size = sum . productSizes

productSizes :: T a -> [Integer]
productSizes (Cons x) =
   map (product . map (fromIntegral . NonEmptySet.size)) x

select :: T a -> Integer -> [a]
select set@(Cons xs) n0 =
   let sizes = productSizes set
   in  if n0<0
         then error "CodeSet.select: negative index"
         else
           case dropWhile (\(n1,sz,_) -> n1>=sz) $
                zip3 (scanl (-) n0 sizes) sizes xs of
             [] -> error "CodeSet.select: index too large"
             (n1,_,prod) : _ ->
                (\(n3,cs) ->
                   if n3==0
                     then cs
                     else error "CodeSet.select: at the end index must be zero") $
                List.mapAccumR
                   (\n2 componentSet ->
                      let (n3,i) =
                              divMod n2
                                 (fromIntegral $ NonEmptySet.size componentSet)
                      in  (n3,
                           nonEmptySetToList componentSet !! fromInteger i))
                   n1 prod

representationSize :: T a -> Int
representationSize (Cons x) =
   sum . map (sum . map NonEmptySet.size) $ x


{- |
We could try to merge set products.
I'll first want to see, whether this is needed in a relevant number of cases.
-}
union :: T a -> T a -> T a
union (Cons x) (Cons y) = Cons (x++y)

intersection :: (Ord a) => T a -> T a -> T a
intersection x y =
   normalize $
   liftM2 (zipWith Set.intersection) (flattenFactors x) (flattenFactors y)

member :: (Ord a) => [a] -> T a -> Bool
member code (Cons xs) =
   any (and . zipWith NonEmptySet.member code) xs

{- |
Remove empty set products.
-}
normalize :: (Ord a) => [[Set.Set a]] -> T a
normalize = Cons . mapMaybe (mapM NonEmptySet.fetch)

flattenFactors :: (Ord a) => T a -> [[Set.Set a]]
flattenFactors (Cons xs) = map (map NonEmptySet.flatten) xs


disjointProduct :: (Ord a) => [Set.Set a] -> [Set.Set a] -> Bool
disjointProduct prod0 prod1 =
   any Set.null $ zipWith Set.intersection prod0 prod1

{- |
for debugging: list all pairs of products, that overlap
-}
overlappingPairs :: (Ord a) => T a -> [([Set.Set a], [Set.Set a])]
overlappingPairs set = do
   prod0:rest <- ListHT.tails $ flattenFactors set
   prod1 <- rest
   guard $ not $ disjointProduct prod0 prod1
   return (prod0, prod1)

{- |
for debugging: list all subsets, that are contained in more than one product
-}
overlapping :: (Ord a) => T a -> [([Set.Set a], [[Set.Set a]])]
overlapping set = do
   let xs = flattenFactors set
   subset <- Set.toList $ Set.fromList $ do
      prod0:rest <- ListHT.tails xs
      prod1 <- rest
      let sec = zipWith Set.intersection prod0 prod1
      guard $ all (not . Set.null) $ sec
      return sec
   return (subset, filter (not . disjointProduct subset) xs)