{-# LANGUAGE FlexibleContexts #-}

-- |
-- Module      : Jikka.RestrictedPython.Convert.TypeInfer
-- Description : split a for-loop into many small for-loops based on the dependency graph of variables and assignments. / 変数と代入の依存関係グラフに基づいて、for ループを複数の小さな for ループに分割します。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.RestrictedPython.Convert.SplitLoops
  ( run,
    run',
    runForLoop,
  )
where

import Data.List (partition)
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.RestrictedPython.Convert.Alpha as Alpha (run)
import Jikka.RestrictedPython.Language.Expr
import Jikka.RestrictedPython.Language.Lint
import Jikka.RestrictedPython.Language.Util
import Jikka.RestrictedPython.Language.VariableAnalysis

-- | `runForLoop` splits a for-loop to many for-loops as possible.
-- This assumes that `doesntHaveSubscriptionInLoopCounters`, `doesntHaveAssignmentToLoopCounters`, and `doesntHaveAssignmentToLoopIterators` hold.
--
-- This function analyzes read-variables and write-variables in statements, and split statements into connected components.
runForLoop :: Target' -> Expr' -> [Statement] -> [Statement]
runForLoop :: Target' -> Expr' -> [Statement] -> [Statement]
runForLoop Target'
x Expr'
iter [Statement]
body =
  let connected :: (a, (ReadList, WriteList)) -> (a, (ReadList, WriteList)) -> Bool
connected (a
_, (ReadList
r, WriteList
w)) (a
_, (ReadList
r', WriteList
w')) = WriteList -> ReadList -> Bool
haveWriteReadIntersection WriteList
w ReadList
r' Bool -> Bool -> Bool
|| WriteList -> ReadList -> Bool
haveWriteReadIntersection WriteList
w' ReadList
r
      go :: [Statement] -> [(Statement, (ReadList, WriteList))] -> [Statement]
go [Statement]
result [] = [Statement] -> [Statement]
forall a. [a] -> [a]
reverse [Statement]
result
      go [Statement]
result ((Statement, (ReadList, WriteList))
stmt : [(Statement, (ReadList, WriteList))]
stmts) =
        let ([(Statement, (ReadList, WriteList))]
same, [(Statement, (ReadList, WriteList))]
diff) = ((Statement, (ReadList, WriteList)) -> Bool)
-> [(Statement, (ReadList, WriteList))]
-> ([(Statement, (ReadList, WriteList))],
    [(Statement, (ReadList, WriteList))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Statement, (ReadList, WriteList))
-> (Statement, (ReadList, WriteList)) -> Bool
forall a a.
(a, (ReadList, WriteList)) -> (a, (ReadList, WriteList)) -> Bool
connected (Statement, (ReadList, WriteList))
stmt) [(Statement, (ReadList, WriteList))]
stmts
         in [Statement] -> [(Statement, (ReadList, WriteList))] -> [Statement]
go (Target' -> Expr' -> [Statement] -> Statement
For Target'
x Expr'
iter (((Statement, (ReadList, WriteList)) -> Statement)
-> [(Statement, (ReadList, WriteList))] -> [Statement]
forall a b. (a -> b) -> [a] -> [b]
map (Statement, (ReadList, WriteList)) -> Statement
forall a b. (a, b) -> a
fst ((Statement, (ReadList, WriteList))
stmt (Statement, (ReadList, WriteList))
-> [(Statement, (ReadList, WriteList))]
-> [(Statement, (ReadList, WriteList))]
forall a. a -> [a] -> [a]
: [(Statement, (ReadList, WriteList))]
same)) Statement -> [Statement] -> [Statement]
forall a. a -> [a] -> [a]
: [Statement]
result) [(Statement, (ReadList, WriteList))]
diff
      body' :: [(Statement, (ReadList, WriteList))]
body' = (Statement -> (Statement, (ReadList, WriteList)))
-> [Statement] -> [(Statement, (ReadList, WriteList))]
forall a b. (a -> b) -> [a] -> [b]
map (\Statement
stmt -> (Statement
stmt, Statement -> (ReadList, WriteList)
analyzeStatementMax Statement
stmt)) [Statement]
body
   in [Statement] -> [(Statement, (ReadList, WriteList))] -> [Statement]
go [] [(Statement, (ReadList, WriteList))]
body'

-- | `run'` splits for-loops into many small for-loops as possible.
-- This assumes that `doesntHaveSubscriptionInLoopCounters`, `doesntHaveAssignmentToLoopCounters`, and `doesntHaveAssignmentToLoopIterators` hold.
-- This may introduce name conflicts.
--
-- For example, the following
--
-- > a = 0
-- > b = 0
-- > for i in range(10):
-- >     c = b
-- >     a += i
-- >     b += c
--
-- is split to
--
-- > a = 0
-- > b = 0
-- > for i in range(10):
-- >     c = b
-- >     b += c
-- > for i in range(10):
-- >     a += i
run' :: Program -> Program
run' :: Program -> Program
run' = (Expr' -> [Statement] -> [Statement] -> [Statement])
-> (Target' -> Expr' -> [Statement] -> [Statement])
-> Program
-> Program
mapLargeStatement (\Expr'
e [Statement]
pred1 [Statement]
pred2 -> [Expr' -> [Statement] -> [Statement] -> Statement
If Expr'
e [Statement]
pred1 [Statement]
pred2]) Target' -> Expr' -> [Statement] -> [Statement]
runForLoop

-- | `run` does alpha conversion, check assumptions, and `run'`.
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.RestrictedPython.Convert.SplitLoops" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
Alpha.run Program
prog
  Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureDoesntHaveSubscriptionInLoopCounters Program
prog
  Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureDoesntHaveAssignmentToLoopCounters Program
prog
  Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureDoesntHaveAssignmentToLoopIterators Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return (Program -> m Program) -> Program -> m Program
forall a b. (a -> b) -> a -> b
$ Program -> Program
run' Program
prog
  Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
Alpha.run Program
prog