module Feldspar.Compiler.Plugins.ForwardPropagation (
ForwardPropagation(..)
)
where
import Feldspar.Compiler.PluginArchitecture
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.List as List
import Feldspar.Compiler.Plugins.PropagationUtils
import Feldspar.Compiler.Error
import Feldspar.Compiler.Options
import Feldspar.Compiler.Imperative.CodeGeneration (simpleType)
fwdPropError = handleError "PluginArch/ForwardPropagation" InternalError
type VarStatFwd = VarStatistics (ExpressionData ForwardPropagationSemInf, [VariableData], Bool)
type OccurrencesFwd = Occurrences (ExpressionData ForwardPropagationSemInf, [VariableData], Bool)
data ForwardPropagation = ForwardPropagation
instance Plugin ForwardPropagation where
type ExternalInfo ForwardPropagation = DebugOption
executePlugin ForwardPropagation externalInfo procedure
| externalInfo == NoSimplification || externalInfo == NoPrimitiveInstructionHandling = procedure
| otherwise = fst $ executeTransformationPhase ForwardPropagationTransform (fst globals1) procedureCollected1
where
(procedureCollected1,globals1) = executeTransformationPhase ForwardPropagationCollect Occurrence_read procedure
instance TransformationPhase ForwardPropagation where
type From ForwardPropagation = ()
type To ForwardPropagation = ()
type Downwards ForwardPropagation = ()
type Upwards ForwardPropagation = ()
data ForwardPropagationSemInf
instance SemanticInfo ForwardPropagationSemInf where
type ProcedureInfo ForwardPropagationSemInf = ()
type BlockInfo ForwardPropagationSemInf = VarStatFwd
type ProgramInfo ForwardPropagationSemInf = ()
type EmptyInfo ForwardPropagationSemInf = ()
type PrimitiveInfo ForwardPropagationSemInf = ()
type SequenceInfo ForwardPropagationSemInf = ()
type BranchInfo ForwardPropagationSemInf = ()
type SequentialLoopInfo ForwardPropagationSemInf = VarStatFwd
type ParallelLoopInfo ForwardPropagationSemInf = ()
type FormalParameterInfo ForwardPropagationSemInf = ()
type LocalDeclarationInfo ForwardPropagationSemInf = ()
type ExpressionInfo ForwardPropagationSemInf = ()
type ConstantInfo ForwardPropagationSemInf = ()
type FunctionCallInfo ForwardPropagationSemInf = ()
type LeftValueInfo ForwardPropagationSemInf = ()
type ArrayElemReferenceInfo ForwardPropagationSemInf = Maybe VariableData
type InstructionInfo ForwardPropagationSemInf = ()
type AssignmentInfo ForwardPropagationSemInf = ()
type ProcedureCallInfo ForwardPropagationSemInf = ()
type ActualParameterInfo ForwardPropagationSemInf = ()
type IntConstantInfo ForwardPropagationSemInf = ()
type FloatConstantInfo ForwardPropagationSemInf = ()
type BoolConstantInfo ForwardPropagationSemInf = ()
type ArrayConstantInfo ForwardPropagationSemInf = ()
type VariableInfo ForwardPropagationSemInf = Occurrence_place
instance Combine (VarStatFwd, Maybe VariableData) where
combine a b = (combine (fst a) $ fst b, Nothing)
data ForwardPropagationCollect = ForwardPropagationCollect
instance TransformationPhase ForwardPropagationCollect where
type From ForwardPropagationCollect = ()
type To ForwardPropagationCollect = ForwardPropagationSemInf
type Downwards ForwardPropagationCollect = Occurrence_place
type Upwards ForwardPropagationCollect = (VarStatFwd, Maybe VariableData)
downwardsBranchProgramInProgram self d orig = occurrenceDownwards orig
downwardsSequentialLoopProgramInProgram self d orig = occurrenceDownwards orig
downwardsParallelLoopProgramInProgram self d orig = occurrenceDownwards orig
downwardsFormalParameter self d orig = occurrenceDownwards orig
downwardsLocalDeclaration self d orig = occurrenceDownwards orig
downwardsAssignmentInstructionInInstruction self d orig = occurrenceDownwards orig
downwardsActualParameter self d orig = occurrenceDownwards orig
downwardsInputActualParameterInActualParameter self d orig = occurrenceDownwards orig
downwardsExpression self d orig = occurrenceDownwards orig
transformBlock self d origBlock u = Block {
blockDeclarations = recursivelyTransformedBlockDeclarations u,
blockInstructions = recursivelyTransformedBlockInstructions u,
blockSemInf = selectFromVarStatistics ( declaredVars origBlock) belowStatistics
} where
belowStatistics = checkFwdDeclaration (map fst $ upwardsInfoFromBlockDeclarations u) (fst $ upwardsInfoFromBlockInstructions u)
transformVariable self d origVar = origVar {
variableSemInf = d
}
upwardsVariable self d origVar newVar = case d of
Occurrence_declare -> (Map.singleton (variableData origVar) $ Occurrences Zero Zero, Just $ variableData origVar)
Occurrence_read -> (Map.singleton (variableData origVar) $ Occurrences Zero (One ()), Just $ variableData origVar)
Occurrence_write -> (Map.singleton (variableData origVar) $ Occurrences (One Nothing) Zero, Just $ variableData origVar)
Occurrence_notopt -> (Map.singleton (variableData origVar) $ Occurrences Multiple Multiple, Just $ variableData origVar)
upwardsSequenceProgramInProgram self d origSeq u transSeq = (checkFwdSequence $ map fst $ upwardsInfoFromSequenceProgramList u, Nothing)
upwardsBlock self d origBlock u newBlock = (deleteFromVarStatistics (declaredVars origBlock) belowStatistics, Nothing) where
belowStatistics = foldl combine (fst $ upwardsInfoFromBlockInstructions u) $ map fst $ upwardsInfoFromBlockDeclarations u
upwardsParallelLoopProgramInProgram self d origParLoop u transParLoop = (multipleVarStatistics $
foldl combine (fst $ upwardsInfoFromParallelLoopConditionVariable u)
[fst $ upwardsInfoFromNumberOfIterations u, fst $ upwardsInfoFromParallelLoopCore u], Nothing)
upwardsAssignmentInstructionInInstruction self d origAssign u transAssig = case leftValueData $ assignmentLhs origAssign of
VariableLeftValue vlv -> (Map.insert var occ $ fst $ upwardsInfoFromAssignmentRhs u, Nothing)
where
var = variableData vlv
occ = Occurrences (One $ Just (assRs, Map.keys $ fst $ upwardsInfoFromAssignmentRhs u, False)) Zero
assRs = case transAssig of
AssignmentInstruction newAssign -> expressionData $ assignmentRhs newAssign
_ -> fwdPropError $ "Internal error: ForwardPropagation/1!"
ArrayElemReferenceLeftValue aer -> (combine (fst $ upwardsInfoFromAssignmentLhs u) (fst $ upwardsInfoFromAssignmentRhs u), Nothing)
upwardsLocalDeclaration self d origDecl u newDecl = case localInitValue newDecl of
Nothing -> defaultCase
Just exp -> case expressionData exp of
ConstantExpression (Constant (ArrayConstant ac) ()) -> defaultCase
initExp -> case upwardsInfoFromLocalInitValue u of
Nothing -> defaultCase
Just justUpFromLocalInitValue -> (Map.insert var (occ initExp $ fst justUpFromLocalInitValue) $ fst justUpFromLocalInitValue, Nothing)
where
var = variableData $ localVariable origDecl
occ initExp justUpFromLocalInitValue = Occurrences (One $ Just (initExp, Map.keys justUpFromLocalInitValue, False)) Zero
defaultCase = (fst $ upwardsInfoFromLocalVariable u, Nothing)
upwardsProcedureCallInstructionInInstruction self d origProcCall u transProcCall
| List.isPrefixOf "copy" $ nameOfProcedureToCall origProcCall = case map actualParameterData actParams of
[InputActualParameter inArr, InputActualParameter arrSize, OutputActualParameter outArr] ->
case leftValueData outArr of
VariableLeftValue vlv -> (Map.insert (var vlv) (occ inArr) $ fst $ head ul, Nothing)
ArrayElemReferenceLeftValue aer -> defaultTr
_ -> defaultTr
| otherwise = defaultTr
where
defaultTr = case ul of
[] -> defaultValue
otherwise -> foldl combine (head ul) (tail ul)
ul = upwardsInfoFromActualParametersOfProcedureToCall u
actParams = case transProcCall of
ProcedureCallInstruction pc -> actualParametersOfProcedureToCall pc
_ -> fwdPropError $ "Internal error: ForwardPropagation/2!"
var vlv = variableData vlv
occ inArr = Occurrences (One $ Just (expressionData inArr, Map.keys $ fst $ head ul, False)) Zero
transformSequentialLoopProgramInProgram self d origSeqLoop u = SequentialLoopProgram $ origSeqLoop {
sequentialLoopCondition = recursivelyTransformedSequentialLoopCondition u,
conditionCalculation = (recursivelyTransformedConditionCalculation u) {
blockSemInf = Map.empty
},
sequentialLoopCore = recursivelyTransformedSequentialLoopCore u,
sequentialLoopSemInf = blockSemInf $ recursivelyTransformedConditionCalculation u
}
upwardsSequentialLoopProgramInProgram self d origSeqLoop u newSeqLoop = (multipleVarStatistics $
combine (deleteFromVarStatistics [condVar] $ fst $ upwardsInfoFromSequentialLoopCondition u) $ fst $ upwardsInfoFromSequentialLoopCore u, Nothing)
where
condVar = head $ Map.keys $ fst $ upwardsInfoFromSequentialLoopCondition u
transformArrayElemReferenceLeftValueInLeftValue self d origArrRef u = ArrayElemReferenceLeftValue $ ArrayElemReference {
arrayName = recursivelyTransformedArrayName u,
arrayIndex = recursivelyTransformedArrayIndex u,
arrayElemReferenceSemInf = snd $ upwardsInfoFromArrayName u
}
upwardsArrayElemReferenceLeftValueInLeftValue self d origArrayRef u transArrayRefe =
(combine (fst $ upwardsInfoFromArrayName u) (fst $ upwardsInfoFromArrayIndex u), snd $ upwardsInfoFromArrayName u)
upwardsVariableLeftValueInLeftValue self d origVar transVar = upwardsVariable self d origVar $ transformVariable self d origVar
transformVariableLeftValueInLeftValue self d origVar = VariableLeftValue $ transformVariable self d origVar
checkFwdSequence :: [VarStatFwd] -> VarStatFwd
checkFwdSequence [] = defaultValue
checkFwdSequence xs = List.foldl checkInSeq Map.empty xs
where
checkInSeq :: VarStatFwd -> VarStatFwd -> VarStatFwd
checkInSeq preSeq curr = combine curr $ Map.mapWithKey (updatePreSeq curr) preSeq
updatePreSeq :: VarStatFwd -> VariableData -> OccurrencesFwd -> OccurrencesFwd
updatePreSeq curr preSeqVar preSeqOcc = case writeVar preSeqOcc of
One (Just (preSeqExp,preSeqVars,preSeqVarsWritten))
| preSeqVarsWritten && curr `hasRead` preSeqVar -> Occurrences (One Nothing) $ readVar preSeqOcc
| any (hasWrite curr) preSeqVars -> case (curr `hasRead` preSeqVar) && not ((simpleType $ variableDataType preSeqVar) && readVar preSeqOcc /= Multiple) of
True -> Occurrences (One Nothing) $ readVar preSeqOcc
False -> Occurrences (One (Just (preSeqExp,preSeqVars ++ (addDep curr preSeqVar),True))) $ readVar preSeqOcc
| otherwise -> case curr `getWrite` preSeqVar of
Nothing -> preSeqOcc
Just (exp,vars,varsWritten)
| exp == preSeqExp -> Occurrences Zero $ readVar preSeqOcc
| otherwise -> preSeqOcc
_ -> preSeqOcc
addDep curr preSeqVar = case curr `getWrite` preSeqVar of
Nothing -> []
Just (exp,vars,varsWritten) -> vars
checkFwdDeclaration :: [VarStatFwd] -> VarStatFwd -> VarStatFwd
checkFwdDeclaration [] blockStat = blockStat
checkFwdDeclaration declStat blockStat = checkFwdSequence $ declStat ++ [blockStat]
type VarWrite t = [(VariableData,ExpressionData t)]
toVarWrite :: VarStatFwd -> VarWrite ForwardPropagationSemInf
toVarWrite vs = Map.foldWithKey (getExp) [] vs where
getExp :: VariableData -> OccurrencesFwd -> VarWrite ForwardPropagationSemInf -> VarWrite ForwardPropagationSemInf
getExp name (Occurrences (One (Just (exp,_,_))) reads) vw
| reads /= Multiple && notConstArray exp = (name,exp):vw
| simpleExpr exp = (name,exp):vw
| otherwise = vw
getExp name _ vw = vw
notConstArray e = case e of
ConstantExpression (Constant c _) -> simplConst c
_ -> True
simpleExpr e = case e of
ConstantExpression (Constant c _) -> simplConst c
LeftValueExpression l -> case leftValueData l of
VariableLeftValue v -> True
ArrayElemReferenceLeftValue a -> simpleExpr $ expressionData $ arrayIndex a
_ -> False
simplConst (ArrayConstant ac) = False
simplConst _ = True
data ForwardPropagationTransform = ForwardPropagationTransform
instance TransformationPhase ForwardPropagationTransform where
type From ForwardPropagationTransform = ForwardPropagationSemInf
type To ForwardPropagationTransform = ()
type Downwards ForwardPropagationTransform = VarStatFwd
type Upwards ForwardPropagationTransform = Set.Set VariableData
downwardsBlock self d origBlock = combine d $ blockSemInf origBlock
downwardsSequentialLoopProgramInProgram self d origSeqLoop = combine d $ sequentialLoopSemInf origSeqLoop
transformLeftValueExpressionInExpression self d origLV u = case leftValueData origLV of
VariableLeftValue origVar -> case List.find (\(vn,e) -> (vn == variableData origVar)) varwrite of
Nothing -> defaultTr
Just repl -> expressionData $ fst $ walkExpression self d $ Expression (snd repl) ()
ArrayElemReferenceLeftValue origArr -> defaultTr
where
varwrite = toVarWrite d
defaultTr = LeftValueExpression $ LeftValue {
leftValueData = recursivelyTransformedLeftValueData u,
leftValueSemInf = ()
}
transformVariableLeftValueInLeftValue self d origVar = case List.find (\(vn,e) -> (vn == var)) varwrite of
Nothing -> defaultTr
Just repl -> case repl of
(_,LeftValueExpression lv) -> leftValueData $ fst $ walkLeftValue self d lv
_ -> defaultTr
where
var = variableData origVar
varwrite = toVarWrite d
defaultTr = VariableLeftValue $ origVar {
variableSemInf = ()
}
transformArrayElemReferenceLeftValueInLeftValue self d origArrayRef u = case List.find (\(vn,e) -> (vn == var)) varwrite of
Nothing -> defaultTr
Just repl -> case repl of
(_,LeftValueExpression lv) -> case leftValueData lv of
VariableLeftValue vlv -> defaultTr
ArrayElemReferenceLeftValue aer -> ArrayElemReferenceLeftValue $ ArrayElemReference {
arrayName = fst $ walkLeftValue self (swapArrayIndex d var aer origArrayRef) $ arrayName origArrayRef,
arrayIndex = fst $ walkExpression self d $ arrayIndex aer,
arrayElemReferenceSemInf = ()
}
_ -> defaultTr
where
swapArrayIndex :: VarStatFwd -> VariableData -> ArrayElemReference ForwardPropagationSemInf -> ArrayElemReference ForwardPropagationSemInf -> VarStatFwd
swapArrayIndex d var rep orig = Map.adjust (swapArrayIndex2 var rep orig) var d
swapArrayIndex2 var rep orig x = x {
writeVar = One $ Just ( LeftValueExpression $ LeftValue {
leftValueData = ArrayElemReferenceLeftValue $ ArrayElemReference {
arrayName = arrayName rep,
arrayIndex = arrayIndex orig,
arrayElemReferenceSemInf = Just var
},
leftValueSemInf = ()
},[],False)
}
var = getJust $ arrayElemReferenceSemInf origArrayRef
getJust (Just a) = a
getJust _ = fwdPropError $ "Internal error: ForwardPropagation/3!"
varwrite = toVarWrite d
defaultTr = ArrayElemReferenceLeftValue $ ArrayElemReference {
arrayName = recursivelyTransformedArrayName u,
arrayIndex = recursivelyTransformedArrayIndex u,
arrayElemReferenceSemInf = convert $ arrayElemReferenceSemInf origArrayRef
}
upwardsVariable self d origVar newVar = case variableSemInf origVar of
Occurrence_declare -> Set.empty
Occurrence_read -> Set.empty
Occurrence_write -> Set.singleton (variableData origVar)
Occurrence_notopt -> Set.empty
upwardsBlock self d origBlock u transformedBlock = foldl (\s e -> Set.delete e s) (upwardsInfoFromBlockInstructions u) (declaredVars origBlock)
transformBlock self d origBlock u = delUnusedDecl (map fst $ toVarWrite $ combine d $ blockSemInf origBlock) origBlock (recursivelyTransformedBlockDeclarations u) (recursivelyTransformedBlockInstructions u)
transformPrimitiveProgramInProgram self d originalPrimitive u
| canDelete && deletablePrimitive = EmptyProgram $ Empty ()
| otherwise = PrimitiveProgram $ Primitive {
primitiveInstruction = recursivelyTransformedPrimitiveInstruction u,
primitiveSemInf = ()
}
where
canDelete = Set.isSubsetOf (upwardsInfoFromPrimitiveInstruction u) (Set.fromList $ map fst $ toVarWrite d)
deletablePrimitive = case instructionData $ primitiveInstruction originalPrimitive of
ProcedureCallInstruction pc -> List.isPrefixOf "copy" $ nameOfProcedureToCall pc
AssignmentInstruction ass -> True
upwardsVariableLeftValueInLeftValue self d origVar transVar = upwardsVariable self d origVar $ transformVariable self d origVar
transformLeftValue self d origLV u = case transformLeftValueExpressionInExpression self d origLV u of
LeftValueExpression lv -> lv
_ -> fwdPropError $ "Internal error: ForwardPropagation/4!"