{-# LANGUAGE BangPatterns, FlexibleContexts, RecordWildCards #-}

module Control.Monad.LPMonad.Internal (
-- 	module Data.LinearProgram.Common,
	-- * Monad definitions
	LPM,
	LPT,
	runLPM,
	runLPT,
	execLPM,
	execLPT,
	evalLPM,
	evalLPT,
	-- * Constructing the LP
	-- ** Objective configuration
	setDirection,
	setObjective,
	addObjective,
	addWeightedObjective,
	-- ** Two-function constraints
	leq,
	equal,
	geq,
	leq',
	equal',
	geq',
	-- ** One-function constraints
	leqTo,
	equalTo,
	geqTo,
	constrain,
	leqTo',
	equalTo',
	geqTo',
	constrain',
	-- ** Variable constraints
	varLeq,
	varEq,
	varGeq,
	varBds,
	setVarBounds,
	setVarKind,
-- 	newVariables,
-- 	newVariables'
	) where

import Control.Monad.State.Strict
import Control.Monad.Identity

import Data.Map
import Data.Monoid
import Data.Functor

import Data.LinearProgram.Common

-- | A simple monad for constructing linear programs.  This library is intended to be able to link to 
-- a variety of different linear programming implementations.
type LPM v c = LPT v c Identity

-- | A simple monad transformer for constructing linear programs in an arbitrary monad.
type LPT v c = StateT (LP v c)

runLPM :: (Ord v, Group c) => LPM v c a -> (a, LP v c)
runLPM = runIdentity . runLPT

runLPT :: (Ord v, Group c) => LPT v c m a -> m (a, LP v c)
runLPT m = runStateT m (LP Max zero [] mempty mempty)

-- | Constructs a linear programming problem.
execLPM :: (Ord v, Group c) => LPM v c a -> LP v c
execLPM = runIdentity . execLPT

-- | Constructs a linear programming problem in the specified monad.
execLPT :: (Ord v, Group c, Monad m) => LPT v c m a -> m (LP v c)
execLPT = liftM snd . runLPT

-- | Runs the specified operation in the linear programming monad.
evalLPM :: (Ord v, Group c) => LPM v c a -> a
evalLPM = runIdentity . evalLPT

-- | Runs the specified operation in the linear programming monad transformer.
evalLPT :: (Ord v, Group c, Monad m) => LPT v c m a -> m a
evalLPT = liftM fst . runLPT

-- | Sets the optimization direction of the linear program: maximization or minimization.
{-# SPECIALIZE setDirection :: Direction -> LPM v c (), Monad m => Direction -> LPT v c m () #-}
setDirection :: (MonadState (LP v c) m) => Direction -> m ()
setDirection dir = modify (\ lp -> lp{direction = dir})

{-# SPECIALIZE equal :: (Ord v, Group c) => LinFunc v c -> LinFunc v c -> LPM v c (),
	(Ord v, Group c, Monad m) => LinFunc v c -> LinFunc v c -> LPT v c m () #-}
{-# SPECIALIZE leq :: (Ord v, Group c) => LinFunc v c -> LinFunc v c -> LPM v c (),
	(Ord v, Group c, Monad m) => LinFunc v c -> LinFunc v c -> LPT v c m () #-}
{-# SPECIALIZE geq :: (Ord v, Group c) => LinFunc v c -> LinFunc v c -> LPM v c (),
	(Ord v, Group c, Monad m) => LinFunc v c -> LinFunc v c -> LPT v c m () #-}
-- | Specifies the relationship between two functions in the variables.  So, for example,
-- 
-- > equal (f ^+^ g) h
-- 
-- constrains the value of @h@ to be equal to the value of @f@ plus the value of @g@.
equal, leq, geq :: (Ord v, Group c, MonadState (LP v c) m) => LinFunc v c -> LinFunc v c -> m ()
equal f g = equalTo (f ^-^ g) zero
leq f g = leqTo (f ^-^ g) zero
geq = flip leq

{-# SPECIALIZE equal' :: (Ord v, Group c) => String -> LinFunc v c -> LinFunc v c -> LPM v c (),
	(Ord v, Group c, Monad m) => String -> LinFunc v c -> LinFunc v c -> LPT v c m () #-}
{-# SPECIALIZE geq' :: (Ord v, Group c) => String -> LinFunc v c -> LinFunc v c -> LPM v c (),
	(Ord v, Group c, Monad m) => String -> LinFunc v c -> LinFunc v c -> LPT v c m () #-}
{-# SPECIALIZE leq' :: (Ord v, Group c) => String -> LinFunc v c -> LinFunc v c -> LPM v c (),
	(Ord v, Group c, Monad m) => String -> LinFunc v c -> LinFunc v c -> LPT v c m () #-}
-- | Specifies the relationship between two functions in the variables, with a label on the constraint.
equal', leq', geq' :: (Ord v, Group c, MonadState (LP v c) m) => String -> LinFunc v c -> LinFunc v c -> m ()
equal' lab f g = equalTo' lab (f ^-^ g) zero
leq' lab f g = leqTo' lab (f ^-^ g) zero
geq' = flip . leq'

{-# SPECIALIZE equalTo :: LinFunc v c -> c -> LPM v c (), Monad m => LinFunc v c -> c -> LPT v c m () #-}
{-# SPECIALIZE geqTo :: LinFunc v c -> c -> LPM v c (), Monad m => LinFunc v c -> c -> LPT v c m () #-}
{-# SPECIALIZE leqTo :: LinFunc v c -> c -> LPM v c (), Monad m => LinFunc v c -> c -> LPT v c m () #-}
-- | Sets a constraint on a linear function in the variables.
equalTo, leqTo, geqTo :: MonadState (LP v c) m => LinFunc v c -> c -> m ()
equalTo f v = constrain f (Equ v)
leqTo f v = constrain f (UBound v)
geqTo f v = constrain f (LBound v)

{-# SPECIALIZE equalTo' :: String -> LinFunc v c -> c -> LPM v c (),
	Monad m => String -> LinFunc v c -> c -> LPT v c m () #-}
{-# SPECIALIZE geqTo' :: String -> LinFunc v c -> c -> LPM v c (),
	Monad m => String -> LinFunc v c -> c -> LPT v c m () #-}
{-# SPECIALIZE leqTo' :: String -> LinFunc v c -> c -> LPM v c (),
	Monad m => String -> LinFunc v c -> c -> LPT v c m () #-}
-- | Sets a labeled constraint on a linear function in the variables.
equalTo', leqTo', geqTo' :: MonadState (LP v c) m => String -> LinFunc v c -> c -> m ()
equalTo' lab f v = constrain' lab f (Equ v)
leqTo' lab f v = constrain' lab f (UBound v)
geqTo' lab f v = constrain' lab f (LBound v)

-- {-# SPECIALIZE newVariables :: (Ord v, Enum v) => Int -> LPM v c [v],
-- 	(Ord v, Enum v, Monad m) => Int -> LPT v c m [v] #-}
-- -- | Returns a list of @k@ unused variables.  If the program is currently empty,
-- -- starts at @'toEnum' 0@.  Otherwise, if @v@ is the biggest variable currently in use
-- -- (by the 'Ord' ordering), then this returns @take k (tail [v..])@, which uses the 'Enum'
-- -- implementation.  Note that if the 'Enum' instance doesn't play well with 'Ord',
-- -- bad things can happen.
-- newVariables :: (MonadState (LP v c) m, Ord v, Enum v) => Int -> m [v]
-- newVariables !k = do	LP{..} <- get
-- 			let allVars0 = () <$ objective `union`
-- 				unions [() <$ f | Constr _ f _ <- constraints] `union`
-- 				(() <$ varBounds) `union` (() <$ varTypes)
-- 			case minViewWithKey allVars0 of
-- 				Nothing	-> return $ take k [toEnum 0..]
-- 				Just ((start, _), _)
-- 					-> return $ take k $ tail [start..]
-- 					
-- {-# SPECIALIZE newVariables' :: (Ord v, Enum v) => LPM v c [v],
-- 	(Ord v, Enum v, Monad m) => LPT v c m [v] #-}
-- -- | Returns an infinite list of unused variables.  If the program is currently empty,
-- -- starts at @'toEnum' 0@.  Otherwise, if @v@ is the biggest variable currently in use
-- -- (by the 'Ord' ordering), then this returns @tail [v..]@, which uses the 'Enum'
-- -- implementation.  Note that if the 'Enum' instance doesn't play well with 'Ord',
-- -- bad things can happen.
-- newVariables' :: (MonadState (LP v c) m, Ord v, Enum v) => m [v]
-- newVariables' = do	LP{..} <- get
-- 			let allVars0 = () <$ objective `union`
-- 				unions [() <$ f | Constr _ f _ <- constraints] `union`
-- 				(() <$ varBounds) `union` (() <$ varTypes)
-- 			case minViewWithKey allVars0 of
-- 				Nothing	-> return [toEnum 0..]
-- 				Just ((start, _), _)
-- 					-> return $ tail [start..]

{-# SPECIALIZE varEq :: (Ord v, Ord c) => v -> c -> LPM v c (),
	(Ord v, Ord c, Monad m) => v -> c -> LPT v c m () #-}
{-# SPECIALIZE varLeq :: (Ord v, Ord c) => v -> c -> LPM v c (),
	(Ord v, Ord c, Monad m) => v -> c -> LPT v c m () #-}
{-# SPECIALIZE varGeq :: (Ord v, Ord c) => v -> c -> LPM v c (),
	(Ord v, Ord c, Monad m) => v -> c -> LPT v c m () #-}
-- | Sets a constraint on the value of a variable.  If you constrain a variable more than once,
-- the constraints will be combined.  If the constraints are mutually contradictory,
-- an error will be generated.  This is more efficient than adding an equivalent function constraint.
varEq, varLeq, varGeq :: (Ord v, Ord c, MonadState (LP v c) m) => v -> c -> m ()
varEq v c = setVarBounds v (Equ c)
varLeq v c = setVarBounds v (UBound c)
varGeq v c = setVarBounds v (LBound c)

{-# SPECIALIZE varBds :: (Ord v, Ord c) => v -> c -> c -> LPM v c (),
	(Ord v, Ord c, Monad m) => v -> c -> c -> LPT v c m () #-}
-- | Bounds the value of a variable on both sides.  If you constrain a variable more than once,
-- the constraints will be combined.  If the constraints are mutually contradictory,
-- an error will be generated.  This is more efficient than adding an equivalent function constraint.
varBds :: (Ord v, Ord c, MonadState (LP v c) m) => v -> c -> c -> m ()
varBds v l u = setVarBounds v (Bound l u)

{-# SPECIALIZE constrain :: LinFunc v c -> Bounds c -> LPM v c (),
	Monad m => LinFunc v c -> Bounds c -> LPT v c m () #-}
-- | The most general form of an unlabeled constraint.
constrain :: MonadState (LP v c) m => LinFunc v c -> Bounds c -> m ()
constrain f bds = modify addConstr where
	addConstr lp@LP{..}
		= lp{constraints = Constr Nothing f bds:constraints}

{-# SPECIALIZE constrain' :: String -> LinFunc v c -> Bounds c -> LPM v c (),
	Monad m => String -> LinFunc v c -> Bounds c -> LPT v c m () #-}
-- | The most general form of a labeled constraint.
constrain' :: MonadState (LP v c) m => String -> LinFunc v c -> Bounds c -> m ()
constrain' lab f bds = modify addConstr where
	addConstr lp@LP{..}
		= lp{constraints = Constr (Just lab) f bds:constraints}

{-# SPECIALIZE setObjective :: LinFunc v c -> LPM v c (),
	Monad m => LinFunc v c -> LPT v c m () #-}
-- | Sets the objective function, overwriting the previous objective function.
setObjective :: MonadState (LP v c) m => LinFunc v c -> m ()
setObjective obj = modify setObj where
	setObj lp = lp{objective = obj}

{-# SPECIALIZE addObjective :: (Ord v, Group c) => LinFunc v c -> LPM v c (),
	(Ord v, Group c, Monad m) => LinFunc v c -> LPT v c m () #-}
-- | Adds this function to the objective function.
addObjective :: (Ord v, Group c, MonadState (LP v c) m) => LinFunc v c -> m ()
addObjective obj = modify addObj where
	addObj lp@LP{..} = lp {objective = obj ^+^ objective}

{-# SPECIALIZE addWeightedObjective :: (Ord v, Module r c) => r -> LinFunc v c -> LPM v c (),
	(Ord v, Module r c, Monad m) => r -> LinFunc v c -> LPT v c m () #-}
-- | Adds this function to the objective function, with the specified weight.  Equivalent to
-- @'addObjective' (wt '*^' obj)@.
addWeightedObjective :: (Ord v, Module r c, MonadState (LP v c) m) => r -> LinFunc v c -> m ()
addWeightedObjective wt obj = addObjective (wt *^ obj)

{-# SPECIALIZE setVarBounds :: (Ord v, Ord c) => v -> Bounds c -> LPM v c (),
	(Ord v, Ord c, Monad m) => v -> Bounds c -> LPT v c m () #-}
-- | The most general way to set constraints on a variable.
-- If you constrain a variable more than once, the constraints will be combined.
-- If you combine mutually contradictory constraints, an error will be generated.
-- This is more efficient than creating an equivalent function constraint.
setVarBounds :: (Ord v, Ord c, MonadState (LP v c) m) => v -> Bounds c -> m ()
setVarBounds var bds = modify addBds where
	addBds lp@LP{..} = lp{varBounds = insertWith mappend var bds varBounds}

{-# SPECIALIZE setVarKind :: Ord v => v -> VarKind -> LPM v c (),
	(Ord v, Monad m) => v -> VarKind -> LPT v c m () #-}
-- | Sets the kind ('type') of a variable.  See 'VarKind'.
setVarKind :: (Ord v, MonadState (LP v c) m) => v -> VarKind -> m ()
setVarKind v k = modify setK where
	setK lp@LP{..} = lp{varTypes = insertWith mappend v k varTypes}