{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fmax-pmcheck-iterations=2500000#-}
module Futhark.Internalise (internaliseProg) where
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.List
import Data.Loc
import Data.Char (chr)
import Data.Maybe
import Language.Futhark as E hiding (TypeArg)
import Language.Futhark.Semantic (Imports)
import Futhark.Representation.SOACS as I hiding (stmPattern)
import Futhark.Transform.Rename as I
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Representation.AST.Attributes.Aliases
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Util (splitAt3)
import Futhark.Internalise.Monad as I
import Futhark.Internalise.AccurateSizes
import Futhark.Internalise.TypesValues
import Futhark.Internalise.Bindings
import Futhark.Internalise.Lambdas
import Futhark.Internalise.Defunctorise as Defunctorise
import Futhark.Internalise.Defunctionalise as Defunctionalise
import Futhark.Internalise.Monomorphise as Monomorphise
internaliseProg :: MonadFreshNames m =>
Bool -> Imports -> m (Either String I.Prog)
internaliseProg always_safe prog = do
prog_decs <- Defunctorise.transformProg prog
prog_decs' <- Monomorphise.transformProg prog_decs
prog_decs'' <- Defunctionalise.transformProg prog_decs'
prog' <- fmap (fmap I.Prog) $ runInternaliseM always_safe $ internaliseValBinds prog_decs''
traverse I.renameProg prog'
internaliseValBinds :: [E.ValBind] -> InternaliseM ()
internaliseValBinds = mapM_ internaliseValBind
internaliseFunName :: VName -> [E.Pattern] -> InternaliseM Name
internaliseFunName ofname [] = return $ nameFromString $ pretty ofname ++ "f"
internaliseFunName ofname _ = do
info <- lookupFunction' ofname
case info of
Just _ -> nameFromString . pretty <$> newNameFromString (baseString ofname)
Nothing -> return $ nameFromString $ pretty ofname
internaliseValBind :: E.ValBind -> InternaliseM ()
internaliseValBind fb@(E.ValBind entry fname retdecl (Info rettype) tparams params body _ loc) = do
info <- bindingParams tparams params $ \pcm shapeparams params' -> do
(rettype_bad, rcm) <- internaliseReturnType rettype
let rettype' = zeroExts rettype_bad
let mkConstParam name = Param name $ I.Prim int32
constparams = map (mkConstParam . snd) $ pcm<>rcm
constnames = map I.paramName constparams
constscope = M.fromList $ zip constnames $ repeat $
FParamInfo $ I.Prim $ IntType Int32
shapenames = map I.paramName shapeparams
normal_params = map I.paramName constparams ++ shapenames ++
map I.paramName (concat params')
normal_param_names = S.fromList normal_params
fname' <- internaliseFunName fname params
body' <- localScope constscope $ do
msg <- case retdecl of
Just dt -> ErrorMsg .
("Function return value does not match shape of type ":) <$>
typeExpForError rcm dt
Nothing -> return $ ErrorMsg ["Function return value does not match shape of declared return type."]
internaliseBody body >>=
ensureResultExtShape asserting msg loc (map I.fromDecl rettype')
let free_in_fun = freeInBody body' `S.difference` normal_param_names
used_free_params <- forM (S.toList free_in_fun) $ \v -> do
v_t <- lookupType v
return $ Param v $ toDecl v_t Nonunique
let free_shape_params = map (`Param` I.Prim int32) $
concatMap (I.shapeVars . I.arrayShape . I.paramType) used_free_params
free_params = nub $ free_shape_params ++ used_free_params
all_params = constparams ++ free_params ++ shapeparams ++ concat params'
addFunction $ I.FunDef Nothing fname' rettype' all_params body'
return (fname',
pcm<>rcm,
map I.paramName free_params,
shapenames,
map declTypeOf $ concat params',
all_params,
applyRetType rettype' all_params)
bindFunction fname info
when entry $ generateEntryPoint fb
where
zeroExts ts = generaliseExtTypes ts ts
generateEntryPoint :: E.ValBind -> InternaliseM ()
generateEntryPoint (E.ValBind _ ofname retdecl (Info rettype) _ params _ _ loc) =
bindingParams [] (map E.patternNoShapeAnnotations params) $
\_ shapeparams params' -> do
(entry_rettype, _) <- internaliseEntryReturnType $ E.vacuousShapeAnnotations rettype
let entry' = entryPoint (zip params params') (retdecl, rettype, entry_rettype)
args = map (I.Var . I.paramName) $ concat params'
entry_body <- insertStmsM $ do
vals <- fst <$> funcall "entry_result" (E.qualName ofname) args loc
ctx <- extractShapeContext (concat entry_rettype) <$>
mapM (fmap I.arrayDims . subExpType) vals
resultBodyM (ctx ++ vals)
addFunction $
I.FunDef (Just entry') (baseName ofname)
(concat entry_rettype)
(shapeparams ++ concat params') entry_body
entryPoint :: [(E.Pattern,[I.FParam])]
-> (Maybe (E.TypeExp VName), E.StructType, [[I.TypeBase ExtShape Uniqueness]])
-> EntryPoint
entryPoint params (retdecl, eret, crets) =
(concatMap (entryPointType . preParam) params,
case isTupleRecord eret of
Just ts -> concatMap entryPointType $ zip3 retdecls ts crets
_ -> entryPointType (retdecl, eret, concat crets))
where preParam (p_pat, ps) = (paramOuterType p_pat,
E.patternStructType p_pat,
staticShapes $ map I.paramDeclType ps)
paramOuterType (E.PatternAscription _ tdecl _) = Just $ declaredType tdecl
paramOuterType (E.PatternParens p _) = paramOuterType p
paramOuterType _ = Nothing
retdecls = case retdecl of Just (TETuple tes _) -> map Just tes
_ -> repeat Nothing
entryPointType :: (Maybe (E.TypeExp VName),
E.StructType,
[I.TypeBase ExtShape Uniqueness])
-> [EntryPointType]
entryPointType (_, E.Prim E.Unsigned{}, _) =
[I.TypeUnsigned]
entryPointType (_, E.Array _ _ (ArrayPrimElem Unsigned{}) _, _) =
[I.TypeUnsigned]
entryPointType (_, E.Prim{}, _) =
[I.TypeDirect]
entryPointType (_, E.Array _ _ ArrayPrimElem{} _, _) =
[I.TypeDirect]
entryPointType (te, t, ts) =
[I.TypeOpaque desc $ length ts]
where desc = maybe (pretty t') typeExpOpaqueName te
t' = removeShapeAnnotations t `E.setUniqueness` Nonunique
typeExpOpaqueName (TEApply te TypeArgExpDim{} _) =
typeExpOpaqueName te
typeExpOpaqueName (TEArray te _ _) =
let (d, te') = withoutDims te
in "arr_" ++ typeExpOpaqueName te' ++
"_" ++ show (1 + d) ++ "d"
typeExpOpaqueName te = pretty te
withoutDims (TEArray te _ _) =
let (d, te') = withoutDims te
in (d+1, te')
withoutDims te = (0::Int, te)
internaliseIdent :: E.Ident -> InternaliseM I.VName
internaliseIdent (E.Ident name (Info tp) loc) =
case tp of
E.Prim{} -> return name
_ -> fail $ "Futhark.Internalise.internaliseIdent: asked to internalise non-prim-typed ident '"
++ pretty name ++ " of type " ++ pretty tp ++
" at " ++ locStr loc ++ "."
internaliseBody :: E.Exp -> InternaliseM Body
internaliseBody e = insertStmsM $ resultBody <$> internaliseExp "res" e
internaliseBodyStms :: E.Exp -> ([SubExp] -> InternaliseM (Body, a))
-> InternaliseM (Body, a)
internaliseBodyStms e m = do
((Body _ bnds res,x), otherbnds) <-
collectStms $ m =<< internaliseExp "res" e
(,x) <$> mkBodyM (otherbnds <> bnds) res
internaliseExp :: String -> E.Exp -> InternaliseM [I.SubExp]
internaliseExp desc (E.Parens e _) =
internaliseExp desc e
internaliseExp desc (E.QualParens _ e _) =
internaliseExp desc e
internaliseExp _ (E.Var (E.QualName _ name) (Info t) loc) = do
subst <- asks $ M.lookup name . envSubsts
case subst of
Just substs -> return substs
Nothing -> do
is_const <- lookupConstant loc name
case is_const of
Just ses -> return ses
Nothing -> (:[]) . I.Var <$> internaliseIdent (E.Ident name (Info t') loc)
where t' = removeShapeAnnotations t
internaliseExp desc (E.Index e idxs _ loc) = do
vs <- internaliseExpToVars "indexed" e
dims <- case vs of [] -> return []
v:_ -> I.arrayDims <$> lookupType v
(idxs', cs) <- internaliseSlice loc dims idxs
let index v = do v_t <- lookupType v
return $ I.BasicOp $ I.Index v $ fullSlice v_t idxs'
certifying cs $ letSubExps desc =<< mapM index vs
internaliseExp desc (E.TupLit es _) =
concat <$> mapM (internaliseExp desc) es
internaliseExp desc (E.RecordLit orig_fields _) =
concatMap snd . sortFields . M.unions . reverse <$> mapM internaliseField orig_fields
where internaliseField (E.RecordFieldExplicit name e _) =
M.singleton name <$> internaliseExp desc e
internaliseField (E.RecordFieldImplicit name t loc) =
internaliseField $ E.RecordFieldExplicit (baseName name)
(E.Var (E.qualName name) (vacuousShapeAnnotations <$> t) loc) loc
internaliseExp desc (E.ArrayLit es (Info arr_t) loc)
| Just ((eshape,e'):es') <- mapM isArrayLiteral es,
not $ null eshape,
all ((eshape==) . fst) es',
Just basetype <- E.peelArray (length eshape) arr_t = do
let flat_lit = E.ArrayLit (e' ++ concatMap snd es') (Info basetype) loc
new_shape = length es:eshape
flat_arrs <- internaliseExpToVars "flat_literal" flat_lit
forM flat_arrs $ \flat_arr -> do
flat_arr_t <- lookupType flat_arr
let new_shape' = reshapeOuter (map (DimNew . constant) new_shape)
1 $ arrayShape flat_arr_t
letSubExp desc $ I.BasicOp $ I.Reshape new_shape' flat_arr
| otherwise = do
es' <- mapM (internaliseExp "arr_elem") es
case es' of
[] -> do
rowtypes <- internaliseType (rowtype `setAliases` ())
let arraylit rt = I.BasicOp $ I.ArrayLit [] rt
letSubExps desc $ map (arraylit . zeroDim . fromDecl) rowtypes
e' : _ -> do
rowtypes <- mapM subExpType e'
let arraylit ks rt = do
ks' <- mapM (ensureShape asserting "shape of element differs from shape of first element"
loc rt "elem_reshaped") ks
return $ I.BasicOp $ I.ArrayLit ks' rt
letSubExps desc =<< zipWithM arraylit (transpose es') rowtypes
where rowtype = E.stripArray 1 arr_t
zeroDim t = t `I.setArrayShape`
I.Shape (replicate (I.arrayRank t) (constant (0::Int32)))
isArrayLiteral :: E.Exp -> Maybe ([Int],[E.Exp])
isArrayLiteral (E.ArrayLit inner_es _ _) = do
(eshape,e):inner_es' <- mapM isArrayLiteral inner_es
guard $ all ((eshape==) . fst) inner_es'
return (length inner_es:eshape, e ++ concatMap snd inner_es')
isArrayLiteral e =
Just ([], [e])
internaliseExp desc (E.Range start maybe_second end _ _) = do
start' <- internaliseExp1 "range_start" start
end' <- internaliseExp1 "range_end" $ case end of
DownToExclusive e -> e
ToInclusive e -> e
UpToExclusive e -> e
(it, le_op, lt_op) <-
case E.typeOf start of
E.Prim (E.Signed it) -> return (it, CmpSle it, CmpSlt it)
E.Prim (E.Unsigned it) -> return (it, CmpUle it, CmpUlt it)
start_t -> fail $ "Start value in range has type " ++ pretty start_t
let one = intConst it 1
negone = intConst it (-1)
default_step = case end of DownToExclusive{} -> negone
ToInclusive{} -> one
UpToExclusive{} -> one
(step, step_zero) <- case maybe_second of
Just second -> do
second' <- internaliseExp1 "range_second" second
subtracted_step <- letSubExp "subtracted_step" $ I.BasicOp $ I.BinOp (I.Sub it) second' start'
step_zero <- letSubExp "step_zero" $ I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) start' second'
return (subtracted_step, step_zero)
Nothing ->
return (default_step, constant False)
step_sign <- letSubExp "s_sign" $ BasicOp $ I.UnOp (I.SSignum it) step
step_sign_i32 <- asIntS Int32 step_sign
bounds_invalid_downwards <- letSubExp "bounds_invalid_downwards" $
I.BasicOp $ I.CmpOp le_op start' end'
bounds_invalid_upwards <- letSubExp "bounds_invalid_upwards" $
I.BasicOp $ I.CmpOp lt_op end' start'
(distance, step_wrong_dir, bounds_invalid) <- case end of
DownToExclusive{} -> do
step_wrong_dir <- letSubExp "step_wrong_dir" $
I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) step_sign one
distance <- letSubExp "distance" $
I.BasicOp $ I.BinOp (Sub it) start' end'
distance_i32 <- asIntZ Int32 distance
return (distance_i32, step_wrong_dir, bounds_invalid_downwards)
UpToExclusive{} -> do
step_wrong_dir <- letSubExp "step_wrong_dir" $
I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) step_sign negone
distance <- letSubExp "distance" $ I.BasicOp $ I.BinOp (Sub it) end' start'
distance_i32 <- asIntZ Int32 distance
return (distance_i32, step_wrong_dir, bounds_invalid_upwards)
ToInclusive{} -> do
downwards <- letSubExp "downwards" $
I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) step_sign negone
distance_downwards_exclusive <-
letSubExp "distance_downwards_exclusive" $
I.BasicOp $ I.BinOp (Sub it) start' end'
distance_upwards_exclusive <-
letSubExp "distance_upwards_exclusive" $
I.BasicOp $ I.BinOp (Sub it) end' start'
bounds_invalid <- letSubExp "bounds_invalid" $
I.If downwards
(resultBody [bounds_invalid_downwards])
(resultBody [bounds_invalid_upwards]) $
ifCommon [I.Prim I.Bool]
distance_exclusive <- letSubExp "distance_exclusive" $
I.If downwards
(resultBody [distance_downwards_exclusive])
(resultBody [distance_upwards_exclusive]) $
ifCommon [I.Prim $ IntType it]
distance_exclusive_i32 <- asIntZ Int32 distance_exclusive
distance <- letSubExp "distance" $
I.BasicOp $ I.BinOp (Add Int32)
distance_exclusive_i32 (intConst Int32 1)
return (distance, constant False, bounds_invalid)
step_invalid <- letSubExp "step_invalid" $
I.BasicOp $ I.BinOp I.LogOr step_wrong_dir step_zero
invalid <- letSubExp "range_invalid" $
I.BasicOp $ I.BinOp I.LogOr step_invalid bounds_invalid
step_i32 <- asIntS Int32 step
pos_step <- letSubExp "pos_step" $
I.BasicOp $ I.BinOp (Mul Int32) step_i32 step_sign_i32
num_elems <- letSubExp "num_elems" =<<
eIf (eSubExp invalid)
(eBody [eSubExp $ intConst Int32 0])
(eBody [eDivRoundingUp Int32 (eSubExp distance) (eSubExp pos_step)])
pure <$> letSubExp desc (I.BasicOp $ I.Iota num_elems start' step it)
internaliseExp desc (E.Ascript e (TypeDecl dt (Info et)) loc) = do
es <- internaliseExp desc e
(ts, cm) <- internaliseReturnType et
mapM_ (uncurry (internaliseDimConstant loc)) cm
dt' <- typeExpForError cm dt
forM (zip es ts) $ \(e',t') -> do
dims <- arrayDims <$> subExpType e'
let parts = ["Value of (core language) shape ("] ++
intersperse ", " (map ErrorInt32 dims) ++
[") cannot match shape of type `"] ++ dt' ++ ["`."]
ensureExtShape asserting (ErrorMsg parts) loc (I.fromDecl t') desc e'
internaliseExp desc (E.Negate e _) = do
e' <- internaliseExp1 "negate_arg" e
et <- subExpType e'
case et of I.Prim (I.IntType t) ->
letTupExp' desc $ I.BasicOp $ I.BinOp (I.Sub t) (I.intConst t 0) e'
I.Prim (I.FloatType t) ->
letTupExp' desc $ I.BasicOp $ I.BinOp (I.FSub t) (I.floatConst t 0) e'
_ -> fail "Futhark.Internalise.internaliseExp: non-numeric type in Negate"
internaliseExp desc e@E.Apply{} = do
(qfname, args, _) <- findFuncall e
let fname = nameFromString $ pretty $ baseName $ qualLeaf qfname
loc = srclocOf e
case () of
() | Just internalise <- isOverloadedFunction qfname args loc ->
internalise desc
| Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do
let tag ses = [ (se, I.Observe) | se <- ses ]
args' <- mapM (internaliseExp "arg") args
let args'' = concatMap tag args'
letTupExp' desc $ I.Apply fname args'' [I.Prim rettype] (Safe, loc, [])
| otherwise -> do
args' <- concat <$> mapM (internaliseExp "arg") args
fst <$> funcall desc qfname args' loc
internaliseExp desc (E.LetPat tparams pat e body loc) =
internalisePat desc tparams pat e body loc (internaliseExp desc)
internaliseExp desc (E.LetFun ofname (tparams, params, retdecl, Info rettype, body) letbody loc) = do
internaliseValBind $ E.ValBind False ofname retdecl (Info rettype) tparams params body Nothing loc
internaliseExp desc letbody
internaliseExp desc (E.DoLoop tparams mergepat mergeexp form loopbody loc) = do
ses <- internaliseExp "loop_init" mergeexp
t <- I.staticShapes <$> mapM I.subExpType ses
stmPattern tparams mergepat t $ \cm mergepat_names match -> do
mapM_ (uncurry (internaliseDimConstant loc)) cm
ses' <- match (srclocOf mergepat) ses
forM_ (zip mergepat_names ses') $ \(v,se) ->
letBindNames_ [v] $ I.BasicOp $ I.SubExp se
let mergeinit = map I.Var mergepat_names
(loopbody', (form', shapepat, mergepat', mergeinit', pre_stms)) <-
handleForm mergeinit form
addStms pre_stms
mergeinit_ts' <- mapM subExpType mergeinit'
let ctxinit = argShapes
(map I.paramName shapepat)
(map I.paramType mergepat')
mergeinit_ts'
ctxmerge = zip shapepat ctxinit
valmerge = zip mergepat' mergeinit'
merge = ctxmerge ++ valmerge
dropCond = case form of E.While{} -> drop 1
_ -> id
let merge_names = map (I.paramName . fst) merge
merge_ts = existentialiseExtTypes merge_names $
staticShapes $ map (I.paramType . fst) merge
loopbody'' <- localScope (scopeOfFParams $ map fst merge) $
ensureResultExtShapeNoCtx asserting
"shape of loop result does not match shapes in loop parameters"
loc merge_ts loopbody'
loop_res <- letTupExp desc $ I.DoLoop ctxmerge valmerge form' loopbody''
return $ map I.Var $ dropCond loop_res
where
forLoop nested_mergepat shapepat mergeinit form' =
inScopeOf form' $ internaliseBodyStms loopbody $ \ses -> do
sets <- mapM subExpType ses
let mergepat' = concat nested_mergepat
shapeargs = argShapes
(map I.paramName shapepat)
(map I.paramType mergepat')
sets
return (resultBody $ shapeargs ++ ses,
(form',
shapepat,
mergepat',
mergeinit,
mempty))
handleForm mergeinit (E.ForIn x arr) = do
arr' <- internaliseExpToVars "for_in_arr" arr
arr_ts <- mapM lookupType arr'
let w = arraysSize 0 arr_ts
i <- newVName "i"
bindingParams tparams [mergepat] $ \mergecm shapepat nested_mergepat ->
bindingLambdaParams [] [x] (map rowType arr_ts) $ \x_cm x_params -> do
mapM_ (uncurry (internaliseDimConstant loc)) x_cm
mapM_ (uncurry (internaliseDimConstant loc)) mergecm
let loopvars = zip x_params arr'
forLoop nested_mergepat shapepat mergeinit $ I.ForLoop i Int32 w loopvars
handleForm mergeinit (E.For i num_iterations) = do
num_iterations' <- internaliseExp1 "upper_bound" num_iterations
i' <- internaliseIdent i
num_iterations_t <- I.subExpType num_iterations'
it <- case num_iterations_t of
I.Prim (IntType it) -> return it
_ -> fail "internaliseExp DoLoop: invalid type"
bindingParams tparams [mergepat] $ \mergecm shapepat nested_mergepat -> do
mapM_ (uncurry (internaliseDimConstant loc)) mergecm
forLoop nested_mergepat shapepat mergeinit $ I.ForLoop i' it num_iterations' []
handleForm mergeinit (E.While cond) =
bindingParams tparams [mergepat] $ \mergecm shapepat nested_mergepat -> do
mergeinit_ts <- mapM subExpType mergeinit
mapM_ (uncurry (internaliseDimConstant loc)) mergecm
let mergepat' = concat nested_mergepat
let shapeinit = argShapes
(map I.paramName shapepat)
(map I.paramType mergepat')
mergeinit_ts
(loop_initial_cond, init_loop_cond_bnds) <- collectStms $ do
forM_ (zip shapepat shapeinit) $ \(p, se) ->
letBindNames_ [paramName p] $ BasicOp $ SubExp se
forM_ (zip (concat nested_mergepat) mergeinit) $ \(p, se) ->
unless (se == I.Var (paramName p)) $
letBindNames_ [paramName p] $ BasicOp $
case se of I.Var v | not $ primType $ paramType p ->
Reshape (map DimCoercion $ arrayDims $ paramType p) v
_ -> SubExp se
internaliseExp1 "loop_cond" cond
internaliseBodyStms loopbody $ \ses -> do
sets <- mapM subExpType ses
loop_while <- newParam "loop_while" $ I.Prim I.Bool
let shapeargs = argShapes
(map I.paramName shapepat)
(map I.paramType mergepat')
sets
loop_end_cond_body <- renameBody <=< insertStmsM $ do
forM_ (zip shapepat shapeargs) $ \(p, se) ->
unless (se == I.Var (paramName p)) $
letBindNames_ [paramName p] $ BasicOp $ SubExp se
forM_ (zip (concat nested_mergepat) ses) $ \(p, se) ->
unless (se == I.Var (paramName p)) $
letBindNames_ [paramName p] $ BasicOp $
case se of I.Var v | not $ primType $ paramType p ->
Reshape (map DimCoercion $ arrayDims $ paramType p) v
_ -> SubExp se
resultBody <$> internaliseExp "loop_cond" cond
loop_end_cond <- bodyBind loop_end_cond_body
return (resultBody $ shapeargs++loop_end_cond++ses,
(I.WhileLoop $ I.paramName loop_while,
shapepat,
loop_while : mergepat',
loop_initial_cond : mergeinit,
init_loop_cond_bnds))
internaliseExp desc (E.LetWith name src idxs ve body loc) = do
let pat = E.Id (E.identName name) (E.vacuousShapeAnnotations <$> E.identType name) loc
src_t = E.fromStruct . E.vacuousShapeAnnotations <$> E.identType src
e = E.Update (E.Var (E.qualName $ E.identName src) src_t loc) idxs ve loc
internaliseExp desc $ E.LetPat [] pat e body loc
internaliseExp desc (E.Update src slice ve loc) = do
ves <- internaliseExp "lw_val" ve
srcs <- internaliseExpToVars "src" src
dims <- case srcs of
[] -> return []
v:_ -> I.arrayDims <$> lookupType v
(idxs', cs) <- internaliseSlice loc dims slice
let comb sname ve' = do
sname_t <- lookupType sname
let full_slice = fullSlice sname_t idxs'
rowtype = sname_t `setArrayDims` sliceDims full_slice
ve'' <- ensureShape asserting "shape of value does not match shape of source array"
loc rowtype "lw_val_correct_shape" ve'
letInPlace desc sname full_slice $ BasicOp $ SubExp ve''
certifying cs $ map I.Var <$> zipWithM comb srcs ves
internaliseExp desc (E.RecordUpdate src fields ve _ _) = do
src' <- internaliseExp desc src
ve' <- internaliseExp desc ve
replace (E.typeOf src `setAliases` ()) fields ve' src'
where replace (E.Record m) (f:fs) ve' src'
| Just t <- M.lookup f m = do
i <- fmap sum $ mapM (internalisedTypeSize . snd) $
takeWhile ((/=f) . fst) $ sortFields m
k <- internalisedTypeSize t
let (bef, to_update, aft) = splitAt3 i k src'
src'' <- replace t fs ve' to_update
return $ bef ++ src'' ++ aft
replace _ _ ve' _ = return ve'
internaliseExp desc (E.Unzip e _ _) =
internaliseExp desc e
internaliseExp desc (E.Unsafe e _) =
local (\env -> env { envDoBoundsChecks = False }) $
internaliseExp desc e
internaliseExp desc (E.Assert e1 e2 (Info check) loc) = do
e1' <- internaliseExp1 "assert_cond" e1
c <- assertingOne $ letExp "assert_c" $
I.BasicOp $ I.Assert e1' (ErrorMsg [ErrorString check]) (loc, mempty)
certifying c $ mapM rebind =<< internaliseExp desc e2
where rebind v = do
v' <- newVName "assert_res"
letBindNames_ [v'] $ I.BasicOp $ I.SubExp v
return $ I.Var v'
internaliseExp _ (E.Zip _ e es _ loc) = do
e' <- internaliseExpToVars "zip_arg" $ TupLit (e:es) loc
case e' of
e_key:es_unchecked -> do
w <- arraySize 0 <$> lookupType e_key
let reshapeToOuter e_unchecked' = do
unchecked_t <- lookupType e_unchecked'
case I.arrayDims unchecked_t of
outer:inner | w /= outer -> do
cmp <- letSubExp "zip_cmp" $ I.BasicOp $
I.CmpOp (I.CmpEq I.int32) w outer
c <- assertingOne $
letExp "zip_assert" $ I.BasicOp $
I.Assert cmp "arrays differ in length" (loc, mempty)
certifying c $ letExp (postfix e_unchecked' "_zip_res") $
shapeCoerce (w:inner) e_unchecked'
_ -> return e_unchecked'
es' <- mapM reshapeToOuter es_unchecked
return $ map I.Var $ e_key : es'
[] -> return []
where postfix i s = baseString i ++ s
internaliseExp desc (E.Map lam arr _ _) = do
arr' <- internaliseExpToVars "map_arr" arr
lam' <- internaliseMapLambda internaliseLambda lam $ map I.Var arr'
w <- arraysSize 0 <$> mapM lookupType arr'
letTupExp' desc $ I.Op $
I.Screma w (I.mapSOAC lam') arr'
internaliseExp desc (E.Reduce comm lam ne arr loc) =
internaliseScanOrReduce desc "reduce" reduce (lam, ne, arr, loc)
where reduce w red_lam nes arrs =
I.Screma w <$> I.reduceSOAC comm red_lam nes <*> pure arrs
internaliseExp desc (E.GenReduce hist op ne buckets img loc) = do
ne' <- internaliseExp "gen_reduce_ne" ne
hist' <- internaliseExpToVars "gen_reduce_hist" hist
buckets' <- letExp "gen_reduce_buckets" . BasicOp . SubExp =<<
internaliseExp1 "gen_reduce_buckets" buckets
img' <- internaliseExpToVars "gen_reduce_img" img
ne_shp <- forM (zip ne' hist') $ \(n, h) -> do
rowtype <- I.stripArray 1 <$> lookupType h
ensureShape asserting
"Row shape of destination array does not match shape of neutral element"
loc rowtype "gen_reduce_ne_right_shape" n
ne_ts <- mapM I.subExpType ne_shp
his_ts <- mapM lookupType hist'
op' <- internaliseFoldLambda internaliseLambda op ne_ts his_ts
bucket_param <- newParam "bucket_p" $ I.Prim int32
img_params <- mapM (newParam "img_p" . rowType) =<< mapM lookupType img'
let params = bucket_param : img_params
rettype = I.Prim int32 : ne_ts
body = mkBody mempty $ map (I.Var . paramName) params
body' <- localScope (scopeOfLParams params) $
ensureResultShape asserting
"Row shape of value array does not match row shape of gen_reduce target"
(srclocOf img) rettype body
w_hist <- arraysSize 0 <$> mapM lookupType hist'
w_img <- arraysSize 0 <$> mapM lookupType img'
b_shape <- arrayShape <$> lookupType buckets'
let b_w = shapeSize 0 b_shape
cmp <- letSubExp "bucket_cmp" $ I.BasicOp $ I.CmpOp (I.CmpEq I.int32) b_w w_img
c <- assertingOne $
letExp "bucket_cert" $ I.BasicOp $
I.Assert cmp "length of index and value array does not match" (loc, mempty)
buckets'' <- certifying c $ letExp (baseString buckets') $
I.BasicOp $ I.Reshape (reshapeOuter [DimCoercion w_img] 1 b_shape) buckets'
letTupExp' desc $ I.Op $
I.GenReduce w_img [GenReduceOp w_hist hist' ne_shp op'] (I.Lambda params body' rettype) $ buckets'' : img'
internaliseExp desc (E.Scan lam ne arr loc) =
internaliseScanOrReduce desc "scan" scan (lam, ne, arr, loc)
where scan w scan_lam nes arrs =
I.Screma w <$> I.scanSOAC scan_lam nes <*> pure arrs
internaliseExp _ (E.Filter lam arr _) = do
arrs <- internaliseExpToVars "filter_input" arr
lam' <- internalisePartitionLambda internaliseLambda 1 lam $ map I.Var arrs
uncurry (++) <$> partitionWithSOACS 1 lam' arrs
internaliseExp _ (E.Partition k lam arr _) = do
arrs <- internaliseExpToVars "partition_input" arr
lam' <- internalisePartitionLambda internaliseLambda k lam $ map I.Var arrs
uncurry (++) <$> partitionWithSOACS k lam' arrs
internaliseExp desc (E.Stream (E.MapLike o) lam arr _) = do
arrs <- internaliseExpToVars "stream_input" arr
lam' <- internaliseStreamMapLambda internaliseLambda lam $ map I.Var arrs
w <- arraysSize 0 <$> mapM lookupType arrs
let form = I.Parallel o Commutative (I.Lambda [] (mkBody mempty []) []) []
letTupExp' desc $ I.Op $ I.Stream w form lam' arrs
internaliseExp desc (E.Stream (E.RedLike o comm lam0) lam arr _) = do
arrs <- internaliseExpToVars "stream_input" arr
rowts <- mapM (fmap I.rowType . lookupType) arrs
(lam_params, lam_body) <-
internaliseStreamLambda internaliseLambda lam rowts
let (chunk_param, _, lam_val_params) =
partitionChunkedFoldParameters 0 lam_params
letBindNames_ [I.paramName chunk_param] $
I.BasicOp $ I.SubExp $ constant (0::Int32)
forM_ lam_val_params $ \p ->
letBindNames_ [I.paramName p] $
I.BasicOp $ I.Scratch (I.elemType $ I.paramType p) $
I.arrayDims $ I.paramType p
accs <- bodyBind =<< renameBody lam_body
acctps <- mapM I.subExpType accs
outsz <- arraysSize 0 <$> mapM lookupType arrs
let acc_arr_tps = [ I.arrayOf t (I.Shape [outsz]) NoUniqueness | t <- acctps ]
lam0' <- internaliseFoldLambda internaliseLambda lam0 acctps acc_arr_tps
let lam0_acc_params = fst $ splitAt (length accs) $ I.lambdaParams lam0'
acc_params <- forM lam0_acc_params $ \p -> do
name <- newVName $ baseString $ I.paramName p
return p { I.paramName = name }
body_with_lam0 <-
ensureResultShape asserting "shape of result does not match shape of initial value"
(srclocOf lam0) acctps <=< insertStmsM $ do
lam_res <- bodyBind lam_body
let consumed = consumedByLambda $ Alias.analyseLambda lam0'
copyIfConsumed p (I.Var v)
| I.paramName p `S.member` consumed =
letSubExp "acc_copy" $ I.BasicOp $ I.Copy v
copyIfConsumed _ x = return x
accs' <- zipWithM copyIfConsumed (I.lambdaParams lam0') accs
lam_res' <- ensureArgShapes asserting
"shape of chunk function result does not match shape of initial value"
(srclocOf lam) [] (map I.typeOf $ I.lambdaParams lam0') lam_res
new_lam_res <- eLambda lam0' $ map eSubExp $ accs' ++ lam_res'
return $ resultBody new_lam_res
let form = I.Parallel o comm lam0' accs
lam' = I.Lambda { lambdaParams = chunk_param : acc_params ++ lam_val_params
, lambdaBody = body_with_lam0
, lambdaReturnType = acctps }
w <- arraysSize 0 <$> mapM lookupType arrs
letTupExp' desc $ I.Op $ I.Stream w form lam' arrs
internaliseExp _ (E.VConstr0 c (Info t) loc) =
case t of
Enum cs ->
case elemIndex c $ sort cs of
Just i -> return [I.Constant $ I.IntValue $ intValue I.Int8 i]
_ -> fail $ "internaliseExp: invalid constructor: #" ++ nameToString c ++
"\nfor enum at " ++ locStr loc ++ ": " ++ pretty t
_ -> fail $ "internaliseExp: nonsensical type for enum at "
++ locStr loc ++ ": " ++ pretty t
internaliseExp desc (E.Match e cs _ loc) =
case cs of
[CasePat _ eCase _] -> internaliseExp desc eCase
(c:cs') -> do
bFalse <- bFalseM
letTupExp' desc =<< generateCaseIf desc e c bFalse
where bFalseM = do
eLast' <- internalisePat desc [] pLast e eLast locLast internaliseBody
foldM (\bf c' -> eBody $ return $ generateCaseIf desc e c' bf) eLast' (reverse $ init cs')
CasePat pLast eLast locLast = last cs'
[] -> fail $ "internaliseExp: match with no cases at: " ++ locStr loc
internaliseExp _ (E.Literal v _) =
return [I.Constant $ internalisePrimValue v]
internaliseExp _ (E.IntLit v (Info t) _) =
case t of
E.Prim (E.Signed it) ->
return [I.Constant $ I.IntValue $ intValue it v]
E.Prim (E.Unsigned it) ->
return [I.Constant $ I.IntValue $ intValue it v]
E.Prim (E.FloatType ft) ->
return [I.Constant $ I.FloatValue $ floatValue ft v]
_ -> fail $ "internaliseExp: nonsensical type for integer literal: " ++ pretty t
internaliseExp _ (E.FloatLit v (Info t) _) =
case t of
E.Prim (E.FloatType ft) ->
return [I.Constant $ I.FloatValue $ floatValue ft v]
_ -> fail $ "internaliseExp: nonsensical type for float literal: " ++ pretty t
internaliseExp desc (E.If ce te fe _ _) =
letTupExp' desc =<< eIf (BasicOp . SubExp <$> internaliseExp1 "cond" ce)
(internaliseBody te) (internaliseBody fe)
internaliseExp desc (E.BinOp op _ (xe,_) (ye,_) _ loc)
| Just internalise <- isOverloadedFunction op [xe, ye] loc =
internalise desc
internaliseExp desc (E.BinOp op (Info t) (xarg, Info xt) (yarg, Info yt) _ loc) =
internaliseExp desc $
E.Apply (E.Apply (E.Var op (Info t) loc) xarg (Info $ E.diet xt)
(Info $ foldFunType [E.fromStruct yt] t) loc)
yarg (Info $ E.diet yt) (Info t) loc
internaliseExp desc (E.Project k e (Info rt) _) = do
n <- internalisedTypeSize $ rt `setAliases` ()
i' <- fmap sum $ mapM internalisedTypeSize $
case E.typeOf e `setAliases` () of
Record fs -> map snd $ takeWhile ((/=k) . fst) $ sortFields fs
t -> [t]
take n . drop i' <$> internaliseExp desc e
internaliseExp _ e@E.Lambda{} =
fail $ "internaliseExp: Unexpected lambda at " ++ locStr (srclocOf e)
internaliseExp _ e@E.OpSection{} =
fail $ "internaliseExp: Unexpected operator section at " ++ locStr (srclocOf e)
internaliseExp _ e@E.OpSectionLeft{} =
fail $ "internaliseExp: Unexpected left operator section at " ++ locStr (srclocOf e)
internaliseExp _ e@E.OpSectionRight{} =
fail $ "internaliseExp: Unexpected right operator section at " ++ locStr (srclocOf e)
internaliseExp _ e@E.ProjectSection{} =
fail $ "internaliseExp: Unexpected projection section at " ++ locStr (srclocOf e)
internaliseExp _ e@E.IndexSection{} =
fail $ "internaliseExp: Unexpected index section at " ++ locStr (srclocOf e)
andExp :: E.Exp -> E.Exp -> E.Exp
andExp l r = E.If l r (E.Literal (E.BoolValue False) noLoc) (Info (E.Prim E.Bool)) noLoc
eqExp :: E.Exp -> E.Exp -> E.Exp
eqExp l r = E.BinOp eq (Info $ vacuousShapeAnnotations ft)
(l, sType l) (r, sType r) (Info (E.Prim E.Bool)) noLoc
where sType e = Info $ toStruct $ vacuousShapeAnnotations $ E.typeOf e
arrow = Arrow S.empty Nothing
ft = E.typeOf l `arrow` E.typeOf r `arrow` E.Prim E.Bool
eq = qualName $ VName "==" (-1)
generateCond :: E.Pattern -> E.Exp -> E.Exp
generateCond p e = foldr andExp (E.Literal (E.BoolValue True) noLoc) conds
where conds = mapMaybe ((<*> pure e) . fst) $ generateCond' p
generateCond' :: E.Pattern -> [(Maybe (E.Exp -> E.Exp), CompType)]
generateCond' (E.TuplePattern ps loc) = generateCond' (E.RecordPattern fs loc)
where fs = zipWith (\i p' -> (nameFromString (show i), p')) ([1..] :: [Integer]) ps
generateCond' (E.RecordPattern fs _) = concatMap instCond holes
where holes = map (\(n, p') -> (generateCond' p', n)) fs
field ([],_) = Nothing
field ((_, t):_, f) = Just (f, t)
t' = Record $ M.fromList $ mapMaybe field holes
projectHole _ (Nothing, _) = (Nothing, t')
projectHole f (Just condHole, t) =
(Just (\e' -> condHole $ Project f e' (Info t) noLoc), t')
instCond (condHoles, f) = map (projectHole f) condHoles
generateCond' (E.PatternParens p' _) = generateCond' p'
generateCond' (E.Id _ (Info t) _) =
[(Nothing, removeShapeAnnotations t)]
generateCond' (E.Wildcard (Info t) _)=
[(Nothing, removeShapeAnnotations t)]
generateCond' (E.PatternAscription p' _ _) = generateCond' p'
generateCond' (E.PatternLit ePat (Info t) _) =
[(Just (eqExp ePat), removeShapeAnnotations t)]
generateCaseIf :: String -> E.Exp -> Case -> I.Body -> InternaliseM I.Exp
generateCaseIf desc e (CasePat p eCase loc) bFail = do
eCase' <- internalisePat desc [] p e eCase loc internaliseBody
eIf cond (return eCase') (return bFail)
where cond = BasicOp . SubExp <$> internaliseExp1 "cond" (generateCond p e)
internalisePat :: String -> [TypeParamBase VName] -> E.Pattern -> E.Exp
-> E.Exp -> SrcLoc -> (E.Exp -> InternaliseM a) -> InternaliseM a
internalisePat desc tparams p e body loc m = do
ses <- internaliseExp desc e
t <- I.staticShapes <$> mapM I.subExpType ses
stmPattern tparams p t $ \cm pat_names match -> do
mapM_ (uncurry (internaliseDimConstant loc)) cm
ses' <- match loc ses
forM_ (zip pat_names ses') $ \(v,se) ->
letBindNames_ [v] $ I.BasicOp $ I.SubExp se
m body
internaliseSlice :: SrcLoc
-> [SubExp]
-> [E.DimIndex]
-> InternaliseM ([I.DimIndex SubExp], Certificates)
internaliseSlice loc dims idxs = do
(idxs', oks, parts) <- unzip3 <$> zipWithM internaliseDimIndex dims idxs
c <- assertingOne $ do
ok <- letSubExp "index_ok" =<< foldBinOp I.LogAnd (constant True) oks
let msg = ErrorMsg $ ["Index ["] ++ intercalate [", "] parts ++
["] out of bounds for array of shape ["] ++
intersperse "][" (map ErrorInt32 $ take (length idxs) dims) ++ ["]."]
letExp "index_certs" $ I.BasicOp $ I.Assert ok msg (loc, mempty)
return (idxs', c)
internaliseDimIndex :: SubExp -> E.DimIndex
-> InternaliseM (I.DimIndex SubExp, SubExp, [ErrorMsgPart SubExp])
internaliseDimIndex w (E.DimFix i) = do
(i', _) <- internaliseDimExp "i" i
let lowerBound = I.BasicOp $
I.CmpOp (I.CmpSle I.Int32) (I.constant (0 :: I.Int32)) i'
upperBound = I.BasicOp $
I.CmpOp (I.CmpSlt I.Int32) i' w
ok <- letSubExp "bounds_check" =<< eBinOp I.LogAnd (pure lowerBound) (pure upperBound)
return (I.DimFix i', ok, [ErrorInt32 i'])
internaliseDimIndex w (E.DimSlice i j s) = do
s' <- maybe (return one) (fmap fst . internaliseDimExp "s") s
s_sign <- letSubExp "s_sign" $ BasicOp $ I.UnOp (I.SSignum Int32) s'
backwards <- letSubExp "backwards" $ I.BasicOp $ I.CmpOp (I.CmpEq int32) s_sign negone
w_minus_1 <- letSubExp "w_minus_1" $ BasicOp $ I.BinOp (Sub Int32) w one
let i_def = letSubExp "i_def" $ I.If backwards
(resultBody [w_minus_1])
(resultBody [zero]) $ ifCommon [I.Prim int32]
j_def = letSubExp "j_def" $ I.If backwards
(resultBody [negone])
(resultBody [w]) $ ifCommon [I.Prim int32]
i' <- maybe i_def (fmap fst . internaliseDimExp "i") i
j' <- maybe j_def (fmap fst . internaliseDimExp "j") j
j_m_i <- letSubExp "j_m_i" $ BasicOp $ I.BinOp (Sub Int32) j' i'
let divRounding x y =
eBinOp (SQuot Int32) (eBinOp (Add Int32) x (eBinOp (Sub Int32) y (eSignum $ toExp s'))) y
n <- letSubExp "n" =<< divRounding (toExp j_m_i) (toExp s')
empty_slice <- letSubExp "empty_slice" $ I.BasicOp $ I.CmpOp (CmpEq int32) n zero
m <- letSubExp "m" $ I.BasicOp $ I.BinOp (Sub Int32) n one
m_t_s <- letSubExp "m_t_s" $ I.BasicOp $ I.BinOp (Mul Int32) m s'
i_p_m_t_s <- letSubExp "i_p_m_t_s" $ I.BasicOp $ I.BinOp (Add Int32) i' m_t_s
zero_leq_i_p_m_t_s <- letSubExp "zero_leq_i_p_m_t_s" $
I.BasicOp $ I.CmpOp (I.CmpSle Int32) zero i_p_m_t_s
i_p_m_t_s_leq_w <- letSubExp "i_p_m_t_s_leq_w" $
I.BasicOp $ I.CmpOp (I.CmpSle Int32) i_p_m_t_s w
i_p_m_t_s_lth_w <- letSubExp "i_p_m_t_s_leq_w" $
I.BasicOp $ I.CmpOp (I.CmpSlt Int32) i_p_m_t_s w
zero_lte_i <- letSubExp "zero_lte_i" $ I.BasicOp $ I.CmpOp (I.CmpSle Int32) zero i'
i_lte_j <- letSubExp "i_lte_j" $ I.BasicOp $ I.CmpOp (I.CmpSle Int32) i' j'
forwards_ok <- letSubExp "forwards_ok" =<<
foldBinOp I.LogAnd zero_lte_i
[zero_lte_i, i_lte_j, zero_leq_i_p_m_t_s, i_p_m_t_s_lth_w]
negone_lte_j <- letSubExp "negone_lte_j" $ I.BasicOp $ I.CmpOp (I.CmpSle Int32) negone j'
j_lte_i <- letSubExp "j_lte_i" $ I.BasicOp $ I.CmpOp (I.CmpSle Int32) j' i'
backwards_ok <- letSubExp "backwards_ok" =<<
foldBinOp I.LogAnd negone_lte_j
[negone_lte_j, j_lte_i, zero_leq_i_p_m_t_s, i_p_m_t_s_leq_w]
slice_ok <- letSubExp "slice_ok" $ I.If backwards
(resultBody [backwards_ok])
(resultBody [forwards_ok]) $
ifCommon [I.Prim I.Bool]
ok_or_empty <- letSubExp "ok_or_empty" $
I.BasicOp $ I.BinOp I.LogOr empty_slice slice_ok
let parts = case (i, j, s) of
(_, _, Just{}) ->
[maybe "" (const $ ErrorInt32 i') i, ":",
maybe "" (const $ ErrorInt32 j') j, ":",
ErrorInt32 s']
(_, Just{}, _) ->
[maybe "" (const $ ErrorInt32 i') i, ":",
ErrorInt32 j'] ++
maybe mempty (const [":", ErrorInt32 s']) s
(_, Nothing, Nothing) ->
[ErrorInt32 i']
return (I.DimSlice i' n s', ok_or_empty, parts)
where zero = constant (0::Int32)
negone = constant (-1::Int32)
one = constant (1::Int32)
internaliseScanOrReduce :: String -> String
-> (SubExp -> I.Lambda -> [SubExp] -> [VName] -> InternaliseM (SOAC SOACS))
-> (E.Exp, E.Exp, E.Exp, SrcLoc)
-> InternaliseM [SubExp]
internaliseScanOrReduce desc what f (lam, ne, arr, loc) = do
arrs <- internaliseExpToVars (what++"_arr") arr
nes <- internaliseExp (what++"_ne") ne
nes' <- forM (zip nes arrs) $ \(ne', arr') -> do
rowtype <- I.stripArray 1 <$> lookupType arr'
ensureShape asserting
"Row shape of input array does not match shape of neutral element"
loc rowtype (what++"_ne_right_shape") ne'
nests <- mapM I.subExpType nes'
arrts <- mapM lookupType arrs
lam' <- internaliseFoldLambda internaliseLambda lam nests arrts
w <- arraysSize 0 <$> mapM lookupType arrs
letTupExp' desc . I.Op =<< f w lam' nes' arrs
internaliseExp1 :: String -> E.Exp -> InternaliseM I.SubExp
internaliseExp1 desc e = do
vs <- internaliseExp desc e
case vs of [se] -> return se
_ -> fail "Internalise.internaliseExp1: was passed not just a single subexpression"
internaliseDimExp :: String -> E.Exp -> InternaliseM (I.SubExp, IntType)
internaliseDimExp s e = do
e' <- internaliseExp1 s e
case E.typeOf e of
E.Prim (Signed it) -> (,it) <$> asIntS Int32 e'
E.Prim (Unsigned it) -> (,it) <$> asIntZ Int32 e'
_ -> fail "internaliseDimExp: bad type"
internaliseExpToVars :: String -> E.Exp -> InternaliseM [I.VName]
internaliseExpToVars desc e =
mapM asIdent =<< internaliseExp desc e
where asIdent (I.Var v) = return v
asIdent se = letExp desc $ I.BasicOp $ I.SubExp se
internaliseOperation :: String
-> E.Exp
-> (I.VName -> InternaliseM I.BasicOp)
-> InternaliseM [I.SubExp]
internaliseOperation s e op = do
vs <- internaliseExpToVars s e
letSubExps s =<< mapM (fmap I.BasicOp . op) vs
internaliseBinOp :: String
-> E.BinOp
-> I.SubExp -> I.SubExp
-> E.PrimType
-> E.PrimType
-> InternaliseM [I.SubExp]
internaliseBinOp desc E.Plus x y (E.Signed t) _ =
simpleBinOp desc (I.Add t) x y
internaliseBinOp desc E.Plus x y (E.Unsigned t) _ =
simpleBinOp desc (I.Add t) x y
internaliseBinOp desc E.Plus x y (E.FloatType t) _ =
simpleBinOp desc (I.FAdd t) x y
internaliseBinOp desc E.Minus x y (E.Signed t) _ =
simpleBinOp desc (I.Sub t) x y
internaliseBinOp desc E.Minus x y (E.Unsigned t) _ =
simpleBinOp desc (I.Sub t) x y
internaliseBinOp desc E.Minus x y (E.FloatType t) _ =
simpleBinOp desc (I.FSub t) x y
internaliseBinOp desc E.Times x y (E.Signed t) _ =
simpleBinOp desc (I.Mul t) x y
internaliseBinOp desc E.Times x y (E.Unsigned t) _ =
simpleBinOp desc (I.Mul t) x y
internaliseBinOp desc E.Times x y (E.FloatType t) _ =
simpleBinOp desc (I.FMul t) x y
internaliseBinOp desc E.Divide x y (E.Signed t) _ =
simpleBinOp desc (I.SDiv t) x y
internaliseBinOp desc E.Divide x y (E.Unsigned t) _ =
simpleBinOp desc (I.UDiv t) x y
internaliseBinOp desc E.Divide x y (E.FloatType t) _ =
simpleBinOp desc (I.FDiv t) x y
internaliseBinOp desc E.Pow x y (E.FloatType t) _ =
simpleBinOp desc (I.FPow t) x y
internaliseBinOp desc E.Pow x y (E.Signed t) _ =
simpleBinOp desc (I.Pow t) x y
internaliseBinOp desc E.Pow x y (E.Unsigned t) _ =
simpleBinOp desc (I.Pow t) x y
internaliseBinOp desc E.Mod x y (E.Signed t) _ =
simpleBinOp desc (I.SMod t) x y
internaliseBinOp desc E.Mod x y (E.Unsigned t) _ =
simpleBinOp desc (I.UMod t) x y
internaliseBinOp desc E.Quot x y (E.Signed t) _ =
simpleBinOp desc (I.SQuot t) x y
internaliseBinOp desc E.Quot x y (E.Unsigned t) _ =
simpleBinOp desc (I.UDiv t) x y
internaliseBinOp desc E.Rem x y (E.Signed t) _ =
simpleBinOp desc (I.SRem t) x y
internaliseBinOp desc E.Rem x y (E.Unsigned t) _ =
simpleBinOp desc (I.UMod t) x y
internaliseBinOp desc E.ShiftR x y (E.Signed t) _ =
simpleBinOp desc (I.AShr t) x y
internaliseBinOp desc E.ShiftR x y (E.Unsigned t) _ =
simpleBinOp desc (I.LShr t) x y
internaliseBinOp desc E.ShiftL x y (E.Signed t) _ =
simpleBinOp desc (I.Shl t) x y
internaliseBinOp desc E.ShiftL x y (E.Unsigned t) _ =
simpleBinOp desc (I.Shl t) x y
internaliseBinOp desc E.Band x y (E.Signed t) _ =
simpleBinOp desc (I.And t) x y
internaliseBinOp desc E.Band x y (E.Unsigned t) _ =
simpleBinOp desc (I.And t) x y
internaliseBinOp desc E.Xor x y (E.Signed t) _ =
simpleBinOp desc (I.Xor t) x y
internaliseBinOp desc E.Xor x y (E.Unsigned t) _ =
simpleBinOp desc (I.Xor t) x y
internaliseBinOp desc E.Bor x y (E.Signed t) _ =
simpleBinOp desc (I.Or t) x y
internaliseBinOp desc E.Bor x y (E.Unsigned t) _ =
simpleBinOp desc (I.Or t) x y
internaliseBinOp desc E.Equal x y t _ =
simpleCmpOp desc (I.CmpEq $ internalisePrimType t) x y
internaliseBinOp desc E.NotEqual x y t _ = do
eq <- letSubExp (desc++"true") $ I.BasicOp $ I.CmpOp (I.CmpEq $ internalisePrimType t) x y
fmap pure $ letSubExp desc $ I.BasicOp $ I.UnOp I.Not eq
internaliseBinOp desc E.Less x y (E.Signed t) _ =
simpleCmpOp desc (I.CmpSlt t) x y
internaliseBinOp desc E.Less x y (E.Unsigned t) _ =
simpleCmpOp desc (I.CmpUlt t) x y
internaliseBinOp desc E.Leq x y (E.Signed t) _ =
simpleCmpOp desc (I.CmpSle t) x y
internaliseBinOp desc E.Leq x y (E.Unsigned t) _ =
simpleCmpOp desc (I.CmpUle t) x y
internaliseBinOp desc E.Greater x y (E.Signed t) _ =
simpleCmpOp desc (I.CmpSlt t) y x
internaliseBinOp desc E.Greater x y (E.Unsigned t) _ =
simpleCmpOp desc (I.CmpUlt t) y x
internaliseBinOp desc E.Geq x y (E.Signed t) _ =
simpleCmpOp desc (I.CmpSle t) y x
internaliseBinOp desc E.Geq x y (E.Unsigned t) _ =
simpleCmpOp desc (I.CmpUle t) y x
internaliseBinOp desc E.Less x y (E.FloatType t) _ =
simpleCmpOp desc (I.FCmpLt t) x y
internaliseBinOp desc E.Leq x y (E.FloatType t) _ =
simpleCmpOp desc (I.FCmpLe t) x y
internaliseBinOp desc E.Greater x y (E.FloatType t) _ =
simpleCmpOp desc (I.FCmpLt t) y x
internaliseBinOp desc E.Geq x y (E.FloatType t) _ =
simpleCmpOp desc (I.FCmpLe t) y x
internaliseBinOp desc E.Less x y E.Bool _ =
simpleCmpOp desc I.CmpLlt x y
internaliseBinOp desc E.Leq x y E.Bool _ =
simpleCmpOp desc I.CmpLle x y
internaliseBinOp desc E.Greater x y E.Bool _ =
simpleCmpOp desc I.CmpLlt y x
internaliseBinOp desc E.Geq x y E.Bool _ =
simpleCmpOp desc I.CmpLle y x
internaliseBinOp _ op _ _ t1 t2 =
fail $ "Invalid binary operator " ++ pretty op ++
" with operand types " ++ pretty t1 ++ ", " ++ pretty t2
simpleBinOp :: String
-> I.BinOp
-> I.SubExp -> I.SubExp
-> InternaliseM [I.SubExp]
simpleBinOp desc bop x y =
letTupExp' desc $ I.BasicOp $ I.BinOp bop x y
simpleCmpOp :: String
-> I.CmpOp
-> I.SubExp -> I.SubExp
-> InternaliseM [I.SubExp]
simpleCmpOp desc op x y =
letTupExp' desc $ I.BasicOp $ I.CmpOp op x y
findFuncall :: E.Exp -> InternaliseM (E.QualName VName, [E.Exp], [E.StructType])
findFuncall (E.Var fname (Info t) _) =
let (remaining, _) = unfoldFunType t
in return (fname, [], map E.toStruct remaining)
findFuncall (E.Apply f arg _ (Info t) _) = do
let (remaining, _) = unfoldFunType t
(fname, args, _) <- findFuncall f
return (fname, args ++ [arg], map E.toStruct remaining)
findFuncall e =
fail $ "Invalid function expression in application: " ++ pretty e
internaliseLambda :: InternaliseLambda
internaliseLambda (E.Parens e _) rowtypes =
internaliseLambda e rowtypes
internaliseLambda (E.Lambda tparams params body _ (Info (_, rettype)) loc) rowtypes =
bindingLambdaParams tparams params rowtypes $ \pcm params' -> do
(rettype', rcm) <- internaliseReturnType rettype
body' <- internaliseBody body
mapM_ (uncurry (internaliseDimConstant loc)) $ pcm<>rcm
return (params', body', map I.fromDecl rettype')
internaliseLambda E.OpSection{} _ = fail "internaliseLambda: unexpected OpSection"
internaliseLambda E.OpSectionLeft{} _ = fail "internaliseLambda: unexpected OpSectionLeft"
internaliseLambda E.OpSectionRight{} _ = fail "internaliseLambda: unexpected OpSectionRight"
internaliseLambda e rowtypes = do
(_, _, remaining_params_ts) <- findFuncall e
(params, param_args) <- fmap unzip $ forM remaining_params_ts $ \et -> do
name <- newVName "not_curried"
return (E.Id name (Info $ E.vacuousShapeAnnotations $ et `setAliases` mempty) loc,
E.Var (E.qualName name)
(Info (et `setAliases` mempty)) loc)
let rettype = E.typeOf e
body = foldl (\f arg -> E.Apply f arg (Info E.Observe)
(Info $ E.vacuousShapeAnnotations rettype) loc)
e
param_args
rettype' = E.vacuousShapeAnnotations $ rettype `E.setAliases` ()
internaliseLambda (E.Lambda [] params body Nothing (Info (mempty, rettype')) loc) rowtypes
where loc = srclocOf e
internaliseDimConstant :: SrcLoc -> Name -> VName -> InternaliseM ()
internaliseDimConstant loc fname name =
letBind_ (basicPattern [] [I.Ident name $ I.Prim I.int32]) $
I.Apply fname [] [I.Prim I.int32] (Safe, loc, mempty)
isOverloadedFunction :: E.QualName VName -> [E.Exp] -> SrcLoc
-> Maybe (String -> InternaliseM [SubExp])
isOverloadedFunction qname args loc = do
guard $ baseTag (qualLeaf qname) <= maxIntrinsicTag
handle args $ baseString $ qualLeaf qname
where
handle [x] "sign_i8" = Just $ toSigned I.Int8 x
handle [x] "sign_i16" = Just $ toSigned I.Int16 x
handle [x] "sign_i32" = Just $ toSigned I.Int32 x
handle [x] "sign_i64" = Just $ toSigned I.Int64 x
handle [x] "unsign_i8" = Just $ toUnsigned I.Int8 x
handle [x] "unsign_i16" = Just $ toUnsigned I.Int16 x
handle [x] "unsign_i32" = Just $ toUnsigned I.Int32 x
handle [x] "unsign_i64" = Just $ toUnsigned I.Int64 x
handle [x] "sgn" = Just $ signumF x
handle [x] "abs" = Just $ absF x
handle [x] "!" = Just $ notF x
handle [x] "~" = Just $ complementF x
handle [x] "opaque" = Just $ \desc ->
mapM (letSubExp desc . BasicOp . Opaque) =<< internaliseExp "opaque_arg" x
handle [x] s
| Just unop <- find ((==s) . pretty) allUnOps = Just $ \desc -> do
x' <- internaliseExp1 "x" x
fmap pure $ letSubExp desc $ I.BasicOp $ I.UnOp unop x'
handle [x,y] s
| Just bop <- find ((==s) . pretty) allBinOps = Just $ \desc -> do
x' <- internaliseExp1 "x" x
y' <- internaliseExp1 "y" y
fmap pure $ letSubExp desc $ I.BasicOp $ I.BinOp bop x' y'
| Just cmp <- find ((==s) . pretty) allCmpOps = Just $ \desc -> do
x' <- internaliseExp1 "x" x
y' <- internaliseExp1 "y" y
fmap pure $ letSubExp desc $ I.BasicOp $ I.CmpOp cmp x' y'
handle [x] s
| Just conv <- find ((==s) . pretty) allConvOps = Just $ \desc -> do
x' <- internaliseExp1 "x" x
fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x'
handle [x,y] "&&" = Just $ \desc ->
internaliseExp desc $
E.If x y (E.Literal (E.BoolValue False) noLoc) (Info (E.Prim E.Bool)) noLoc
handle [x,y] "||" = Just $ \desc ->
internaliseExp desc $
E.If x (E.Literal (E.BoolValue True) noLoc) y (Info (E.Prim E.Bool)) noLoc
handle [xe,ye] op
| Just cmp_f <- isEqlOp op = Just $ \desc -> do
xe' <- internaliseExp "x" xe
ye' <- internaliseExp "y" ye
rs <- zipWithM (doComparison desc) xe' ye'
cmp_f desc =<< letSubExp "eq" =<< foldBinOp I.LogAnd (constant True) rs
where isEqlOp "!=" = Just $ \desc eq ->
letTupExp' desc $ I.BasicOp $ I.UnOp I.Not eq
isEqlOp "==" = Just $ \_ eq ->
return [eq]
isEqlOp _ = Nothing
doComparison desc x y = do
x_t <- I.subExpType x
y_t <- I.subExpType y
case x_t of
I.Prim t -> letSubExp desc $ I.BasicOp $ I.CmpOp (I.CmpEq t) x y
_ -> do
let x_dims = I.arrayDims x_t
y_dims = I.arrayDims y_t
dims_match <- forM (zip x_dims y_dims) $ \(x_dim, y_dim) ->
letSubExp "dim_eq" $ I.BasicOp $ I.CmpOp (I.CmpEq int32) x_dim y_dim
shapes_match <- letSubExp "shapes_match" =<<
foldBinOp I.LogAnd (constant True) dims_match
compare_elems_body <- runBodyBinder $ do
x_num_elems <- letSubExp "x_num_elems" =<<
foldBinOp (I.Mul Int32) (constant (1::Int32)) x_dims
x' <- letExp "x" $ I.BasicOp $ I.SubExp x
y' <- letExp "x" $ I.BasicOp $ I.SubExp y
x_flat <- letExp "x_flat" $ I.BasicOp $ I.Reshape [I.DimNew x_num_elems] x'
y_flat <- letExp "y_flat" $ I.BasicOp $ I.Reshape [I.DimNew x_num_elems] y'
cmp_lam <- cmpOpLambda (I.CmpEq (elemType x_t)) (elemType x_t)
cmps <- letExp "cmps" $ I.Op $
I.Screma x_num_elems (I.mapSOAC cmp_lam) [x_flat, y_flat]
and_lam <- binOpLambda I.LogAnd I.Bool
reduce <- I.reduceSOAC Commutative and_lam [constant True]
all_equal <- letSubExp "all_equal" $ I.Op $ I.Screma x_num_elems reduce [cmps]
return $ resultBody [all_equal]
letSubExp "arrays_equal" $
I.If shapes_match compare_elems_body (resultBody [constant False]) $
ifCommon [I.Prim I.Bool]
handle [x,y] name
| Just bop <- find ((name==) . pretty) [minBound..maxBound::E.BinOp] =
Just $ \desc -> do
x' <- internaliseExp1 "x" x
y' <- internaliseExp1 "y" y
case (E.typeOf x, E.typeOf y) of
(E.Prim t1, E.Prim t2) ->
internaliseBinOp desc bop x' y' t1 t2
_ -> fail "Futhark.Internalise.internaliseExp: non-primitive type in BinOp."
handle [E.TupLit [a, si, v] _] "scatter" = Just $ scatterF a si v
handle [E.TupLit [e, E.ArrayLit vs _ _] _] "cmp_threshold" = do
s <- mapM isCharLit vs
Just $ \desc -> do
x <- internaliseExp1 "threshold_x" e
pure <$> letSubExp desc (Op $ CmpThreshold x s)
where isCharLit (Literal (SignedValue iv) _) = Just $ chr $ fromIntegral $ intToInt64 iv
isCharLit _ = Nothing
handle [E.TupLit [n, m, arr] _] "unflatten" = Just $ \desc -> do
arrs <- internaliseExpToVars "unflatten_arr" arr
n' <- internaliseExp1 "n" n
m' <- internaliseExp1 "m" m
old_dim <- I.arraysSize 0 <$> mapM lookupType arrs
dim_ok <- assertingOne $ letExp "dim_ok" =<<
eAssert (eCmpOp (I.CmpEq I.int32)
(eBinOp (I.Mul Int32) (eSubExp n') (eSubExp m'))
(eSubExp old_dim))
"new shape has different number of elements than old shape" loc
certifying dim_ok $ forM arrs $ \arr' -> do
arr_t <- lookupType arr'
letSubExp desc $ I.BasicOp $
I.Reshape (reshapeOuter [DimNew n', DimNew m'] 1 $ arrayShape arr_t) arr'
handle [arr] "flatten" = Just $ \desc -> do
arrs <- internaliseExpToVars "flatten_arr" arr
forM arrs $ \arr' -> do
arr_t <- lookupType arr'
let n = arraySize 0 arr_t
m = arraySize 1 arr_t
k <- letSubExp "flat_dim" $ I.BasicOp $ I.BinOp (Mul Int32) n m
letSubExp desc $ I.BasicOp $
I.Reshape (reshapeOuter [DimNew k] 2 $ arrayShape arr_t) arr'
handle [TupLit [x, y] _] "concat" = Just $ \desc -> do
xs <- internaliseExpToVars "concat_x" x
ys <- internaliseExpToVars "concat_y" y
outer_size <- arraysSize 0 <$> mapM lookupType xs
let sumdims xsize ysize = letSubExp "conc_tmp" $ I.BasicOp $
I.BinOp (I.Add I.Int32) xsize ysize
ressize <- foldM sumdims outer_size =<<
mapM (fmap (arraysSize 0) . mapM lookupType) [ys]
let conc xarr yarr = do
xt <- lookupType xarr
yt <- lookupType yarr
let matches n m =
letSubExp "match" $
I.BasicOp $ I.CmpOp (I.CmpEq I.int32) n m
emptyRow arr_t =
letSubExp "empty_row" =<<
foldBinOp I.LogOr (constant False) =<<
mapM (matches (intConst Int32 0)) (arrayDims $ rowType arr_t)
all_match <- letSubExp "all_match" =<<
foldBinOp I.LogAnd (constant True) =<<
zipWithM matches
(arrayDims (rowType xt)) (arrayDims (rowType yt))
xarr_empty <- emptyRow xt
yarr_empty <- emptyRow yt
either_empty <- letSubExp "either_empty" $
I.BasicOp $ I.BinOp I.LogOr xarr_empty yarr_empty
matchcs <- assertingOne $ letExp "concat_ok" =<<
eAssert (pure $ I.BasicOp $ I.BinOp I.LogOr either_empty all_match)
"row sizes do not match when concatenating" loc
let updims (j, xd, yd)
| j == 0 =
return (xd, yd)
| otherwise = do
d <- letSubExp "dim" $ I.BasicOp $ I.BinOp (SMax Int32) xd yd
return (d, d)
(xdims, ydims) <- unzip <$>
mapM updims (zip3 [(0::Int)..] (I.arrayDims xt) (I.arrayDims yt))
xarr' <- certifying matchcs $ letExp "concat_x_reshaped" $
shapeCoerce xdims xarr
yarr' <- certifying matchcs $ letExp "concat_y_reshaped" $
shapeCoerce ydims yarr
return $ I.BasicOp $ I.Concat 0 xarr' [yarr'] ressize
letSubExps desc =<< zipWithM conc xs ys
handle [TupLit [offset, e] _] "rotate" = Just $ \desc -> do
offset' <- internaliseExp1 "rotation_offset" offset
internaliseOperation desc e $ \v -> do
r <- I.arrayRank <$> lookupType v
let zero = intConst Int32 0
offsets = offset' : replicate (r-1) zero
return $ I.Rotate offsets v
handle [e] "transpose" = Just $ \desc ->
internaliseOperation desc e $ \v -> do
r <- I.arrayRank <$> lookupType v
return $ I.Rearrange ([1,0] ++ [2..r-1]) v
handle [TupLit [x, y] _] "zip" = Just $ \desc ->
(++) <$> internaliseExp (desc ++ "_zip_x") x
<*> internaliseExp (desc ++ "_zip_y") y
handle [x] "unzip" = Just $ flip internaliseExp x
handle [x] "trace" = Just $ flip internaliseExp x
handle [x] "break" = Just $ flip internaliseExp x
handle _ _ = Nothing
toSigned int_to e desc = do
e' <- internaliseExp1 "trunc_arg" e
case E.typeOf e of
E.Prim E.Bool ->
letTupExp' desc $ I.If e' (resultBody [intConst int_to 1])
(resultBody [intConst int_to 0]) $
ifCommon [I.Prim $ I.IntType int_to]
E.Prim (E.Signed int_from) ->
letTupExp' desc $ I.BasicOp $ I.ConvOp (I.SExt int_from int_to) e'
E.Prim (E.Unsigned int_from) ->
letTupExp' desc $ I.BasicOp $ I.ConvOp (I.ZExt int_from int_to) e'
E.Prim (E.FloatType float_from) ->
letTupExp' desc $ I.BasicOp $ I.ConvOp (I.FPToSI float_from int_to) e'
_ -> fail "Futhark.Internalise.handle: non-numeric type in ToSigned"
toUnsigned int_to e desc = do
e' <- internaliseExp1 "trunc_arg" e
case E.typeOf e of
E.Prim E.Bool ->
letTupExp' desc $ I.If e' (resultBody [intConst int_to 1])
(resultBody [intConst int_to 0]) $
ifCommon [I.Prim $ I.IntType int_to]
E.Prim (E.Signed int_from) ->
letTupExp' desc $ I.BasicOp $ I.ConvOp (I.ZExt int_from int_to) e'
E.Prim (E.Unsigned int_from) ->
letTupExp' desc $ I.BasicOp $ I.ConvOp (I.ZExt int_from int_to) e'
E.Prim (E.FloatType float_from) ->
letTupExp' desc $ I.BasicOp $ I.ConvOp (I.FPToUI float_from int_to) e'
_ -> fail "Futhark.Internalise.internaliseExp: non-numeric type in ToUnsigned"
signumF e desc = do
e' <- internaliseExp1 "signum_arg" e
case E.typeOf e of
E.Prim (E.Signed t) ->
letTupExp' desc $ I.BasicOp $ I.UnOp (I.SSignum t) e'
E.Prim (E.Unsigned t) ->
letTupExp' desc $ I.BasicOp $ I.UnOp (I.USignum t) e'
_ -> fail "Futhark.Internalise.internaliseExp: non-integer type in Signum"
absF e desc = do
e' <- internaliseExp1 "abs_arg" e
case E.typeOf e of
E.Prim (E.Signed t) ->
letTupExp' desc $ I.BasicOp $ I.UnOp (I.Abs t) e'
E.Prim (E.Unsigned _) ->
return [e']
E.Prim (E.FloatType t) ->
letTupExp' desc $ I.BasicOp $ I.UnOp (I.FAbs t) e'
_ -> fail "Futhark.Internalise.internaliseExp: non-integer type in Abs"
notF e desc = do
e' <- internaliseExp1 "not_arg" e
letTupExp' desc $ I.BasicOp $ I.UnOp I.Not e'
complementF e desc = do
e' <- internaliseExp1 "complement_arg" e
et <- subExpType e'
case et of I.Prim (I.IntType t) ->
letTupExp' desc $ I.BasicOp $ I.UnOp (I.Complement t) e'
_ ->
fail "Futhark.Internalise.internaliseExp: non-integer type in Complement"
scatterF a si v desc = do
si' <- letExp "write_si" . BasicOp . SubExp =<< internaliseExp1 "write_arg_i" si
svs <- internaliseExpToVars "write_arg_v" v
sas <- internaliseExpToVars "write_arg_a" a
si_shape <- I.arrayShape <$> lookupType si'
let si_w = shapeSize 0 si_shape
sv_ts <- mapM lookupType svs
svs' <- forM (zip svs sv_ts) $ \(sv,sv_t) -> do
let sv_shape = I.arrayShape sv_t
sv_w = arraySize 0 sv_t
cmp <- letSubExp "write_cmp" $ I.BasicOp $
I.CmpOp (I.CmpEq I.int32) si_w sv_w
c <- assertingOne $
letExp "write_cert" $ I.BasicOp $
I.Assert cmp "length of index and value array does not match" (loc, mempty)
certifying c $ letExp (baseString sv ++ "_write_sv") $
I.BasicOp $ I.Reshape (reshapeOuter [DimCoercion si_w] 1 sv_shape) sv
indexType <- rowType <$> lookupType si'
indexName <- newVName "write_index"
valueNames <- replicateM (length sv_ts) $ newVName "write_value"
sa_ts <- mapM lookupType sas
let bodyTypes = replicate (length sv_ts) indexType ++ map rowType sa_ts
paramTypes = indexType : map rowType sv_ts
bodyNames = indexName : valueNames
bodyParams = zipWith I.Param bodyNames paramTypes
body <- localScope (scopeOfLParams bodyParams) $ insertStmsM $ do
let outs = replicate (length valueNames) indexName ++ valueNames
results <- forM outs $ \name ->
letSubExp "write_res" $ I.BasicOp $ I.SubExp $ I.Var name
ensureResultShape asserting "scatter value has wrong size" loc
bodyTypes $ resultBody results
let lam = I.Lambda { I.lambdaParams = bodyParams
, I.lambdaReturnType = bodyTypes
, I.lambdaBody = body
}
sivs = si' : svs'
let sa_ws = map (arraySize 0) sa_ts
letTupExp' desc $ I.Op $ I.Scatter si_w lam sivs $ zip3 sa_ws (repeat 1) sas
lookupConstant :: SrcLoc -> VName -> InternaliseM (Maybe [SubExp])
lookupConstant loc name = do
is_const <- lookupFunction' name
scope <- askScope
case is_const of
Just (fname, constparams, _, _, _, _, mk_rettype)
| name `M.notMember` scope -> do
(constargs, const_ds, const_ts) <- unzip3 <$> constFunctionArgs loc constparams
safety <- askSafety
case mk_rettype $ zip constargs $ map I.fromDecl const_ts of
Nothing -> fail $ "lookupConstant: " ++
unwords (pretty name : zipWith (curry pretty) constargs const_ts) ++
" failed"
Just rettype ->
fmap (Just . map I.Var) $ letTupExp (baseString name) $
I.Apply fname (zip constargs const_ds) rettype (safety, loc, mempty)
_ -> return Nothing
constFunctionArgs :: SrcLoc -> ConstParams -> InternaliseM [(SubExp, I.Diet, I.DeclType)]
constFunctionArgs loc = mapM arg
where arg (fname, name) = do
safety <- askSafety
se <- letSubExp (baseString name ++ "_arg") $
I.Apply fname [] [I.Prim I.int32] (safety, loc, [])
return (se, I.Observe, I.Prim I.int32)
funcall :: String -> QualName VName -> [SubExp] -> SrcLoc
-> InternaliseM ([SubExp], [I.ExtType])
funcall desc (QualName _ fname) args loc = do
(fname', constparams, closure, shapes, value_paramts, fun_params, rettype_fun) <-
lookupFunction fname
(constargs, const_ds, _) <- unzip3 <$> constFunctionArgs loc constparams
argts <- mapM subExpType args
closure_ts <- mapM lookupType closure
let shapeargs = argShapes shapes value_paramts argts
diets = const_ds ++ replicate (length closure + length shapeargs) I.Observe ++
map I.diet value_paramts
constOrShape = const $ I.Prim int32
paramts = map constOrShape constargs ++ closure_ts ++
map constOrShape shapeargs ++ map I.fromDecl value_paramts
args' <- ensureArgShapes asserting "function arguments of wrong shape"
loc (map I.paramName fun_params)
paramts (constargs ++ map I.Var closure ++ shapeargs ++ args)
argts' <- mapM subExpType args'
case rettype_fun $ zip args' argts' of
Nothing -> fail $ "Cannot apply " ++ pretty fname ++ " to arguments\n " ++
pretty args' ++ "\nof types\n " ++
pretty argts' ++
"\nFunction has parameters\n " ++ pretty fun_params
Just ts -> do
safety <- askSafety
ses <- letTupExp' desc $ I.Apply fname' (zip args' diets) ts (safety, loc, mempty)
return (ses, map I.fromDecl ts)
askSafety :: InternaliseM Safety
askSafety = do check <- asks envDoBoundsChecks
safe <- asks envSafe
return $ if check || safe then I.Safe else I.Unsafe
partitionWithSOACS :: Int -> I.Lambda -> [I.VName] -> InternaliseM ([I.SubExp], [I.SubExp])
partitionWithSOACS k lam arrs = do
arr_ts <- mapM lookupType arrs
let w = arraysSize 0 arr_ts
classes_and_increments <- letTupExp "increments" $ I.Op $ I.Screma w (mapSOAC lam) arrs
(classes, increments) <- case classes_and_increments of
classes : increments -> return (classes, take k increments)
_ -> fail "partitionWithSOACS"
add_lam_x_params <-
replicateM k $ I.Param <$> newVName "x" <*> pure (I.Prim int32)
add_lam_y_params <-
replicateM k $ I.Param <$> newVName "y" <*> pure (I.Prim int32)
add_lam_body <- runBodyBinder $
localScope (scopeOfLParams $ add_lam_x_params++add_lam_y_params) $
fmap resultBody $ forM (zip add_lam_x_params add_lam_y_params) $ \(x,y) ->
letSubExp "z" $ I.BasicOp $ I.BinOp (I.Add Int32)
(I.Var $ I.paramName x) (I.Var $ I.paramName y)
let add_lam = I.Lambda { I.lambdaBody = add_lam_body
, I.lambdaParams = add_lam_x_params ++ add_lam_y_params
, I.lambdaReturnType = replicate k $ I.Prim int32
}
nes = replicate (length increments) $ constant (0::Int32)
scan <- I.scanSOAC add_lam nes
all_offsets <- letTupExp "offsets" $ I.Op $ I.Screma w scan increments
last_index <- letSubExp "last_index" $ I.BasicOp $ I.BinOp (I.Sub Int32) w $ constant (1::Int32)
nonempty_body <- runBodyBinder $ fmap resultBody $ forM all_offsets $ \offset_array ->
letSubExp "last_offset" $ I.BasicOp $ I.Index offset_array [I.DimFix last_index]
let empty_body = resultBody $ replicate k $ constant (0::Int32)
is_empty <- letSubExp "is_empty" $ I.BasicOp $ I.CmpOp (CmpEq int32) w $ constant (0::Int32)
sizes <- letTupExp "partition_size" $
I.If is_empty empty_body nonempty_body $
ifCommon $ replicate k $ I.Prim int32
sum_of_partition_sizes <- letSubExp "sum_of_partition_sizes" =<<
foldBinOp (Add Int32) (constant (0::Int32)) (map I.Var sizes)
blanks <- forM arr_ts $ \arr_t ->
letExp "partition_dest" $ I.BasicOp $
Scratch (elemType arr_t) (sum_of_partition_sizes : drop 1 (I.arrayDims arr_t))
write_lam <- do
c_param <- I.Param <$> newVName "c" <*> pure (I.Prim int32)
offset_params <- replicateM k $ I.Param <$> newVName "offset" <*> pure (I.Prim int32)
value_params <- forM arr_ts $ \arr_t ->
I.Param <$> newVName "v" <*> pure (I.rowType arr_t)
(offset, offset_stms) <- collectStms $ mkOffsetLambdaBody (map I.Var sizes)
(I.Var $ I.paramName c_param) 0 offset_params
return I.Lambda { I.lambdaParams = c_param : offset_params ++ value_params
, I.lambdaReturnType = replicate (length arr_ts) (I.Prim int32) ++
map I.rowType arr_ts
, I.lambdaBody = mkBody offset_stms $
replicate (length arr_ts) offset ++
map (I.Var . I.paramName) value_params
}
results <- letTupExp "partition_res" $ I.Op $ I.Scatter w
write_lam (classes : all_offsets ++ arrs) $
zip3 (repeat sum_of_partition_sizes) (repeat 1) blanks
sizes' <- letSubExp "partition_sizes" $ I.BasicOp $
I.ArrayLit (map I.Var sizes) $ I.Prim int32
return (map I.Var results, [sizes'])
where
mkOffsetLambdaBody :: [SubExp]
-> SubExp
-> Int
-> [I.LParam]
-> InternaliseM SubExp
mkOffsetLambdaBody _ _ _ [] =
return $ constant (-1::Int32)
mkOffsetLambdaBody sizes c i (p:ps) = do
is_this_one <- letSubExp "is_this_one" $ I.BasicOp $ I.CmpOp (CmpEq int32) c (constant i)
next_one <- mkOffsetLambdaBody sizes c (i+1) ps
this_one <- letSubExp "this_offset" =<<
foldBinOp (Add Int32) (constant (-1::Int32))
(I.Var (I.paramName p) : take i sizes)
letSubExp "total_res" $ I.If is_this_one
(resultBody [this_one]) (resultBody [next_one]) $ ifCommon [I.Prim int32]
typeExpForError :: ConstParams -> E.TypeExp VName -> InternaliseM [ErrorMsgPart SubExp]
typeExpForError _ (E.TEVar qn _) =
return [ErrorString $ pretty qn]
typeExpForError cm (E.TEUnique te _) = ("*":) <$> typeExpForError cm te
typeExpForError cm (E.TEArray te d _) = do
d' <- dimDeclForError cm d
te' <- typeExpForError cm te
return $ ["[", d', "]"] ++ te'
typeExpForError cm (E.TETuple tes _) = do
tes' <- mapM (typeExpForError cm) tes
return $ ["("] ++ intercalate [", "] tes' ++ [")"]
typeExpForError cm (E.TERecord fields _) = do
fields' <- mapM onField fields
return $ ["{"] ++ intercalate [", "] fields' ++ ["}"]
where onField (k, te) = (ErrorString (pretty k ++ ": "):) <$> typeExpForError cm te
typeExpForError cm (E.TEArrow _ t1 t2 _) = do
t1' <- typeExpForError cm t1
t2' <- typeExpForError cm t2
return $ t1' ++ [" -> "] ++ t2'
typeExpForError cm (E.TEApply t arg _) = do
t' <- typeExpForError cm t
arg' <- case arg of TypeArgExpType argt -> typeExpForError cm argt
TypeArgExpDim d _ -> pure <$> dimDeclForError cm d
return $ t' ++ [" "] ++ arg'
typeExpForError _ e@E.TEEnum{} =
return [ErrorString $ pretty e]
dimDeclForError :: ConstParams -> E.DimDecl VName -> InternaliseM (ErrorMsgPart SubExp)
dimDeclForError cm (NamedDim d) = do
substs <- asks $ M.lookup (E.qualLeaf d) . envSubsts
let fname = nameFromString $ pretty (E.qualLeaf d) ++ "f"
d' <- case (substs, lookup fname cm) of
(Just [v], _) -> return v
(_, Just v) -> return $ I.Var v
_ -> return $ I.Var $ E.qualLeaf d
return $ ErrorInt32 d'
dimDeclForError _ (ConstDim d) =
return $ ErrorString $ pretty d
dimDeclForError _ AnyDim = return ""