{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} module Grenade.Utils.OneHot ( oneHot , hotMap , makeHot , unHot , sample ) where import qualified Control.Monad.Random as MR import Data.List ( group, sort ) import Data.Map ( Map ) import qualified Data.Map as M import Data.Proxy import Data.Singletons.TypeLits import Data.Vector ( Vector ) import qualified Data.Vector as V import qualified Data.Vector.Storable as VS import Numeric.LinearAlgebra ( maxIndex ) import Numeric.LinearAlgebra.Devel import Numeric.LinearAlgebra.Static import Grenade.Core.Shape -- | From an int which is hot, create a 1D Shape -- with one index hot (1) with the rest 0. -- Rerurns Nothing if the hot number is larger -- than the length of the vector. oneHot :: forall n. (KnownNat n) => Int -> Maybe (S ('D1 n)) oneHot hot = let len = fromIntegral $ natVal (Proxy :: Proxy n) in if hot < len then fmap S1D . create $ runSTVector $ do vec <- newVector 0 len writeVector vec hot 1 return vec else Nothing -- | Create a one hot map from any enumerable. -- Returns a map, and the ordered list for the reverse transformation hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a) hotMap n as = let len = fromIntegral $ natVal n uniq = [ c | (c:_) <- group $ sort as] hotl = length uniq in if hotl <= len then Just (M.fromList $ zip uniq [0..], V.fromList uniq) else Nothing -- | From a map and value, create a 1D Shape -- with one index hot (1) with the rest 0. -- Rerurns Nothing if the hot number is larger -- than the length of the vector or the map -- doesn't contain the value. makeHot :: forall a n. (Ord a, KnownNat n) => Map a Int -> a -> Maybe (S ('D1 n)) makeHot m x = do hot <- M.lookup x m let len = fromIntegral $ natVal (Proxy :: Proxy n) if hot < len then fmap S1D . create $ runSTVector $ do vec <- newVector 0 len writeVector vec hot 1 return vec else Nothing unHot :: forall a n. KnownNat n => Vector a -> S ('D1 n) -> Maybe a unHot v (S1D xs) = (V.!?) v $ maxIndex (extract xs) sample :: forall a n m. (KnownNat n, MR.MonadRandom m) => Double -> Vector a -> S ('D1 n) -> m a sample temperature v (S1D xs) = do ix <- MR.fromList . zip [0..] . fmap (toRational . exp . (/ temperature) . log) . VS.toList . extract $ xs return $ v V.! ix