module Grenade.Utils.OneHot (
oneHot
, hotMap
, makeHot
, unHot
, sample
) where
import qualified Control.Monad.Random as MR
import Data.List ( group, sort )
import Data.Map ( Map )
import qualified Data.Map as M
import Data.Proxy
import Data.Singletons.TypeLits
import Data.Vector ( Vector )
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS
import Numeric.LinearAlgebra ( maxIndex )
import Numeric.LinearAlgebra.Devel
import Numeric.LinearAlgebra.Static
import Grenade.Core.Shape
oneHot :: forall n. (KnownNat n)
=> Int -> Maybe (S ('D1 n))
oneHot hot =
let len = fromIntegral $ natVal (Proxy :: Proxy n)
in if hot < len
then
fmap S1D . create $ runSTVector $ do
vec <- newVector 0 len
writeVector vec hot 1
return vec
else Nothing
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a)
hotMap n as =
let len = fromIntegral $ natVal n
uniq = [ c | (c:_) <- group $ sort as]
hotl = length uniq
in if hotl <= len
then
Just (M.fromList $ zip uniq [0..], V.fromList uniq)
else Nothing
makeHot :: forall a n. (Ord a, KnownNat n)
=> Map a Int -> a -> Maybe (S ('D1 n))
makeHot m x = do
hot <- M.lookup x m
let len = fromIntegral $ natVal (Proxy :: Proxy n)
if hot < len
then
fmap S1D . create $ runSTVector $ do
vec <- newVector 0 len
writeVector vec hot 1
return vec
else Nothing
unHot :: forall a n. KnownNat n
=> Vector a -> S ('D1 n) -> Maybe a
unHot v (S1D xs)
= (V.!?) v
$ maxIndex (extract xs)
sample :: forall a n m. (KnownNat n, MR.MonadRandom m)
=> Double -> Vector a -> S ('D1 n) -> m a
sample temperature v (S1D xs) = do
ix <- MR.fromList . zip [0..] . fmap (toRational . exp . (/ temperature) . log) . VS.toList . extract $ xs
return $ v V.! ix