{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- |
The following LP problem

maximize @4 x_1 - 3 x_2 + 2 x_3@ subject to

@2 x_1 + x_2 <= 10@

@x_2 + 5 x_3 <= 20@

and

@x_i >= 0@

is used as an example in the doctest comments.


By default all indeterminates are non-negative.
A given bound for a variable completely replaces the default,
so @0 <= x_i <= b@ must be explicitly given as @i >=<. (0,b)@.
Multiple bounds for a variable are not allowed,
instead of @[i >=. a, i <=. b]@ use @i >=<. (a,b)@.
-}
module Numeric.GLPK (
   Term(..),
   Bound(..),
   Inequality(..),
   free, (<=.), (>=.), (==.), (>=<.),
   NoSolutionType(..),
   SolutionType(..),
   Solution,
   Constraints,
   Direction(..),
   Objective,
   Bounds,
   (.*),
   objectiveFromTerms,
   simplex,
   simplexMulti,
   simplexSuccessive,
   exact,
   exactMulti,
   exactSuccessive,
   interior,
   interiorMulti,
   interiorSuccessive,

   solveSuccessive,

   FormatIdentifier,
   formatMathProg,
   ) where

import qualified Math.Programming.Glpk.Header as FFI
import Numeric.GLPK.Private

import qualified Data.Array.Comfort.Storable.Mutable as Mutable
import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape
import qualified Data.NonEmpty as NonEmpty
import qualified Data.List as List
import Data.Array.Comfort.Storable (Array)
import Data.Tuple.HT (mapFst, mapSnd)
import Data.Traversable (for)
import Data.Foldable (for_)

import Text.Printf (printf)

import qualified Control.Monad.Trans.Except as ME
import qualified Control.Monad.Trans.State as MS
import Control.Monad (void, when)
import Control.Applicative (liftA2)
import Control.Exception (bracket)

import System.IO.Unsafe (unsafePerformIO)

import qualified Foreign
import Foreign.Ptr (nullPtr)


{- $setup
>>> import qualified Test.Numeric.GLPK.Generator as TestLP
>>> import qualified Numeric.GLPK as LP
>>> import Numeric.GLPK ((.*), (<=.), (==.))
>>>
>>> import qualified Data.Array.Comfort.Storable as Array
>>> import qualified Data.Array.Comfort.Shape as Shape
>>>
>>> import Data.Tuple.HT (mapPair, mapSnd)
>>>
>>> import qualified Test.QuickCheck as QC
>>> import Test.QuickCheck ((===), (.&&.), (.||.))
>>>
>>> type X = Shape.Element
>>> type PairShape = Shape.NestedTuple Shape.TupleIndex (X,X)
>>> type TripletShape = Shape.NestedTuple Shape.TupleIndex (X,X,X)
>>>
>>> pairShape :: PairShape
>>> pairShape = Shape.static
>>>
>>> tripletShape :: TripletShape
>>> tripletShape = Shape.static
>>>
>>> round3 :: Double -> Double
>>> round3 x = fromInteger (round (1000*x)) / 1000
-}


infix 7 .*

(.*) :: Double -> ix -> Term ix
(.*) = Term


infix 4 <=., >=., >=<., ==.

(<=.), (>=.), (==.) :: x -> Double -> Inequality x
x <=. bnd = Inequality x $ LessEqual bnd
x >=. bnd = Inequality x $ GreaterEqual bnd
x ==. bnd = Inequality x $ Equal bnd

(>=<.) :: x -> (Double,Double) -> Inequality x
x >=<. bnd = Inequality x $ uncurry Between bnd

free :: x -> Inequality x
free x = Inequality x Free



objectiveFromTerms ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) => sh -> [Term ix] -> Objective sh
objectiveFromTerms sh =
   Array.fromAssociations 0 sh . map (\(Term x ix) -> (ix,x))


{- |
>>> case Shape.indexTupleFromShape tripletShape of (x1,x2,x3) -> mapSnd (mapSnd Array.toTuple) <$> LP.simplex [] [[2.*x1, 1.*x2] <=. 10, [1.*x2, 5.*x3] <=. 20] (LP.Maximize, Array.fromTuple (4,-3,2) :: Array.Array TripletShape Double)
Right (Optimal,(28.0,(5.0,0.0,4.0)))

prop> \target -> case Shape.indexTupleFromShape pairShape of (pos,neg) -> case mapSnd (mapSnd Array.toTuple) <$> LP.simplex [] [[1.*pos, (-1).*neg] ==. target] (LP.Minimize, Array.fromTuple (1,1) :: Array.Array PairShape Double) of (Right (LP.Optimal,(absol,(posResult,negResult)))) -> QC.property (TestLP.approxReal 0.001 absol (abs target)) .&&. (posResult === 0 .||. negResult === 0); _ -> QC.property False
prop> \(QC.Positive posWeight) (QC.Positive negWeight) target -> case Shape.indexTupleFromShape pairShape of (pos,neg) -> case mapSnd (mapSnd Array.toTuple) <$> LP.simplex [] [[1.*pos, (-1).*neg] ==. target] (LP.Minimize, Array.fromTuple (posWeight,negWeight) :: Array.Array PairShape Double) of (Right (LP.Optimal,(absol,(posResult,negResult)))) -> QC.property (absol>=0) .&&. (posResult === 0 .||. negResult === 0); _ -> QC.property False
prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAll (TestLP.genProblem origin) $ \(bounds, constrs) -> QC.forAll (TestLP.genObjective origin) $ \(dir,obj) -> case LP.simplex bounds constrs (dir,obj) of Right (LP.Optimal, _) -> True; _ -> False
-}
simplex ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) =>
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) -> Solution sh
simplex = solve (flip FFI.glp_simplex nullPtr)

{-# DEPRECATED simplexMulti "use GLPK.Monad instead" #-}
{-# DEPRECATED exactMulti "use GLPK.Monad instead" #-}
{-# DEPRECATED interiorMulti "run 'interior' in Either monad instead" #-}
{- |
Optimize for one objective after another.
That is, if the first optimization succeeds
then the optimum is fixed as constraint
and the optimization is continued with respect to the second objective and so on.
The iteration fails if one optimization fails.
The obtained objective values are returned as well.
Their number equals the number of attempted optimizations.

The last objective value is included in the Solution value.
This is a bit inconsistent,
but this way you have a warranty that there is an objective value
if the optimization is successful.

The objectives are expected as 'Term's
because after successful optimization step
they are used as (sparse) constraints.
It's also easy to assert that the same array shape
is used for all objectives.

The function does not work reliably,
because an added objective can make the system infeasible
due to rounding errors.
E.g. a non-negative objective can become very small but negative.


prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAllShrink (TestLP.genProblem origin) TestLP.shrinkProblem $ \(bounds, constrs) -> QC.forAllShrink (TestLP.genObjectives origin) TestLP.shrinkObjectives $ \objs -> case LP.simplexMulti bounds constrs (Array.shape origin) objs of (_, Right (LP.Optimal, _)) -> QC.property True; result -> QC.counterexample (show result) False

The same property fails for 'exactMulti' and 'interiorMulti'.
I guess, due to rounding errors.
-}
simplexMulti, exactMulti, interiorMulti ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) =>
   Bounds ix -> Constraints ix ->
   sh -> NonEmpty.T [] (Direction, [Term ix]) -> ([Double], Solution sh)
simplexMulti = solveMulti . simplex
exactMulti = solveMulti . exact
interiorMulti = solveMulti . interior

solveMulti ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) =>
   (Constraints ix -> (Direction, Objective sh) -> Solution sh) ->
   Constraints ix ->
   sh -> NonEmpty.T [] (Direction, [Term ix]) -> ([Double], Solution sh)
solveMulti solver constrs0 sh (NonEmpty.Cons obj0 objs0) =
   let go constrs curObj ((dir,obj):objs) (Right (Optimal, (opt,_))) =
         mapFst (opt:) $
         let extConstrs = (curObj==.opt) : constrs in
         go extConstrs obj objs $
         solver extConstrs (dir, objectiveFromTerms sh obj)
       go _ _ _ sol = ([], sol)
   in go constrs0 (snd obj0) objs0 $
      solver constrs0 $ mapSnd (objectiveFromTerms sh) obj0


{-# DEPRECATED simplexSuccessive "use GLPK.Monad instead" #-}
{-# DEPRECATED exactSuccessive "use GLPK.Monad instead" #-}
{-# DEPRECATED interiorSuccessive "run 'interior' in Either monad instead" #-}
{- |
Like the @Multi@ functions,
but allows not only to fix the previously
found optimal solution as constraint,
but allows constraints with a tolerance.
This is necessary in the presence of rounding errors.

prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAllShrink (TestLP.genProblem origin) TestLP.shrinkProblem $ \(bounds, constrs) -> QC.forAllShrink (TestLP.genObjectives origin) TestLP.shrinkObjectives $ \objs -> case uncurry (LP.simplexSuccessive bounds constrs) $ TestLP.successiveObjectives origin 0.01 objs of result -> QC.counterexample (show result) $ case result of Right results -> all (\r -> case r of (LP.Optimal, _) -> True; _ -> False) results; _ -> False
prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAllShrink (TestLP.genProblem origin) TestLP.shrinkProblem $ \(bounds, constrs) -> QC.forAllShrink (TestLP.genObjectives origin) TestLP.shrinkObjectives $ \objs -> case uncurry (LP.exactSuccessive bounds constrs) $ TestLP.successiveObjectives origin 0.01 objs of result -> QC.counterexample (show result) $ case result of Right results -> all (\r -> case r of (LP.Optimal, _) -> True; _ -> False) results; _ -> False
-}
simplexSuccessive, exactSuccessive, interiorSuccessive ::
   (Traversable f, Eq sh, Shape.Indexed sh, Shape.Index sh ~ ix) =>
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) ->
   f ((SolutionType, (Double, Array sh Double)) -> Constraints ix,
      (Direction, Objective sh)) ->
   Either NoSolutionType
      (NonEmpty.T f (SolutionType, (Double, Array sh Double)))
simplexSuccessive = solveSuccessiveInPlace (flip FFI.glp_simplex nullPtr)
exactSuccessive = solveSuccessiveInPlace (flip FFI.glp_exact nullPtr)
interiorSuccessive = solveSuccessive . interior

{-# DEPRECATED solveSuccessive
      "run simple solvers in GLPK.Monad or Either monad instead" #-}
{- |
Allows for generic implementation of 'simplexSuccessive' et.al.
without reuse of interim results.

prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAll (TestLP.genProblem origin) $ \(bounds, constrs) -> QC.forAll (TestLP.genObjectives origin) $ (. TestLP.successiveObjectives origin 0.01) $ \(obj,objs) -> case (LP.simplexSuccessive bounds constrs obj objs, LP.solveSuccessive (LP.simplex bounds) constrs obj objs) of (resultA,resultB) -> TestLP.approxSuccession 0.01 resultA resultB
prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAll (TestLP.genProblem origin) $ \(bounds, constrs) -> QC.forAll (TestLP.genObjectives origin) $ (. TestLP.successiveObjectives origin 0.01) $ \(obj,objs) -> case (LP.exactSuccessive bounds constrs obj objs, LP.solveSuccessive (LP.exact bounds) constrs obj objs) of (resultA,resultB) -> TestLP.approxSuccession 0.01 resultA resultB
-}
solveSuccessive ::
   (Traversable f, Eq sh, Shape.Indexed sh, Shape.Index sh ~ ix) =>
   (Constraints ix -> (Direction, Objective sh) -> Solution sh) ->
   Constraints ix ->
   (Direction, Objective sh) ->
   f ((SolutionType, (Double, Array sh Double)) -> Constraints ix,
      (Direction, Objective sh)) ->
   Either NoSolutionType
      (NonEmpty.T f (SolutionType, (Double, Array sh Double)))
solveSuccessive solver constrs0 obj0 objs = do
   let checkShape obj =
         if Array.shape (snd obj0) == Array.shape obj
            then obj
            else error "GLPK.solveSuccessive: objective shapes mismatch"
   let solveWithConstraints constrs problem =
         (\sol -> (sol, (constrs,sol))) <$> solver constrs problem
   (sol0,state0) <- solveWithConstraints constrs0 obj0
   NonEmpty.cons sol0 <$>
      MS.evalStateT
         (for objs $
            \(newConstrs,(dir,obj)) -> MS.StateT $ \(constrs,sol) ->
               solveWithConstraints
                  (newConstrs sol ++ constrs)
                  (dir, checkShape obj))
         state0


{- |
>>> case Shape.indexTupleFromShape tripletShape of (x1,x2,x3) -> mapSnd (mapSnd Array.toTuple) <$> LP.exact [] [[2.*x1, 1.*x2] <=. 10, [1.*x2, 5.*x3] <=. 20] (LP.Maximize, Array.fromTuple (4,-3,2) :: Array.Array TripletShape Double)
Right (Optimal,(28.0,(5.0,0.0,4.0)))

prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAll (TestLP.genProblem origin) $ \(bounds, constrs) -> QC.forAll (TestLP.genObjective origin) $ \(dir,obj) -> case (LP.simplex bounds constrs (dir,obj), LP.exact bounds constrs (dir,obj)) of (Right (LP.Optimal, (optSimplex,_)), Right (LP.Optimal, (optExact,_))) -> TestLP.approx "optimum" 0.001 optSimplex optExact; _ -> QC.property False
-}
exact ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) =>
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) -> Solution sh
exact = solve (flip FFI.glp_exact nullPtr)


{-# INLINE solve #-}
solve ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) =>
   (Foreign.Ptr FFI.Problem -> IO FFI.GlpkSimplexStatus) ->
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) -> Solution sh
solve solver bounds constrs (dir,obj) = unsafePerformIO $
   bracket FFI.glp_create_prob FFI.glp_delete_prob $ \lp -> do
   storeProblem bounds constrs (dir,obj) lp
   void $ solver lp
   peekSimplexSolution (Array.shape obj) lp

{-# INLINE solveSuccessiveInPlace #-}
solveSuccessiveInPlace ::
   (Traversable f, Eq sh, Shape.Indexed sh, Shape.Index sh ~ ix) =>
   (Foreign.Ptr FFI.Problem -> IO FFI.GlpkSimplexStatus) ->
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) ->
   f ((SolutionType, (Double, Array sh Double)) -> Constraints ix,
      (Direction, Objective sh)) ->
   Either NoSolutionType
      (NonEmpty.T f (SolutionType, (Double, Array sh Double)))
solveSuccessiveInPlace solver bounds constrs0 (dir0,obj0) objs =
      unsafePerformIO $
   bracket FFI.glp_create_prob FFI.glp_delete_prob $ \lp -> ME.runExceptT $ do
   let shape = Array.shape obj0
   sol0 <- ME.ExceptT $ do
      storeProblem bounds constrs0 (dir0,obj0) lp
      void $ solver lp
      peekSimplexSolution shape lp
   NonEmpty.cons sol0 <$>
      MS.evalStateT
         (for objs $
            \(makeNewConstrs,(dir,obj)) -> MS.StateT $ \sol ->
                  fmap (\sol1 -> (sol1, sol1)) $ ME.ExceptT $ do
               setDirection lp dir
               when (shape /= Array.shape obj) $
                  error "GLPK.solveSuccessiveInplace: objective shapes mismatch"
               setObjective lp obj
               let newConstrs = makeNewConstrs sol
               newRow <- FFI.glp_add_rows lp $ fromIntegral $ length newConstrs
               for_ (zip [newRow..] (map prepareBounds newConstrs)) $
                     \(row, (terms,(bnd,lo,up))) -> do
                  FFI.glp_set_row_bnds lp row bnd lo up
                  let numTerms = length terms
                  allocaArray numTerms $ \indicesPtr ->
                     allocaArray numTerms $ \coeffsPtr -> do
                     for_ (zip [1..] terms) $
                        \(k, Term c x) -> do
                           pokeElem indicesPtr k (columnIndex shape x)
                           pokeElem coeffsPtr k (realToFrac c)
                     FFI.glp_set_mat_row lp row
                        (fromIntegral numTerms) indicesPtr coeffsPtr
               void $ solver lp
               peekSimplexSolution shape lp)
         sol0



{- |
>>> case Shape.indexTupleFromShape tripletShape of (x1,x2,x3) -> mapSnd (mapPair (round3, Array.toTuple . Array.map round3)) <$> LP.interior [] [[2.*x1, 1.*x2] <=. 10, [1.*x2, 5.*x3] <=. 20] (LP.Maximize, Array.fromTuple (4,-3,2) :: Array.Array TripletShape Double)
Right (Optimal,(28.0,(5.0,0.0,4.0)))

prop> QC.forAllShrink TestLP.genOrigin TestLP.shrinkOrigin $ \origin -> QC.forAll (TestLP.genProblem origin) $ \(bounds, constrs) -> QC.forAll (TestLP.genObjective origin) $ \(dir,obj) -> case (LP.simplex bounds constrs (dir,obj), LP.interior bounds constrs (dir,obj)) of (Right (LP.Optimal, (optSimplex,_)), Right (LP.Optimal, (optExact,_))) -> TestLP.approx "optimum" 0.001 optSimplex optExact; _ -> QC.property False
-}
interior ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) =>
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) -> Solution sh
interior bounds constrs (dir,obj) = unsafePerformIO $
   bracket FFI.glp_create_prob FFI.glp_delete_prob $ \lp -> do
   storeProblem bounds constrs (dir,obj) lp
   void $ FFI.glp_interior lp nullPtr
   let examine =
         liftA2 (,)
            (realToFrac <$> FFI.glp_ipt_obj_val lp)
            (readGLPArray (Array.shape obj) $ \arr ix ->
               Mutable.write arr ix . realToFrac
                  =<< FFI.glp_ipt_col_prim lp (deferredColumnIndex ix))
   status <- FFI.glp_ipt_status lp
   either (return . Left) (\typ -> Right . (,) typ <$> examine) $
      analyzeStatus status


storeProblem ::
   (Shape.Indexed sh, Shape.Index sh ~ ix) =>
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) -> Foreign.Ptr FFI.Problem -> IO ()
storeProblem bounds constrs (dir,obj) lp = do
   void $ FFI.glp_term_out FFI.glpkOff
   let shape = Array.shape obj
   setDirection lp dir
   firstRow <- FFI.glp_add_rows lp $ fromIntegral $ length constrs
   for_ (zip [firstRow..] $
      map prepareBounds constrs) $ \(row,(_x,(bnd,lo,up))) ->
      FFI.glp_set_row_bnds lp row bnd lo up
   storeBounds lp shape bounds
   setObjective lp obj
   let numTerms = length $ concatMap (fst . prepareBounds) constrs
   allocaArray numTerms $ \ia ->
      allocaArray numTerms $ \ja ->
      allocaArray numTerms $ \ar -> do
      for_ (zip [1..] $ concat $
            zipWith (map . (,)) [firstRow..] $
            map (fst . prepareBounds) constrs) $
         \(k, (row, Term c x)) -> do
            pokeElem ia k row
            pokeElem ja k (columnIndex shape x)
            pokeElem ar k (realToFrac c)
      FFI.glp_load_matrix lp (fromIntegral numTerms) ia ja ar




class FormatIdentifier ix where
   formatIdentifier :: ix -> String

instance FormatIdentifier Char where
   formatIdentifier x = [x]

instance FormatIdentifier c => FormatIdentifier [c] where
   formatIdentifier = concatMap formatIdentifier

instance FormatIdentifier Int where
   formatIdentifier = printf "x%d"

instance FormatIdentifier Integer where
   formatIdentifier = printf "x%d"


formatBound :: (FormatIdentifier ix) => Inequality ix -> String
formatBound (Inequality ix bnd) =
   printf "var %s%s;" (formatIdentifier ix) $
   case bnd of
      LessEqual up -> printf ", <=%f" up
      GreaterEqual lo -> printf ", >=%f" lo
      Between lo up -> printf ", >=%f, <=%f" lo up
      Equal x -> printf ", =%f" x
      Free -> ""


formatSum :: (FormatIdentifier ix) => [Term ix] -> String
formatSum [] = "0"
formatSum xs =
   let formatTerm (Term c ix) = printf "%f*%s" c (formatIdentifier ix) in
   List.intercalate "+" $ map formatTerm xs

formatConstraint :: (FormatIdentifier ix) => Inequality [Term ix] -> String
formatConstraint (Inequality terms bnd) =
   let sumStr = formatSum terms in
   case bnd of
      LessEqual up -> printf "%s <= %f" sumStr up
      GreaterEqual lo -> printf "%f <= %s" lo sumStr
      Between lo up -> printf "%f <= %s <= %f" lo sumStr up
      Equal x -> printf "%s = %f" sumStr x
      Free -> sumStr

formatDirection :: Direction -> String
formatDirection Minimize = "minimize"
formatDirection Maximize = "maximize"

formatObjective ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, FormatIdentifier ix) =>
   Objective sh -> String
formatObjective =
   formatSum . map (\(ix,c) -> Term c ix) . Array.toAssociations

formatMathProg ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, FormatIdentifier ix) =>
   Bounds ix -> Constraints ix ->
   (Direction, Objective sh) -> [String]
formatMathProg bounds constrs (dir,obj) =
   map formatBound bounds ++
   "" :
   formatDirection dir :
   printf "value: %s;" (formatObjective obj) :
   "" :
   "subject to" :
   zipWith
      (\k constr -> printf "constr%d: %s;" k $ formatConstraint constr)
      [(0::Int)..] constrs ++
   "" :
   "end;" :
   []