{-# LANGUAGE ScopedTypeVariables #-}

module Numeric.Sibe.Word2Vec
  ( word2vec
  , Word2Vec (..)
  , W2VMethod (..)
  ) where
    import Numeric.Sibe
    import Numeric.Sibe.Utils
    import Debug.Trace
    import Data.Char
    import Data.Maybe
    import Data.List
    import Numeric.LinearAlgebra as H hiding (find)
    import qualified Data.Vector.Storable as V
    import Data.Default.Class
    import Data.Function (on)
    import Control.Monad
    import System.Random

    import Graphics.Rendering.Chart as Chart
    import Graphics.Rendering.Chart.Backend.Cairo
    import Control.Lens

    data W2VMethod = SkipGram | CBOW
    data Word2Vec = Word2Vec { docs :: [String]
                             , window :: Int
                             , dimensions :: Int
                             , method :: W2VMethod
                             , w2vChartName :: String
                             , w2vDrawChart :: Bool
                             }
    instance Default Word2Vec where
      def = Word2Vec { docs = []
                     , window = 2
                     , w2vChartName = "w2v.png"
                     , w2vDrawChart = False
                     }

    word2vec w2v session = do
      seed <- newStdGen

      let s = session { training = trainingData
                      , network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy'))
                      }

      when (debug s) $ do
        putStr "vocabulary size: "
        print v

        putStr "trainingData length: "
        print . length $ trainingData

      -- biases are not used in skipgram/cbow
      newses <- run (sgd . ignoreBiases) s


      -- export the hidden layer
      let (hidden@(Layer biases nodes _) :- _) = network newses
      -- run words through the hidden layer alone to get the word vector
      let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec

      when (w2vDrawChart w2v) $ do
        let m = fromRows . map snd $ computedVocVec
            twoDimensions = pca m 2
            textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toRows twoDimensions)

            chart = toRenderable layout
              where
                textP = plot_annotation_values .~ textData
                      $ def
                layout = layout_title .~ "word vectors"
                       $ layout_plots .~ [toPlot textP]
                       $ def
                    
        renderableToFile def (w2vChartName w2v) chart
        return ()

      return (computedVocVec, vocvec)
      where
        -- clean documents
        ds = map cleanText (docs w2v)

        -- words of each document
        wd = map (words . (++ " ") . (map toLower)) ds

        -- all words together, used to generate the vocabulary
        ws = words (concatMap ((++ " ") . map toLower) ds)
        vocabulary = ordNub ws
        v = length vocabulary

        -- generate one-hot vectors for each word of vocabulary
        vocvec = zip vocabulary $ map (onehot v) [0..v - 1]

        -- training data: generate input and output pairs for each word and the words in it's window
        trainingData = concatMap (\wds -> concatMap (iter wds) $ zip [0..] wds) wd
          where
            iter wds (i, w) =
              let v = snd . fromJust . find ((==w) . fst) $ vocvec
                  before = take (window w2v) . drop (i - window w2v) $ wds
                  after = take (window w2v) . drop (i + 1) $ wds
                  ns 
                    | i == 0 = after
                    | i == length vocvec - 1 = before
                    | otherwise = before ++ after
                  vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns
                  new = foldl1 (+) vectorized
              in
                if length wds <= 1
                  then []
                  else
                    case method w2v of
                      SkipGram -> [(v, average new)]
                      CBOW     -> [(average new, v)]
                      _        -> error "unsupported word2vec method"

    cleanText :: String -> String
    cleanText string =
      let puncs = filter (`notElem` ['!', '"', '#', '$', '%', '(', ')', '.', '?', '\'']) (trim string)
          spacify = foldl (\acc x -> replace x ' ' acc) puncs [',', '/', '-', '\n', '\r']
          nonumber = filter (not . isNumber) spacify
          lower = map toLower nonumber
      in (unwords . words) lower -- remove unnecessary spaces
      where
        trim = f . f
          where
            f = reverse . dropWhile isSpace
        replace needle replacement =
          map (\c -> if c == needle then replacement else c)