{-# LANGUAGE DeriveFunctor #-}
module RL.TDl where

import qualified Data.List as List
import qualified Data.HashMap.Strict as HashMap
import qualified Data.HashSet as HashSet

import RL.Imports
import RL.Types
import RL.Utils (eps_greedy_action)

data TDl_Opts = TDl_Opts {
    o_alpha :: TD_Number
  , o_gamma :: TD_Number
  , o_eps :: TD_Number
  , o_lambda :: TD_Number
  } deriving (Show)

type TD_Number = Double
type Q s a = M s a TD_Number
type Z s a = M s a TD_Number
type V s a = HashMap s (a, TD_Number)

emptyQ :: TD_Number -> Q s a
emptyQ = initM

toV :: (Bounded a, Enum a, Eq a, Hashable a, Eq s, Hashable s) => Q s a -> HashMap s TD_Number
toV = foldMap_s (\(s,l) -> HashMap.singleton s (snd $ layer_s_max l))

data TDl_State s a = TDl_State {
    _tdl_q :: Q s a
  , _tdl_z :: Z s a
  }

$(makeLenses ''TDl_State)

initialState :: Q s a -> TDl_State s a
initialState q0 = TDl_State q0 (initM 0)

class (Eq s, Hashable s, Show s, Eq a, Hashable a, Enum a, Bounded a, Show a) =>
    TDl_Problem pr m s a | pr -> m, pr -> s , pr -> a where
  td_is_terminal :: pr -> s -> Bool
  td_greedy :: pr -> Bool -> a -> a
  td_transition :: pr -> s -> a -> TDl_State s a -> m s
  td_reward :: pr -> s -> a -> s -> TD_Number
  td_modify :: pr -> s -> a -> TDl_State s a  -> m ()

queryQ s = HashMap.toList <$> get_s s <$> use tdl_q
modifyQ pr s a f = tdl_q %= modify_s_a s a f
listZ pr s a f = (list <$> use tdl_z) >>= mapM_ f >> get >>= lift . td_modify pr s a
modifyZ pr s a f = tdl_z %= modify_s_a s a f
action pr s eps = queryQ s >>= eps_greedy_action eps (td_greedy pr)
transition pr s a = get >>= lift . td_transition pr s a
getQ s a = get_s_a s a <$> use tdl_q

-- | TD(lambda) learning, aka Sarsa(lambda), pg 171
tdl_learn :: (MonadRnd g m, TDl_Problem pr m s a)
  => TDl_Opts -> Q s a -> s -> pr -> m (s, Q s a)
tdl_learn TDl_Opts{..} q0 s0 pr = do
  (view _1 *** view tdl_q) <$> do
  flip runStateT (initialState q0) $ do
    (a0,q0) <- action pr s0 o_eps
    loopM (s0,a0) (not . td_is_terminal pr . view _1) $ \(s,a) -> do
      q <- getQ s a
      s' <- transition pr s a
      r <- pure $ td_reward pr s a s'
      (a',q') <- action pr s' o_eps
      delta <- pure $ r + o_gamma * q' - q
      modifyZ pr s a (+1)
      listZ pr s a $ \(s,a,z) -> do
        modifyQ pr s a (\q -> q + o_alpha * delta * z)
        modifyZ pr s a (\z -> o_gamma * o_lambda * z)
      return (s',a')


-- | Watkins's Q(lambda) learning algorithm, pg 174
qlw_learn :: (MonadRnd g m, TDl_Problem pr m s a)
  => TDl_Opts -> Q s a -> s -> pr -> m (s, Q s a)
qlw_learn TDl_Opts{..}  q0 s0 pr =
  (view _1 *** view tdl_q) <$> do
  flip runStateT (initialState q0) $ do
    (a0,q0) <- action pr s0 o_eps
    loopM (s0,a0,q0) (not . td_is_terminal pr . view _1) $ \(s,a,q) -> do
      s' <- transition pr s a
      r <- pure $ td_reward pr s a s'
      (a',q') <- action pr s' o_eps
      (a'',q'') <- maximumBy (compare`on`snd) <$> queryQ s'
      delta <- pure $ r + o_gamma * q'' - q
      modifyZ pr s a (+1)
      listZ pr s a $ \(s,a,z) -> do
        modifyQ pr s a (\q -> q + o_alpha * delta * z)
        modifyZ pr s a (\z -> if a' == a''
                                then o_gamma*o_lambda*z
                                else 0)
      return (s',a',q')