module NLP.Hext.NaiveBayes (FrequencyList,
Labeled(..),
Classified(..),
BayesModel(..),
emptyModel,
teach,
runBayes,
) where
import qualified Data.HashMap.Lazy as H
import qualified Data.Set as S
import Data.Maybe
import Data.Char
import Data.Function
import Data.List
import qualified Data.Text.Lazy as T
type FrequencyList = H.HashMap T.Text Int
data Labeled a = Labeled { hash :: FrequencyList
, label :: a
}
data Classified a = Classified { _class :: a
, probability :: Double } deriving (Eq)
data BayesModel a = BayesModel { classes :: S.Set a
, vocab :: FrequencyList
, material :: [Labeled a]
}
instance (Show a) => Show (BayesModel a) where
show model = show (classes model) ++
" " ++ show (vocab model)
instance (Eq a) => Ord (Classified a) where compare = compare `on` probability
instance (Show a) => Show (Classified a) where
show c = show (_class c, probability c)
emptyModel :: BayesModel a
emptyModel = BayesModel S.empty H.empty []
teach :: (Ord a) => T.Text
-> a
-> BayesModel a
-> BayesModel a
teach source c model =
let fl = vectorize source
labeled = Labeled fl c
vc = vocab model
vc' = H.union fl vc
cs = classes model
mat = material model
in BayesModel (S.insert c cs) vc' (labeled:mat)
runBayes :: (Ord a, Eq a) => BayesModel a
-> String
-> a
runBayes model sample = argmax $ classify model (T.words $ T.pack sample)
classify :: (Ord a, Eq a) => BayesModel a -> [T.Text] -> S.Set (Classified a)
classify model = f where
cs = classes model
lengthVocab = H.size $ vocab model
mat = material model
prob c ws =
let caseC = unions . vecs $ filter ((== c) . label) mat
n = totalWords caseC
denom = n + lengthVocab
in foldl' (\acc word -> (pWordGivenClass word denom caseC) * acc) (pClass c mat) ws
f wrds = S.map (\c -> Classified c $ prob c wrds) cs
pClass :: (Eq a) => a -> [Labeled a] -> Double
pClass cl [] = 0
pClass cl docs =
let count = length $ filter (\(Labeled fl clas) -> clas == cl) docs
in (fromIntegral count) / (fromIntegral $ length docs)
pWordGivenClass :: T.Text -> Int -> FrequencyList -> Double
pWordGivenClass w denom currentCase =
(fromIntegral (nk + 1)) / (fromIntegral denom) where
nk = totalOfWord w currentCase
argmax :: (Eq a) => S.Set (Classified a) -> a
argmax = _class . S.findMax
removePunctuation :: T.Text -> T.Text
removePunctuation = T.filter (not . isPunctuation)
vectorize :: T.Text -> FrequencyList
vectorize =
H.fromListWith (+) . flip zip (repeat 1) . T.words . removePunctuation
vecs :: [Labeled a] -> [FrequencyList]
vecs = map hash
unions :: [FrequencyList] -> FrequencyList
unions = foldl' (\acc hmap -> H.unionWith (+) hmap acc) H.empty
totalWords :: FrequencyList -> Int
totalWords = H.foldl' (+) 0
totalOfWord :: T.Text -> FrequencyList -> Int
totalOfWord word doc = H.lookupDefault 0 word doc