{- Copyright 2016, Dominic Orchard, Andrew Rice, Mistral Contrastin, Matthew Danish Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} module Camfort.Transformation.DerivedTypeIntro where import Data.Data import Data.List hiding (union, insert) import Data.Maybe import Data.Set hiding (foldl, map) import Data.Generics.Uniplate.Operations import Control.Monad.State.Lazy import Debug.Trace import qualified Data.Map as Data.Map import Language.Fortran import Camfort.Analysis.Annotations import Camfort.Analysis.IntermediateReps import Camfort.Analysis.Loops import Camfort.Analysis.Syntax import Camfort.Transformation.Syntax import Camfort.Analysis.Types import Camfort.Helpers import Camfort.Traverse typeStruct :: [(Filename, Program Annotation)] -> (Report, [(Filename, Program Annotation)]) typeStruct fps = mapM (\(f, ps) -> mapM typeStructPerProgram ps >>= (\ps' -> return (f, ps'))) fps -- raph data structures used to build interference graphs type Graph v a = [((v, v), a)] -- Note, this is graphs with labelled edges type WeightedEdge v a = ((v, v), (a, Int)) type WeightedGraph v a = [WeightedEdge v a] -- vertices :: WeightedGraph v a -> [v] (also works for Graph v a) vertices = concatMap (\((x, y), _) -> [x, y]) -- isVertex :: v -> WeightedGraph v a -> Bool (also works Graph v a) isVertex v wgs = elem v (vertices wgs) getVertex v [] = Nothing getVertex v (((v1, v2), d):es) = if v == v1 || v == v2 then Just d else getVertex v es -- on-interprocedural version first typeStructPerProgram :: ProgUnit Annotation -> (Report, ProgUnit Annotation) typeStructPerProgram p = descendBiM (\b@(Block a uses implicits span decs blockBody) -> let tenv = typeEnv b -- Compute graph of semantically related projection variables es = Exprs `topFrom` b prjVarsWTarget = map locsFromArrayIndex es iGraph = toInterferenceGraph prjVarsWTarget wiGraph = calculateWeights iGraph -- weighted inteference graph wgf = decomposeWeightedGraph wiGraph -- Generate definitions tDefsAndNames = evalState (mapM (mkTypeDef tenv (fst span, fst span)) wgf) 0 nwgf = zip wgf (map snd tDefsAndNames) rAnnotation = if (length tDefsAndNames > 0) then unitAnnotation { refactored = Just (fst span) } else unitAnnotation blockBody' = elimProjectionDefs blockBody iGraph decs' = foldl (DSeq unitAnnotation) decs (map fst tDefsAndNames) a' = if (length tDefsAndNames > 0) then a { refactored = Just (fst span) } else a in -- Create outgoing block (show wiGraph ++ "\n\n" ++ show wgf, Block a' uses implicits span decs' blockBody')) p -- - Graph Access Variable here is a graph with projection variables at nodes -- - and the array target that they both index as the edge label toInterferenceGraph :: [[(Variable, Access)]] -> Graph Access Variable toInterferenceGraph pvars = let rel = concatMap listToSymmRelation pvars matchingArrayTargets r ((a, x), (b, y)) | a == b = ((x, y), a) : r | otherwise = r in foldl matchingArrayTargets [] rel listToSymmRelation :: [a] -> [(a, a)] listToSymmRelation [] = [] listToSymmRelation (x:xs) = ((repeat x) `zip` xs) ++ (listToSymmRelation xs) -- heck coherence of original manual projection approach correctManualImpl ranges stmt graph = let (_, pvarmap) = runState (transformBiM collect stmt) Data.Map.empty in Data.Map.foldWithKey (\arr vixs p -> case (lookup arr ranges) of Just (l, u) -> (sort (map snd vixs) == [l..u]) && p) True pvarmap where collect :: Fortran A -> State (Data.Map.Map Variable [(Variable, Integer)]) (Fortran A) collect a@(Assg p sp e1 e2) = do indexMap <- get case (do v <- varExprToVariable e1 arr <- getVertex (VarA v) graph case e2 of (ConS _ _ val) -> case (Data.Map.lookup arr indexMap) of Just ixs -> case (lookup v ixs) of Just val' -> Nothing -- error "Repeated definition of projection" Nothing -> Just $ Data.Map.update (\ixs -> Just $ ((v, read $ val) : ixs)) arr indexMap Nothing -> Just $ Data.Map.insert arr [(v, read $ val)] indexMap) of Just indexMap' -> do put indexMap'; return a Nothing -> return a collect f = return f elimProjectionDefs :: Fortran A -> Graph Access Variable -> Fortran A elimProjectionDefs stmt graph = transformBi ef stmt where ef a@(Assg p sp e1 e2) = case (varExprToVariable e1) of Just v -> if (isVertex (VarA v) graph) then NullStmt (p { refactored = Just $ dropLine' sp }) sp else a Nothing -> a ef f = f arrayAccessToProjection :: Fortran A -> Graph Access Variable -> Fortran A arrayAccessToProjection = undefined -- ounts number of duplicate edges and makes this the "weight" calculateWeights :: (Eq (AnnotationFree a), Eq (AnnotationFree v), Ord a, Ord v) => Graph v a -> WeightedGraph v a calculateWeights xs = calcWs (sort xs) 1 where calcWs [] _ = [] calcWs [((v1, v2), a)] n = [((v1, v2), (a, n))] calcWs (e@((v1, v2), a):(e':es)) n | ((af e == af e') || (af e == (af (swap e')))) = calcWs (e':es) (n + 1) | otherwise = ((v1, v2), (a, n)) : (calcWs (e':es) 1) swap ((a, b), v) = ((b, a), v) -- inds the variables that are used to index arrays directly locsFromArrayIndex :: Data t => t -> [(Variable, Access)] locsFromArrayIndex x = concat . concat $ each (Vars `from` x) (\(Var _ _ ves) -> each ves (\(VarName _ v, ixs) -> if (not $ all isConstant ixs) then map (\x -> (v, x)) (Locs `from` ixs) else [])) findMatch v ix ((wg, n):wgns) = vertices -- replaceAccess :: [(WeightedGraph Variable Access, Variable)] -> Block Annotation -> Block Annotation -- replaceAccess wgns x = transformBi (\t@(VarName _ v, ixs) -> t) x -- -- mkTyDecl :: SrcSpan -> Variable -> Type Annotation -> Decl Annotation mkTyDecl sp v t = let ua = unitAnnotation in Decl ua sp [(Var ua sp [(VarName ua v, [])], NullExpr ua sp, Nothing)] t mkTypeDef :: TypeEnv Annotation -> SrcSpan -> WeightedGraph Access Variable -> State Int (Decl Annotation, String) mkTypeDef tenv sp wg = (inventName wg) >>= (\name -> let edgeToDecls ((vx, vy), (va, w)) = case (lookup va tenv) of Just t -> [mkTyDecl sp (accessToVarName vx) (arrayElementType t), mkTyDecl sp (accessToVarName vy) (arrayElementType t)] Nothing -> error $ "Can't find the type of " ++ show va ++ "\n" ra = unitAnnotation { refactored = Just (fst sp) } (_, (arrayVar, _)) = head wg tdecls = concatMap edgeToDecls wg typeDecl = DerivedTypeDef ra sp (SubName ra name) [] [] tdecls typeCons = BaseType ra (DerivedType ra (SubName ra name)) [] (NullExpr ra sp) (NullExpr ra sp) valDecl = Decl ra sp [(Var ra sp [(VarName ra (arrayVar ++ name), [])] , NullExpr ra sp, Nothing)] typeCons in return $ (DSeq unitAnnotation typeDecl valDecl, name)) inventName :: WeightedGraph Access Variable -> State Int String inventName graph = do n <- get put (n + 1) let vs = vertices graph return $ map mode (transpose (map accessToVarName vs)) ++ (show n) -- mode :: String -> Char mode x = let freqs = (map (\x -> (head x, length x))) . group . sort $ x sortedFreqs = sortBy (\x -> \y -> (snd x) `compare` (snd y)) freqs max = last sortedFreqs in -- mode or 'X' if mode is less than the majority if (snd max) > ((length x) `div` 2) then fst max else 'X' decomposeWeightedGraph :: forall v a . (Show v, Ord v, Ord a) => WeightedGraph v a -> [WeightedGraph v a] decomposeWeightedGraph g = map snd (concatMap (foldl binEdge []) (groupBy groupOnArrayVar (sortBy sortOnArrayVar g))) where groupOnArrayVar (_, (av, _)) (_, (av', _)) = av == av' sortOnArrayVar (_, (av, _)) (_, (av', _)) = compare av av' -- ap snd (foldl binEdge [] g) -- bins" edges into a list of graphs with a set of their vertices binEdge :: (Show v, Ord v, Ord a) => [(Set v, WeightedGraph v a)] -> WeightedEdge v a -> [(Set v, WeightedGraph v a)] binEdge bins e@((x, y), _) = let findBin v [] = ((insert x empty, []), []) findBin v ((vs, es):bs) | member v vs = ((insert v vs, es), bs) | otherwise = let (n, bs') = findBin v bs in (n, (vs, es) : bs') ((vs, es), bins') = findBin x bins ((vs', es'), bins'') = findBin y bins' in (vs `union` vs', e : (es ++ es')) : bins'' -- binEdge bins e@((x, y), _) = let r = binVertex y e (binVertex x e bins) in (show r) `trace` r -- binVertex :: Ord a => a -> WeightedEdge a -> [(Set a, WeightedGraph a)] -> [(Set a, WeightedGraph a)] -- binVertex x e ss = bin' x e ss [] Nothing -- where bin' x e [] bs' Nothing = (insert x empty, [e]) : bs' -- bin' x e [] bs' (Just s) = s : bs' -- -- bin' x e ((vs, es):bs) bs' ms | member x vs = -- case ms of -- Nothing -> bin' x e bs bs' (Just (insert x vs, e:es)) -- Just (vs', es') -> bin' x e bs bs' (Just (union vs' (insert x vs'), (e:es) ++ es')) -- | otherwise = bin' x e bs ((vs, es):bs) ms