module Numeric.Sibe.Word2Vec
( word2vec
, Word2Vec (..)
, W2VMethod (..)
) where
import Numeric.Sibe
import Numeric.Sibe.Utils
import Debug.Trace
import Data.Char
import Data.Maybe
import Data.List
import Numeric.LinearAlgebra as H hiding (find)
import qualified Data.Vector.Storable as V
import Data.Default.Class
import Data.Function (on)
import Control.Monad
import System.Random
import Graphics.Rendering.Chart as Chart
import Graphics.Rendering.Chart.Backend.Cairo
import Control.Lens
data W2VMethod = SkipGram | CBOW
data Word2Vec = Word2Vec { docs :: [String]
, window :: Int
, dimensions :: Int
, method :: W2VMethod
, w2vChartName :: String
, w2vDrawChart :: Bool
}
instance Default Word2Vec where
def = Word2Vec { docs = []
, window = 2
, w2vChartName = "w2v.png"
, w2vDrawChart = False
}
word2vec w2v session = do
seed <- newStdGen
let s = session { training = trainingData
, network = randomNetwork 0 (1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy'))
}
when (debug s) $ do
putStr "vocabulary size: "
print v
putStr "trainingData length: "
print . length $ trainingData
newses <- run (sgd . ignoreBiases) s
let (hidden@(Layer biases nodes _) :- _) = network newses
let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec
when (w2vDrawChart w2v) $ do
let m = fromRows . map snd $ computedVocVec
twoDimensions = pca m 2
textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toRows twoDimensions)
chart = toRenderable layout
where
textP = plot_annotation_values .~ textData
$ def
layout = layout_title .~ "word vectors"
$ layout_plots .~ [toPlot textP]
$ def
renderableToFile def (w2vChartName w2v) chart
return ()
return (computedVocVec, vocvec)
where
ds = map cleanText (docs w2v)
wd = map (words . (++ " ") . (map toLower)) ds
ws = words (concatMap ((++ " ") . map toLower) ds)
vocabulary = ordNub ws
v = length vocabulary
vocvec = zip vocabulary $ map (onehot v) [0..v 1]
trainingData = concatMap (\wds -> concatMap (iter wds) $ zip [0..] wds) wd
where
iter wds (i, w) =
let v = snd . fromJust . find ((==w) . fst) $ vocvec
before = take (window w2v) . drop (i window w2v) $ wds
after = take (window w2v) . drop (i + 1) $ wds
ns
| i == 0 = after
| i == length vocvec 1 = before
| otherwise = before ++ after
vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns
new = foldl1 (+) vectorized
in
if length wds <= 1
then []
else
case method w2v of
SkipGram -> [(v, average new)]
CBOW -> [(average new, v)]
_ -> error "unsupported word2vec method"
cleanText :: String -> String
cleanText string =
let puncs = filter (`notElem` ['!', '"', '#', '$', '%', '(', ')', '.', '?', '\'']) (trim string)
spacify = foldl (\acc x -> replace x ' ' acc) puncs [',', '/', '-', '\n', '\r']
nonumber = filter (not . isNumber) spacify
lower = map toLower nonumber
in (unwords . words) lower
where
trim = f . f
where
f = reverse . dropWhile isSpace
replace needle replacement =
map (\c -> if c == needle then replacement else c)