{-# LANGUAGE RecordWildCards #-}
module Numeric.Recommender.ALS where
import Control.Parallel.Strategies
import Data.Bifunctor
import Data.Default.Class
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import Data.List (sortBy)
import Data.Maybe
import Data.Tuple
import qualified Data.Vector.Storable
import Numeric.LinearAlgebra
import System.Random
import Prelude hiding ((<>))
data ALSParams = ALSParams
{ lambda :: Double
, alpha :: Double
, seed :: Int
, nFactors :: Int
, nIterations :: Int
} deriving (Show)
instance Default ALSParams where
def = ALSParams 0.1 10 0 10 10
data ALSModel u i = ALSModel
{ encodeUser :: u -> Maybe Int
, decodeUser :: Int -> u
, encodeItem :: i -> Maybe Int
, decodeItem :: Int -> i
, feature :: !(Matrix Double)
, recommend :: IntMap.IntMap [(i, Bool)]
}
buildModel
:: (Functor f, Foldable f)
=> ALSParams
-> (u -> Int)
-> (Int -> u)
-> (i -> Int)
-> (Int -> i)
-> f (u, i)
-> ALSModel u i
buildModel ALSParams{..} fromUser toUser fromItem toItem xs = let
rnd = mkStdGen seed
(encUser, decUser) = bimap (. fromUser) (toUser .) .
compact $ fmap (fromUser . fst) xs
(encItem, decItem) = bimap (. fromItem) (toItem .) .
compact $ fmap (fromItem . snd) xs
xs' = fmap (bimap (fromJust . encUser) (fromJust . encItem)) xs
usrIt = foldr
(\(k,v) -> IntMap.insertWith IntSet.union k (IntSet.singleton v)) mempty xs'
nU = 1 + (maximum $ fmap fst xs')
nM = 1 + (maximum $ fmap snd xs')
selections = foldr (\(u,c) -> IntSet.insert (c+(nM*u))) mempty xs'
ratings = (nU><nM) $
map (\k -> if IntSet.member k selections then 1 else 0) [0..(nU*nM)-1]
weighted = scalar alpha * ratings
mIni = (nFactors><nM) $ replicate nFactors 1 ++
(take (nFactors*(nM-1)) $ randomRs (0,lambda) rnd)
sumsU = vector $ map (Data.Vector.Storable.foldr (+) 0) $ toRows ratings
sumsM = vector $ map (Data.Vector.Storable.foldr (+) 0) $ toColumns ratings
f m = let
mtm = m <> tr m
u = fromRows $ parMap rdeepseq
(\i -> let
m' = m ¿ ((filter (\j -> (>0.1) $
atIndex ratings (i,j))) [0..nM-1])
f' x = vector $ map (atIndex ((toRows x) !! i))
((filter (\j -> (>0.1) $ atIndex ratings (i,j))) [0..nM-1])
w' = f' weighted
r' = f' ratings
m'' = tr $ (tr m') * asColumn w'
x1 = mtm + (m'' <> tr m' +
(scalar lambda * scalar (atIndex sumsU i) * ident nFactors))
x2 = asColumn $ (m'' + m') #> r'
in flatten . maybe (linearSolveSVD x1 x2) id $ linearSolve x1 x2
) [0..nU-1]
tuu = tr u <> u
m2 = fromColumns $ parMap rdeepseq
(\j -> let
u' = u ? ((filter (\i -> (>0.1) $
atIndex ratings (i,j))) [0..nU-1])
f' x = vector $ map (atIndex ((toColumns x) !! j))
((filter (\i -> (>0.1) $ atIndex ratings (i,j))) [0..nU-1])
w' = f' weighted
r' = f' ratings
u'' = tr $ asColumn w' * u'
x1 = tuu + u'' <> u' +
(scalar lambda * scalar (atIndex sumsM j) * ident nFactors)
x2 = asColumn $ (u'' + tr u') #> r'
in flatten . maybe (linearSolveSVD x1 x2) id $ linearSolve x1 x2
) [0..nM-1]
in (u, m2)
(userFeature, itemFeature) = iterate
((\x -> x `seq` f x) . snd) (f mIni) !! (nIterations - 1)
feat = userFeature <> itemFeature
in ALSModel encUser decUser encItem decItem feat $
foldr (\u -> let inUsr = fromJust $ IntMap.lookup u usrIt in
IntMap.insert u $
map ((\x -> (decItem x, not $ IntSet.member x inUsr)) . fst) $
sortBy (\(_,a) (_,b) -> compare b a) $
zip [0..] $ head $ toLists $ feat ? [u])
mempty $ map fst $ foldr (:) [] xs'
where
compact
:: Foldable f
=> f Int
-> (Int -> Maybe Int, Int -> Int)
compact ys = let
mp = foldr (\x a -> IntMap.insertWith (flip const) x (IntMap.size a) a) mempty ys
pm = IntMap.fromList . map swap $ IntMap.toList mp
in ( flip IntMap.lookup mp
, maybe (error $ "missing value") id . flip IntMap.lookup pm
)