{-# LANGUAGE ScopedTypeVariables #-} module Torch.Data.Metrics where import Data.List (genericLength) import Data.Function (on) #ifdef CUDA import Torch.Cuda.Double import qualified Torch.Cuda.Long as Long #else import Torch.Double import qualified Torch.Long as Long #endif catAccuracy :: forall c sz . (Eq c, Enum c) -- , sz ~ FromEnum (MaxBound c), KnownDim sz, KnownNat sz) => [(Int, c)] -- [(Tensor '[FromEnum (MaxBound c)], c)] -> Double catAccuracy xs = filter issame xs // xs where (//) = (/) `on` genericLength issame (p, y) = toEnum p == y