{-# LANGUAGE TypeFamilies #-}

{-# LANGUAGE TransformListComp #-}

{-# LANGUAGE MultiParamTypeClasses #-}

{-# LANGUAGE FlexibleContexts #-}



module Control.CP.EnumTerm (

  EnumTerm(..),

  assignment, assignments,

  inOrder, firstFail, middleOut, endsOut,

  labelling, levelList, enumerate

) where



import GHC.Exts (sortWith)



import Control.CP.Solver

import Control.CP.SearchTree



class (Solver s, Term s t, Show (TermBaseType s t)) => EnumTerm s t where

  type TermBaseType s t :: *



  getDomainSize :: t -> s (Int)

  getDomain :: t -> s [TermBaseType s t]

  setValue :: t -> TermBaseType s t -> s [Constraint s]

  splitDomain :: t -> s ([[Constraint s]],Bool)

  splitDomains :: [t] -> s ([[Constraint s]],[t])

  getValue :: t -> s (Maybe (TermBaseType s t))

  defaultOrder :: [t] -> s [t]

  enumerator :: (MonadTree m, TreeSolver m ~ s) => Maybe ([t] -> m ())



  getDomainSize x = do

    r <- getDomain x

    return $ length r



  getValue x = do

    d <- getDomain x

    return $ case d of

      [v] -> Just v

      _ -> Nothing

  splitDomain x = do

    d <- getDomain x

    case d of

      [] ->  return ([],True)

      [_] -> return ([[]],True)

      _ ->   do

        rr <- mapM (setValue x) d

        return (rr,True)



  splitDomains [] = return ([[]],[])

  splitDomains (a@(x:b)) = do

    s <- getDomainSize x

    if s==0

      then return ([],[])

      else if s==1 

        then splitDomains b

        else do

          (r,v) <- splitDomain x

          if v

            then return (r,b)

            else return (r,a)



  defaultOrder = firstFail

  enumerator = Nothing



enumerate :: (MonadTree m, TreeSolver m ~ s, EnumTerm s t) => [t] -> m ()

enumerate = case enumerator of

  Nothing -> labelling defaultOrder

  Just x -> x



assignment :: (EnumTerm s t, MonadTree m, TreeSolver m ~ s) => t -> m (TermBaseType s t)

assignment q = label $ getValue q >>= \y -> (case y of Just x -> return $ return x; _ -> return false)



assignments :: (EnumTerm s t, MonadTree m, TreeSolver m ~ s) => [t] -> m [TermBaseType s t]

assignments = mapM assignment



firstFail :: EnumTerm s t => [t] -> s [t]

firstFail qs = do ds <- mapM getDomainSize qs 

                  return [ q | (d,q) <- zip ds qs 

                             , then sortWith by d ]



inOrder :: EnumTerm s t => [t] -> s [t]

inOrder = return



middleOut :: EnumTerm s t => [t] -> s [t]

middleOut l = let n = (length l) `div` 2 in

              return $ interleave (drop n l) (reverse $ take n l)



endsOut :: EnumTerm s t => [t] -> s [t]

endsOut  l = let n = (length l) `div` 2 in

             return $ interleave (reverse $ drop n l) (take n l)



interleave []     ys = ys

interleave (x:xs) ys = x:interleave ys xs



levelList :: (Solver s, MonadTree m, TreeSolver m ~ s) => [m ()] -> m ()

levelList [] = false

levelList [a] = a

levelList l = 

  let len = length l

      (p1,p2) = splitAt (len `div` 2) l

      in (levelList p1) \/ (levelList p2)

--levelList [] = false

--levelList [a] = a

--levelList (a:b) = a \/ levelList b



labelling :: (MonadTree m, TreeSolver m ~ s, EnumTerm s t) => ([t] -> s [t]) -> [t] -> m ()

labelling _ [] = true

labelling o l = label $ do 

  ll <- o l

  (cl,c) <- splitDomains ll

  let ml = map (\l -> foldr (/\) true $ map addC l) cl

  return $ do

    levelList ml

    labelling return c