module Linear.Grammar where
import Data.List (find)
import Data.String
data LinAst =
EVar String
| ELit Double
| ECoeff LinAst Double
| EAdd LinAst LinAst
deriving (Show, Eq)
instance IsString LinAst where
fromString = EVar
(.+.) :: LinAst -> LinAst -> LinAst
(.+.) = EAdd
infixr 8 .+.
class Coefficient x y where
(.*.) :: x -> y -> LinAst
infixr 9 .*.
instance Coefficient LinAst Double where
(.*.) = ECoeff
instance Coefficient Double LinAst where
(.*.) = flip ECoeff
multLin :: LinAst -> LinAst
multLin (EVar n) = EVar n
multLin (ELit x) = ELit x
multLin (ECoeff e x) = case multLin e of
(ELit y) -> ELit (y * x)
(EVar n) -> ECoeff (EVar n) x
(ECoeff e' y) -> ECoeff e' (y * x)
(EAdd e1 e2) -> EAdd (multLin $ ECoeff e1 x) (multLin $ ECoeff e2 x)
multLin (EAdd e1 e2) = EAdd (multLin e1) (multLin e2)
data LinVar = LinVar
{ varName :: String
, varCoeff :: Double
} deriving (Show, Eq)
instance Ord LinVar where
compare (LinVar x _) (LinVar y _) = compare x y
hasName :: LinVar -> LinVar -> Bool
hasName (LinVar n _) (LinVar m _) = n == m
data LinExpr = LinExpr
{ exprCoeffs :: [LinVar]
, exprConst :: Double
} deriving (Show, Eq)
mergeLinExpr :: LinExpr -> LinExpr -> LinExpr
mergeLinExpr (LinExpr vs1 x) (LinExpr vs2 y) = LinExpr (vs1 ++ vs2) (x + y)
addLin :: LinAst -> LinExpr
addLin = go (LinExpr [] 0)
where
go :: LinExpr -> LinAst -> LinExpr
go (LinExpr vs c) (EVar n) = LinExpr (LinVar n 1:vs) c
go (LinExpr vs c) (ELit x) = LinExpr vs (c + x)
go (LinExpr vs c) (ECoeff (EVar n) x) = LinExpr (LinVar n x:vs) c
go le (EAdd e1 e2) = mergeLinExpr (go le e1) (go le e2)
removeDupLin :: LinExpr -> LinExpr
removeDupLin (LinExpr vs c) = LinExpr (foldr go [] vs) c
where
go :: LinVar -> [LinVar] -> [LinVar]
go x [] = [x]
go (LinVar n x) xs = case find (hasName $ LinVar n x) xs of
Just (LinVar m y) -> LinVar m (y + x):filter (not . hasName (LinVar n x)) xs
Nothing -> xs
makeLinExpr :: LinAst -> LinExpr
makeLinExpr = removeDupLin . addLin . multLin
data Ineq =
Equ LinExpr LinExpr
| Lte LinExpr LinExpr
deriving (Show, Eq)
(.==.) :: LinAst -> LinAst -> Ineq
x .==. y = Equ (makeLinExpr x) (makeLinExpr y)
infixl 7 .==.
(.<=.) :: LinAst -> LinAst -> Ineq
x .<=. y = Lte (makeLinExpr x) (makeLinExpr y)
infixl 7 .<=.
(.=>.) :: LinAst -> LinAst -> Ineq
(.=>.) = flip (.<=.)
infixl 7 .=>.
data IneqStdForm =
EquStd [LinVar] Double
| LteStd [LinVar] Double
deriving (Show, Eq)
standardize :: Ineq -> Ineq
standardize (Equ (LinExpr xs xc) (LinExpr ys yc))
| null xs = Equ (LinExpr [] (xc yc)) (LinExpr ys 0)
| null ys = Equ (LinExpr xs 0) (LinExpr [] (yc xc))
| otherwise = let ys' = map (\y -> y {varCoeff = varCoeff y * (1)}) ys in
Equ (removeDupLin $ LinExpr (ys' ++ xs) 0) (LinExpr [] (yc xc))
standardize (Lte (LinExpr xs xc) (LinExpr ys yc))
| null xs = Lte (LinExpr [] (xc yc)) (LinExpr ys 0)
| null ys = Lte (LinExpr xs 0) (LinExpr [] (yc xc))
| otherwise = let ys' = map (\y -> y {varCoeff = varCoeff y * (1)}) ys in
Lte (removeDupLin $ LinExpr (ys' ++ xs) 0) (LinExpr [] (yc xc))
standardForm :: Ineq -> IneqStdForm
standardForm = go . standardize
where
go (Equ (LinExpr xs xc) (LinExpr ys yc)) | null xs && yc == 0 = EquStd ys xc
| null ys && xc == 0 = EquStd xs yc
go (Lte (LinExpr xs xc) (LinExpr ys yc)) | null xs && yc == 0 = LteStd ys xc
| null ys && xc == 0 = LteStd xs yc
go _ = error "Non-standard Ineq"