{-# LANGUAGE LambdaCase #-}

module Jikka.RestrictedPython.Language.VariableAnalysis where

import Data.List (delete, intersect, nub)
import Jikka.RestrictedPython.Language.Expr
import Jikka.RestrictedPython.Language.Util

newtype ReadList = ReadList [VarName]
  deriving (ReadList -> ReadList -> Bool
(ReadList -> ReadList -> Bool)
-> (ReadList -> ReadList -> Bool) -> Eq ReadList
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ReadList -> ReadList -> Bool
$c/= :: ReadList -> ReadList -> Bool
== :: ReadList -> ReadList -> Bool
$c== :: ReadList -> ReadList -> Bool
Eq, Eq ReadList
Eq ReadList
-> (ReadList -> ReadList -> Ordering)
-> (ReadList -> ReadList -> Bool)
-> (ReadList -> ReadList -> Bool)
-> (ReadList -> ReadList -> Bool)
-> (ReadList -> ReadList -> Bool)
-> (ReadList -> ReadList -> ReadList)
-> (ReadList -> ReadList -> ReadList)
-> Ord ReadList
ReadList -> ReadList -> Bool
ReadList -> ReadList -> Ordering
ReadList -> ReadList -> ReadList
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ReadList -> ReadList -> ReadList
$cmin :: ReadList -> ReadList -> ReadList
max :: ReadList -> ReadList -> ReadList
$cmax :: ReadList -> ReadList -> ReadList
>= :: ReadList -> ReadList -> Bool
$c>= :: ReadList -> ReadList -> Bool
> :: ReadList -> ReadList -> Bool
$c> :: ReadList -> ReadList -> Bool
<= :: ReadList -> ReadList -> Bool
$c<= :: ReadList -> ReadList -> Bool
< :: ReadList -> ReadList -> Bool
$c< :: ReadList -> ReadList -> Bool
compare :: ReadList -> ReadList -> Ordering
$ccompare :: ReadList -> ReadList -> Ordering
$cp1Ord :: Eq ReadList
Ord, Int -> ReadList -> ShowS
[ReadList] -> ShowS
ReadList -> String
(Int -> ReadList -> ShowS)
-> (ReadList -> String) -> ([ReadList] -> ShowS) -> Show ReadList
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReadList] -> ShowS
$cshowList :: [ReadList] -> ShowS
show :: ReadList -> String
$cshow :: ReadList -> String
showsPrec :: Int -> ReadList -> ShowS
$cshowsPrec :: Int -> ReadList -> ShowS
Show, ReadPrec [ReadList]
ReadPrec ReadList
Int -> ReadS ReadList
ReadS [ReadList]
(Int -> ReadS ReadList)
-> ReadS [ReadList]
-> ReadPrec ReadList
-> ReadPrec [ReadList]
-> Read ReadList
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [ReadList]
$creadListPrec :: ReadPrec [ReadList]
readPrec :: ReadPrec ReadList
$creadPrec :: ReadPrec ReadList
readList :: ReadS [ReadList]
$creadList :: ReadS [ReadList]
readsPrec :: Int -> ReadS ReadList
$creadsPrec :: Int -> ReadS ReadList
Read)

newtype WriteList = WriteList [VarName]
  deriving (WriteList -> WriteList -> Bool
(WriteList -> WriteList -> Bool)
-> (WriteList -> WriteList -> Bool) -> Eq WriteList
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WriteList -> WriteList -> Bool
$c/= :: WriteList -> WriteList -> Bool
== :: WriteList -> WriteList -> Bool
$c== :: WriteList -> WriteList -> Bool
Eq, Eq WriteList
Eq WriteList
-> (WriteList -> WriteList -> Ordering)
-> (WriteList -> WriteList -> Bool)
-> (WriteList -> WriteList -> Bool)
-> (WriteList -> WriteList -> Bool)
-> (WriteList -> WriteList -> Bool)
-> (WriteList -> WriteList -> WriteList)
-> (WriteList -> WriteList -> WriteList)
-> Ord WriteList
WriteList -> WriteList -> Bool
WriteList -> WriteList -> Ordering
WriteList -> WriteList -> WriteList
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: WriteList -> WriteList -> WriteList
$cmin :: WriteList -> WriteList -> WriteList
max :: WriteList -> WriteList -> WriteList
$cmax :: WriteList -> WriteList -> WriteList
>= :: WriteList -> WriteList -> Bool
$c>= :: WriteList -> WriteList -> Bool
> :: WriteList -> WriteList -> Bool
$c> :: WriteList -> WriteList -> Bool
<= :: WriteList -> WriteList -> Bool
$c<= :: WriteList -> WriteList -> Bool
< :: WriteList -> WriteList -> Bool
$c< :: WriteList -> WriteList -> Bool
compare :: WriteList -> WriteList -> Ordering
$ccompare :: WriteList -> WriteList -> Ordering
$cp1Ord :: Eq WriteList
Ord, Int -> WriteList -> ShowS
[WriteList] -> ShowS
WriteList -> String
(Int -> WriteList -> ShowS)
-> (WriteList -> String)
-> ([WriteList] -> ShowS)
-> Show WriteList
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WriteList] -> ShowS
$cshowList :: [WriteList] -> ShowS
show :: WriteList -> String
$cshow :: WriteList -> String
showsPrec :: Int -> WriteList -> ShowS
$cshowsPrec :: Int -> WriteList -> ShowS
Show, ReadPrec [WriteList]
ReadPrec WriteList
Int -> ReadS WriteList
ReadS [WriteList]
(Int -> ReadS WriteList)
-> ReadS [WriteList]
-> ReadPrec WriteList
-> ReadPrec [WriteList]
-> Read WriteList
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [WriteList]
$creadListPrec :: ReadPrec [WriteList]
readPrec :: ReadPrec WriteList
$creadPrec :: ReadPrec WriteList
readList :: ReadS [WriteList]
$creadList :: ReadS [WriteList]
readsPrec :: Int -> ReadS WriteList
$creadsPrec :: Int -> ReadS WriteList
Read)

haveWriteReadIntersection :: WriteList -> ReadList -> Bool
haveWriteReadIntersection :: WriteList -> ReadList -> Bool
haveWriteReadIntersection (WriteList [VarName]
w) (ReadList [VarName]
r) = Bool -> Bool
not ([VarName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VarName]
w [VarName] -> [VarName] -> [VarName]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [VarName]
r))

analyzeExpr :: Expr' -> ReadList
analyzeExpr :: Expr' -> ReadList
analyzeExpr = [VarName] -> ReadList
ReadList ([VarName] -> ReadList)
-> (Expr' -> [VarName]) -> Expr' -> ReadList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr' -> [VarName]
freeVars

analyzeTargetRead :: Target' -> ReadList
analyzeTargetRead :: Target' -> ReadList
analyzeTargetRead = [VarName] -> ReadList
ReadList ([VarName] -> ReadList)
-> (Target' -> [VarName]) -> Target' -> ReadList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target' -> [VarName]
freeVarsTarget

analyzeTargetWrite :: Target' -> WriteList
analyzeTargetWrite :: Target' -> WriteList
analyzeTargetWrite = [VarName] -> WriteList
WriteList ([VarName] -> WriteList)
-> (Target' -> [VarName]) -> Target' -> WriteList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target' -> [VarName]
targetVars

analyzeStatementGeneric :: Bool -> Statement -> (ReadList, WriteList)
analyzeStatementGeneric :: Bool -> Statement -> (ReadList, WriteList)
analyzeStatementGeneric Bool
isMax = \case
  Return Expr'
e -> (Expr' -> ReadList
analyzeExpr Expr'
e, [VarName] -> WriteList
WriteList [])
  AugAssign Target'
x Operator
_ Expr'
e ->
    let w :: WriteList
w = Target' -> WriteList
analyzeTargetWrite Target'
x
        (ReadList [VarName]
r) = Target' -> ReadList
analyzeTargetRead Target'
x
        (ReadList [VarName]
r') = Expr' -> ReadList
analyzeExpr Expr'
e
     in ([VarName] -> ReadList
ReadList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
r [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
r'), WriteList
w)
  AnnAssign Target'
x Type
_ Expr'
e ->
    let w :: WriteList
w = Target' -> WriteList
analyzeTargetWrite Target'
x
        (ReadList [VarName]
r) = Target' -> ReadList
analyzeTargetRead Target'
x
        (ReadList [VarName]
r') = Expr' -> ReadList
analyzeExpr Expr'
e
     in ([VarName] -> ReadList
ReadList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
r [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
r'), WriteList
w)
  For Target'
x Expr'
iter [Statement]
body ->
    let xs :: [VarName]
xs = Target' -> [VarName]
targetVars Target'
x
        ReadList [VarName]
r = Expr' -> ReadList
analyzeExpr Expr'
iter
        (ReadList [VarName]
r', WriteList [VarName]
w) = Bool -> [Statement] -> (ReadList, WriteList)
analyzeStatementsGeneric Bool
isMax [Statement]
body
     in if Bool
isMax
          then ([VarName] -> ReadList
ReadList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
r [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ ([VarName] -> VarName -> [VarName])
-> [VarName] -> [VarName] -> [VarName]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((VarName -> [VarName] -> [VarName])
-> [VarName] -> VarName -> [VarName]
forall a b c. (a -> b -> c) -> b -> a -> c
flip VarName -> [VarName] -> [VarName]
forall a. Eq a => a -> [a] -> [a]
delete) [VarName]
r' [VarName]
xs), [VarName] -> WriteList
WriteList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ ([VarName] -> VarName -> [VarName])
-> [VarName] -> [VarName] -> [VarName]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((VarName -> [VarName] -> [VarName])
-> [VarName] -> VarName -> [VarName]
forall a b c. (a -> b -> c) -> b -> a -> c
flip VarName -> [VarName] -> [VarName]
forall a. Eq a => a -> [a] -> [a]
delete) [VarName]
w [VarName]
xs))
          else ([VarName] -> ReadList
ReadList [VarName]
r, [VarName] -> WriteList
WriteList [])
  If Expr'
e [Statement]
body1 [Statement]
body2 ->
    let ReadList [VarName]
r = Expr' -> ReadList
analyzeExpr Expr'
e
        (ReadList [VarName]
r1, WriteList [VarName]
w1) = Bool -> [Statement] -> (ReadList, WriteList)
analyzeStatementsGeneric Bool
isMax [Statement]
body1
        (ReadList [VarName]
r2, WriteList [VarName]
w2) = Bool -> [Statement] -> (ReadList, WriteList)
analyzeStatementsGeneric Bool
isMax [Statement]
body2
     in if Bool
isMax
          then ([VarName] -> ReadList
ReadList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
r [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
r1 [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
r2), [VarName] -> WriteList
WriteList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
w1 [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
w2))
          else ([VarName] -> ReadList
ReadList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
r [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName] -> [VarName] -> [VarName]
forall a. Eq a => [a] -> [a] -> [a]
intersect [VarName]
r1 [VarName]
r2), [VarName] -> WriteList
WriteList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
w1 [VarName] -> [VarName] -> [VarName]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [VarName]
w2))
  Assert Expr'
e -> (Expr' -> ReadList
analyzeExpr Expr'
e, [VarName] -> WriteList
WriteList [])
  Append Maybe Loc
_ Type
_ Expr'
x Expr'
e ->
    let w :: WriteList
w = WriteList -> (Target' -> WriteList) -> Maybe Target' -> WriteList
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([VarName] -> WriteList
WriteList []) Target' -> WriteList
analyzeTargetWrite (Expr' -> Maybe Target'
exprToTarget Expr'
x)
        (ReadList [VarName]
r) = ReadList -> (Target' -> ReadList) -> Maybe Target' -> ReadList
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([VarName] -> ReadList
ReadList []) Target' -> ReadList
analyzeTargetRead (Expr' -> Maybe Target'
exprToTarget Expr'
x)
        (ReadList [VarName]
r') = Expr' -> ReadList
analyzeExpr Expr'
e
     in ([VarName] -> ReadList
ReadList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub ([VarName] -> [VarName]) -> [VarName] -> [VarName]
forall a b. (a -> b) -> a -> b
$ [VarName]
r [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
r'), WriteList
w)
  Expr' Expr'
e -> (Expr' -> ReadList
analyzeExpr Expr'
e, [VarName] -> WriteList
WriteList [])

analyzeStatementsGeneric :: Bool -> [Statement] -> (ReadList, WriteList)
analyzeStatementsGeneric :: Bool -> [Statement] -> (ReadList, WriteList)
analyzeStatementsGeneric Bool
isMax = [VarName] -> [VarName] -> [Statement] -> (ReadList, WriteList)
go [] []
  where
    go :: [VarName] -> [VarName] -> [Statement] -> (ReadList, WriteList)
go [VarName]
r [VarName]
w [] = ([VarName] -> ReadList
ReadList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub [VarName]
r), [VarName] -> WriteList
WriteList ([VarName] -> [VarName]
forall a. Eq a => [a] -> [a]
nub [VarName]
w))
    go [VarName]
r [VarName]
w (Statement
stmt : [Statement]
stmts) =
      let (ReadList [VarName]
r', WriteList [VarName]
w') = Bool -> Statement -> (ReadList, WriteList)
analyzeStatementGeneric Bool
isMax Statement
stmt
       in [VarName] -> [VarName] -> [Statement] -> (ReadList, WriteList)
go ([VarName]
r' [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
r) ([VarName]
w' [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ [VarName]
w) [Statement]
stmts

-- | `analyzeStatementMax` returns lists of variables which are possibly read or written in given statements.
analyzeStatementMax :: Statement -> (ReadList, WriteList)
analyzeStatementMax :: Statement -> (ReadList, WriteList)
analyzeStatementMax = Bool -> Statement -> (ReadList, WriteList)
analyzeStatementGeneric Bool
True

analyzeStatementsMax :: [Statement] -> (ReadList, WriteList)
analyzeStatementsMax :: [Statement] -> (ReadList, WriteList)
analyzeStatementsMax = Bool -> [Statement] -> (ReadList, WriteList)
analyzeStatementsGeneric Bool
True

-- | `analyzeStatementMin` returns lists of variables which are always read or written in given statements.
analyzeStatementMin :: Statement -> (ReadList, WriteList)
analyzeStatementMin :: Statement -> (ReadList, WriteList)
analyzeStatementMin = Bool -> Statement -> (ReadList, WriteList)
analyzeStatementGeneric Bool
False

analyzeStatementsMin :: [Statement] -> (ReadList, WriteList)
analyzeStatementsMin :: [Statement] -> (ReadList, WriteList)
analyzeStatementsMin = Bool -> [Statement] -> (ReadList, WriteList)
analyzeStatementsGeneric Bool
False