module Dvda.MultipleShooting.MSMonad ( State
, setStates
, setActions
, addParam
, addParams
, addConstant
, addConstants
, setDxdt
, setLagrangeTerm
, setMayerTerm
, setDt
, addOutput
, setPeriodic
, addConstraint
, setBound
, lagrangeStateName
, lagrangeTermName
) where
import Data.Hashable ( Hashable )
import qualified Data.HashSet as HS
import Data.List ( nub, sort )
import Data.Maybe ( isJust, fromMaybe )
import Data.Monoid ( mappend )
import Control.Monad ( when, zipWithM_ )
import Control.Monad.State ( State )
import qualified Control.Monad.State as State
import qualified Dvda.HashMap as HM
import Dvda.Expr ( Expr(..), sym )
import Dvda.MultipleShooting.Types
lagrangeStateName,lagrangeTermName :: String
lagrangeStateName = "lagrangeState"
lagrangeTermName = "lagrangeTerm"
failDuplicates :: [String] -> [String]
failDuplicates names
| length names == length (nub names) = names
| otherwise = error $ "ERROR: saw duplicate names in: " ++ show (sort names)
checkOctaveName :: String -> String
checkOctaveName name
| any (`elem` badChars) name =
error $ "ERROR: saw illegal octave variable character in string: \"" ++ name ++
"\", illegal characters: " ++ badChars
| name == lagrangeStateName = error "don't call your variable \"" ++ lagrangeStateName ++ "\", it's reserved"
| name == lagrangeTermName = error "don't call your variable \"" ++ lagrangeTermName ++ "\", it's reserved"
| otherwise = name
where
badChars = "\"'~!@#$%^&*()+`-=[]{}\\|;:,.<>/?"
setStates :: [String] -> State (Step a) [Expr a]
setStates names' = do
step <- State.get
case stepStates step of Just _ -> error "states already set, don't call setStates twice"
Nothing -> do
let names = failDuplicates (map checkOctaveName names')
syms = map sym (failDuplicates names)
State.put $ step {stepStates = Just syms}
zipWithM_ addOutput syms names
return syms
setActions :: [String] -> State (Step a) [Expr a]
setActions names' = do
step <- State.get
case stepActions step of Just _ -> error "actions already set, don't call setActions twice"
Nothing -> do
let names = failDuplicates (map checkOctaveName names')
syms = map sym (failDuplicates names)
State.put $ step {stepActions = Just syms}
zipWithM_ addOutput syms names
return syms
addParam :: (Eq a, Hashable a) => String -> State (Step a) (Expr a)
addParam name = do
[blah] <- addParams [name]
return blah
addConstant :: (Eq a, Hashable a) => String -> State (Step a) (Expr a)
addConstant name = do
[blah] <- addConstants [name]
return blah
addParams :: (Eq a, Hashable a) => [String] -> State (Step a) [Expr a]
addParams names = do
step <- State.get
let syms = map (sym . checkOctaveName) names
params0 = stepParams step
State.put $ step {stepParams = HS.union params0 (HS.fromList syms)}
return syms
addConstants :: (Eq a, Hashable a) => [String] -> State (Step a) [Expr a]
addConstants names = do
step <- State.get
let syms = map (sym . checkOctaveName) names
constants0 = stepConstants step
State.put $ step {stepConstants = HS.union constants0 (HS.fromList syms)}
return syms
addOutput :: Expr a -> String -> State (Step a) ()
addOutput var name = do
step <- State.get
let hm = stepOutputs step
err = error $ "ERROR: already have an output with name: \"" ++ name ++ "\""
State.put $ step {stepOutputs = HM.insertWith err (checkOctaveName name) var hm}
setDt :: Expr a -> State (Step a) ()
setDt expr = do
step <- State.get
when (isJust (stepDt step)) $ error "dt already set, don't call setDt twice"
State.put $ step {stepDt = Just expr}
setPeriodic :: (Eq a, Hashable a, Show a) => Expr a -> State (Step a) ()
setPeriodic var = do
step <- State.get
let newPeriodic
| var `HS.member` (stepPeriodic step) = error $ "you called setPeriodic twice on \"" ++ show var ++ "\""
| not (var `elem` (fromMaybe [] (mappend (stepStates step) (stepActions step)))) =
error $ "you can only make states or actions periodic, you can't make \"" ++ show var ++ "\" periodic"
| otherwise = HS.insert var (stepPeriodic step)
State.put $ step {stepPeriodic = newPeriodic}
setDxdt :: [Expr a] -> State (Step a) ()
setDxdt vars = do
step <- State.get
when (isJust (stepDxdt step)) $ error "dxdt already set, don't call setDxdt twice"
State.put $ step {stepDxdt = Just vars}
setLagrangeTerm :: Expr a -> (a,a) -> State (Step a) ()
setLagrangeTerm var (lb,ub) = do
step <- State.get
when (isJust (stepLagrangeTerm step)) $ error "Lagrange term already set, don't call setLagrangeTerm twice"
State.put $ step {stepLagrangeTerm = Just (var,(lb,ub))}
setMayerTerm :: Expr a -> State (Step a) ()
setMayerTerm var = do
step <- State.get
when (isJust (stepMayerTerm step)) $ error "Mayer term already set, don't call setMayerTerm twice"
State.put $ step {stepMayerTerm = Just var}
setBound :: (Show a, Eq a, Hashable a)
=> Expr a -> (a, a) -> BCTime -> State (Step a) ()
setBound var@(ESym _) (lb, ub) bctime = do
step <- State.get
State.put $ step {stepBounds = (var, (lb,ub,bctime)):(stepBounds step)}
setBound _ _ _ = error "WARNING - setBound called on non-design variable, try addConstraint instead"
addConstraint :: Expr a -> Ordering -> Expr a -> State (Step a) ()
addConstraint x ordering y =
State.state (\step -> ((), step {stepConstraints = (stepConstraints step) ++ [Constraint x ordering y]}))