{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
module Application.TermSearch.Evaluation
( runBenchmark
) where
import Control.Monad ( forM_ )
import Data.Time ( diffUTCTime
, getCurrentTime
)
import System.IO ( hFlush
, stdout
)
import System.Timeout
import qualified Data.Bifunctor as Bi
import qualified Data.Text as Text
import qualified Data.Text.IO as Text
import Data.ECTA
import Data.ECTA.Term
import Application.TermSearch.Dataset
import Application.TermSearch.TermSearch
import Application.TermSearch.Type
import Application.TermSearch.Utils
import qualified Data.Interned.Extended.HashTableBased as Interned
import Data.Interned.Extended.HashTableBased ( cache )
import qualified Data.Memoization as Memoization
import Data.Text.Extended.Pretty
printCacheStatsForReduction :: Node -> IO Node
printCacheStatsForReduction :: Node -> IO Node
printCacheStatsForReduction Node
n = do
let n' :: Node
n' = Node -> Node
reduceFully Node
n
#ifdef PROFILE_CACHES
Text.putStrLn $ "Nodes: " <> Text.pack (show (nodeCount n'))
Text.putStrLn $ "Edges: " <> Text.pack (show (edgeCount n'))
Text.putStrLn $ "Max indegree: " <> Text.pack (show (maxIndegree n'))
Memoization.printAllCacheMetrics
Text.putStrLn =<< (pretty <$> Interned.getMetrics (cache @Node))
Text.putStrLn =<< (pretty <$> Interned.getMetrics (cache @Edge))
Text.putStrLn ""
#endif
Handle -> IO ()
hFlush Handle
stdout
Node -> IO Node
forall (m :: * -> *) a. Monad m => a -> m a
return Node
n'
runBenchmark :: Benchmark -> AblationType -> Int -> IO ()
runBenchmark :: Benchmark -> AblationType -> Int -> IO ()
runBenchmark (Benchmark Text
name Int
size Term
sol [(Text, TypeSkeleton)]
args TypeSkeleton
res) AblationType
ablation Int
limit = do
String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Running benchmark " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Text -> String
Text.unpack Text
name
let argNodes :: [(Symbol, Node)]
argNodes = ((Text, TypeSkeleton) -> (Symbol, Node))
-> [(Text, TypeSkeleton)] -> [(Symbol, Node)]
forall a b. (a -> b) -> [a] -> [b]
map ((Text -> Symbol)
-> (TypeSkeleton -> Node) -> (Text, TypeSkeleton) -> (Symbol, Node)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
Bi.bimap Text -> Symbol
Symbol TypeSkeleton -> Node
typeToFta) [(Text, TypeSkeleton)]
args
let resNode :: Node
resNode = TypeSkeleton -> Node
typeToFta TypeSkeleton
res
UTCTime
start <- IO UTCTime
getCurrentTime
Maybe ()
_ <- Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
timeout (Int
limit Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
6 :: Int)) (IO () -> IO (Maybe ())) -> IO () -> IO (Maybe ())
forall a b. (a -> b) -> a -> b
$ [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
size] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ [(Symbol, Node)] -> Node -> Int -> IO ()
synthesize [(Symbol, Node)]
argNodes Node
resNode
UTCTime
end <- IO UTCTime
getCurrentTime
String -> IO ()
forall a. Show a => a -> IO ()
print (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Time: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ NominalDiffTime -> String
forall a. Show a => a -> String
show (UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
end UTCTime
start)
Handle -> IO ()
hFlush Handle
stdout
where
synthesize :: [Argument] -> Node -> Int -> IO ()
synthesize :: [(Symbol, Node)] -> Node -> Int -> IO ()
synthesize [(Symbol, Node)]
argNodes Node
resNode Int
sz = do
let anyArg :: Node
anyArg = [Edge] -> Node
Node (((Symbol, Node) -> Edge) -> [(Symbol, Node)] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map ((Symbol -> Node -> Edge) -> (Symbol, Node) -> Edge
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Symbol -> Node -> Edge
constArg) [(Symbol, Node)]
argNodes)
let !filterNode :: Node
filterNode = Node -> Node -> Node
filterType (Node -> [(Symbol, Node)] -> Int -> Node
relevantTermsOfSize Node
anyArg [(Symbol, Node)]
argNodes Int
sz) Node
resNode
case AblationType
ablation of
AblationType
NoReduction -> do
AblationType -> Term -> Node -> IO ()
prettyPrintAllTerms AblationType
ablation (Term -> Term
substTerm Term
sol) Node
filterNode
AblationType
NoOptimize -> do
AblationType -> Term -> Node -> IO ()
prettyPrintAllTerms AblationType
ablation (Term -> Term
substTerm Term
sol) Node
filterNode
AblationType
_ -> do
#ifdef PROFILE_CACHES
reducedNode <- printCacheStatsForReduction filterNode
#else
Node
reducedNode <- Node -> IO Node
reduceFullyAndLog Node
filterNode
#endif
let foldedNode :: Node
foldedNode = Node -> Node
refold Node
reducedNode
AblationType -> Term -> Node -> IO ()
prettyPrintAllTerms AblationType
ablation (Term -> Term
substTerm Term
sol) Node
foldedNode