{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

module Application.TermSearch.Evaluation
    ( runBenchmark
    ) where

import           Control.Monad                  ( forM_ )
import           Data.Time                      ( diffUTCTime
                                                , getCurrentTime
                                                )
import           System.IO                      ( hFlush
                                                , stdout
                                                )
import           System.Timeout

import qualified Data.Bifunctor                as Bi
import qualified Data.Text                     as Text
import qualified Data.Text.IO                  as Text

import           Data.ECTA
import           Data.ECTA.Term

import           Application.TermSearch.Dataset
import           Application.TermSearch.TermSearch
import           Application.TermSearch.Type
import           Application.TermSearch.Utils

import qualified Data.Interned.Extended.HashTableBased as Interned
import           Data.Interned.Extended.HashTableBased ( cache )
import qualified Data.Memoization                      as Memoization
import           Data.Text.Extended.Pretty

printCacheStatsForReduction :: Node -> IO Node
printCacheStatsForReduction :: Node -> IO Node
printCacheStatsForReduction Node
n = do
    let n' :: Node
n' = Node -> Node
reduceFully Node
n
#ifdef PROFILE_CACHES
    Text.putStrLn $ "Nodes: "        <> Text.pack (show (nodeCount   n'))
    Text.putStrLn $ "Edges: "        <> Text.pack (show (edgeCount   n'))
    Text.putStrLn $ "Max indegree: " <> Text.pack (show (maxIndegree n'))
    Memoization.printAllCacheMetrics
    Text.putStrLn =<< (pretty <$> Interned.getMetrics (cache @Node))
    Text.putStrLn =<< (pretty <$> Interned.getMetrics (cache @Edge))
    Text.putStrLn ""
#endif
    Handle -> IO ()
hFlush Handle
stdout
    Node -> IO Node
forall (m :: * -> *) a. Monad m => a -> m a
return Node
n'

runBenchmark :: Benchmark -> AblationType -> Int -> IO ()
runBenchmark :: Benchmark -> AblationType -> Int -> IO ()
runBenchmark (Benchmark Text
name Int
size Term
sol [(Text, TypeSkeleton)]
args TypeSkeleton
res) AblationType
ablation Int
limit = do
    String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Running benchmark " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Text -> String
Text.unpack Text
name

    let argNodes :: [(Symbol, Node)]
argNodes = ((Text, TypeSkeleton) -> (Symbol, Node))
-> [(Text, TypeSkeleton)] -> [(Symbol, Node)]
forall a b. (a -> b) -> [a] -> [b]
map ((Text -> Symbol)
-> (TypeSkeleton -> Node) -> (Text, TypeSkeleton) -> (Symbol, Node)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
Bi.bimap Text -> Symbol
Symbol TypeSkeleton -> Node
typeToFta) [(Text, TypeSkeleton)]
args
    let resNode :: Node
resNode  = TypeSkeleton -> Node
typeToFta TypeSkeleton
res

    UTCTime
start <- IO UTCTime
getCurrentTime
    Maybe ()
_ <- Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
timeout (Int
limit Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
6 :: Int)) (IO () -> IO (Maybe ())) -> IO () -> IO (Maybe ())
forall a b. (a -> b) -> a -> b
$ [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
size] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ [(Symbol, Node)] -> Node -> Int -> IO ()
synthesize [(Symbol, Node)]
argNodes Node
resNode
    UTCTime
end <- IO UTCTime
getCurrentTime
    String -> IO ()
forall a. Show a => a -> IO ()
print (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Time: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ NominalDiffTime -> String
forall a. Show a => a -> String
show (UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
end UTCTime
start)
    Handle -> IO ()
hFlush Handle
stdout

  where
    synthesize :: [Argument] -> Node -> Int -> IO ()
    synthesize :: [(Symbol, Node)] -> Node -> Int -> IO ()
synthesize [(Symbol, Node)]
argNodes Node
resNode Int
sz = do
      let anyArg :: Node
anyArg   = [Edge] -> Node
Node (((Symbol, Node) -> Edge) -> [(Symbol, Node)] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map ((Symbol -> Node -> Edge) -> (Symbol, Node) -> Edge
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Symbol -> Node -> Edge
constArg) [(Symbol, Node)]
argNodes)
      let !filterNode :: Node
filterNode = Node -> Node -> Node
filterType (Node -> [(Symbol, Node)] -> Int -> Node
relevantTermsOfSize Node
anyArg [(Symbol, Node)]
argNodes Int
sz) Node
resNode
      case AblationType
ablation of
          AblationType
NoReduction -> do
              AblationType -> Term -> Node -> IO ()
prettyPrintAllTerms AblationType
ablation (Term -> Term
substTerm Term
sol) Node
filterNode
          AblationType
NoOptimize  -> do
              AblationType -> Term -> Node -> IO ()
prettyPrintAllTerms AblationType
ablation (Term -> Term
substTerm Term
sol) Node
filterNode
          AblationType
_           -> do
#ifdef PROFILE_CACHES
              reducedNode <- printCacheStatsForReduction filterNode
#else
              Node
reducedNode <- Node -> IO Node
reduceFullyAndLog Node
filterNode
#endif
              -- let reducedNode = reduceFully filterNode
              let foldedNode :: Node
foldedNode = Node -> Node
refold Node
reducedNode
              AblationType -> Term -> Node -> IO ()
prettyPrintAllTerms AblationType
ablation (Term -> Term
substTerm Term
sol) Node
foldedNode