module DDC.Core.Flow.Transform.Rates.SeriesOfVector
(seriesOfVectorModule
,seriesOfVectorFunction)
where
import DDC.Core.Collect
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Transform.Rates.Constraints
import DDC.Core.Flow.Transform.Rates.Fail
import DDC.Core.Flow.Transform.Rates.Graph
import DDC.Core.Module
import DDC.Core.Transform.Annotate
import DDC.Core.Transform.Deannotate
import qualified DDC.Type.Env as Env
import Control.Applicative
import Control.Monad
import Data.List (intersect, nub)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import qualified Data.Set as Set
seriesOfVectorModule :: ModuleF -> (ModuleF, [(Name,Fail)])
seriesOfVectorModule mm
= let body = deannotate (const Nothing)
$ moduleBody mm
(lets, xx) = splitXLets body
letsErrs = map seriesOfVectorLets lets
lets' = map fst letsErrs
errs = concatMap snd letsErrs
body' = annotate ()
$ xLets lets' xx
in
(mm { moduleBody = body' }, errs)
seriesOfVectorLets :: LetsF -> (LetsF, [(Name,Fail)])
seriesOfVectorLets ll
| LLet b@(BName n _) x <- ll
, (x',errs) <- seriesOfVectorFunction x
= (LLet b x', map (\f -> (n,f)) errs)
| LRec bxs <- ll
, (bs,xs) <- unzip bxs
, (xs',_errs) <- unzip $ map seriesOfVectorFunction xs
= (LRec (bs `zip` xs'), [])
| otherwise
= (ll, [])
seriesOfVectorFunction :: ExpF -> (ExpF, [Fail])
seriesOfVectorFunction fun
= run $ do
let (lams, body) = takeXLamFlags_safe fun
(lets, xx) = splitXLets body
binds <- takeLets lets
let tymap = takeTypes (concatMap valwitBindsOfLets lets ++ map snd lams)
let names = map fst binds
when (length names /= length (nub names)) $
warn FailNamesNotUnique
(constrs, equivs)
<- checkBindConstraints binds
let extras = catMaybes
$ map (takeNameOfBind . snd) lams
let graph = graphOfBinds binds extras
let rets = catMaybes
$ map takeNameOfBound
$ Set.toList
$ freeX Env.empty xx
loops <- schedule graph equivs rets
binds' <- orderBinds binds loops
let outputs = map lOutputs loops
let inputs = map lInputs loops
let getMax = getMaxSize constrs equivs extras
return $ construct getMax lams (zip3 binds' outputs inputs) equivs tymap xx
takeXLamFlags_safe x
| Just (binds, body) <- takeXLamFlags x
= (binds, body)
| otherwise
= ([], x)
takeLets :: [LetsF] -> LogFailures [(Name, ExpF)]
takeLets lets
= concat <$> mapM get lets
where
get (LLet (BName n _) x) = return [(n,x)]
get (LLet (BNone _) _) = return []
get (LLet (BAnon _) _) = w FailNoDeBruijnAllowed
get (LRec _ ) = w FailRecursiveBindings
get (LPrivate _ _ _) = w FailLetRegionNotHandled
get (LWithRegion _ ) = w FailLetRegionNotHandled
w err = warn err >> return []
takeTypes :: [Bind Name] -> Map.Map Name TypeF
takeTypes binds
= Map.fromList $ concatMap get binds
where
get (BName n t) = [(n,t)]
get _ = []
data Loop
= Loop
{ lBindings :: [Name]
, lOutputs :: [Name]
, lInputs :: [Name]
} deriving (Eq,Show)
schedule :: Graph -> EquivClass -> [Name] -> LogFailures [Loop]
schedule graph equivs rets
= let type_order = map (canonName equivs . Set.findMin) equivs
(wts, graph') = scheduleTypes graph equivs type_order
loops = scheduleAll (map snd wts) graph graph'
outputs = scheduleOutputs loops graph rets
inputs = scheduleInputs loops graph
in
return $ zipWith3 Loop loops outputs inputs
scheduleTypes :: Graph -> EquivClass -> [Name] -> ([(Name, Map.Map Name Int)], Graph)
scheduleTypes graph types type_order
= foldl go ([],graph) type_order
where
go (w,g) ty
= let w' = typedTraversal g types ty
g' = mergeWeights g w'
in ((ty,w') : w, g')
scheduleAll :: [Map.Map Name Int] -> Graph -> Graph -> [[Name]]
scheduleAll weights graph graph'
= loops
where
weights' = map invertMap weights
topo = graphTopoOrder graph'
loops = map getNames topo
getNames n
= sort $ find n (weights `zip` weights')
original_order = graphTopoOrder graph
sort ns
= filter (flip elem ns) original_order
find _ []
= []
find n ((w,w') : rest)
| Just i <- n `Map.lookup` w
, Just ns <- i `Map.lookup` w'
= ns
| otherwise
= find n rest
scheduleOutputs :: [[Name]] -> Graph -> [Name] -> [[Name]]
scheduleOutputs loops graph rets
= map output loops
where
output ns
= graphOuts ns ++ filter (`elem` ns) rets
graphOuts ns
= concatMap (\(k,es) -> if k `elem` ns
then []
else ns `intersect` map fst es)
$ Map.toList graph
scheduleInputs :: [[Name]] -> Graph -> [[Name]]
scheduleInputs loops graph
= map input loops
where
input ns
= filter (\n -> not (n `elem` ns))
$ graphIns ns
graphIns ns
= nub $ concatMap (map fst . mlookup "graphIns" graph) ns
typedTraversal :: Graph -> EquivClass -> Name -> Map.Map Name Int
typedTraversal graph types current_type
= restrictTypes types current_type
$ traversal graph w
where
w u v = if w' u v then 1 else 0
w' (u, fusible) v
| canonName types u == current_type
= canonName types v /= current_type || not fusible
| otherwise
= False
restrictTypes :: EquivClass -> Name -> Map.Map Name Int -> Map.Map Name Int
restrictTypes types current_type weights
= Map.filterWithKey restrict weights
where
restrict n _
= canonName types n == current_type
orderBinds :: [(Name,ExpF)] -> [Loop] -> LogFailures [[(Name,ExpF)]]
orderBinds binds loops
= let bindsM = Map.fromList binds
order = map lBindings loops
get k | Just v <- Map.lookup k bindsM
= [(k,v)]
| otherwise
= []
in return $ map (\o -> concatMap get o) order
construct
:: (Name -> Name)
-> [(Bool, BindF)]
-> [([(Name, ExpF)], [Name], [Name])]
-> EquivClass
-> Map.Map Name TypeF
-> ExpF
-> ExpF
construct getMax lams loops equivs tys xx
= let lets = concatMap convert loops
in makeXLamFlags lams
$ xLets lets
$ xx
where
convert (binds, outputs, inputs)
= convertToSeries getMax binds outputs inputs equivs tys
convertToSeries
:: (Name -> Name) -> [(Name,ExpF)] -> [Name] -> [Name]
-> EquivClass -> Map.Map Name TypeF -> [LetsF]
convertToSeries getMax binds outputs inputs equivs tys
= concat setups
++ [LLet (BNone tBool) (runprocs inputs' processes)]
++ concat readrefs
where
runprocs :: [(Name,TypeF)] -> ExpF -> ExpF
runprocs vecs@((cn,_):_) body
= let cnn = canonName equivs cn
kN = NameVarMod cnn "k"
kFlags = [ (True, BName kN kRate)
, (False, BNone $ tRateNat $ TVar $ UName kN)]
vFlags = map (\(n,t) -> (False, BName (NameVarMod n "s") (tSeries (TVar (UName kN)) t)))
vecs
in xApps (xVarOpSeries (OpSeriesRunProcess $ length vecs))
( map (XType . snd) vecs
++ map (XVar . UName . fst) vecs
++ [(makeXLamFlags (kFlags ++ vFlags) body)])
runprocs [] body
= body
inputs' :: [(Name,TypeF)]
inputs' = concatMap filterInputs inputs
filterInputs inp
| tyI <- mlookup "collectKloks" tys inp
, Just (_tcVec, [tyA]) <- takeTyConApps tyI
, tyI == tVector tyA
= [(inp, tyA)]
| otherwise
= []
processes
= foldr wrap joins binds
wrap (n,x) body
= wrapSeriesX equivs outputs n (mlookup "wrap" tys n) x body
joins
| not $ null outputs
= foldl1 mkJoin
$ map (\n -> XVar $ UName $ NameVarMod n "proc") outputs
| otherwise
= xUnit
mkJoin p q
= xApps (xVarOpSeries OpSeriesJoin) [p, q]
(setups, readrefs)
= unzip
$ map setread
$ filter (flip elem outputs . fst) binds
setread (n,x)
= setreadSeriesX getMax tys n (mlookup "setread" tys n) x
setreadSeriesX
:: (Name -> Name) -> Map.Map Name TypeF -> Name -> TypeF -> ExpF -> ([LetsF], [LetsF])
setreadSeriesX getMax tys name ty xx
| Just (f, args) <- takeXApps xx
, XVar (UPrim (NameOpVector ov) _) <- f
= case ov of
OpVectorReduce
| [_tA, _f, z, _vA] <- args
-> ([ LLet (BName (nm "ref") (tRef ty)) (xNew ty z) ]
,[ LLet (BName name ty) (xRead ty (vr $ nm "ref"))])
_
| [_vec, tyR] <- takeTApps ty
, v <- getMax name
, [_vec, tyCR] <- takeTApps $ mlookup "setreadSeriesX" tys v
-> let vl = xApps (xVarOpVector OpVectorLength)
[XType tyCR, XVar $ UName v]
in ([ LLet (BName name $ tBot kData) $ xNewVector tyR vl ]
, [])
_
-> ([], [])
| otherwise
= ([],[])
where
nm s = NameVarMod name s
vr n = XVar $ UName n
wrapSeriesX :: EquivClass -> [Name] -> Name -> TypeF -> ExpF -> ExpF -> ExpF
wrapSeriesX equivs outputs name ty xx wrap
| Just (op, args) <- takeXApps xx
, XVar (UPrim (NameOpVector ov) _) <- op
= case ov of
OpVectorReduce
| [_tA, f, z, vA] <- args
, XVar (UName nvA) <- vA
, kA <- klok nvA
-> XLet (LLet (BName name'proc tProcess)
$ xApps (xVarOpSeries OpSeriesReduce)
[kA, XType ty, XVar (UName name'ref), f, z, modNameX "s" vA])
wrap
OpVectorMap n
| (tys, f : rest) <- splitAt (n+1) args
, length rest == n
, kT <- klok name
, rest' <- map (modNameX "s") rest
-> XLet (LLet (BName name's $ tBot kData)
$ xApps (xVarOpSeries (OpSeriesMap n))
([kT] ++ tys ++ [f] ++ rest'))
wrap'fill
OpVectorFilter
| [tA, p, vA] <- args
, XVar (UName nvA) <- vA
, tkA <- klokT nvA
, kA <- klok nvA
, TVar (UName nkT) <- klokT name
, tkT <- klokT name
-> XLet (LLet (BName name'flags (tBot kData))
$ xApps (xVarOpSeries (OpSeriesMap 1))
([kA, tA, XType tBool, p, modNameX "s" vA]))
$ xApps (xVarOpSeries (OpSeriesMkSel 1))
([kA, XVar (UName name'flags)
, XLAM (BName nkT kRate)
$ XLam (BName name'sel (tSel1 tkA tkT))
$ XLet (LLet (BName name's (tBot kData))
$ xApps (xVarOpSeries OpSeriesPack)
([kA, XType tkT, tA, XVar (UName name'sel), modNameX "s" vA]))
wrap'fill ])
_
-> xx
| otherwise
= xx
where
name'flags= NameVarMod name "flags"
name'proc = NameVarMod name "proc"
name'ref = NameVarMod name "ref"
name's = NameVarMod name "s"
name'sel = NameVarMod name "sel"
klokT n
= let n' = canonName equivs n
kN = NameVarMod n' "k"
in TVar $ UName kN
klok n
= XType $ klokT n
tyR
| [_vec, tyR'] <- takeTApps ty
= Just tyR'
| otherwise
= Nothing
wrap'fill
| name `elem` outputs
, Just tyR' <- tyR
= XLet (LLet (BName name'proc tProcess) $ xApps fillV [klok name, XType tyR', vr name, vr name's])
wrap
| otherwise
= wrap
fillV = xVarOpSeries OpSeriesFill
vr n = XVar $ UName n
xVarOpSeries n = XVar (UPrim (NameOpSeries n) (typeOpSeries n))
xVarOpVector n = XVar (UPrim (NameOpVector n) (typeOpVector n))
modNameX :: String -> ExpF -> ExpF
modNameX s xx
= case xx of
XVar (UName n)
-> XVar (UName (NameVarMod n s))
_
-> xx