{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections   #-}
module Language.GLSL.Optimizer.Deinline where

import           Control.Applicative                       (ZipList (..))
import           Control.Arrow                             ((&&&))
import qualified Data.List                                 as List
import           Data.Maybe                                (listToMaybe)
import           Debug.Trace                               (trace)
import           Language.GLSL.AST
import           Language.GLSL.ConstExpr                   (ConstExprs,
                                                            collectConstExprs)
import qualified Language.GLSL.Optimizer.FunctionGenerator as FunctionGenerator
import           Language.GLSL.PrettyPrint                 (pp, ppTopDecl)
import qualified Language.GLSL.StructuralEquality          as StructuralEquality


data Config = Config
  { Config -> Int
maxLookahead :: Int
  -- ^ Maximum number of statements to look ahead for equality.
  --
  --   Increasing this potentially finds more de-inlining opportunities but also
  --   drastically increases the cost of not finding any. This number does not
  --   matter if we always find an opportunity quickly.

  , Config -> Int
minRepeats   :: Int
  -- ^ Minimum number of times a piece of code needs to appear for it to be
  --   worth extracting into a function.

  , Config -> Int
maxRepeats   :: Int
  -- ^ Maximum number of initial repeats to use for maximization. If we find
  --   enough, we're happy and stop looking. Most of the time we'll find fewer
  --   than 10, but sometimes a bit of code is repeated a lot which would slow
  --   down the algorithm significantly.

  , Config -> Int
windowSize   :: Int
  -- ^ Number of statements in the sliding window.
  }


defaultConfig :: Config
defaultConfig :: Config
defaultConfig = Config :: Int -> Int -> Int -> Int -> Config
Config
  { maxLookahead :: Int
maxLookahead = Int
200
  , minRepeats :: Int
minRepeats = Int
3
  , maxRepeats :: Int
maxRepeats = Int
10
  , windowSize :: Int
windowSize = Int
10
  }


pass :: Annot a => Config -> GLSL a -> GLSL a
pass :: Config -> GLSL a -> GLSL a
pass Config
config (GLSL Version
v [TopDecl a]
d) = Version -> [TopDecl a] -> GLSL a
forall a. Version -> [TopDecl a] -> GLSL a
GLSL Version
v ((TopDecl a -> TopDecl a) -> [TopDecl a] -> [TopDecl a]
forall a b. (a -> b) -> [a] -> [b]
map (Config -> TopDecl a -> TopDecl a
forall a. Annot a => Config -> TopDecl a -> TopDecl a
diTopDecl Config
config) [TopDecl a]
d)

diTopDecl :: Annot a => Config -> TopDecl a -> TopDecl a
diTopDecl :: Config -> TopDecl a -> TopDecl a
diTopDecl Config
config (ProcDecl ProcName
fn [ParamDecl]
params [StmtAnnot a]
body) =
  ProcName -> [ParamDecl] -> [StmtAnnot a] -> TopDecl a
forall a. ProcName -> [ParamDecl] -> [StmtAnnot a] -> TopDecl a
ProcDecl ProcName
fn [ParamDecl]
params ([StmtAnnot a] -> TopDecl a) -> [StmtAnnot a] -> TopDecl a
forall a b. (a -> b) -> a -> b
$ Config -> [StmtAnnot a] -> [StmtAnnot a]
forall a. Annot a => Config -> [StmtAnnot a] -> [StmtAnnot a]
diStmts Config
config [StmtAnnot a]
body
diTopDecl Config
_ TopDecl a
d = TopDecl a
d


diStmts :: Annot a => Config -> [StmtAnnot a] -> [StmtAnnot a]
diStmts :: Config -> [StmtAnnot a] -> [StmtAnnot a]
diStmts Config
config [StmtAnnot a]
ss =
  let ce :: Maybe ConstExprs
ce = ConstExprs -> Maybe ConstExprs
forall a. a -> Maybe a
Just ([StmtAnnot a] -> ConstExprs
forall a. [StmtAnnot a] -> ConstExprs
collectConstExprs [StmtAnnot a]
ss) in
  case Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
forall a.
Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
config Maybe ConstExprs
ce [StmtAnnot a]
ss of
    Maybe [StmtAnnot a]
Nothing -> [StmtAnnot a]
ss
    Just [StmtAnnot a]
body ->
      let _newProc :: String
_newProc = (TopDecl a -> Builder) -> TopDecl a -> String
forall a. (a -> Builder) -> a -> String
pp TopDecl a -> Builder
forall a. Annot a => TopDecl a -> Builder
ppTopDecl ([StmtAnnot a] -> TopDecl a
forall a. [StmtAnnot a] -> TopDecl a
FunctionGenerator.makeFunction [StmtAnnot a]
body) in
      String -> [StmtAnnot a] -> [StmtAnnot a]
forall a. String -> a -> a
trace (
        String
"found one! length = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([StmtAnnot a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StmtAnnot a]
body)
        -- <> "\n" <> ppl ppStmtAnnot body <> "\n\n"
        -- <> newProc
      ) ([StmtAnnot a] -> [StmtAnnot a]) -> [StmtAnnot a] -> [StmtAnnot a]
forall a b. (a -> b) -> a -> b
$ Maybe ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
forall a.
Maybe ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
deleteBody Maybe ConstExprs
ce [StmtAnnot a]
body [StmtAnnot a]
ss


-- | Remove all occurrences of 'body' from 'ss'.
deleteBody :: Maybe ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
deleteBody :: Maybe ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
deleteBody Maybe ConstExprs
ce [StmtAnnot a]
body = [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go []
  where
    go :: [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go [StmtAnnot a]
acc [] = [StmtAnnot a] -> [StmtAnnot a]
forall a. [a] -> [a]
reverse [StmtAnnot a]
acc
    go [StmtAnnot a]
acc (StmtAnnot a
s:[StmtAnnot a]
ss) =
      if Maybe ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. Maybe ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
StructuralEquality.eqStmtAnnots Maybe ConstExprs
ce ([StmtAnnot a] -> [StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StmtAnnot a]
body [StmtAnnot a]
ss)
        then [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go (StmtAnnot a
sStmtAnnot a -> [StmtAnnot a] -> [StmtAnnot a]
forall a. a -> [a] -> [a]
:[StmtAnnot a]
acc) (Int -> [StmtAnnot a] -> [StmtAnnot a]
forall a. Int -> [a] -> [a]
drop ([StmtAnnot a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StmtAnnot a]
body) [StmtAnnot a]
ss)
        else [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go (StmtAnnot a
sStmtAnnot a -> [StmtAnnot a] -> [StmtAnnot a]
forall a. a -> [a] -> [a]
:[StmtAnnot a]
acc) [StmtAnnot a]
ss


findBody :: Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody :: Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
_ Maybe ConstExprs
_ [] = Maybe [StmtAnnot a]
forall a. Maybe a
Nothing
findBody Config{Int
windowSize :: Int
maxRepeats :: Int
minRepeats :: Int
maxLookahead :: Int
windowSize :: Config -> Int
maxRepeats :: Config -> Int
minRepeats :: Config -> Int
maxLookahead :: Config -> Int
..} Maybe ConstExprs
_ (StmtAnnot a
_:[StmtAnnot a]
ss) | [StmtAnnot a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StmtAnnot a]
ss Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
windowSize = Maybe [StmtAnnot a]
forall a. Maybe a
Nothing
findBody config :: Config
config@Config{Int
windowSize :: Int
maxRepeats :: Int
minRepeats :: Int
maxLookahead :: Int
windowSize :: Config -> Int
maxRepeats :: Config -> Int
minRepeats :: Config -> Int
maxLookahead :: Config -> Int
..} Maybe ConstExprs
ce (StmtAnnot a
_:[StmtAnnot a]
ss) =
  let
    -- Get a peep hole window of statements.
    window :: [StmtAnnot a]
window = Int -> [StmtAnnot a] -> [StmtAnnot a]
forall a. Int -> [a] -> [a]
take Int
windowSize [StmtAnnot a]
ss

    -- We'll iterate over all possible sub-programs from the current position.
    tails :: [[StmtAnnot a]]
tails = [StmtAnnot a] -> [[StmtAnnot a]]
forall a. [a] -> [[a]]
List.tails [StmtAnnot a]
ss

    -- We want to find similar statements and filter out the empty sub-program
    -- since the empty list is trivially equal to another empty list.
    isSimilar :: [(StmtAnnot a, StmtAnnot a)] -> Bool
isSimilar [(StmtAnnot a, StmtAnnot a)]
l = Bool -> Bool
not ([(StmtAnnot a, StmtAnnot a)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(StmtAnnot a, StmtAnnot a)]
l) Bool -> Bool -> Bool
&& Maybe ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. Maybe ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
StructuralEquality.eqStmtAnnots Maybe ConstExprs
ce [(StmtAnnot a, StmtAnnot a)]
l

    -- Try to find a similar set of statements to the window somewhere in the
    -- lookahead range.
    firstRepeat :: Maybe [(StmtAnnot a, StmtAnnot a)]
firstRepeat =
      ([(StmtAnnot a, StmtAnnot a)] -> Bool)
-> [[(StmtAnnot a, StmtAnnot a)]]
-> Maybe [(StmtAnnot a, StmtAnnot a)]
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
List.find [(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. [(StmtAnnot a, StmtAnnot a)] -> Bool
isSimilar
      ([[(StmtAnnot a, StmtAnnot a)]]
 -> Maybe [(StmtAnnot a, StmtAnnot a)])
-> ([[StmtAnnot a]] -> [[(StmtAnnot a, StmtAnnot a)]])
-> [[StmtAnnot a]]
-> Maybe [(StmtAnnot a, StmtAnnot a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)])
-> [[StmtAnnot a]] -> [[(StmtAnnot a, StmtAnnot a)]]
forall a b. (a -> b) -> [a] -> [b]
map ([StmtAnnot a] -> [StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StmtAnnot a]
window)
      ([[StmtAnnot a]] -> [[(StmtAnnot a, StmtAnnot a)]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[(StmtAnnot a, StmtAnnot a)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. Int -> [a] -> [a]
take Int
maxLookahead
      ([[StmtAnnot a]] -> Maybe [(StmtAnnot a, StmtAnnot a)])
-> [[StmtAnnot a]] -> Maybe [(StmtAnnot a, StmtAnnot a)]
forall a b. (a -> b) -> a -> b
$ [[StmtAnnot a]]
tails

    -- If we found one, see how many more we find in the code.
    --
    -- If we find enough, we're happy and stop looking. Most of the time we'll
    -- find fewer than 10, but sometimes a bit of code is repeated a lot which would
    -- slow down the algorithm.
    allRepeats :: [[StmtAnnot a]]
allRepeats =
      Int -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. Int -> [a] -> [a]
take Int
maxRepeats
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]) -> [StmtAnnot a])
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
-> [[StmtAnnot a]]
forall a b. (a -> b) -> [a] -> [b]
map ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]) -> [StmtAnnot a]
forall a b. (a, b) -> a
fst
      ([([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
 -> [[StmtAnnot a]])
-> ([[StmtAnnot a]]
    -> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]) -> Bool)
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
forall a. (a -> Bool) -> [a] -> [a]
filter ([(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. [(StmtAnnot a, StmtAnnot a)] -> Bool
isSimilar ([(StmtAnnot a, StmtAnnot a)] -> Bool)
-> (([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
    -> [(StmtAnnot a, StmtAnnot a)])
-> ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
-> [(StmtAnnot a, StmtAnnot a)]
forall a b. (a, b) -> b
snd)
      ([([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
 -> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])])
-> ([[StmtAnnot a]]
    -> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])])
-> [[StmtAnnot a]]
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a] -> ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]))
-> [[StmtAnnot a]]
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
forall a b. (a -> b) -> [a] -> [b]
map ([StmtAnnot a] -> [StmtAnnot a]
forall a. a -> a
id ([StmtAnnot a] -> [StmtAnnot a])
-> ([StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)])
-> [StmtAnnot a]
-> ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& [StmtAnnot a] -> [StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StmtAnnot a]
window)
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a b. (a -> b) -> a -> b
$ [[StmtAnnot a]]
tails

    -- If there are enough repeats to be worth extracting, try to maximise
    -- the amount of code extracted.
    maximised :: [[StmtAnnot a]]
maximised =
      [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. [[a]] -> [[a]]
transpose
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a] -> Bool) -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Maybe ConstExprs -> [StmtAnnot a] -> Bool
forall a. Maybe ConstExprs -> [StmtAnnot a] -> Bool
allEqual Maybe ConstExprs
ce)
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. [[a]] -> [[a]]
transpose
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a b. (a -> b) -> a -> b
$ [StmtAnnot a]
ss [StmtAnnot a] -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. a -> [a] -> [a]
: [[StmtAnnot a]]
allRepeats
  in
  case Maybe [(StmtAnnot a, StmtAnnot a)]
firstRepeat of
    -- No matches, continue looking.
    Maybe [(StmtAnnot a, StmtAnnot a)]
Nothing -> Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
forall a.
Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
config Maybe ConstExprs
ce [StmtAnnot a]
ss
    -- Found one, but the number of repeats doesn't make it worth
    -- extracting into a function (minRepeats counts the first occurrence
    -- which is in the window and not in allRepeats).
    Just [(StmtAnnot a, StmtAnnot a)]
_ | [[StmtAnnot a]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Int -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. Int -> [a] -> [a]
take (Int
minRepeats Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [[StmtAnnot a]]
allRepeats) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minRepeats Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 ->
      Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
forall a.
Config -> Maybe ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
config Maybe ConstExprs
ce [StmtAnnot a]
ss
    -- Found one with several repeats, we'll extract this one.
    Just [(StmtAnnot a, StmtAnnot a)]
_ -> [[StmtAnnot a]] -> Maybe [StmtAnnot a]
forall a. [a] -> Maybe a
listToMaybe [[StmtAnnot a]]
maximised


transpose :: [[a]] -> [[a]]
transpose :: [[a]] -> [[a]]
transpose = ZipList [a] -> [[a]]
forall a. ZipList a -> [a]
getZipList (ZipList [a] -> [[a]]) -> ([[a]] -> ZipList [a]) -> [[a]] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> ZipList a) -> [[a]] -> ZipList [a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse [a] -> ZipList a
forall a. [a] -> ZipList a
ZipList

-- | Check for each statement whether it's structurally equal to the first one.
allEqual :: Maybe ConstExprs -> [StmtAnnot a] -> Bool
allEqual :: Maybe ConstExprs -> [StmtAnnot a] -> Bool
allEqual Maybe ConstExprs
_ []      = Bool
True
allEqual Maybe ConstExprs
ce (StmtAnnot a
x:[StmtAnnot a]
xs) = (StmtAnnot a -> Bool) -> [StmtAnnot a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe ConstExprs -> StmtAnnot a -> StmtAnnot a -> Bool
forall a. Maybe ConstExprs -> StmtAnnot a -> StmtAnnot a -> Bool
StructuralEquality.eqStmtAnnot Maybe ConstExprs
ce StmtAnnot a
x) [StmtAnnot a]
xs