{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -- | Interchanging scans with inner maps. module Futhark.Pass.ExtractKernels.ISRWIM ( iswim, irwim, rwimPossible, ) where import Control.Arrow (first) import Control.Monad.State import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Tools -- | Interchange Scan With Inner Map. Tries to turn a @scan(map)@ into a -- @map(scan) iswim :: (MonadBinder m, Lore m ~ SOACS) => Pattern -> SubExp -> Lambda -> [(SubExp, VName)] -> Maybe (m ()) iswim res_pat w scan_fun scan_input | Just (map_pat, map_cs, map_w, map_fun) <- rwimPossible scan_fun = Just $ do let (accs, arrs) = unzip scan_input arrs' <- transposedArrays arrs accs' <- mapM (letExp "acc" . BasicOp . SubExp) accs let map_arrs' = accs' ++ arrs' (scan_acc_params, scan_elem_params) = splitAt (length arrs) $ lambdaParams scan_fun map_params = map removeParamOuterDim scan_acc_params ++ map (setParamOuterDimTo w) scan_elem_params map_rettype = map (setOuterDimTo w) $ lambdaReturnType scan_fun scan_params = lambdaParams map_fun scan_body = lambdaBody map_fun scan_rettype = lambdaReturnType map_fun scan_fun' = Lambda scan_params scan_body scan_rettype scan_input' = map (first Var) $ uncurry zip $ splitAt (length arrs') $ map paramName map_params (nes', scan_arrs) = unzip scan_input' scan_soac <- scanSOAC [Scan scan_fun' nes'] let map_body = mkBody ( oneStm $ Let (setPatternOuterDimTo w map_pat) (defAux ()) $ Op $ Screma w scan_arrs scan_soac ) $ map Var $ patternNames map_pat map_fun' = Lambda map_params map_body map_rettype res_pat' <- fmap (basicPattern []) $ mapM (newIdent' (<> "_transposed") . transposeIdentType) $ patternValueIdents res_pat addStm $ Let res_pat' (StmAux map_cs mempty ()) $ Op $ Screma map_w map_arrs' (mapSOAC map_fun') forM_ ( zip (patternValueIdents res_pat) (patternValueIdents res_pat') ) $ \(to, from) -> do let perm = [1, 0] ++ [2 .. arrayRank (identType from) -1] addStm $ Let (basicPattern [] [to]) (defAux ()) $ BasicOp $ Rearrange perm $ identName from | otherwise = Nothing -- | Interchange Reduce With Inner Map. Tries to turn a @reduce(map)@ into a -- @map(reduce) irwim :: (MonadBinder m, Lore m ~ SOACS) => Pattern -> SubExp -> Commutativity -> Lambda -> [(SubExp, VName)] -> Maybe (m ()) irwim res_pat w comm red_fun red_input | Just (map_pat, map_cs, map_w, map_fun) <- rwimPossible red_fun = Just $ do let (accs, arrs) = unzip red_input arrs' <- transposedArrays arrs -- FIXME? Can we reasonably assume that the accumulator is a -- replicate? We also assume that it is non-empty. let indexAcc (Var v) = do v_t <- lookupType v letSubExp "acc" $ BasicOp $ Index v $ fullSlice v_t [DimFix $ intConst Int64 0] indexAcc Constant {} = error "irwim: array accumulator is a constant." accs' <- mapM indexAcc accs let (_red_acc_params, red_elem_params) = splitAt (length arrs) $ lambdaParams red_fun map_rettype = map rowType $ lambdaReturnType red_fun map_params = map (setParamOuterDimTo w) red_elem_params red_params = lambdaParams map_fun red_body = lambdaBody map_fun red_rettype = lambdaReturnType map_fun red_fun' = Lambda red_params red_body red_rettype red_input' = zip accs' $ map paramName map_params red_pat = stripPatternOuterDim map_pat map_body <- case irwim red_pat w comm red_fun' red_input' of Nothing -> do reduce_soac <- reduceSOAC [Reduce comm red_fun' $ map fst red_input'] return $ mkBody ( oneStm $ Let red_pat (defAux ()) $ Op $ Screma w (map snd red_input') reduce_soac ) $ map Var $ patternNames map_pat Just m -> localScope (scopeOfLParams map_params) $ do map_body_bnds <- collectStms_ m return $ mkBody map_body_bnds $ map Var $ patternNames map_pat let map_fun' = Lambda map_params map_body map_rettype addStm $ Let res_pat (StmAux map_cs mempty ()) $ Op $ Screma map_w arrs' $ mapSOAC map_fun' | otherwise = Nothing -- | Does this reduce operator contain an inner map, and if so, what -- does that map look like? rwimPossible :: Lambda -> Maybe (Pattern, Certificates, SubExp, Lambda) rwimPossible fun | Body _ stms res <- lambdaBody fun, [bnd] <- stmsToList stms, -- Body has a single binding map_pat <- stmPattern bnd, map Var (patternNames map_pat) == res, -- Returned verbatim Op (Screma map_w map_arrs form) <- stmExp bnd, Just map_fun <- isMapSOAC form, map paramName (lambdaParams fun) == map_arrs = Just (map_pat, stmCerts bnd, map_w, map_fun) | otherwise = Nothing transposedArrays :: MonadBinder m => [VName] -> m [VName] transposedArrays arrs = forM arrs $ \arr -> do t <- lookupType arr let perm = [1, 0] ++ [2 .. arrayRank t -1] letExp (baseString arr) $ BasicOp $ Rearrange perm arr removeParamOuterDim :: LParam -> LParam removeParamOuterDim param = let t = rowType $ paramType param in param {paramDec = t} setParamOuterDimTo :: SubExp -> LParam -> LParam setParamOuterDimTo w param = let t = setOuterDimTo w $ paramType param in param {paramDec = t} setIdentOuterDimTo :: SubExp -> Ident -> Ident setIdentOuterDimTo w ident = let t = setOuterDimTo w $ identType ident in ident {identType = t} setOuterDimTo :: SubExp -> Type -> Type setOuterDimTo w t = arrayOfRow (rowType t) w setPatternOuterDimTo :: SubExp -> Pattern -> Pattern setPatternOuterDimTo w pat = basicPattern [] $ map (setIdentOuterDimTo w) $ patternValueIdents pat transposeIdentType :: Ident -> Ident transposeIdentType ident = ident {identType = transposeType $ identType ident} stripIdentOuterDim :: Ident -> Ident stripIdentOuterDim ident = ident {identType = rowType $ identType ident} stripPatternOuterDim :: Pattern -> Pattern stripPatternOuterDim pat = basicPattern [] $ map stripIdentOuterDim $ patternValueIdents pat