{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Memoization (
MemoCacheTag(..)
, resetAllCaches
#ifdef PROFILE_CACHES
, getAllCacheMetrics
, printAllCacheMetrics
#endif
, memoIO
, memo
, memo2
) where
import Data.Hashable ( Hashable )
import qualified Data.HashTable.IO as HT
import Data.Text ( Text )
import GHC.Generics ( Generic )
import System.IO.Unsafe ( unsafePerformIO )
import Data.HashTable.Extended
import Data.Text.Extended.Pretty
#ifdef PROFILE_CACHES
import Data.IORef ( IORef, newIORef, readIORef, writeIORef, modifyIORef )
import Data.List ( sort )
import Data.Memoization.Metrics ( CacheMetrics(CacheMetrics) )
import qualified Data.Text.IO as Text
#endif
#ifdef PROFILE_CACHES
data MemoCache = MemoCache { queryCount :: !(IORef Int)
, missCount :: !(IORef Int)
, contents :: ![AnyHashTable]
}
mkCache:: AnyHashTable -> IO MemoCache
mkCache ht = MemoCache <$> newIORef 0 <*> newIORef 0 <*> pure [ht]
resetCache :: MemoCache -> IO ()
resetCache c = do
writeIORef (queryCount c) 0
writeIORef (missCount c) 0
mapM_ resetHashTable (contents c)
#else
type MemoCache = ()
#endif
bumpQueryCount :: MemoCache -> IO ()
#ifdef PROFILE_CACHES
bumpQueryCount c = modifyIORef (queryCount c) (+1)
#else
bumpQueryCount :: MemoCache -> IO MemoCache
bumpQueryCount MemoCache
_ = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
bumpMissCount :: MemoCache -> IO ()
#ifdef PROFILE_CACHES
bumpMissCount c = modifyIORef (missCount c) (+1)
#else
bumpMissCount :: MemoCache -> IO MemoCache
bumpMissCount MemoCache
_ = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
data MemoCacheTag = NameTag Text
deriving ( MemoCacheTag -> MemoCacheTag -> Bool
(MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool) -> Eq MemoCacheTag
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemoCacheTag -> MemoCacheTag -> Bool
$c/= :: MemoCacheTag -> MemoCacheTag -> Bool
== :: MemoCacheTag -> MemoCacheTag -> Bool
$c== :: MemoCacheTag -> MemoCacheTag -> Bool
Eq, Eq MemoCacheTag
Eq MemoCacheTag
-> (MemoCacheTag -> MemoCacheTag -> Ordering)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> Bool)
-> (MemoCacheTag -> MemoCacheTag -> MemoCacheTag)
-> (MemoCacheTag -> MemoCacheTag -> MemoCacheTag)
-> Ord MemoCacheTag
MemoCacheTag -> MemoCacheTag -> Bool
MemoCacheTag -> MemoCacheTag -> Ordering
MemoCacheTag -> MemoCacheTag -> MemoCacheTag
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
$cmin :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
max :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
$cmax :: MemoCacheTag -> MemoCacheTag -> MemoCacheTag
>= :: MemoCacheTag -> MemoCacheTag -> Bool
$c>= :: MemoCacheTag -> MemoCacheTag -> Bool
> :: MemoCacheTag -> MemoCacheTag -> Bool
$c> :: MemoCacheTag -> MemoCacheTag -> Bool
<= :: MemoCacheTag -> MemoCacheTag -> Bool
$c<= :: MemoCacheTag -> MemoCacheTag -> Bool
< :: MemoCacheTag -> MemoCacheTag -> Bool
$c< :: MemoCacheTag -> MemoCacheTag -> Bool
compare :: MemoCacheTag -> MemoCacheTag -> Ordering
$ccompare :: MemoCacheTag -> MemoCacheTag -> Ordering
$cp1Ord :: Eq MemoCacheTag
Ord, Int -> MemoCacheTag -> ShowS
[MemoCacheTag] -> ShowS
MemoCacheTag -> String
(Int -> MemoCacheTag -> ShowS)
-> (MemoCacheTag -> String)
-> ([MemoCacheTag] -> ShowS)
-> Show MemoCacheTag
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemoCacheTag] -> ShowS
$cshowList :: [MemoCacheTag] -> ShowS
show :: MemoCacheTag -> String
$cshow :: MemoCacheTag -> String
showsPrec :: Int -> MemoCacheTag -> ShowS
$cshowsPrec :: Int -> MemoCacheTag -> ShowS
Show, (forall x. MemoCacheTag -> Rep MemoCacheTag x)
-> (forall x. Rep MemoCacheTag x -> MemoCacheTag)
-> Generic MemoCacheTag
forall x. Rep MemoCacheTag x -> MemoCacheTag
forall x. MemoCacheTag -> Rep MemoCacheTag x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MemoCacheTag x -> MemoCacheTag
$cfrom :: forall x. MemoCacheTag -> Rep MemoCacheTag x
Generic )
instance Hashable MemoCacheTag
mkInnerTag :: MemoCacheTag -> MemoCacheTag
mkInnerTag :: MemoCacheTag -> MemoCacheTag
mkInnerTag (NameTag Text
t) = Text -> MemoCacheTag
NameTag (Text
t Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"-inner")
instance Pretty MemoCacheTag where
pretty :: MemoCacheTag -> Text
pretty (NameTag Text
t) = Text
t
#ifdef PROFILE_CACHES
memoCaches :: HT.CuckooHashTable MemoCacheTag MemoCache
memoCaches = unsafePerformIO $ HT.new
{-# NOINLINE memoCaches #-}
#endif
initMetrics :: MemoCacheTag -> AnyHashTable -> IO MemoCache
#ifdef PROFILE_CACHES
initMetrics tag ht = do
newC <- mkCache ht
HT.mutate memoCaches
tag
(\case Nothing -> (Just newC, newC)
Just c -> let c' = c { contents = ht : contents c}
in (Just c', c'))
#else
initMetrics :: MemoCacheTag -> AnyHashTable -> IO MemoCache
initMetrics MemoCacheTag
_ AnyHashTable
_ = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
resetAllCaches :: IO ()
#ifdef PROFILE_CACHES
resetAllCaches = HT.mapM_ (\(_, c) -> resetCache c) memoCaches
#else
resetAllCaches :: IO MemoCache
resetAllCaches = MemoCache -> IO MemoCache
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#ifdef PROFILE_CACHES
getAllCacheMetrics :: IO [(MemoCacheTag, CacheMetrics)]
getAllCacheMetrics = HT.foldM (\l (k, v) -> getMetrics v >>= \v' -> return ((k, v') : l)) [] memoCaches
where
getMetrics :: MemoCache -> IO CacheMetrics
getMetrics c = CacheMetrics <$> readIORef (queryCount c) <*> readIORef (missCount c)
printAllCacheMetrics :: IO ()
printAllCacheMetrics = do metrics <- getAllCacheMetrics
mapM_ (\(tag, cm)-> Text.putStrLn $ "(" <> pretty tag <> ")\t" <> pretty cm)
(sort metrics)
#endif
memoIO :: forall a b. (Eq a, Hashable a) => MemoCacheTag -> (a -> b) -> IO (a -> IO b)
memoIO :: MemoCacheTag -> (a -> b) -> IO (a -> IO b)
memoIO MemoCacheTag
tag a -> b
f = do
ht :: HT.CuckooHashTable a b <- IO (HashTable RealWorld a b)
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
MemoCache
cache <- MemoCacheTag -> AnyHashTable -> IO MemoCache
initMetrics MemoCacheTag
tag (IOHashTable HashTable a b -> AnyHashTable
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> AnyHashTable
AnyHashTable IOHashTable HashTable a b
ht)
let f' :: a -> IO b
f' a
x = do MemoCache -> IO MemoCache
bumpQueryCount MemoCache
cache
Maybe b
v <- IOHashTable HashTable a b -> a -> IO (Maybe b)
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
HT.lookup IOHashTable HashTable a b
ht a
x
case Maybe b
v of
Maybe b
Nothing -> do MemoCache -> IO MemoCache
bumpMissCount MemoCache
cache
let r :: b
r = a -> b
f a
x
IOHashTable HashTable a b -> a -> b -> IO MemoCache
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO MemoCache
HT.insert IOHashTable HashTable a b
ht a
x b
r
b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r
Just b
r -> b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r
(a -> IO b) -> IO (a -> IO b)
forall (m :: * -> *) a. Monad m => a -> m a
return a -> IO b
f'
memo :: (Eq a, Hashable a) => MemoCacheTag -> (a -> b) -> (a -> b)
memo :: MemoCacheTag -> (a -> b) -> a -> b
memo MemoCacheTag
tag a -> b
f = let f' :: a -> IO b
f' = IO (a -> IO b) -> a -> IO b
forall a. IO a -> a
unsafePerformIO (MemoCacheTag -> (a -> b) -> IO (a -> IO b)
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> IO (a -> IO b)
memoIO MemoCacheTag
tag a -> b
f)
in \a
x -> IO b -> b
forall a. IO a -> a
unsafePerformIO (a -> IO b
f' a
x)
memo2 :: (Eq a, Hashable a, Eq b, Hashable b) => MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 :: MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 MemoCacheTag
tag a -> b -> c
f = MemoCacheTag -> (a -> b -> c) -> a -> b -> c
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo MemoCacheTag
tag (MemoCacheTag -> (b -> c) -> b -> c
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (MemoCacheTag -> MemoCacheTag
mkInnerTag MemoCacheTag
tag) ((b -> c) -> b -> c) -> (a -> b -> c) -> a -> b -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> c
f)