module Language.Fixpoint.Graph.Partition (
CPart (..)
, partition, partition', partitionN
, MCInfo (..)
, mcInfo
, dumpPartitions
) where
import GHC.Conc (getNumProcessors)
import Control.Monad (forM_)
import Language.Fixpoint.Misc
import Language.Fixpoint.Utils.Files
import Language.Fixpoint.Types.Config
import qualified Language.Fixpoint.Types as F
import Language.Fixpoint.Graph.Types
import Language.Fixpoint.Graph.Deps
import qualified Data.HashMap.Strict as M
import Data.Maybe (fromMaybe)
import Data.Hashable
import Text.PrettyPrint.HughesPJ
import Data.List (sortBy)
data CPart c a = CPart { pws :: !(M.HashMap F.KVar (F.WfC a))
, pcm :: !(M.HashMap Integer (c a))
}
instance Monoid (CPart c a) where
mempty = CPart mempty mempty
mappend l r = CPart { pws = pws l `mappend` pws r
, pcm = pcm l `mappend` pcm r
}
data MCInfo = MCInfo { mcCores :: !Int
, mcMinPartSize :: !Int
, mcMaxPartSize :: !Int
} deriving (Show)
mcInfo :: Config -> IO MCInfo
mcInfo c = do
np <- getNumProcessors
let nc = fromMaybe np (cores c)
return MCInfo { mcCores = nc
, mcMinPartSize = minPartSize c
, mcMaxPartSize = maxPartSize c
}
partition :: (F.Fixpoint a, F.Fixpoint (c a), F.TaggedC c a) => Config -> F.GInfo c a -> IO (F.Result (Integer, a))
partition cfg fi
= do dumpPartitions cfg fis
return mempty
where
fis = partition' Nothing fi
partition' :: (F.TaggedC c a)
=> Maybe MCInfo -> F.GInfo c a -> [F.GInfo c a]
partition' mn fi = case mn of
Nothing -> fis mkPartition id
Just mi -> partitionN mi fi $ fis mkPartition' finfoToCpart
where
css = decompose fi
fis partF ctor = applyNonNull [ctor fi] (pbc partF) css
pbc partF = partitionByConstraints partF fi
partitionN :: MCInfo
-> F.GInfo c a
-> [CPart c a]
-> [F.GInfo c a]
partitionN mi fi cp
| cpartSize (finfoToCpart fi) <= minThresh = [fi]
| otherwise = map (cpartToFinfo fi) $ toNParts sortedParts
where
toNParts p
| isDone p = p
| otherwise = toNParts $ insertSorted firstTwo rest
where (firstTwo, rest) = unionFirstTwo p
isDone [] = True
isDone [_] = True
isDone fi'@(a:b:_) = length fi' <= prts
&& (cpartSize a >= minThresh
|| cpartSize a + cpartSize b >= maxThresh)
sortedParts = sortBy sortPredicate cp
unionFirstTwo (a:b:xs) = (a `mappend` b, xs)
unionFirstTwo _ = errorstar "Partition.partitionN.unionFirstTwo called on bad arguments"
sortPredicate lhs rhs
| cpartSize lhs < cpartSize rhs = GT
| cpartSize lhs > cpartSize rhs = LT
| otherwise = EQ
insertSorted a [] = [a]
insertSorted a (x:xs) = if sortPredicate a x == LT
then x : insertSorted a xs
else a : x : xs
prts = mcCores mi
minThresh = mcMinPartSize mi
maxThresh = mcMaxPartSize mi
cpartSize :: CPart c a -> Int
cpartSize c = (M.size . pcm) c + (length . pws) c
cpartToFinfo :: F.GInfo c a -> CPart c a -> F.GInfo c a
cpartToFinfo fi p = fi {F.ws = pws p, F.cm = pcm p}
finfoToCpart :: F.GInfo c a -> CPart c a
finfoToCpart fi = CPart { pcm = F.cm fi
, pws = F.ws fi
}
dumpPartitions :: (F.Fixpoint (c a), F.Fixpoint a) => Config -> [F.GInfo c a] -> IO ()
dumpPartitions cfg fis =
forM_ (zip [0..] fis) $ \(i, fi) ->
writeFile (queryFile (Part i) cfg) (render $ F.toFixpoint cfg fi)
type PartitionCtor c a b = F.GInfo c a
-> M.HashMap Int [(Integer, c a)]
-> M.HashMap Int [(F.KVar, F.WfC a)]
-> Int
-> b
partitionByConstraints :: PartitionCtor c a b
-> F.GInfo c a
-> KVComps
-> ListNE b
partitionByConstraints f fi kvss = f fi icM iwM <$> js
where
js = fst <$> jkvs
gc = groupFun cM
gk = groupFun kM
iwM = groupMap (gk . fst) (M.toList (F.ws fi))
icM = groupMap (gc . fst) (M.toList (F.cm fi))
jkvs = zip [1..] kvss
kvI = [ (x, j) | (j, kvs) <- jkvs, x <- kvs ]
kM = M.fromList [ (k, i) | (KVar k, i) <- kvI ]
cM = M.fromList [ (c, i) | (Cstr c, i) <- kvI ]
mkPartition :: F.GInfo c a
-> M.HashMap Int [(Integer, c a)]
-> M.HashMap Int [(F.KVar, F.WfC a)]
-> Int
-> F.GInfo c a
mkPartition fi icM iwM j
= fi{ F.cm = M.fromList $ M.lookupDefault [] j icM
, F.ws = M.fromList $ M.lookupDefault [] j iwM
}
mkPartition' :: F.GInfo c a
-> M.HashMap Int [(Integer, c a)]
-> M.HashMap Int [(F.KVar, F.WfC a)]
-> Int
-> CPart c a
mkPartition' _ icM iwM j
= CPart { pcm = M.fromList $ M.lookupDefault [] j icM
, pws = M.fromList $ M.lookupDefault [] j iwM
}
groupFun :: (Show k, Eq k, Hashable k) => M.HashMap k Int -> k -> Int
groupFun m k = safeLookup ("groupFun: " ++ show k) k m