module Data.QLearn ( QLearner , State(State, Stop) , Action , Reward , Environment , initQLearner , initEnvironment , moveLearner , moveLearnerAndPrint , testGrid , possibleGrid , executeGrid , moveLearnerPrintRepeat , gridFromList ) where import qualified Data.Vector as V import Numeric import Data.List import System.Random -- | Data type specifying the parameters and Q table for a particular Q learner. qAlpha is the learning -- rate associated with each iterative update. qGamma is the discount rate on rewards. qGrid is a matrix -- (dimension number of states by number of actions) that specifies the Q(s,a) function learned by this -- Q learner. qEpsilon is a function that maps from the number of iterations left to epsilon for the epsilon -- greedy strategy (can return 1 uniformly if an epsilon greedy strategy is not wanted). data QLearner = QLearner {qAlpha::Double, qGamma::Double, qEpsilon::(Int -> Double), qGrid::V.Vector (V.Vector Double)} -- |Wrapper around Int, specifying a state index. data State = State {getStateValue::Int} | Stop deriving (Show) -- |Wrapper around Int, specifying an action index. data Action = Action {getActionValue::Int} -- |Wrapper around Double, specifying a reward value. data Reward = Reward {getRewardValue::Double} -- |Data type specifying the environment in which the Q learner operates. envExecute is the function -- used to execute actions at a particular state, returning the new state and the award associated with -- the state, action pair. envPossible returns the actions possible at any given -- state. data Environment = Environment {envExecute::(State -> Action -> (State, Reward)), envPossible::(State -> [Action])} -- |Given alpha, gamma, the number of states and the maximum number of actions possible at any state, -- returns a QLearner initialized with a zero Q-table. initQLearner :: Double -> Double -> (Int -> Double) -> Int -> Int -> QLearner initQLearner alpha gamma epsilon numStates numActions = QLearner alpha gamma epsilon $ createZeroQ numStates numActions -- |Given the envExecute and envPossible functions, constructs an Environment. This is purely for -- for uniformity of the API. You are welcome to use the data type constructor "Environment" since -- they are equivalent. initEnvironment :: (State -> Action -> (State, Reward)) -> (State -> [Action]) -> Environment initEnvironment execute possible = Environment execute possible unwrapExecute :: (State -> Action -> (State, Reward)) -> Int -> Int -> (Int, Double) unwrapExecute execute state action = let execRet = execute (State state) (Action action) in (getStateValue $ fst execRet, getRewardValue $ snd execRet) unwrapPossible :: (State -> [Action]) -> Int -> [Int] unwrapPossible possible state = let possibRet = possible (State state) in map (\x -> getActionValue x) possibRet -- |Given an Environment, a Q learner and the state the Q Learner is on, returns the Q Learner with an updated Q table -- and the new state of the Q learner within the Environment. Also takes the number of time steps left for the epsilon -- computation. moveLearner :: Int -> StdGen -> Environment -> QLearner -> State -> ((QLearner, State), StdGen) moveLearner times g env qlearner Stop = ((qlearner, Stop), g) moveLearner times g (Environment execute' possible') (QLearner alpha gamma epsilon qtable) (State s) = let epRet = checkEpsilon g epsilon times execute = unwrapExecute execute' possible = unwrapPossible possible' doRandom = fst $ epRet g' = snd $ epRet in if doRandom then let randRet = qRandomIter g execute possible s qtable iter = fst randRet g'' = snd randRet qtable' = fst iter state' = snd iter in ((QLearner alpha gamma epsilon qtable', State state'), g'') else let iter = qLearnIter execute possible s qtable qtable' = fst iter state' = snd iter in ((QLearner alpha gamma epsilon qtable', State state'), g') -- |Same thing as "moveLearner" but prints out the Q table and the current state after moving the QLearner. moveLearnerAndPrint :: Int -> StdGen -> Environment -> QLearner -> State -> IO ((QLearner, State), StdGen) moveLearnerAndPrint times g env qlearner Stop = do putStrLn "Stop state." return ((qlearner, Stop), g) moveLearnerAndPrint times g env qlearner state = do let iter = moveLearner times g env qlearner state g' = snd iter qlearner' = fst $ fst iter state' = snd $ fst iter putStrLn $ (++) "Reached: " $ show state' putStrLn $ prettyPrintQ $ qGrid qlearner' return ((qlearner', state'), g') -- |Repeatedly moves (i.e. moves the given number of times) the qLearner and prints the Q table -- at every move until a stop state is encountered. moveLearnerPrintRepeat :: Int -> StdGen -> Environment -> QLearner -> State -> IO () moveLearnerPrintRepeat _ _ _ _ Stop = putStrLn "Stopped repeating due to stop state." moveLearnerPrintRepeat 0 g env qlearner state = putStrLn "Done." moveLearnerPrintRepeat numTimes g env qlearner state = do moveRet <- moveLearnerAndPrint numTimes g env qlearner state let g' = snd moveRet qlearner' = fst $ fst moveRet state' = snd $ fst moveRet moveLearnerPrintRepeat (numTimes - 1) g' env qlearner' state' -- |Returns the maximum number of characters needed to "show" an element from the given vector. maxSpaceRow :: V.Vector Double -> Int maxSpaceRow vec = if V.null vec then 0 else max (length $ showGFloat (Just 2) (V.head vec) "") (maxSpaceRow $ V.tail vec) -- |Returns the maximum number of characters needed to "show" an element in the 2D matrix given. maxSpaceMat :: V.Vector (V.Vector Double) -> Int maxSpaceMat mat = if V.null mat then 0 else max (maxSpaceRow $ V.head mat) (maxSpaceMat $ V.tail mat) -- |Internal function that pads strings with spaces in order to make sure that the string is of a certain length. padSpaces :: Int -> String -> String padSpaces space str = str ++ replicate (space - (length str)) ' ' -- |Internal function that does a pretty print for a row vector given the maximum space that the -- row can take up in terms of the characters. prettyPrintRow :: Int -> V.Vector Double -> String prettyPrintRow space row = if V.null row then "" else (padSpaces space $ showGFloat (Just 2) (V.head row) "") ++ " " ++ (prettyPrintRow space $ V.tail row) -- |Internal function that does a pretty print for the Q-table given the maximum space that the -- a single element can take up in terms of characters. prettyPrintQ' :: Int -> V.Vector (V.Vector Double) -> String prettyPrintQ' space mat = if V.null mat then "" else (prettyPrintRow space $ V.head mat) ++ "\n" ++ (prettyPrintQ' space $ V.tail mat) -- |Does a pretty print for the Q-table. prettyPrintQ :: V.Vector (V.Vector Double) -> String prettyPrintQ mat = let space = maxSpaceMat mat in prettyPrintQ' space mat -- |Create a table for Q(s,a) values, each element representing the expected value of a give state and action -- pair. Takes the number of possible states and the number of actions as arguments. createZeroQ :: Int -> Int -> V.Vector (V.Vector Double) createZeroQ s a = V.generate s (\n -> V.replicate a 0.0) updateQRow :: Int -> Double -> V.Vector Double -> V.Vector Double updateQRow index value q_row = q_row V.// [(index, value)] indexQ :: Int -> Int -> V.Vector (V.Vector Double) -> Double indexQ s a q = q V.! s V.! a multIndex row (index:indices) = (row V.! index) : [] unwrapMaybe (Just a) = a unwrapMaybe Nothing = 0 -- |Figures out the highest Q(s,a) action given a particular state and returns that action index. maxAction :: (Int -> [Int]) -> Int -> V.Vector (V.Vector Double) -> Int maxAction possible s q = let possibleActions = possible s possibleValues = map (\action -> q V.! s V.! action) possibleActions in possibleActions !! (unwrapMaybe $ elemIndex (maximum possibleValues) possibleValues) randomAction :: StdGen -> (Int -> [Int]) -> Int -> V.Vector (V.Vector Double) -> (Int, StdGen) randomAction g possible s q = let possibleActions = possible s randomRet = randomR (0, length possibleActions - 1) g in (possibleActions !! (fst randomRet), snd randomRet) -- |Returns the largest Q(s,a) value given a particular state. maxActionValue :: Int -> V.Vector (V.Vector Double) -> Double maxActionValue s q = V.maximum (q V.! s) -- |Updates the Q(s, a) value based on the previous value of Q(s, a) for a given value of s (the state at which an action was executed), -- a (the action executed at that state), r (the reward attained given the state action pair), s' (the new state) and gamma (the discount -- factor for the rewards). updatedQ :: Int -> Int -> Double -> Int -> Double -> Double -> V.Vector (V.Vector Double) -> V.Vector (V.Vector Double) updatedQ s a r s' gamma alpha q = q V.// [(s, updateQRow a updatedValue $ q V.! s)] where updatedValue = (indexQ s a q) + alpha * (r + gamma * (maxActionValue s' q) - (indexQ s a q)) createRewardTable :: Int -> Int -> V.Vector (V.Vector Double) createRewardTable s a = V.generate s (\n -> V.replicate a 0.0) -- |Create an s x s grid consisting of rewards. Used for grid searches. createGrid :: Int -> V.Vector (V.Vector Double) createGrid s = createRewardTable s s -- |Take a Q table, current state and return the new Q table along with the new state index. Takes a function -- "execute" that takes a state, action pair and returns the reward and new state associated that state and action pair. -- The argument "possible" is a function that gives us a list of actions that are possible at a particular state. For example, -- we can't go off the grid when we're at the edge of a grid so such an action would not be part of the possible states. -- TODO make params tunable qLearnIter :: (Int -> Int -> (Int, Double)) -> (Int -> [Int]) -> Int -> V.Vector (V.Vector Double) -> (V.Vector (V.Vector Double), Int) qLearnIter execute possible state q = let action = maxAction possible state q retExec = execute state action state' = fst retExec reward = snd retExec in (updatedQ state action reward state' 0.8 0.4 q, state') qRandomIter :: StdGen -> (Int -> Int -> (Int, Double)) -> (Int -> [Int]) -> Int -> V.Vector (V.Vector Double) -> ((V.Vector (V.Vector Double), Int), StdGen) qRandomIter g execute possible state q = let randomRet = randomAction g possible state q action = fst randomRet g' = snd randomRet retExec = execute state action reward = snd retExec state' = fst retExec in ((updatedQ state action reward state' 0.8 0.4 q, state'), g') -- |Takes an integer the width and height of a 2D matrix and a linear index and converts it to a 2D index. linearTo2D :: Int -> Int -> Int -> (Int, Int) linearTo2D rows cols lin_index = (lin_index `div` cols, (lin_index `mod` cols)) -- |Takes a 2D coordinate and turns it into a linear coordinate. twoDToLinear :: Int -> Int -> (Int, Int) -> Int twoDToLinear rows cols (r, c) = (r * cols) + c -- |Takes the number of rows, number of cols (in a grid), the currents state (specified as a linear index) -- and an action to determine the next state' (also a linear index). The action can be one of the following: -- 0: move up -- 1: move down -- 2: move left -- 3: move right. -- Note that this does not perform any bounds checking. In addition, if the action is invalid, a -1 state is returned. applyGridAction :: Int -> Int -> Int -> Int -> Int applyGridAction rows cols state 0 = let state2DIndex = linearTo2D rows cols state state2DIndex' = (fst state2DIndex - 1, (snd state2DIndex)) in twoDToLinear rows cols state2DIndex' applyGridAction rows cols state 1 = let state2DIndex = linearTo2D rows cols state state2DIndex' = (fst state2DIndex + 1, snd state2DIndex) in twoDToLinear rows cols state2DIndex' applyGridAction rows cols state 2 = let state2DIndex = linearTo2D rows cols state state2DIndex' = (fst state2DIndex, snd state2DIndex - 1) in twoDToLinear rows cols state2DIndex' applyGridAction rows cols state 3 = let state2DIndex = linearTo2D rows cols state state2DIndex' = (fst state2DIndex, snd state2DIndex + 1) in twoDToLinear rows cols state2DIndex' applyGridAction rows cols state _ = -1 -- |Takes a grid descirbing reward values (often from environments that look like grids), a state, an action -- and returns the new state and new reward. executeGrid :: V.Vector (V.Vector Double) -> State -> Action -> (State, Reward) executeGrid grid (State state) (Action action) = let exRet = executeOnGrid grid state action in (State $ fst exRet, Reward $ snd exRet) -- |Takes a grid of reward values (i.e. each point in this grid is a state and each state has a reward associated with it) -- and functions as an "execute" for qLearnIter. executeOnGrid :: V.Vector (V.Vector Double) -> Int -> Int -> (Int, Double) executeOnGrid grid state action = let rows = V.length $ grid cols = V.length $ (grid V.! 0) coord = linearTo2D rows cols state reward = grid V.! (fst coord) V.! (snd coord) state' = applyGridAction rows cols state action in (state', reward) -- |Create a V.Vector (V.Vector Double) from a [[Double]]. Used to create grid-based environments for the agent. gridFromList :: [[Double]] -> V.Vector (V.Vector Double) gridFromList (list:[]) = V.fromList [V.fromList list] gridFromList (list:lists) = V.cons (V.fromList list) (gridFromList lists) -- |A grid consisting of some number used primarily for examples. Here's what it looks like: -- [[1.0,2.0,3.0,4.0], -- [5.0,6.0,7.0,8.0], -- [12.0,11.0,10.0,9.0], -- [13.0,14.0,15.0,16.0]] testGrid :: V.Vector (V.Vector Double) testGrid = gridFromList [[1.0,2.0,3.0,4.0], [5.0,6.0,7.0,8.0], [12.0,11.0,10.0,9.0], [13.0,14.0,15.0,16.0]] gridPossibleX i j rows cols | j <= 0 = [3] | j >= rows-1 = [2] | otherwise = [2,3] gridPossibleY i j rows cols | i <= 0 = [1] | i >= cols-1 = [0] | otherwise = [0, 1] -- |A "envPossible" function for use in the Environment data type, specifically for environments -- that look like grids. possibleGrid :: V.Vector (V.Vector Double) -> State -> [Action] possibleGrid grid (State state) = map (\x -> Action x) $ gridPossible grid state gridPossible :: V.Vector (V.Vector Double) -> Int -> [Int] gridPossible grid state = let rows = V.length grid cols = V.length $ (grid V.! 0) i = fst $ linearTo2D rows cols state j = snd $ linearTo2D rows cols state in (gridPossibleX i j rows cols) ++ (gridPossibleY i j rows cols) qPrint grid times s q = do putStrLn $ (++) "Original state: " $ show $ s let iter = qLearnIter (executeOnGrid grid) (gridPossible grid) s q let qgrid = fst $ iter let state = snd $ iter putStrLn $ prettyPrintQ $ qgrid putStrLn $ (++) "State: " $ show $ state qPrint grid (times - 1) state qgrid checkEpsilon :: StdGen -> (Int -> Double) -> Int -> (Bool, StdGen) checkEpsilon g epsilon times = let randRet = randomR (0, 1) g randVal = fst randRet g' = snd randRet in if randVal < (epsilon times) then (True, g') else (False, g') pick (x, y) v = if v then x else y qEpsilonPrint :: StdGen -> (Int -> Double) -> V.Vector (V.Vector Double) -> Int -> Int -> V.Vector (V.Vector Double) -> IO () qEpsilonPrint g epsilon grid 0 s q = putStrLn "Done!" qEpsilonPrint g epsilon grid times s q = do let execute = executeOnGrid grid possible = gridPossible grid epRet = checkEpsilon g epsilon times doRandom = fst $ epRet g' = snd $ epRet in if doRandom then do putStrLn "Doing a random action!" let randomRet = qRandomIter g' execute possible s q let iter = fst randomRet let g'' = snd randomRet let qgrid = fst $ iter let state = snd $ iter putStrLn $ prettyPrintQ $ qgrid putStrLn $ (++) "State: " $ show $ state qEpsilonPrint g'' epsilon grid (times - 1) state qgrid else do putStrLn "Doing a normal action" putStrLn $ (++) "Original state: " $ show $ s let iter = qLearnIter (executeOnGrid grid) (gridPossible grid) s q let qgrid = fst $ iter let state = snd $ iter putStrLn $ prettyPrintQ $ qgrid putStrLn $ (++) "State: " $ show $ state qEpsilonPrint g' epsilon grid (times - 1) state qgrid epsilon :: Int -> Int -> Double -- epsilon totalTimes timesLeft = 1.0/(fromIntegral $ (totalTimes - timesLeft)) epsilon totalTimes timesLeft = 1