{-# LANGUAGE RecordWildCards #-}

module Numeric.Recommender.ALS where

import Control.Parallel.Strategies
import Data.Bifunctor
import Data.Default.Class
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import Data.List (sortBy)
import Data.Maybe
import Data.Tuple
import qualified Data.Vector.Storable
import Numeric.LinearAlgebra
import System.Random

import Prelude hiding ((<>))

data ALSParams = ALSParams
  { lambda :: Double    -- ^ Training speed
  , alpha :: Double     -- ^ Weight multiplier
  , seed :: Int         -- ^ RNG seed
  , nFactors :: Int     -- ^ Hidden features dimension
  , nIterations :: Int  -- ^ Training iterations
  } deriving (Show)

instance Default ALSParams where
  def = ALSParams 0.1 10 0 10 10

data ALSModel u i = ALSModel
  { encodeUser :: u -> Maybe Int        -- ^ User to dense representation
  , decodeUser :: Int -> u              -- ^ User from dense representation
  , encodeItem :: i -> Maybe Int        -- ^ Item to dense representation
  , decodeItem :: Int -> i              -- ^ Item from dense representation
  , feature :: !(Matrix Double)         -- ^ The feature matrix
  -- | Per user recommendations, best match first.  The second
  -- parameter is True if the user hasn't selected this item.  Use
  -- 'encodeUser' to get a key to this map.
  , recommend :: IntMap.IntMap [(i, Bool)]
  }

{-
-- Useful for debugging
costFunction
  :: Matrix Double -> Matrix Double -> Matrix Double -> Matrix Double
  -> Double -> Vector Double -> Vector Double -> Double
costFunction r u m w l nui nmj = let rum = r - (u <> m)
  in sumElements ((w + 1) * (rum * rum)) +
     (l * (sumElements (nui <# (u^2)) + sumElements ((m^2) #> nmj)))
-}

-- | Build recommendations based on users' unrated item choices.
--
-- Takes conversion functions to/from Int representation for user
-- supplied data types.  Use 'id' if you're already based on them.
--
-- The implementation follows the one in the recommenderlab library in
-- CRAN.  For further details, see "Large-scale Parallel Collaborative
-- Filtering for the Netflix Prize" by Yunhong Zhou, Dennis Wilkinson,
-- Robert Schreiber and Rong Pan.
buildModel
  :: (Functor f, Foldable f)
  => ALSParams
  -> (u -> Int)
  -> (Int -> u)
  -> (i -> Int)
  -> (Int -> i)
  -> f (u, i)      -- ^ User-item pairs
  -> ALSModel u i
buildModel ALSParams{..} fromUser toUser fromItem toItem xs = let
  rnd = mkStdGen seed
  (encUser, decUser) = bimap (. fromUser) (toUser .) .
    compact $ fmap (fromUser . fst) xs
  (encItem, decItem) = bimap (. fromItem) (toItem .) .
    compact $ fmap (fromItem . snd) xs
  xs' = fmap (bimap (fromJust . encUser) (fromJust . encItem)) xs
  usrIt = foldr
    (\(k,v) -> IntMap.insertWith IntSet.union k (IntSet.singleton v)) mempty xs'
  nU = 1 + (maximum $ fmap fst xs')
  nM = 1 + (maximum $ fmap snd xs')
  selections = foldr (\(u,c) -> IntSet.insert (c+(nM*u))) mempty xs'
  ratings = (nU><nM) $
    map (\k -> if IntSet.member k selections then 1 else 0) [0..(nU*nM)-1]
  weighted = scalar alpha * ratings
  mIni = (nFactors><nM) $ replicate nFactors 1 ++
         (take (nFactors*(nM-1)) $ randomRs (0,lambda) rnd)
  sumsU = vector $ map (Data.Vector.Storable.foldr (+) 0) $ toRows ratings
  sumsM = vector $ map (Data.Vector.Storable.foldr (+) 0) $ toColumns ratings
  f m = let
    -- Minimize the user feature matrix
    mtm = m <> tr m
    u = fromRows $ parMap rdeepseq
      (\i -> let
          -- Drop the rows and columns not relevant to this user
          m' = m ¿ ((filter (\j -> (>0.1) $
                                   atIndex ratings (i,j))) [0..nM-1])
          f' x = vector $ map (atIndex ((toRows x) !! i))
            ((filter (\j -> (>0.1) $ atIndex ratings (i,j))) [0..nM-1])
          w' = f' weighted
          r' = f' ratings
          m'' = tr $ (tr m') * asColumn w'
          x1 = mtm + (m'' <> tr m' +
                          (scalar lambda * scalar (atIndex sumsU i) * ident nFactors))
          x2 = asColumn $ (m'' + m') #> r'
          in flatten . maybe (linearSolveSVD x1 x2) id $ linearSolve x1 x2
      ) [0..nU-1]
    -- Minimize the item feature matrix
    tuu = tr u <> u
    m2 = fromColumns $ parMap rdeepseq
      (\j -> let
          u' = u ? ((filter (\i -> (>0.1) $
                                           atIndex ratings (i,j))) [0..nU-1])
          f' x = vector $ map (atIndex ((toColumns x) !! j))
            ((filter (\i -> (>0.1) $ atIndex ratings (i,j))) [0..nU-1])
          w' = f' weighted
          r' = f' ratings
          u'' = tr $ asColumn w' * u'
          x1 = tuu + u'' <> u' +
               (scalar lambda * scalar (atIndex sumsM j) * ident nFactors)
          x2 = asColumn $ (u'' + tr u') #> r'
          in flatten . maybe (linearSolveSVD x1 x2) id $ linearSolve x1 x2
      ) [0..nM-1]
    in (u, m2)
  (userFeature, itemFeature) = iterate
    ((\x -> x `seq` f x) . snd) (f mIni) !! (nIterations - 1)
  feat = userFeature <> itemFeature
  in ALSModel encUser decUser encItem decItem feat $
     foldr (\u -> let inUsr = fromJust $ IntMap.lookup u usrIt in
                    IntMap.insert u $
                    map ((\x -> (decItem x, not $ IntSet.member x inUsr)) . fst) $
                    sortBy (\(_,a) (_,b) -> compare b a) $
                    zip [0..] $ head $ toLists $ feat ? [u])
     mempty $ map fst $ foldr (:) [] xs'
  where
    -- |Build to/from functions from a sparse set to a dense 0..n-1
    -- range.
    --
    -- The reverse function is total for convenience since the inputs
    -- for it are better controlled.
    compact
      :: Foldable f
      => f Int
      -> (Int -> Maybe Int, Int -> Int)
    compact ys = let
      mp = foldr (\x a -> IntMap.insertWith (flip const) x (IntMap.size a) a) mempty ys
      pm = IntMap.fromList . map swap $ IntMap.toList mp
      in ( flip IntMap.lookup mp
         , maybe (error $ "missing value") id . flip IntMap.lookup pm
         )