{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Cache.LRU
 ( LRUCache
 , Strategy(..)
 ) where

import           Control.Monad.Ref
import           Data.Cache.Trace
import           Data.Cache.Type
import           Data.Hashable
import qualified Data.HashPSQ as PSQ
import qualified Data.DList as DList

newtype LRUCache m (s::Strategy) p k t v = LRUCache (Ref m (LRUCache' s p k t v))

data Strategy
 = FIFO -- ^ Priority is set on entry and not changed.
 -- | LIFO -- ^ A generally bad policy where the most recent thing is evicted.
 | LRU  -- ^ Every lookup generates a new priority.
 | LFU  -- ^ Priority is number of times looked up (tie breaker of order inserted?).


data LRUCache' (s::Strategy) p k t v
 = LRUCache1'
   { lcCapacity :: !p
   -- ^ The maximum number of elements in the queue.
   , lcSize     :: !p
   -- ^ The current number of elements in the queue.
   , lcGen     :: !p
   -- ^ The next priority.
   , lcPSQueue  :: !(PSQ.HashPSQ k p (t, v))
   -- ^ The actual cache
   }
 -- | We need a variant with 2 queues to handle rollover of the priority.
 | LRUCache2'
   { lcCapacity :: !p
   -- ^ The maximum number of elements between both queues.
   , lcSize     :: !p
   -- ^ The current number of elements in the queue.
   , lcGen     :: !p
   -- ^ The next priority.
   , lcPSQueue  :: !(PSQ.HashPSQ k p (t, v))
   -- ^ The cache we're currently inserting to.
   , lcPSQueueOverflow  :: !(PSQ.HashPSQ k p (t, v))
   -- ^ The cache we're drawing down.
   }
 deriving (Eq, Show)

shrinkQueue :: forall k p t v . (Hashable k, Ord k, Enum p, Num p, Ord p)
            => p -> p -> PSQ.HashPSQ k p (t, v) -> (p, [CacheEvent k t v], PSQ.HashPSQ k p (t, v))
shrinkQueue c s' q' =
  go s' 0 [] q'
  where
    -- since the number removed is only the distance between c and s' which are in p, a p shouldn't roll over
    -- (This analysis fails with negative ranges taken into account).
    go :: p -> p -> [CacheEvent k t v] -> PSQ.HashPSQ k p (t, v) -> (p, [CacheEvent k t v], PSQ.HashPSQ k p (t, v))
    go s cnt trc q | s <= c = (cnt, trc, q)
    go s cnt trc q =
      case PSQ.minView q of
        -- can't remove anything more somehow.
        Nothing -> (cnt, trc, q)
        Just (k, p, (t, v), q') -> go (pred s) (cnt+1) ((CacheEvict k t v) : trc) q

trim :: forall s k p t v (trc::Bool) m
      . (Hashable k, Ord k, Enum p, Num p, Ord p, MonadTrace trc, Applicative m, Monad (Tracable trc (CacheTrace k t v) m))
     => LRUCache' s p k t v -> Tracable trc (CacheTrace k t v) m (LRUCache' s p k t v)
-- Skip the cases where we're the right size already.
trim (c@(LRUCache1' {lcCapacity=cap, lcSize=sz})) | cap >= sz = pure c
trim (c@(LRUCache2' {lcCapacity=cap, lcSize=sz})) | cap >= sz = pure c
-- Now we must have something to remove.
trim (c@(LRUCache1' {lcCapacity=cap, lcSize=sz, lcPSQueue=q})) = do
  let (removed, trc, nq) = shrinkQueue cap sz q
  trace $ DList.fromList trc
  pure (c {lcSize=sz-removed, lcPSQueue=nq})
-- We remove as much as we want to the limit of availabuility from the over flow queue,
-- if we still need to remove things we start removing from the main queue, possibly
-- switching to the 1 constructor variant.
trim (c@(LRUCache2' {lcCapacity=cap, lcSize=sz, lcPSQueue=q, lcPSQueueOverflow=qo})) = do
  let (removedo, trco, nqo) = shrinkQueue cap sz qo
  trace $ DList.fromList trco
  let newsz = sz-removedo
  if cap >= newsz
  then pure $ reconstructCache newsz q nqo
  else do
      let (removed, trc, nq) = shrinkQueue cap newsz q
      trace $ DList.fromList trc
      pure $ reconstructCache (newsz-removed) nq nqo
  where
    -- One might argue that because we decriment from the current queue we know that the overflow
    -- queue is empty. I choose not to assume since I hadn't worked through all the various priority strategies.
    reconstructCache :: p -> PSQ.HashPSQ k p (t, v) -> PSQ.HashPSQ k p (t, v) -> LRUCache' s p k t v
    reconstructCache resSz resQ resQO | PSQ.null resQO =
          LRUCache1' { lcCapacity=cap
                     , lcSize = resSz
                     , lcPSQueue = resQ
                     , lcGen = lcGen c
                     }
    reconstructCache resSz resQ resQO =
          LRUCache2' { lcCapacity=cap
                     , lcSize = resSz
                     , lcPSQueue = resQ
                     , lcPSQueueOverflow = resQO
                     , lcGen = lcGen c
                     }