{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-} -- Scheduler
{-# LANGUAGE TypeFamilies #-}
{-|

Definition of 'Scheduler' as a way to control application of rewrite rules.

The 'BackoffScheduler' is a scheduler which implements exponential rule backoff
and is used by default in 'Data.Equality.Saturation.equalitySaturation'

-}
module Data.Equality.Saturation.Scheduler
    ( Scheduler(..), BackoffScheduler
    ) where

import qualified Data.IntMap.Strict as IM
import Data.Equality.Matching

-- | A 'Scheduler' determines whether a certain rewrite rule is banned from
-- being used based on statistics it defines and collects on applied rewrite
-- rules.
class Scheduler s where
    type Stat s

    -- | Scheduler: update stats
    updateStats :: Int                -- ^ Iteration we're in
                -> Int                -- ^ Index of rewrite rule we're updating
                -> Maybe (Stat s)     -- ^ Current stat for this rewrite rule (we already got it so no point in doing a lookup again)
                -> IM.IntMap (Stat s) -- ^ The current stats map
                -> [Match]            -- ^ The list of matches resulting from matching this rewrite rule
                -> IM.IntMap (Stat s) -- ^ The updated map with new stats

    -- Decide whether to apply a matched rule based on its stats and current iteration
    isBanned :: Int -- ^ Iteration we're in
             -> Stat s -- ^ Stats for the rewrite rule
             -> Bool -- ^ Whether the rule should be applied or not

-- | A 'Scheduler' that implements exponentional rule backoff.
--
-- For each rewrite, there exists a configurable initial match limit. If a rewrite
-- search yield more than this limit, then we ban this rule for number of
-- iterations, double its limit, and double the time it will be banned next time.
--
-- This seems effective at preventing explosive rules like associativity from
-- taking an unfair amount of resources.
--
-- Originaly in [egg](https://docs.rs/egg/0.6.0/egg/struct.BackoffScheduler.html)
data BackoffScheduler
instance Scheduler BackoffScheduler where
    type Stat BackoffScheduler = BoSchStat

    updateStats :: Int
-> Int
-> Maybe (Stat BackoffScheduler)
-> IntMap (Stat BackoffScheduler)
-> [Match]
-> IntMap (Stat BackoffScheduler)
updateStats Int
i Int
rw Maybe (Stat BackoffScheduler)
currentStat IntMap (Stat BackoffScheduler)
stats [Match]
matches =

        if Int
total_len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
threshold

          then
            (Maybe BoSchStat -> Maybe BoSchStat)
-> Int -> IntMap BoSchStat -> IntMap BoSchStat
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IM.alter Maybe BoSchStat -> Maybe BoSchStat
updateBans Int
rw IntMap BoSchStat
IntMap (Stat BackoffScheduler)
stats

          else
            IntMap (Stat BackoffScheduler)
stats

        where

          -- TODO: Overall difficult, and buggy at the moment.
          total_len :: Int
total_len = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Match -> Int) -> [Match] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (IntMap Int -> Int
forall a. IntMap a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IntMap Int -> Int) -> (Match -> IntMap Int) -> Match -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Match -> IntMap Int
matchSubst) [Match]
matches)

          defaultMatchLimit :: Int
defaultMatchLimit = Int
1000
          defaultBanLength :: Int
defaultBanLength  = Int
10

          bannedN :: Int
bannedN = case Maybe (Stat BackoffScheduler)
currentStat of
                      Maybe (Stat BackoffScheduler)
Nothing -> Int
0;
                      Just (BoSchStat -> Int
Stat BackoffScheduler -> Int
timesBanned -> Int
n) -> Int
n

          threshold :: Int
threshold = Int
defaultMatchLimit Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
bannedN)

          ban_length :: Int
ban_length = Int
defaultBanLength Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
bannedN)

          updateBans :: Maybe BoSchStat -> Maybe BoSchStat
updateBans = \case
            Maybe BoSchStat
Nothing -> BoSchStat -> Maybe BoSchStat
forall a. a -> Maybe a
Just (Int -> Int -> BoSchStat
BSS (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ban_length) Int
1)
            Just (BSS Int
_ Int
n)  -> BoSchStat -> Maybe BoSchStat
forall a. a -> Maybe a
Just (Int -> Int -> BoSchStat
BSS (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ban_length) (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
    {-# SCC updateStats #-}

    isBanned :: Int -> Stat BackoffScheduler -> Bool
isBanned Int
i Stat BackoffScheduler
s = Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< BoSchStat -> Int
bannedUntil BoSchStat
Stat BackoffScheduler
s


data BoSchStat = BSS { BoSchStat -> Int
bannedUntil :: {-# UNPACK #-} !Int
                     , BoSchStat -> Int
timesBanned :: {-# UNPACK #-} !Int
                     } deriving Int -> BoSchStat -> ShowS
[BoSchStat] -> ShowS
BoSchStat -> String
(Int -> BoSchStat -> ShowS)
-> (BoSchStat -> String)
-> ([BoSchStat] -> ShowS)
-> Show BoSchStat
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BoSchStat -> ShowS
showsPrec :: Int -> BoSchStat -> ShowS
$cshow :: BoSchStat -> String
show :: BoSchStat -> String
$cshowList :: [BoSchStat] -> ShowS
showList :: [BoSchStat] -> ShowS
Show