------------------------------------------------------------------------------- -- | -- Module : Spaces.Action -- Copyright : (c) Sentenai 2017 -- License : BSD3 -- Maintainer: sam@sentenai.com -- Stability : experimental -- Portability: non-portable -- -- typeclass for a discrete action space, as well as helper functions ------------------------------------------------------------------------------- {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} module Reinforce.Spaces.Action ( DiscreteActionSpace(..) , oneHot , oneHot' , allActions , randomChoice ) where import Reinforce.Prelude import Control.MonadMWCRandom import Numeric.LinearAlgebra.Static (R) import qualified Numeric.LinearAlgebra.Static as LA import qualified Data.Vector as V -- | Mostly tags around an Enum, but includes information about the size of -- an action space and is used in helper functions. class (Bounded a, Enum a) => DiscreteActionSpace a where type Size a :: Nat toAction :: Int -> a toAction = toEnum fromAction :: a -> Int fromAction = fromEnum -- | one-hot encode a bounded enumerable. Doesn't care if minBound is < or > 0 oneHot :: forall a . (KnownNat (Size a), DiscreteActionSpace a) => a -> R (Size a) oneHot e = LA.vector . V.toList $ V.unsafeUpd (replicateZeros (Proxy :: Proxy a)) [(fromEnum e, 1)] -- | one-hot encode a bounded enumerable oneHot' :: forall a . (DiscreteActionSpace a) => a -> Vector Double oneHot' e = V.unsafeUpd (replicateZeros (Proxy :: Proxy a)) [(fromEnum e, 1)] -- | helper function to initialize a one-hot vector replicateZeros :: forall a . (Enum a, Bounded a) => Proxy a -> Vector Double replicateZeros _ = V.fromList $ replicate (fromEnum (maxBound :: a) + 1) 0 -- | helper function to get all actions in a discrete action space allActions :: DiscreteActionSpace a => [a] allActions = [minBound..maxBound] -- | make a uniform-random selection of an Action in a discrete action space randomChoice :: forall m a . (MonadIO m , MonadMWCRandom m, DiscreteActionSpace a) => m a randomChoice = toEnum . fst <$> sampleFrom uniformDist where uniformDist :: [Double] uniformDist = fmap (\a -> convert a / total) allActions where convert :: a -> Double convert = fromIntegral . fromEnum total :: Double total = sum (fmap convert allActions)