{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections   #-}
module Language.GLSL.Optimizer.Deinline where

import           Control.Applicative                       (ZipList (..))
import           Control.Arrow                             ((&&&))
import qualified Data.List                                 as List
import           Data.Maybe                                (listToMaybe)
import           Debug.Trace                               (trace)
import           Language.GLSL.ConstExpr                   (ConstExprs,
                                                            collectConstExprs)
import qualified Language.GLSL.Optimizer.FunctionGenerator as FunctionGenerator
import qualified Language.GLSL.StructuralEquality          as StructuralEquality
import           Language.GLSL.Types


data Config = Config
  { Config -> Int
maxLookahead :: Int
  -- ^ Maximum number of statements to look ahead for equality.
  --
  --   Increasing this potentially finds more de-inlining opportunities but also
  --   drastically increases the cost of not finding any. This number does not
  --   matter if we always find an opportunity quickly.

  , Config -> Int
minRepeats   :: Int
  -- ^ Minimum number of times a piece of code needs to appear for it to be
  --   worth extracting into a function.

  , Config -> Int
maxRepeats   :: Int
  -- ^ Maximum number of initial repeats to use for maximization. If we find
  --   enough, we're happy and stop looking. Most of the time we'll find fewer
  --   than 10, but sometimes a bit of code is repeated a lot which would slow
  --   down the algorithm significantly.

  , Config -> Int
windowSize   :: Int
  -- ^ Number of statements in the sliding window.
  }


defaultConfig :: Config
defaultConfig :: Config
defaultConfig = Config :: Int -> Int -> Int -> Int -> Config
Config
  { maxLookahead :: Int
maxLookahead = Int
200
  , minRepeats :: Int
minRepeats = Int
3
  , maxRepeats :: Int
maxRepeats = Int
10
  , windowSize :: Int
windowSize = Int
10
  }


pass :: Annot a => Config -> GLSL a -> GLSL a
pass :: Config -> GLSL a -> GLSL a
pass Config
config (GLSL Version
v [TopDecl a]
d) = Version -> [TopDecl a] -> GLSL a
forall a. Version -> [TopDecl a] -> GLSL a
GLSL Version
v ((TopDecl a -> TopDecl a) -> [TopDecl a] -> [TopDecl a]
forall a b. (a -> b) -> [a] -> [b]
map (Config -> TopDecl a -> TopDecl a
forall a. Annot a => Config -> TopDecl a -> TopDecl a
diTopDecl Config
config) [TopDecl a]
d)

diTopDecl :: Annot a => Config -> TopDecl a -> TopDecl a
diTopDecl :: Config -> TopDecl a -> TopDecl a
diTopDecl Config
config (ProcDecl ProcName
fn [ParamDecl]
params [StmtAnnot a]
body) =
  ProcName -> [ParamDecl] -> [StmtAnnot a] -> TopDecl a
forall a. ProcName -> [ParamDecl] -> [StmtAnnot a] -> TopDecl a
ProcDecl ProcName
fn [ParamDecl]
params ([StmtAnnot a] -> TopDecl a) -> [StmtAnnot a] -> TopDecl a
forall a b. (a -> b) -> a -> b
$ Config -> [StmtAnnot a] -> [StmtAnnot a]
forall a. Annot a => Config -> [StmtAnnot a] -> [StmtAnnot a]
diStmts Config
config [StmtAnnot a]
body
diTopDecl Config
_ TopDecl a
d = TopDecl a
d


diStmts :: Annot a => Config -> [StmtAnnot a] -> [StmtAnnot a]
diStmts :: Config -> [StmtAnnot a] -> [StmtAnnot a]
diStmts Config
config [StmtAnnot a]
ss =
  let ce :: ConstExprs
ce = [StmtAnnot a] -> ConstExprs
forall a. [StmtAnnot a] -> ConstExprs
collectConstExprs [StmtAnnot a]
ss in
  case Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
forall a.
Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
config ConstExprs
ce [StmtAnnot a]
ss of
    Maybe [StmtAnnot a]
Nothing -> [StmtAnnot a]
ss
    Just [StmtAnnot a]
body ->
      let _newProc :: String
_newProc = (TopDecl a -> Builder) -> TopDecl a -> String
forall a. (a -> Builder) -> a -> String
pp TopDecl a -> Builder
forall a. Annot a => TopDecl a -> Builder
ppTopDecl ([StmtAnnot a] -> TopDecl a
forall a. [StmtAnnot a] -> TopDecl a
FunctionGenerator.makeFunction [StmtAnnot a]
body) in
      String -> [StmtAnnot a] -> [StmtAnnot a]
forall a. String -> a -> a
trace (
        String
"found one! length = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([StmtAnnot a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StmtAnnot a]
body)
        -- <> "\n" <> ppl ppStmtAnnot body <> "\n\n"
        -- <> newProc
      ) ([StmtAnnot a] -> [StmtAnnot a]) -> [StmtAnnot a] -> [StmtAnnot a]
forall a b. (a -> b) -> a -> b
$ ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
forall a.
ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
deleteBody ConstExprs
ce [StmtAnnot a]
body [StmtAnnot a]
ss


-- | Remove all occurrences of 'body' from 'ss'.
deleteBody :: ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
deleteBody :: ConstExprs -> [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
deleteBody ConstExprs
ce [StmtAnnot a]
body = [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go []
  where
    go :: [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go [StmtAnnot a]
acc [] = [StmtAnnot a] -> [StmtAnnot a]
forall a. [a] -> [a]
reverse [StmtAnnot a]
acc
    go [StmtAnnot a]
acc (StmtAnnot a
s:[StmtAnnot a]
ss) =
      if ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
StructuralEquality.eqStmtAnnots ConstExprs
ce ([StmtAnnot a] -> [StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StmtAnnot a]
body [StmtAnnot a]
ss)
        then [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go (StmtAnnot a
sStmtAnnot a -> [StmtAnnot a] -> [StmtAnnot a]
forall a. a -> [a] -> [a]
:[StmtAnnot a]
acc) (Int -> [StmtAnnot a] -> [StmtAnnot a]
forall a. Int -> [a] -> [a]
drop ([StmtAnnot a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StmtAnnot a]
body) [StmtAnnot a]
ss)
        else [StmtAnnot a] -> [StmtAnnot a] -> [StmtAnnot a]
go (StmtAnnot a
sStmtAnnot a -> [StmtAnnot a] -> [StmtAnnot a]
forall a. a -> [a] -> [a]
:[StmtAnnot a]
acc) [StmtAnnot a]
ss


findBody :: Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody :: Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
_ ConstExprs
_ [] = Maybe [StmtAnnot a]
forall a. Maybe a
Nothing
findBody Config{Int
windowSize :: Int
maxRepeats :: Int
minRepeats :: Int
maxLookahead :: Int
windowSize :: Config -> Int
maxRepeats :: Config -> Int
minRepeats :: Config -> Int
maxLookahead :: Config -> Int
..} ConstExprs
_ (StmtAnnot a
_:[StmtAnnot a]
ss) | [StmtAnnot a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StmtAnnot a]
ss Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
windowSize = Maybe [StmtAnnot a]
forall a. Maybe a
Nothing
findBody config :: Config
config@Config{Int
windowSize :: Int
maxRepeats :: Int
minRepeats :: Int
maxLookahead :: Int
windowSize :: Config -> Int
maxRepeats :: Config -> Int
minRepeats :: Config -> Int
maxLookahead :: Config -> Int
..} ConstExprs
ce (StmtAnnot a
_:[StmtAnnot a]
ss) =
  let
    -- Get a peep hole window of statements.
    window :: [StmtAnnot a]
window = Int -> [StmtAnnot a] -> [StmtAnnot a]
forall a. Int -> [a] -> [a]
take Int
windowSize [StmtAnnot a]
ss

    -- We'll iterate over all possible sub-programs from the current position.
    tails :: [[StmtAnnot a]]
tails = [StmtAnnot a] -> [[StmtAnnot a]]
forall a. [a] -> [[a]]
List.tails [StmtAnnot a]
ss

    -- We want to find similar statements and filter out the empty sub-program
    -- since the empty list is trivially equal to another empty list.
    isSimilar :: [(StmtAnnot a, StmtAnnot a)] -> Bool
isSimilar [(StmtAnnot a, StmtAnnot a)]
l = Bool -> Bool
not ([(StmtAnnot a, StmtAnnot a)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(StmtAnnot a, StmtAnnot a)]
l) Bool -> Bool -> Bool
&& ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. ConstExprs -> [(StmtAnnot a, StmtAnnot a)] -> Bool
StructuralEquality.eqStmtAnnots ConstExprs
ce [(StmtAnnot a, StmtAnnot a)]
l

    -- Try to find a similar set of statements to the window somewhere in the
    -- lookahead range.
    firstRepeat :: Maybe [(StmtAnnot a, StmtAnnot a)]
firstRepeat =
      ([(StmtAnnot a, StmtAnnot a)] -> Bool)
-> [[(StmtAnnot a, StmtAnnot a)]]
-> Maybe [(StmtAnnot a, StmtAnnot a)]
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
List.find [(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. [(StmtAnnot a, StmtAnnot a)] -> Bool
isSimilar
      ([[(StmtAnnot a, StmtAnnot a)]]
 -> Maybe [(StmtAnnot a, StmtAnnot a)])
-> ([[StmtAnnot a]] -> [[(StmtAnnot a, StmtAnnot a)]])
-> [[StmtAnnot a]]
-> Maybe [(StmtAnnot a, StmtAnnot a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)])
-> [[StmtAnnot a]] -> [[(StmtAnnot a, StmtAnnot a)]]
forall a b. (a -> b) -> [a] -> [b]
map ([StmtAnnot a] -> [StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StmtAnnot a]
window)
      ([[StmtAnnot a]] -> [[(StmtAnnot a, StmtAnnot a)]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[(StmtAnnot a, StmtAnnot a)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. Int -> [a] -> [a]
take Int
maxLookahead
      ([[StmtAnnot a]] -> Maybe [(StmtAnnot a, StmtAnnot a)])
-> [[StmtAnnot a]] -> Maybe [(StmtAnnot a, StmtAnnot a)]
forall a b. (a -> b) -> a -> b
$ [[StmtAnnot a]]
tails

    -- If we found one, see how many more we find in the code.
    --
    -- If we find enough, we're happy and stop looking. Most of the time we'll
    -- find fewer than 10, but sometimes a bit of code is repeated a lot which would
    -- slow down the algorithm.
    allRepeats :: [[StmtAnnot a]]
allRepeats =
      Int -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. Int -> [a] -> [a]
take Int
maxRepeats
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]) -> [StmtAnnot a])
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
-> [[StmtAnnot a]]
forall a b. (a -> b) -> [a] -> [b]
map ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]) -> [StmtAnnot a]
forall a b. (a, b) -> a
fst
      ([([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
 -> [[StmtAnnot a]])
-> ([[StmtAnnot a]]
    -> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]) -> Bool)
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
forall a. (a -> Bool) -> [a] -> [a]
filter ([(StmtAnnot a, StmtAnnot a)] -> Bool
forall a. [(StmtAnnot a, StmtAnnot a)] -> Bool
isSimilar ([(StmtAnnot a, StmtAnnot a)] -> Bool)
-> (([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
    -> [(StmtAnnot a, StmtAnnot a)])
-> ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
-> [(StmtAnnot a, StmtAnnot a)]
forall a b. (a, b) -> b
snd)
      ([([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
 -> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])])
-> ([[StmtAnnot a]]
    -> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])])
-> [[StmtAnnot a]]
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a] -> ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)]))
-> [[StmtAnnot a]]
-> [([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])]
forall a b. (a -> b) -> [a] -> [b]
map ([StmtAnnot a] -> [StmtAnnot a]
forall a. a -> a
id ([StmtAnnot a] -> [StmtAnnot a])
-> ([StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)])
-> [StmtAnnot a]
-> ([StmtAnnot a], [(StmtAnnot a, StmtAnnot a)])
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& [StmtAnnot a] -> [StmtAnnot a] -> [(StmtAnnot a, StmtAnnot a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StmtAnnot a]
window)
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a b. (a -> b) -> a -> b
$ [[StmtAnnot a]]
tails

    -- If there are enough repeats to be worth extracting, try to maximise
    -- the amount of code extracted.
    maximised :: [[StmtAnnot a]]
maximised =
      [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. [[a]] -> [[a]]
transpose
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([StmtAnnot a] -> Bool) -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (ConstExprs -> [StmtAnnot a] -> Bool
forall a. ConstExprs -> [StmtAnnot a] -> Bool
allEqual ConstExprs
ce)
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]]
-> [[StmtAnnot a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. [[a]] -> [[a]]
transpose
      ([[StmtAnnot a]] -> [[StmtAnnot a]])
-> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a b. (a -> b) -> a -> b
$ [StmtAnnot a]
ss [StmtAnnot a] -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. a -> [a] -> [a]
: [[StmtAnnot a]]
allRepeats
  in
  case Maybe [(StmtAnnot a, StmtAnnot a)]
firstRepeat of
    -- No matches, continue looking.
    Maybe [(StmtAnnot a, StmtAnnot a)]
Nothing -> Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
forall a.
Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
config ConstExprs
ce [StmtAnnot a]
ss
    -- Found one, but the number of repeats doesn't make it worth
    -- extracting into a function (minRepeats counts the first occurrence
    -- which is in the window and not in allRepeats).
    Just [(StmtAnnot a, StmtAnnot a)]
_ | [[StmtAnnot a]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Int -> [[StmtAnnot a]] -> [[StmtAnnot a]]
forall a. Int -> [a] -> [a]
take (Int
minRepeats Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [[StmtAnnot a]]
allRepeats) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minRepeats Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 ->
      Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
forall a.
Config -> ConstExprs -> [StmtAnnot a] -> Maybe [StmtAnnot a]
findBody Config
config ConstExprs
ce [StmtAnnot a]
ss
    -- Found one with several repeats, we'll extract this one.
    Just [(StmtAnnot a, StmtAnnot a)]
_ -> [[StmtAnnot a]] -> Maybe [StmtAnnot a]
forall a. [a] -> Maybe a
listToMaybe [[StmtAnnot a]]
maximised


transpose :: [[a]] -> [[a]]
transpose :: [[a]] -> [[a]]
transpose = ZipList [a] -> [[a]]
forall a. ZipList a -> [a]
getZipList (ZipList [a] -> [[a]]) -> ([[a]] -> ZipList [a]) -> [[a]] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> ZipList a) -> [[a]] -> ZipList [a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse [a] -> ZipList a
forall a. [a] -> ZipList a
ZipList

-- | Check for each statement whether it's structurally equal to the first one.
allEqual :: ConstExprs -> [StmtAnnot a] -> Bool
allEqual :: ConstExprs -> [StmtAnnot a] -> Bool
allEqual ConstExprs
_ []      = Bool
True
allEqual ConstExprs
ce (StmtAnnot a
x:[StmtAnnot a]
xs) = (StmtAnnot a -> Bool) -> [StmtAnnot a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ConstExprs -> StmtAnnot a -> StmtAnnot a -> Bool
forall a. ConstExprs -> StmtAnnot a -> StmtAnnot a -> Bool
StructuralEquality.eqStmtAnnot ConstExprs
ce StmtAnnot a
x) [StmtAnnot a]
xs