module Dvda.MultipleShooting.MSCoctave ( msCoctave
, run
) where
import qualified Control.Monad.State as State
import Data.Hashable ( Hashable )
import qualified Data.HashSet as HS
import Data.List ( zipWith6 )
import Data.Maybe ( fromMaybe )
import Dvda.AD ( rad )
import Dvda.CGen ( showMex )
import Dvda.CSE ( cse )
import Dvda.Codegen.WriteFile ( writeSourceFile )
import Dvda.Expr ( Expr(..), sym, substitute )
import Dvda.FunGraph ( (:*)(..), toFunGraph, countNodes )
import Dvda.HashMap ( HashMap )
import qualified Dvda.HashMap as HM
import Dvda.MultipleShooting.CoctaveTemplates
import Dvda.MultipleShooting.MSMonad
import Dvda.MultipleShooting.Types
type Integrator a = [Expr Double]
-> [Expr Double]
-> [Expr Double]
-> [Expr Double]
-> ([Expr Double]
-> [Expr Double] -> [Expr Double])
-> Expr Double
-> [Expr Double]
setupBounds :: (Eq a, Hashable a, Show a)
=> [(Expr a, (a,a, BCTime))]
-> Int
-> (Expr a -> Int -> (a,a), Expr a -> (a,a))
setupBounds userBounds nSteps = (lookupAll, lookupParam)
where
lookupAll x k
| k >= nSteps = error "don't ask for bounds at timestep >= number of total timesteps"
| otherwise = case HM.lookup (x,k) specificTimestepBounds of
Just bnd -> bnd
Nothing -> case HM.lookup x everyTimestepBounds of
Just bnd -> bnd
Nothing -> error $ "need to set bounds for \"" ++ show x ++ "\" at timestep " ++ show k
lookupParam x = case HM.lookup x everyTimestepBounds of
Just bnd -> bnd
Nothing -> error $ "need to set bounds for \"" ++ show x ++ "\""
everyTimestepBounds = let
everyTS (e,(lb,ub,ALWAYS)) = [(e,(lb,ub))]
everyTS _ = []
f (e,lbub) hm =
if HM.member e hm
then error $ "you set bounds twice for \"" ++ show e ++ "\""
else HM.insert e lbub hm
in foldr f HM.empty $ concatMap everyTS userBounds
specificTimestepBounds = let
specificTS (e,(lb,ub,TIMESTEP k)) = [((e,k),(lb,ub))]
specificTS _ = []
f (e,lbub) hm =
if HM.member e hm
then error $ "you set bounds twice for \"" ++ show e ++ "\""
else HM.insert e lbub hm
in foldr f HM.empty $ concatMap specificTS userBounds
vectorizeDvs :: [[a]] -> [[a]] -> [a] -> [a]
vectorizeDvs allStates allActions params = concat allStates ++ concat allActions ++ params
msCoctave ::
State (Step Double) b
-> Integrator Double
-> Int
-> String
-> FilePath
-> IO ()
msCoctave userStep' odeError n funDir name = do
let step = State.execState userStep' $
Step { stepStates = Nothing
, stepActions = Nothing
, stepDxdt = Nothing
, stepDt = Nothing
, stepLagrangeTerm = Nothing
, stepMayerTerm = Nothing
, stepBounds = []
, stepConstraints = []
, stepParams = HS.empty
, stepConstants = HS.empty
, stepOutputs = HM.empty
, stepPeriodic = HS.empty
}
getWithErr :: String -> (Step Double -> Maybe c) -> c
getWithErr fieldName f = case f step of
Nothing -> error $ "need to set " ++ fieldName
Just ret -> ret
actions = getWithErr "actions" stepActions
dt = getWithErr "dt" stepDt
(states,outputs,dxdt,lagrangeState) = let
states' = getWithErr "states" stepStates
dxdt' = getWithErr "dxdt" stepDxdt
outputs' = stepOutputs step
in
case stepLagrangeTerm step of
Nothing -> (states',outputs',dxdt',Nothing)
Just (lagrangeTerm,(lb,ub)) ->
( states' ++ [lagrangeState']
, HM.union outputs' $ HM.fromList
[(lagrangeStateName, lagrangeState'), (lagrangeTermName, lagrangeTerm)]
, dxdt'++[lagrangeTerm]
, Just (lagrangeState',(lb,ub)) )
where
lagrangeState' = sym lagrangeStateName
params = HS.toList (stepParams step)
constants = HS.toList (stepConstants step)
allStates = [[sym $ show x ++ "__" ++ show k | x <- states] | k <- [0..(n1)]]
allActions = [[sym $ show u ++ "__" ++ show k | u <- actions] | k <- [0..(n1)]]
dvs = vectorizeDvs allStates allActions params
outputMap :: HashMap String [Expr Double]
outputMap = HM.map f outputs
where
f output = zipWith (subStatesActions output) allStates allActions
subStatesActions f x u = substitute f (zip states x ++ zip actions u)
subAllTimesteps :: Expr Double -> [Expr Double]
subAllTimesteps something = zipWith (subStatesActions something) allStates allActions
(lbs,ubs) = unzip $ vectorizeDvs stateBounds actionBounds paramBounds
where
(getAllBounds,getParamBounds) = setupBounds bounds n
stateBounds = [[getAllBounds x k | x <- states ] | k <- [0..(n1)]]
actionBounds = [[getAllBounds u k | u <- actions] | k <- [0..(n1)]]
paramBounds = [getParamBounds p | p <- params]
bounds = stepBounds step ++ lagrangeBound
where
lagrangeBound = case lagrangeState of
Nothing -> []
Just (ls,(lb,ub)) -> [(ls,(0,0,TIMESTEP 0)),(ls, (lb, ub, ALWAYS))]
cost = subStatesActions finalCost (last allStates) (last allActions)
where
finalCost = case (stepMayerTerm step, lagrangeState) of
(Just mc, Nothing) -> mc
(Nothing, Just (ls,_)) -> ls
(Just mc, Just (ls,_)) -> mc + ls
(Nothing,Nothing) -> error "need to set cost function"
(ceq, cineq) = foldl f ([],[]) allConstraints
where
f (eqs,ineqs) (Constraint x EQ y) = (eqs ++ [x y], ineqs)
f (eqs,ineqs) (Constraint x LT y) = (eqs, ineqs ++ [x y])
f (eqs,ineqs) (Constraint x GT y) = (eqs, ineqs ++ [y x])
execDxdt x u = map (flip substitute (zip states x ++ zip actions u)) dxdt
dodeConstraints = map (Constraint 0 EQ) $ concat $
zipWith6 odeError (init allStates) (init allActions) (tail allStates) (tail allActions)
(repeat execDxdt) (repeat dt)
allConstraints = dodeConstraints ++ (concatMap (g . (fmap subAllTimesteps)) (stepConstraints step)) ++ periodicConstraints
where
g (Constraint [] _ _) = []
g (Constraint _ _ []) = []
g (Constraint (x:xs) ord (y:ys)) = Constraint x ord y : g (Constraint xs ord ys)
periodicConstraints = map lookup' $ HS.toList (stepPeriodic step)
where
lookup' x = fromMaybe (error $ "couldn't find periodic thing \"" ++ show x ++ "\" in hashmap")
$ HM.lookup x xuMap
xuMap = HM.fromList $ zip states (zipWith setEqual (head allStates) (last allStates )) ++
zip actions (zipWith setEqual (head allActions) (last allActions))
where
setEqual x y = Constraint x EQ y
(costSource,costFg0,costFg) <- do
let costGrad = rad cost dvs
fg0 <- toFunGraph (dvs :* constants) (cost :* costGrad)
let fg = cse fg0
return (showMex (name ++ "_cost") fg, fg0, fg)
(constraintsSource,constraintsFg0,constraintsFg) <- do
let cineqJacob = map (flip rad dvs) cineq
ceqJacob = map (flip rad dvs) ceq
fg0 <- toFunGraph (dvs :* constants) (cineq :* ceq :* cineqJacob :* ceqJacob)
let fg = cse fg0
return (showMex (name ++ "_constraints") fg, fg0, fg)
(timeSource,timeFg) <- do
fg <- toFunGraph (dvs :* constants) (take n $ scanl (+) 0 (repeat dt))
return (showMex (name ++ "_time") fg, fg)
(outputSource,outputFg) <- do
fg <- toFunGraph (dvs :* constants) (HM.elems outputMap)
return (showMex (name ++ "_outputs") fg, fg)
(simSource,simFg) <- do
fg <- toFunGraph (states :* actions :* params :* constants) dxdt
return (showMex (name ++ "_sim") fg, fg)
let setupSource = writeSetupSource name dvs lbs ubs
mexAllSource = writeMexAll name
unstructConstsSource = writeUnstructConsts name constants
structSource = writeToStruct name dvs params constants outputMap
unstructSource = writeUnstruct name dvs params states allStates actions allActions
plotSource = writePlot name outputMap
_ <- writeSourceFile mexAllSource funDir $ name ++ "_mex_all.m"
_ <- writeSourceFile setupSource funDir $ name ++ "_setup.m"
_ <- writeSourceFile structSource funDir $ name ++ "_struct.m"
_ <- writeSourceFile unstructConstsSource funDir $ name ++ "_unstructConstants.m"
_ <- writeSourceFile unstructSource funDir $ name ++ "_unstruct.m"
_ <- writeSourceFile plotSource funDir $ name ++ "_plot.m"
_ <- writeSourceFile timeSource funDir $ name ++ "_time.c"
_ <- writeSourceFile outputSource funDir $ name ++ "_outputs.c"
_ <- writeSourceFile simSource funDir $ name ++ "_sim.c"
_ <- writeSourceFile costSource funDir $ name ++ "_cost.c"
_ <- writeSourceFile constraintsSource funDir $ name ++ "_constraints.c"
putStrLn $ "nodes in time: " ++ show (countNodes timeFg)
putStrLn $ "nodes in output: " ++ show (countNodes outputFg)
putStrLn $ "nodes in sim: " ++ show (countNodes simFg)
putStrLn $ "nodes in cost: " ++ show (countNodes costFg) ++
" (" ++ show (countNodes costFg0) ++ " before CSE)"
putStrLn $ "nodes in constraints: " ++ show (countNodes constraintsFg) ++
" (" ++ show (countNodes constraintsFg0) ++ " before CSE)"
spring :: State (Step Double) ()
spring = do
[x, v] <- setStates ["x","v"]
[u] <- setActions ["u"]
[k, b] <- addConstants ["k", "b"]
let cost = 2*x*x + 3*v*v + 10*u*u
setDxdt [v, k*x b*v + u]
setDt (tEnd/((fromIntegral n')1))
setLagrangeTerm cost (1,2000)
setBound x (5,5) (TIMESTEP 0)
setBound v (0,0) (TIMESTEP 0)
setBound x (5,5) ALWAYS
setBound v (10,10) ALWAYS
setBound u (200, 200) ALWAYS
setBound v (0,0) (TIMESTEP (n'1))
setPeriodic x
setPeriodic u
tEnd :: Expr Double
tEnd = 1.5
n' :: Int
n' = 18
run :: IO ()
run = msCoctave spring simpsonsRuleError' n' "../Documents/MATLAB/" "spring"