{-# LANGUAGE CPP                   #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE TemplateHaskell       #-}

-- | Quick-and-dirty, thread-unsafe, hash-based memoization.

module Data.Memoization (
    MemoCacheTag(..)

  , resetAllCaches
#ifdef PROFILE_CACHES
  , getAllCacheMetrics
  , printAllCacheMetrics
#endif

  , memoIO
  , memo
  , memo2
  ) where

import Data.Hashable ( Hashable )
import qualified Data.HashTable.IO as HT
import Data.Text ( Text )
import GHC.Generics ( Generic )
import System.IO.Unsafe ( unsafePerformIO )

import Data.HashTable.Extended

import Data.Text.Extended.Pretty

#ifdef PROFILE_CACHES
import Data.IORef ( IORef, newIORef, readIORef, writeIORef, modifyIORef )
import Data.List ( sort )
import Data.Memoization.Metrics ( CacheMetrics(CacheMetrics) )

import qualified Data.Text.IO as Text
#endif

-----------------------------------------------------------------

-------------------------------------------------------------
------------------ Caches and cache metrics -----------------
-------------------------------------------------------------

--------------
---- Memo cache
--------------

#ifdef PROFILE_CACHES
-- | Slightly ill-named. Tracks statistics and hash tables for all memo-caches under a given tag.
--   Multiple caches may be collapsed into the same tag.
data MemoCache = MemoCache { queryCount :: !(IORef Int)
                           , missCount  :: !(IORef Int)
                           , contents   :: ![AnyHashTable]
                           }

mkCache:: AnyHashTable -> IO MemoCache
mkCache ht = MemoCache <$> newIORef 0 <*> newIORef 0 <*> pure [ht]

resetCache :: MemoCache -> IO ()
resetCache c = do
  writeIORef (queryCount c) 0
  writeIORef (missCount  c) 0
  mapM_ resetHashTable (contents c)
#else
type MemoCache = ()
#endif

bumpQueryCount :: MemoCache -> IO ()
#ifdef PROFILE_CACHES
bumpQueryCount c = modifyIORef (queryCount c) (+1)
#else
bumpQueryCount :: MemoCache -> IO MemoCache
bumpQueryCount MemoCache
_ = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif


bumpMissCount :: MemoCache -> IO ()
#ifdef PROFILE_CACHES
bumpMissCount c = modifyIORef (missCount c) (+1)
#else
bumpMissCount :: MemoCache -> IO MemoCache
bumpMissCount MemoCache
_ = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

--------------
---- Tags
--------------

data MemoCacheTag = NameTag Text
  deriving ( MemoCacheTag -> MemoCacheTag -> Bool
(MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool) -> Eq MemoCacheTag
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemoCacheTag -> MemoCacheTag -> Bool
$c/= :: MemoCacheTag -> MemoCacheTag -> Bool
== :: MemoCacheTag -> MemoCacheTag -> Bool
$c== :: MemoCacheTag -> MemoCacheTag -> Bool
Eq, Eq MemoCacheTag
Eq MemoCacheTag
-> (MemoCacheTag -> MemoCacheTag -> Ordering)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> MemoCacheTag)
-> (MemoCacheTag -> MemoCacheTag -> MemoCacheTag)
-> Ord MemoCacheTag
MemoCacheTag -> MemoCacheTag -> Bool
MemoCacheTag -> MemoCacheTag -> Ordering
MemoCacheTag -> MemoCacheTag -> MemoCacheTag
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
$cmin :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
max :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
$cmax :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
>= :: MemoCacheTag -> MemoCacheTag -> Bool
$c>= :: MemoCacheTag -> MemoCacheTag -> Bool
> :: MemoCacheTag -> MemoCacheTag -> Bool
$c> :: MemoCacheTag -> MemoCacheTag -> Bool
<= :: MemoCacheTag -> MemoCacheTag -> Bool
$c<= :: MemoCacheTag -> MemoCacheTag -> Bool
< :: MemoCacheTag -> MemoCacheTag -> Bool
$c< :: MemoCacheTag -> MemoCacheTag -> Bool
compare :: MemoCacheTag -> MemoCacheTag -> Ordering
$ccompare :: MemoCacheTag -> MemoCacheTag -> Ordering
$cp1Ord :: Eq MemoCacheTag
Ord, Int -> MemoCacheTag -> ShowS
[MemoCacheTag] -> ShowS
MemoCacheTag -> String
(Int -> MemoCacheTag -> ShowS)
-> (MemoCacheTag -> String)
-> ([MemoCacheTag] -> ShowS)
-> Show MemoCacheTag
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemoCacheTag] -> ShowS
$cshowList :: [MemoCacheTag] -> ShowS
show :: MemoCacheTag -> String
$cshow :: MemoCacheTag -> String
showsPrec :: Int -> MemoCacheTag -> ShowS
$cshowsPrec :: Int -> MemoCacheTag -> ShowS
Show, (forall x. MemoCacheTag -> Rep MemoCacheTag x)
-> (forall x. Rep MemoCacheTag x -> MemoCacheTag)
-> Generic MemoCacheTag
forall x. Rep MemoCacheTag x -> MemoCacheTag
forall x. MemoCacheTag -> Rep MemoCacheTag x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MemoCacheTag x -> MemoCacheTag
$cfrom :: forall x. MemoCacheTag -> Rep MemoCacheTag x
Generic )

instance Hashable MemoCacheTag

mkInnerTag :: MemoCacheTag -> MemoCacheTag
mkInnerTag :: MemoCacheTag -> MemoCacheTag
mkInnerTag (NameTag Text
t) = Text -> MemoCacheTag
NameTag (Text
t Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"-inner")

instance Pretty MemoCacheTag where
  pretty :: MemoCacheTag -> Text
pretty (NameTag Text
t) = Text
t

--------------
---- Global metrics store
--------------

#ifdef PROFILE_CACHES
memoCaches :: HT.CuckooHashTable MemoCacheTag MemoCache
memoCaches = unsafePerformIO $ HT.new
{-# NOINLINE memoCaches #-}
#endif

initMetrics :: MemoCacheTag -> AnyHashTable -> IO MemoCache
#ifdef PROFILE_CACHES
initMetrics tag ht = do
    newC <- mkCache ht
    HT.mutate memoCaches
              tag
              (\case Nothing -> (Just newC, newC)
                     Just c  -> let c' = c { contents = ht : contents c}
                                 in (Just c', c'))
#else
initMetrics :: MemoCacheTag -> AnyHashTable -> IO MemoCache
initMetrics MemoCacheTag
_ AnyHashTable
_ = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

resetAllCaches :: IO ()
#ifdef PROFILE_CACHES
resetAllCaches = HT.mapM_ (\(_, c) -> resetCache c) memoCaches
#else
resetAllCaches :: IO MemoCache
resetAllCaches = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#ifdef PROFILE_CACHES
getAllCacheMetrics :: IO [(MemoCacheTag, CacheMetrics)]
getAllCacheMetrics = HT.foldM (\l (k, v) -> getMetrics v >>= \v' -> return ((k, v') : l)) [] memoCaches
  where
    getMetrics :: MemoCache -> IO CacheMetrics
    getMetrics c = CacheMetrics <$> readIORef (queryCount c) <*> readIORef (missCount c)

printAllCacheMetrics :: IO ()
printAllCacheMetrics = do metrics <- getAllCacheMetrics
                          mapM_ (\(tag, cm)-> Text.putStrLn $ "(" <> pretty tag <> ")\t" <> pretty cm)
                                (sort metrics)
#endif

-------------------------------------------------------------
------------------------ Memoization ------------------------
-------------------------------------------------------------


memoIO :: forall a b. (Eq a, Hashable a) => MemoCacheTag -> (a -> b) -> IO (a -> IO b)
memoIO :: MemoCacheTag -> (a -> b) -> IO (a -> IO b)
memoIO MemoCacheTag
tag a -> b
f = do
    ht :: HT.CuckooHashTable a b <- IO (HashTable RealWorld a b)
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
    MemoCache
cache <- MemoCacheTag -> AnyHashTable -> IO MemoCache
initMetrics MemoCacheTag
tag (IOHashTable HashTable a b -> AnyHashTable
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> AnyHashTable
AnyHashTable IOHashTable HashTable a b
ht)
    let f' :: a -> IO b
f' a
x = do MemoCache -> IO MemoCache
bumpQueryCount MemoCache
cache
                  Maybe b
v <- IOHashTable HashTable a b -> a -> IO (Maybe b)
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
HT.lookup IOHashTable HashTable a b
ht a
x
                  case Maybe b
v of
                    Maybe b
Nothing -> do MemoCache -> IO MemoCache
bumpMissCount MemoCache
cache
                                  let r :: b
r = a -> b
f a
x
                                  IOHashTable HashTable a b -> a -> b -> IO MemoCache
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO MemoCache
HT.insert IOHashTable HashTable a b
ht a
x b
r
                                  b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r

                    Just b
r  -> b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r
    (a -> IO b) -> IO (a -> IO b)
forall (m :: * -> *) a. Monad m => a -> m a
return a -> IO b
f'


memo :: (Eq a, Hashable a) => MemoCacheTag -> (a -> b) -> (a -> b)
memo :: MemoCacheTag -> (a -> b) -> a -> b
memo MemoCacheTag
tag a -> b
f = let f' :: a -> IO b
f' = IO (a -> IO b) -> a -> IO b
forall a. IO a -> a
unsafePerformIO (MemoCacheTag -> (a -> b) -> IO (a -> IO b)
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> IO (a -> IO b)
memoIO MemoCacheTag
tag a -> b
f)
             in \a
x -> IO b -> b
forall a. IO a -> a
unsafePerformIO (a -> IO b
f' a
x)

memo2 :: (Eq a, Hashable a, Eq b, Hashable b) => MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 :: MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 MemoCacheTag
tag a -> b -> c
f = MemoCacheTag -> (a -> b -> c) -> a -> b -> c
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo MemoCacheTag
tag (MemoCacheTag -> (b -> c) -> b -> c
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (MemoCacheTag -> MemoCacheTag
mkInnerTag MemoCacheTag
tag) ((b -> c) -> b -> c) -> (a -> b -> c) -> a -> b -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> c
f)