{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.MemoUtils
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.MemoUtils
  ( -- * Hashtable-based memoization
    htmemo,
    htmemo2,
    htmemo3,
    htmup,
    htmemoFix,
  )
where

import Data.Function (fix)
import qualified Data.HashTable.IO as H
import Data.Hashable (Hashable)
import System.IO.Unsafe (unsafePerformIO)

type HashTable k v = H.BasicHashTable k v

-- | Function memoizer with mutable hash table.
htmemo :: (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo :: forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo k -> a
f = IO (k -> a) -> k -> a
forall a. IO a -> a
unsafePerformIO (IO (k -> a) -> k -> a) -> IO (k -> a) -> k -> a
forall a b. (a -> b) -> a -> b
$ do
  HashTable RealWorld k a
cache <- IO (IOHashTable HashTable k v)
forall {k} {v}. IO (IOHashTable HashTable k v)
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
H.new :: IO (HashTable k v)
  (k -> a) -> IO (k -> a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((k -> a) -> IO (k -> a)) -> (k -> a) -> IO (k -> a)
forall a b. (a -> b) -> a -> b
$ \k
x -> IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
    Maybe a
tryV <- IOHashTable HashTable k a -> k -> IO (Maybe a)
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
H.lookup HashTable RealWorld k a
IOHashTable HashTable k a
cache k
x
    case Maybe a
tryV of
      Maybe a
Nothing -> do
        -- traceM "New value"
        let v :: a
v = k -> a
f k
x
        IOHashTable HashTable k a -> k -> a -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
H.insert HashTable RealWorld k a
IOHashTable HashTable k a
cache k
x a
v
        a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
      Just a
v -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v

-- | Lift a memoizer to work with one more argument.
htmup :: (Eq k, Hashable k) => (b -> c) -> (k -> b) -> (k -> c)
htmup :: forall k b c. (Eq k, Hashable k) => (b -> c) -> (k -> b) -> k -> c
htmup b -> c
mem k -> b
f = (k -> c) -> k -> c
forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo (b -> c
mem (b -> c) -> (k -> b) -> k -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> b
f)

-- | Function memoizer with mutable hash table. Works on binary functions.
htmemo2 :: (Eq k1, Hashable k1, Eq k2, Hashable k2) => (k1 -> k2 -> a) -> (k1 -> k2 -> a)
htmemo2 :: forall k1 k2 a.
(Eq k1, Hashable k1, Eq k2, Hashable k2) =>
(k1 -> k2 -> a) -> k1 -> k2 -> a
htmemo2 = ((k2 -> a) -> k2 -> a) -> (k1 -> k2 -> a) -> k1 -> k2 -> a
forall k b c. (Eq k, Hashable k) => (b -> c) -> (k -> b) -> k -> c
htmup (k2 -> a) -> k2 -> a
forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo

-- | Function memoizer with mutable hash table. Works on ternary functions.
htmemo3 ::
  (Eq k1, Hashable k1, Eq k2, Hashable k2, Eq k3, Hashable k3) =>
  (k1 -> k2 -> k3 -> a) ->
  (k1 -> k2 -> k3 -> a)
htmemo3 :: forall k1 k2 k3 a.
(Eq k1, Hashable k1, Eq k2, Hashable k2, Eq k3, Hashable k3) =>
(k1 -> k2 -> k3 -> a) -> k1 -> k2 -> k3 -> a
htmemo3 = ((k2 -> k3 -> a) -> k2 -> k3 -> a)
-> (k1 -> k2 -> k3 -> a) -> k1 -> k2 -> k3 -> a
forall k b c. (Eq k, Hashable k) => (b -> c) -> (k -> b) -> k -> c
htmup (k2 -> k3 -> a) -> k2 -> k3 -> a
forall k1 k2 a.
(Eq k1, Hashable k1, Eq k2, Hashable k2) =>
(k1 -> k2 -> a) -> k1 -> k2 -> a
htmemo2

-- | Memoizing recursion. Use like 'fix'.
htmemoFix :: (Eq k, Hashable k) => ((k -> a) -> (k -> a)) -> k -> a
htmemoFix :: forall k a. (Eq k, Hashable k) => ((k -> a) -> k -> a) -> k -> a
htmemoFix (k -> a) -> k -> a
h = ((k -> a) -> k -> a) -> k -> a
forall a. (a -> a) -> a
fix ((k -> a) -> k -> a
forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo ((k -> a) -> k -> a) -> ((k -> a) -> k -> a) -> (k -> a) -> k -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (k -> a) -> k -> a
h)