module NLP.SwiftLDA
(
pass
, passOne
, LDA
, Doc
, D
, W
, Z
, Table2D
, Table1D
, Finalized (..)
, initial
, finalize
, docTopicWeights_
, priorDocTopicWeights_
, docTopicWeights
, wordTopicWeights
, docCounts
)
where
import Prelude hiding (read, exponent)
import Data.Array.ST
import Data.STRef
import Control.Applicative
import System.Random.MWC
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Word
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
import qualified Data.IntMap as IntMap
import qualified Data.List as List
import qualified Data.Foldable as Fold
import GHC.Generics (Generic)
import Debug.Trace
import NLP.SwiftLDA.UnboxedMaybeVector ()
type Array2D s = STUArray s (Int,Int) Double
type Array1D s = STUArray s Int Double
type Table2D = IntMap.IntMap Table1D
type Table1D = IntMap.IntMap Double
type D = Int
type Z = Int
type W = Int
type Doc = (D, U.Vector (W, Maybe Z))
data LDA s =
LDA
{ _docTopics :: !(STRef s (Array2D s))
, _wordTopics :: !(STRef s (Array2D s))
, _topics :: !(Array1D s)
, _alphasum :: !Double
, _beta :: !Double
, _topicNum :: !Int
, _wSize :: !(STRef s Int)
, weights :: !(Array1D s)
, weightSum :: !(STRef s Double)
, gen :: !(Gen (PrimState (ST s)))
, _exponent :: !(Maybe Double)
}
data Finalized =
Finalized
{ docTopics :: !Table2D
, wordTopics :: !Table2D
, topics :: !Table1D
, topicDocs :: !Table2D
, topicWords :: !Table2D
, alphasum :: !Double
, beta :: !Double
, topicNum :: !Int
, wSize :: !Int
, exponent :: !(Maybe Double)
} deriving (Generic)
docCounts :: Finalized -> Table1D
docCounts = IntMap.map (sum . IntMap.elems) . docTopics
finalize :: LDA s -> ST s Finalized
finalize m = do
dt <- read . _docTopics $ m
wt <- read . _wordTopics $ m
dtf <- freezeArray2D dt
wtf <- freezeArray2D wt
tf <- freezeArray1D (_topics m)
ws <- read . _wSize $ m
return $! Finalized {
docTopics = dtf
, wordTopics = wtf
, topics = tf
, topicDocs = invert dtf
, topicWords = invert wtf
, alphasum = _alphasum m
, beta = _beta m
, topicNum = _topicNum m
, wSize = ws
, exponent = _exponent m
}
iDSIZE :: Int
iDSIZE = 1000
iWSIZE :: Int
iWSIZE = 1000
initial :: U.Vector Word32 -> Int -> Double -> Double -> Maybe Double
-> ST s (LDA s)
initial s k a b e = do
dta <- newArray ((0,0),(iDSIZE, k1)) 0
wta <- newArray ((0,0),(iWSIZE, k1)) 0
LDA <$>
new dta <*>
new wta <*>
newArray (0,k1) 0 <*>
pure a <*>
pure b <*>
pure k <*>
new 0 <*>
newArray (0,k1) 0 <*>
new 0 <*>
initialize s <*>
pure e
rho :: Double -> Int -> Double
rho e t = 1 (1 + fromIntegral t)**(e)
pass :: Int -> LDA s -> V.Vector Doc -> ST s (V.Vector Doc)
pass t m = V.mapM (passOne t m)
passOne :: Int -> LDA s -> Doc -> ST s Doc
passOne t m doc@(!d, wz) = do
grow m doc
zs <- U.mapM one wz
return (d, U.zip (U.map fst wz) (U.map Just zs))
where r = maybe 1 (flip rho t) . _exponent $ m
one (w, mz) = do
case mz of
Nothing -> return ()
Just z -> update (negate r) m d w z
!z <- randomZ m d w
update r m d w z
return z
randomZ :: LDA s -> Int -> Int -> ST s Int
randomZ m !d !w = do
wordTopicWeights_ m d w
!s <- read (weightSum m)
sample (weights m) s (gen m)
wordTopicWeights_ :: LDA s -> Int -> Int -> ST s ()
wordTopicWeights_ m !d !w = do
let k = _topicNum m
a = _alphasum m / fromIntegral k
b = _beta m
v <- fromIntegral <$> read (_wSize m)
dt <- read (_docTopics m)
wt <- read (_wordTopics m)
let ws = weights m
write (weightSum m) 0
(l,u) <- getBounds ws
let go !z | z > u = return ()
go !z = do
nzd <- readArray dt (d,z)
nzw <- readArray wt (w,z)
nz <- readArray (_topics m) z
let !n = (nzd + a) * (nzw + b) / (nz + v * b)
!s <- read (weightSum m)
write (weightSum m) (s+n)
writeArray ws z n
go (z+1)
go l
docTopicWeights :: Finalized -> Doc -> U.Vector Double
docTopicWeights m (d, ws) =
U.accumulate (+) (U.replicate (topicNum m) 0)
. U.concatMap (U.indexed . wordTopicWeights m d)
. U.map fst
$ ws
priorDocTopicWeights_ :: LDA s -> D -> ST s (U.Vector Double)
priorDocTopicWeights_ m d = do
grow m (d, U.empty)
dt <- read (_docTopics m)
((_,0),(_,u)) <- getBounds dt
U.generateM (u+1) (\z -> readArray dt (d,z))
docTopicWeights_ :: LDA s -> Doc -> ST s (U.Vector Double)
docTopicWeights_ m doc@(d, ws) = do
grow m doc
(0,u) <- getBounds (weights m)
let r = U.replicate (_topicNum m) 0
let one w = do
wordTopicWeights_ m d w
U.generateM (u+1) (readArray (weights m))
U.foldM' (\z w -> do y <- one w
return $! U.zipWith (+) z y) r
. U.map fst
$ ws
wordTopicWeights :: Finalized -> D -> W -> U.Vector Double
wordTopicWeights m d w =
let k = topicNum m
a = alphasum m / fromIntegral k
b = beta m
dt = IntMap.findWithDefault IntMap.empty d . docTopics $ m
wt = IntMap.findWithDefault IntMap.empty w . wordTopics $ m
v = fromIntegral . wSize $ m
weights = [ (count z dt + a)
* (count z wt + b)
* (1/(count z (topics m) + v * b))
| z <- [0..k1] ]
in U.fromList weights
update :: Double -> LDA s -> Int -> Int -> Int -> ST s ()
update c m d w z = do
dt <- read (_docTopics m)
wt <- read (_wordTopics m)
wsz <- read (_wSize m) ; write (_wSize m) (max (w+1) wsz)
nz <- readArray (_topics m) z ; writeArray (_topics m) z (nz+c)
nzd <- readArray dt (d,z) ; writeArray dt (d,z) (nzd+c)
nzw <- readArray wt (w,z) ; writeArray wt (w,z) (nzw+c)
grow :: LDA s -> Doc -> ST s ()
grow m (d, wz) = do
let w = if U.null wz then 0 else U.maximum (U.map fst wz)
dt <- read (_docTopics m) ; (_,(d_max,_)) <- getBounds dt
wt <- read (_wordTopics m) ; (_,(w_max,_)) <- getBounds wt
when (d > d_max) (do dt' <- resize dt
write (_docTopics m) dt')
when (w > w_max) (do wt' <- resize wt
write (_wordTopics m) wt')
resize :: Array2D s -> ST s (Array2D s)
resize a = do
bs@((l1,l2),(u1_old,u2)) <- getBounds a
trace (show bs) () `seq` return ()
let u1 = u1_old * 2
bs' = ((l1,l2),(u1,u2))
b <- newArray bs' 0
let copy !i = do
v <- readArray a i
writeArray b i v
mapM_ copy (range bs)
return b
sample :: Array1D s -> Double -> Gen s -> ST s Int
sample ws s g = do
!r <- uniformR (0,s) g
findEvent r ws
findEvent :: Double -> Array1D s -> ST s Int
findEvent !r ws = do
(l,u) <- getBounds ws
let go !i !_n | i > u = return (i1)
go !i !n | n > 0.0 = do v <- readArray ws i
go (i+1) (nv)
| otherwise = return (i1)
go l r
read :: STRef s a -> ST s a
read = readSTRef
write :: STRef s a -> a -> ST s ()
write = writeSTRef
new :: a -> ST s (STRef s a)
new = newSTRef
invert :: Table2D -> Table2D
invert outer =
List.foldl' (\z (k,k',v) -> upd v z k k') IntMap.empty
[ (k',k,v)
| (k, inner) <- IntMap.toList outer
, (k', v) <- IntMap.toList inner ]
upd :: Double -> Table2D -> Int -> Int -> Table2D
upd c m k k' = IntMap.insertWith' (flip (IntMap.unionWith (+)))
k
(IntMap.singleton k' c)
m
freezeArray2D :: Array2D s -> ST s Table2D
freezeArray2D a = do
bs <- getBounds a
Fold.foldlM f IntMap.empty (range bs)
where f z ind@(!i,!i') = do
!v <- readArray a ind
if v > 0
then return $! upd v z i i'
else return $! z
freezeArray1D :: Array1D s -> ST s Table1D
freezeArray1D a = IntMap.fromList . filter ((>0) . snd) <$> getAssocs a
count :: Int -> IntMap.IntMap Double -> Double
count z t = case IntMap.findWithDefault 0 z t of
n | n < 0 -> error "NLP.SwiftLDA.count: negative count"
n -> n