{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleContexts, LambdaCase, TypeSynonymInstances, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.CodeGen.ImpGen
(
compileProg
, OpCompiler
, ExpCompiler
, CopyCompiler
, StmsCompiler
, Operations (..)
, defaultOperations
, ValueDestination
, arrayDestination
, MemLocation (..)
, MemEntry (..)
, ScalarEntry (..)
, ImpM
, Env (envDefaultSpace, envFunction)
, VTable
, getVTable
, localVTable
, subImpM
, subImpM_
, emit
, emitFunction
, hasFunction
, collect
, comment
, VarEntry (..)
, ArrayEntry (..)
, lookupVar
, lookupArray
, lookupMemory
, compileSubExp
, compileSubExpOfType
, compileSubExpTo
, compilePrimExp
, compileAlloc
, subExpToDimSize
, everythingVolatile
, compileBody
, compileBody'
, compileLoopBody
, defCompileStms
, compileStms
, compileExp
, defCompileExp
, offsetArray
, strideArray
, fullyIndexArray
, fullyIndexArray'
, varIndex
, Imp.dimSizeToExp
, dimSizeToSubExp
, copy
, copyDWIM
, copyDWIMDest
, copyElementWise
, dLParams
, dFParams
, dScope
, dScopes
, dArray
, dPrim, dPrim_, dPrimV
, sFor, sWhile
, sComment
, sIf, sWhen, sUnless
, sOp
, sAlloc
, sArray, sAllocArray, sStaticArray
, sWrite
, (<--)
)
where
import Control.Monad.RWS hiding (mapM, forM)
import Control.Monad.State hiding (mapM, forM, State)
import Control.Monad.Writer hiding (mapM, forM)
import Control.Monad.Except hiding (mapM, forM)
import qualified Control.Monad.Fail as Fail
import Data.Either
import Data.Traversable
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpCode
(Count (..),
Bytes, Elements,
bytes, withElemType)
import Futhark.Representation.ExplicitMemory
import Futhark.Representation.SOACS (SOACS)
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Construct (fullSliceNum)
import Futhark.MonadFreshNames
import Futhark.Error
import Futhark.Util
type OpCompiler lore op = Pattern lore -> Op lore -> ImpM lore op ()
type StmsCompiler lore op = Names -> [Stm lore] -> ImpM lore op () -> ImpM lore op ()
type ExpCompiler lore op = Pattern lore -> Exp lore -> ImpM lore op ()
type CopyCompiler lore op = PrimType
-> MemLocation
-> MemLocation
-> Count Elements
-> ImpM lore op ()
data Operations lore op = Operations { opsExpCompiler :: ExpCompiler lore op
, opsOpCompiler :: OpCompiler lore op
, opsStmsCompiler :: StmsCompiler lore op
, opsCopyCompiler :: CopyCompiler lore op
}
defaultOperations :: (ExplicitMemorish lore, FreeIn op) =>
OpCompiler lore op -> Operations lore op
defaultOperations opc = Operations { opsExpCompiler = defCompileExp
, opsOpCompiler = opc
, opsStmsCompiler = defCompileStms
, opsCopyCompiler = defaultCopy
}
data MemLocation = MemLocation { memLocationName :: VName
, memLocationShape :: [Imp.DimSize]
, memLocationIxFun :: IxFun.IxFun Imp.Exp
}
deriving (Eq, Show)
data ArrayEntry = ArrayEntry {
entryArrayLocation :: MemLocation
, entryArrayElemType :: PrimType
}
deriving (Show)
entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape = memLocationShape . entryArrayLocation
data MemEntry = MemEntry {
entryMemSize :: Imp.MemSize
, entryMemSpace :: Imp.Space
}
deriving (Show)
newtype ScalarEntry = ScalarEntry {
entryScalarType :: PrimType
}
deriving (Show)
data VarEntry lore = ArrayVar (Maybe (Exp lore)) ArrayEntry
| ScalarVar (Maybe (Exp lore)) ScalarEntry
| MemVar (Maybe (Exp lore)) MemEntry
deriving (Show)
data Destination = Destination { destinationTag :: Maybe Int
, valueDestinations :: [ValueDestination] }
deriving (Show)
data ValueDestination = ScalarDestination VName
| MemoryDestination VName
| ArrayDestination (Maybe MemLocation)
deriving (Show)
arrayDestination :: MemLocation -> ValueDestination
arrayDestination = ArrayDestination . Just
data Env lore op = Env {
envExpCompiler :: ExpCompiler lore op
, envStmsCompiler :: StmsCompiler lore op
, envOpCompiler :: OpCompiler lore op
, envCopyCompiler :: CopyCompiler lore op
, envDefaultSpace :: Imp.Space
, envVolatility :: Imp.Volatility
, envFakeMemory :: [Space]
, envFunction :: Name
}
newEnv :: Operations lore op -> Imp.Space -> [Imp.Space] -> Name -> Env lore op
newEnv ops ds fake fname =
Env { envExpCompiler = opsExpCompiler ops
, envStmsCompiler = opsStmsCompiler ops
, envOpCompiler = opsOpCompiler ops
, envCopyCompiler = opsCopyCompiler ops
, envDefaultSpace = ds
, envVolatility = Imp.Nonvolatile
, envFakeMemory = fake
, envFunction = fname
}
type VTable lore = M.Map VName (VarEntry lore)
data State lore op = State { stateVTable :: VTable lore
, stateFunctions :: Imp.Functions op
, stateNameSource :: VNameSource
}
newState :: VNameSource -> State lore op
newState = State mempty mempty
newtype ImpM lore op a = ImpM (RWST (Env lore op) (Imp.Code op) (State lore op) (Either InternalError) a)
deriving (Functor, Applicative, Monad,
MonadState (State lore op),
MonadReader (Env lore op),
MonadWriter (Imp.Code op),
MonadError InternalError)
instance Fail.MonadFail (ImpM lore op) where
fail = error . ("ImpM.fail: "++)
instance MonadFreshNames (ImpM lore op) where
getNameSource = gets stateNameSource
putNameSource src = modify $ \s -> s { stateNameSource = src }
instance HasScope SOACS (ImpM lore op) where
askScope = M.map (LetInfo . entryType) <$> gets stateVTable
where entryType (MemVar _ memEntry) =
Mem (dimSizeToSubExp $ entryMemSize memEntry) (entryMemSpace memEntry)
entryType (ArrayVar _ arrayEntry) =
Array
(entryArrayElemType arrayEntry)
(Shape $ map dimSizeToSubExp $ entryArrayShape arrayEntry)
NoUniqueness
entryType (ScalarVar _ scalarEntry) =
Prim $ entryScalarType scalarEntry
runImpM :: ImpM lore op a
-> Operations lore op -> Imp.Space -> [Imp.Space] -> Name -> State lore op
-> Either InternalError (a, State lore op, Imp.Code op)
runImpM (ImpM m) comp space fake fname =
runRWST m (newEnv comp space fake fname)
subImpM_ :: Operations lore' op' -> ImpM lore' op' a
-> ImpM lore op (Imp.Code op')
subImpM_ ops m = snd <$> subImpM ops m
subImpM :: Operations lore' op' -> ImpM lore' op' a
-> ImpM lore op (a, Imp.Code op')
subImpM ops (ImpM m) = do
env <- ask
s <- get
case runRWST m env { envExpCompiler = opsExpCompiler ops
, envStmsCompiler = opsStmsCompiler ops
, envCopyCompiler = opsCopyCompiler ops
, envOpCompiler = opsOpCompiler ops
}
s { stateVTable = M.map scrubExps $ stateVTable s
, stateFunctions = mempty } of
Left err -> throwError err
Right (x, s', code) -> do
putNameSource $ stateNameSource s'
return (x, code)
where scrubExps (ArrayVar _ entry) = ArrayVar Nothing entry
scrubExps (MemVar _ entry) = MemVar Nothing entry
scrubExps (ScalarVar _ entry) = ScalarVar Nothing entry
collect :: ImpM lore op () -> ImpM lore op (Imp.Code op)
collect m = pass $ do
((), code) <- listen m
return (code, const mempty)
collect' :: ImpM lore op a -> ImpM lore op (a, Imp.Code op)
collect' m = pass $ do
(x, code) <- listen m
return ((x, code), const mempty)
comment :: String -> ImpM lore op () -> ImpM lore op ()
comment desc m = do code <- collect m
emit $ Imp.Comment desc code
emit :: Imp.Code op -> ImpM lore op ()
emit = tell
emitFunction :: Name -> Imp.Function op -> ImpM lore op ()
emitFunction fname fun = do
Imp.Functions fs <- gets stateFunctions
modify $ \s -> s { stateFunctions = Imp.Functions $ (fname,fun) : fs }
hasFunction :: Name -> ImpM lore op Bool
hasFunction fname = gets $ \s -> let Imp.Functions fs = stateFunctions s
in isJust $ lookup fname fs
compileProg :: (ExplicitMemorish lore, MonadFreshNames m) =>
Operations lore op -> Imp.Space -> [Imp.Space]
-> Prog lore -> m (Either InternalError (Imp.Functions op))
compileProg ops space fake prog =
modifyNameSource $ \src ->
case foldM compileFunDef' (newState src) (progFunctions prog) of
Left err -> (Left err, src)
Right s -> (Right $ stateFunctions s, stateNameSource s)
where compileFunDef' s fdef = do
((), s', _) <-
runImpM (compileFunDef fdef) ops space fake (funDefName fdef) s
return s'
compileInParam :: ExplicitMemorish lore =>
FParam lore -> ImpM lore op (Either Imp.Param ArrayDecl)
compileInParam fparam = case paramAttr fparam of
MemPrim bt ->
return $ Left $ Imp.ScalarParam name bt
MemMem _ space ->
return $ Left $ Imp.MemParam name space
MemArray bt shape _ (ArrayIn mem ixfun) -> do
shape' <- mapM subExpToDimSize $ shapeDims shape
return $ Right $ ArrayDecl name bt $
MemLocation mem shape' $ fmap compilePrimExp ixfun
where name = paramName fparam
data ArrayDecl = ArrayDecl VName PrimType MemLocation
fparamSizes :: Typed attr => Param attr -> S.Set VName
fparamSizes fparam
| Mem (Var size) _ <- paramType fparam = S.singleton size
| otherwise = S.fromList $ subExpVars $ arrayDims $ paramType fparam
compileInParams :: ExplicitMemorish lore =>
[FParam lore] -> [EntryPointType]
-> ImpM lore op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams params orig_epts = do
let (ctx_params, val_params) =
splitAt (length params - sum (map entryPointSize orig_epts)) params
(inparams, arrayds) <- partitionEithers <$> mapM compileInParam (ctx_params++val_params)
let findArray x = find (isArrayDecl x) arrayds
sizes = mconcat $ map fparamSizes $ ctx_params++val_params
summaries = M.fromList $ mapMaybe memSummary params
where memSummary param
| MemMem (Constant (IntValue (Int64Value size))) space <- paramAttr param =
Just (paramName param, (Imp.ConstSize size, space))
| MemMem (Var size) space <- paramAttr param =
Just (paramName param, (Imp.VarSize size, space))
| otherwise =
Nothing
findMemInfo :: VName -> Maybe (Imp.MemSize, Space)
findMemInfo = flip M.lookup summaries
mkValueDesc fparam signedness =
case (findArray $ paramName fparam, paramType fparam) of
(Just (ArrayDecl _ bt (MemLocation mem shape _)), _) -> do
(memsize, memspace) <- findMemInfo mem
Just $ Imp.ArrayValue mem memsize memspace bt signedness shape
(_, Prim bt)
| paramName fparam `S.member` sizes ->
Nothing
| otherwise ->
Just $ Imp.ScalarValue bt signedness $ paramName fparam
_ ->
Nothing
mkExts (TypeOpaque desc n:epts) fparams =
let (fparams',rest) = splitAt n fparams
in Imp.OpaqueValue desc
(mapMaybe (`mkValueDesc` Imp.TypeDirect) fparams') :
mkExts epts rest
mkExts (TypeUnsigned:epts) (fparam:fparams) =
maybeToList (Imp.TransparentValue <$> mkValueDesc fparam Imp.TypeUnsigned) ++
mkExts epts fparams
mkExts (TypeDirect:epts) (fparam:fparams) =
maybeToList (Imp.TransparentValue <$> mkValueDesc fparam Imp.TypeDirect) ++
mkExts epts fparams
mkExts _ _ = []
return (inparams, arrayds, mkExts orig_epts val_params)
where isArrayDecl x (ArrayDecl y _ _) = x == y
compileOutParams :: ExplicitMemorish lore =>
[RetType lore] -> [EntryPointType]
-> ImpM lore op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams orig_rts orig_epts = do
((extvs, dests), (outparams,ctx_dests)) <-
runWriterT $ evalStateT (mkExts orig_epts orig_rts) (M.empty, M.empty)
let ctx_dests' = map snd $ sortOn fst $ M.toList ctx_dests
return (extvs, outparams, Destination Nothing $ ctx_dests' <> dests)
where imp = lift . lift
mkExts (TypeOpaque desc n:epts) rts = do
let (rts',rest) = splitAt n rts
(evs, dests) <- unzip <$> zipWithM mkParam rts' (repeat Imp.TypeDirect)
(more_values, more_dests) <- mkExts epts rest
return (Imp.OpaqueValue desc evs : more_values,
dests ++ more_dests)
mkExts (TypeUnsigned:epts) (rt:rts) = do
(ev,dest) <- mkParam rt Imp.TypeUnsigned
(more_values, more_dests) <- mkExts epts rts
return (Imp.TransparentValue ev : more_values,
dest : more_dests)
mkExts (TypeDirect:epts) (rt:rts) = do
(ev,dest) <- mkParam rt Imp.TypeDirect
(more_values, more_dests) <- mkExts epts rts
return (Imp.TransparentValue ev : more_values,
dest : more_dests)
mkExts _ _ = return ([], [])
mkParam MemMem{} _ =
compilerBugS "Functions may not explicitly return memory blocks."
mkParam (MemPrim t) ept = do
out <- imp $ newVName "scalar_out"
tell ([Imp.ScalarParam out t], mempty)
return (Imp.ScalarValue t ept out, ScalarDestination out)
mkParam (MemArray t shape _ attr) ept = do
space <- asks envDefaultSpace
(memout, memsize) <- case attr of
ReturnsNewBlock _ x x_size _ixfun -> do
memout <- imp $ newVName "out_mem"
sizeout <- ensureMemSizeOut x_size
tell ([Imp.MemParam memout space],
M.singleton x $ MemoryDestination memout)
return (memout, sizeout)
ReturnsInBlock memout _ -> do
memsize <- imp $ entryMemSize <$> lookupMemory memout
return (memout, memsize)
resultshape <- mapM inspectExtSize $ shapeDims shape
return (Imp.ArrayValue memout memsize space t ept resultshape,
ArrayDestination Nothing)
inspectExtSize (Ext x) = do
(memseen,arrseen) <- get
case M.lookup x arrseen of
Nothing -> do
out <- imp $ newVName "out_arrsize"
tell ([Imp.ScalarParam out int32],
M.singleton x $ ScalarDestination out)
put (memseen, M.insert x out arrseen)
return $ Imp.VarSize out
Just out ->
return $ Imp.VarSize out
inspectExtSize (Free se) =
imp $ subExpToDimSize se
ensureMemSizeOut (Ext x) = do
(memseen, arrseen) <- get
case M.lookup x memseen of
Nothing -> do sizeout <- imp $ newVName "out_memsize"
tell ([Imp.ScalarParam sizeout int64],
M.singleton x $ ScalarDestination sizeout)
put (M.insert x sizeout memseen, arrseen)
return $ Imp.VarSize sizeout
Just sizeout -> return $ Imp.VarSize sizeout
ensureMemSizeOut (Free v) = imp $ subExpToDimSize v
compileFunDef :: ExplicitMemorish lore =>
FunDef lore
-> ImpM lore op ()
compileFunDef (FunDef entry fname rettype params body) = do
((outparams, inparams, results, args), body') <- collect' compile
emitFunction fname $ Imp.Function (isJust entry) outparams inparams body' results args
where params_entry = maybe (replicate (length params) TypeDirect) fst entry
ret_entry = maybe (replicate (length rettype) TypeDirect) snd entry
compile = do
(inparams, arrayds, args) <- compileInParams params params_entry
(results, outparams, Destination _ dests) <- compileOutParams rettype ret_entry
addFParams params
addArrays arrayds
let Body _ stms ses = body
compileStms (freeIn ses) (stmsToList stms) $
forM_ (zip dests ses) $ \(d, se) -> copyDWIMDest d [] se []
return (outparams, inparams, results, args)
compileBody :: (ExplicitMemorish lore) => Pattern lore -> Body lore -> ImpM lore op ()
compileBody pat (Body _ bnds ses) = do
Destination _ dests <- destinationFromPattern pat
compileStms (freeIn ses) (stmsToList bnds) $
forM_ (zip dests ses) $ \(d, se) -> copyDWIMDest d [] se []
compileBody' :: (ExplicitMemorish lore, attr ~ LetAttr lore)
=> [Param attr] -> Body lore -> ImpM lore op ()
compileBody' = compileBody . patternFromParams
compileLoopBody :: [VName] -> Body lore -> ImpM lore op ()
compileLoopBody mergenames (Body _ bnds ses) = do
tmpnames <- mapM (newVName . (++"_tmp") . baseString) mergenames
compileStms (freeIn ses) (stmsToList bnds) $ do
copy_to_merge_params <- forM (zip3 mergenames tmpnames ses) $ \(d,tmp,se) ->
subExpType se >>= \case
Prim bt -> do
se' <- compileSubExp se
emit $ Imp.DeclareScalar tmp bt
emit $ Imp.SetScalar tmp se'
return $ emit $ Imp.SetScalar d $ Imp.var tmp bt
Mem _ space | Var v <- se -> do
emit $ Imp.DeclareMem tmp space
emit $ Imp.SetMem tmp v space
return $ emit $ Imp.SetMem d tmp space
_ -> return $ return ()
sequence_ copy_to_merge_params
compileStms :: Names -> [Stm lore] -> ImpM lore op () -> ImpM lore op ()
compileStms alive_after_stms all_stms m = do
cb <- asks envStmsCompiler
cb alive_after_stms all_stms m
defCompileStms :: (ExplicitMemorish lore, FreeIn op) =>
Names -> [Stm lore] -> ImpM lore op () -> ImpM lore op ()
defCompileStms alive_after_stms all_stms m =
void $ compileStms' mempty all_stms
where compileStms' allocs (Let pat _ e:bs) = do
dVars (Just e) (patternElements pat)
e_code <- collect $ compileExp pat e
(live_after, bs_code) <- collect' $ compileStms' (patternAllocs pat <> allocs) bs
let dies_here v = not (v `S.member` live_after) &&
v `S.member` freeIn e_code
to_free = S.filter (dies_here . fst) allocs
emit e_code
mapM_ (emit . uncurry Imp.Free) to_free
emit bs_code
return $ freeIn e_code <> live_after
compileStms' _ [] = do
code <- collect m
emit code
return $ freeIn code <> alive_after_stms
patternAllocs = S.fromList . mapMaybe isMemPatElem . patternElements
isMemPatElem pe = case patElemType pe of
Mem _ space -> Just (patElemName pe, space)
_ -> Nothing
compileExp :: Pattern lore -> Exp lore -> ImpM lore op ()
compileExp pat e = do
ec <- asks envExpCompiler
ec pat e
defCompileExp :: (ExplicitMemorish lore) =>
Pattern lore -> Exp lore -> ImpM lore op ()
defCompileExp pat (If cond tbranch fbranch _) = do
cond' <- compileSubExp cond
tcode <- collect $ compileBody pat tbranch
fcode <- collect $ compileBody pat fbranch
emit $ Imp.If cond' tcode fcode
defCompileExp pat (Apply fname args _ _) = do
dest <- destinationFromPattern pat
targets <- funcallTargets dest
args' <- catMaybes <$> mapM compileArg args
emit $ Imp.Call targets fname args'
where compileArg (se, _) = do
t <- subExpType se
case (se, t) of
(_, Prim pt) -> return $ Just $ Imp.ExpArg $ compileSubExpOfType pt se
(Var v, Mem{}) -> return $ Just $ Imp.MemArg v
_ -> return Nothing
defCompileExp pat (BasicOp op) = defCompileBasicOp pat op
defCompileExp pat (DoLoop ctx val form body) = do
dFParams mergepat
forM_ merge $ \(p, se) -> do
na <- subExpNotArray se
when na $
copyDWIM (paramName p) [] se []
let doBody = compileLoopBody mergenames body
case form of
ForLoop i it bound loopvars -> do
bound' <- compileSubExp bound
let setLoopParam (p,a)
| Prim _ <- paramType p =
copyDWIM (paramName p) [] (Var a) [varIndex i]
| otherwise =
return ()
dLParams $ map fst loopvars
sFor i it bound' $ mapM_ setLoopParam loopvars >> doBody
WhileLoop cond ->
sWhile (Imp.var cond Bool) doBody
Destination _ pat_dests <- destinationFromPattern pat
forM_ (zip pat_dests $ map (Var . paramName . fst) merge) $ \(d, r) ->
copyDWIMDest d [] r []
where merge = ctx ++ val
mergepat = map fst merge
mergenames = map paramName mergepat
defCompileExp pat (Op op) = do
opc <- asks envOpCompiler
opc pat op
defCompileBasicOp :: ExplicitMemorish lore =>
Pattern lore -> BasicOp lore -> ImpM lore op ()
defCompileBasicOp (Pattern _ [pe]) (SubExp se) =
copyDWIM (patElemName pe) [] se []
defCompileBasicOp (Pattern _ [pe]) (Opaque se) =
copyDWIM (patElemName pe) [] se []
defCompileBasicOp (Pattern _ [pe]) (UnOp op e) = do
e' <- compileSubExp e
patElemName pe <-- Imp.UnOpExp op e'
defCompileBasicOp (Pattern _ [pe]) (ConvOp conv e) = do
e' <- compileSubExp e
patElemName pe <-- Imp.ConvOpExp conv e'
defCompileBasicOp (Pattern _ [pe]) (BinOp bop x y) = do
x' <- compileSubExp x
y' <- compileSubExp y
patElemName pe <-- Imp.BinOpExp bop x' y'
defCompileBasicOp (Pattern _ [pe]) (CmpOp bop x y) = do
x' <- compileSubExp x
y' <- compileSubExp y
patElemName pe <-- Imp.CmpOpExp bop x' y'
defCompileBasicOp _ (Assert e msg loc) = do
e' <- compileSubExp e
msg' <- traverse compileSubExp msg
emit $ Imp.Assert e' msg' loc
defCompileBasicOp (Pattern _ [pe]) (Index src slice)
| Just idxs <- sliceIndices slice =
copyDWIM (patElemName pe) [] (Var src) $ map (compileSubExpOfType int32) idxs
defCompileBasicOp _ Index{} =
return ()
defCompileBasicOp (Pattern _ [pe]) (Update _ slice se) = do
MemLocation mem shape ixfun <- entryArrayLocation <$> lookupArray (patElemName pe)
let memdest = sliceArray (MemLocation mem shape ixfun) $
map (fmap (compileSubExpOfType int32)) slice
copyDWIMDest (ArrayDestination $ Just memdest) [] se []
defCompileBasicOp (Pattern _ [pe]) (Replicate (Shape ds) se) = do
ds' <- mapM compileSubExp ds
is <- replicateM (length ds) (newVName "i")
copy_elem <- collect $ copyDWIM (patElemName pe) (map varIndex is) se []
emit $ foldl (.) id (zipWith (`Imp.For` Int32) is ds') copy_elem
defCompileBasicOp _ Scratch{} =
return ()
defCompileBasicOp (Pattern [] [pe]) (Iota n e s et) = do
i <- newVName "i"
x <- newVName "x"
n' <- compileSubExp n
e' <- compileSubExp e
s' <- compileSubExp s
let i' = ConvOpExp (SExt Int32 et) $ Imp.var i $ IntType Int32
dPrim_ x $ IntType et
sFor i Int32 n' $ do
x <-- e' + i' * s'
copyDWIM (patElemName pe) [varIndex i] (Var x) []
defCompileBasicOp (Pattern _ [pe]) (Copy src) =
copyDWIM (patElemName pe) [] (Var src) []
defCompileBasicOp (Pattern _ [pe]) (Manifest _ src) =
copyDWIM (patElemName pe) [] (Var src) []
defCompileBasicOp (Pattern _ [pe]) (Concat i x ys _) = do
MemLocation destmem destshape destixfun <-
entryArrayLocation <$> lookupArray (patElemName pe)
xtype <- lookupType x
offs_glb <- dPrim "tmp_offs" int32
emit $ Imp.SetScalar offs_glb 0
let perm = [i] ++ [0..i-1] ++ [i+1..length destshape-1]
invperm = rearrangeInverse perm
destloc = MemLocation destmem destshape
(IxFun.permute (IxFun.offsetIndex (IxFun.permute destixfun perm) $
varIndex offs_glb)
invperm)
forM_ (x:ys) $ \y -> do
yentry <- lookupArray y
let srcloc = entryArrayLocation yentry
rows = case drop i $ entryArrayShape yentry of
[] -> error $ "defCompileBasicOp Concat: empty array shape for " ++ pretty y
r:_ -> innerExp $ Imp.dimSizeToExp r
copy (elemType xtype) destloc srcloc $ arrayOuterSize yentry
emit $ Imp.SetScalar offs_glb $ Imp.var offs_glb int32 + rows
defCompileBasicOp (Pattern [] [pe]) (ArrayLit es _)
| Just vs@(v:_) <- mapM isLiteral es = do
dest_mem <- entryArrayLocation <$> lookupArray (patElemName pe)
dest_space <- entryMemSpace <$> lookupMemory (memLocationName dest_mem)
let t = primValueType v
static_array <- newVName "static_array"
emit $ Imp.DeclareArray static_array dest_space t vs
let static_src = MemLocation static_array [Imp.ConstSize $ fromIntegral $ length es] $
IxFun.iota [fromIntegral $ length es]
num_bytes = Imp.ConstSize $ fromIntegral (length es) * primByteSize t
entry = MemVar Nothing $ MemEntry num_bytes dest_space
addVar static_array entry
copy t dest_mem static_src $ fromIntegral $ length es
| otherwise =
forM_ (zip [0..] es) $ \(i,e) ->
copyDWIM (patElemName pe) [constIndex i] e []
where isLiteral (Constant v) = Just v
isLiteral _ = Nothing
defCompileBasicOp _ Rearrange{} =
return ()
defCompileBasicOp _ Rotate{} =
return ()
defCompileBasicOp _ Reshape{} =
return ()
defCompileBasicOp _ Repeat{} =
return ()
defCompileBasicOp pat e =
compilerBugS $ "ImpGen.defCompileBasicOp: Invalid pattern\n " ++
pretty pat ++ "\nfor expression\n " ++ pretty e
addArrays :: [ArrayDecl] -> ImpM lore op ()
addArrays = mapM_ addArray
where addArray (ArrayDecl name bt location) =
addVar name $
ArrayVar Nothing ArrayEntry
{ entryArrayLocation = location
, entryArrayElemType = bt
}
addFParams :: ExplicitMemorish lore => [FParam lore] -> ImpM lore op ()
addFParams = mapM_ addFParam
where addFParam fparam = do
entry <- memBoundToVarEntry Nothing $ noUniquenessReturns $ paramAttr fparam
addVar (paramName fparam) entry
addLoopVar :: VName -> IntType -> ImpM lore op ()
addLoopVar i it = addVar i $ ScalarVar Nothing $ ScalarEntry $ IntType it
dVars :: ExplicitMemorish lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore op ()
dVars e = mapM_ dVar
where dVar = dScope e . scopeOfPatElem
dFParams :: ExplicitMemorish lore => [FParam lore] -> ImpM lore op ()
dFParams = dScope Nothing . scopeOfFParams
dLParams :: ExplicitMemorish lore => [LParam lore] -> ImpM lore op ()
dLParams = dScope Nothing . scopeOfLParams
dPrim_ :: VName -> PrimType -> ImpM lore op ()
dPrim_ name t = do
emit $ Imp.DeclareScalar name t
addVar name $ ScalarVar Nothing $ ScalarEntry t
dPrim :: String -> PrimType -> ImpM lore op VName
dPrim name t = do name' <- newVName name
dPrim_ name' t
return name'
dPrimV :: String -> Imp.Exp -> ImpM lore op VName
dPrimV name e = do name' <- dPrim name $ primExpType e
name' <-- e
return name'
memBoundToVarEntry :: Maybe (Exp lore) -> MemBound NoUniqueness
-> ImpM lore op (VarEntry lore)
memBoundToVarEntry e (MemPrim bt) =
return $ ScalarVar e ScalarEntry { entryScalarType = bt }
memBoundToVarEntry e (MemMem size space) = do
size' <- subExpToDimSize size
return $ MemVar e MemEntry { entryMemSize = size'
, entryMemSpace = space
}
memBoundToVarEntry e (MemArray bt shape _ (ArrayIn mem ixfun)) = do
shape' <- mapM subExpToDimSize $ shapeDims shape
let location = MemLocation mem shape' $ fmap compilePrimExp ixfun
return $ ArrayVar e ArrayEntry { entryArrayLocation = location
, entryArrayElemType = bt
}
dInfo :: Maybe (Exp lore) -> VName -> NameInfo ExplicitMemory
-> ImpM lore op ()
dInfo e name info = do
entry <- memBoundToVarEntry e $ infoAttr info
case entry of
MemVar _ entry' ->
emit $ Imp.DeclareMem name $ entryMemSpace entry'
ScalarVar _ entry' ->
emit $ Imp.DeclareScalar name $ entryScalarType entry'
ArrayVar _ _ ->
return ()
addVar name entry
where infoAttr (LetInfo attr) = attr
infoAttr (FParamInfo attr) = noUniquenessReturns attr
infoAttr (LParamInfo attr) = attr
infoAttr (IndexInfo it) = MemPrim $ IntType it
dScope :: Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore op ()
dScope e = mapM_ (uncurry $ dInfo e) . M.toList
dScopes :: [(Maybe (Exp lore), Scope ExplicitMemory)] -> ImpM lore op ()
dScopes = mapM_ $ uncurry dScope
dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore op ()
dArray name bt shape membind = do
entry <- memBoundToVarEntry Nothing $ MemArray bt shape NoUniqueness membind
addVar name entry
everythingVolatile :: ImpM lore op a -> ImpM lore op a
everythingVolatile = local $ \env -> env { envVolatility = Imp.Volatile }
funcallTargets :: Destination -> ImpM lore op [VName]
funcallTargets (Destination _ dests) =
concat <$> mapM funcallTarget dests
where funcallTarget (ScalarDestination name) =
return [name]
funcallTarget (ArrayDestination _) =
return []
funcallTarget (MemoryDestination name) =
return [name]
subExpToDimSize :: SubExp -> ImpM lore op Imp.DimSize
subExpToDimSize (Var v) =
return $ Imp.VarSize v
subExpToDimSize (Constant (IntValue (Int64Value i))) =
return $ Imp.ConstSize $ fromIntegral i
subExpToDimSize (Constant (IntValue (Int32Value i))) =
return $ Imp.ConstSize $ fromIntegral i
subExpToDimSize Constant{} =
compilerBugS "Size subexp is not an int32 or int64 constant."
compileSubExpTo :: VName -> SubExp -> ImpM lore op ()
compileSubExpTo d se = copyDWIM d [] se []
compileSubExp :: SubExp -> ImpM lore op Imp.Exp
compileSubExp (Constant v) =
return $ Imp.ValueExp v
compileSubExp (Var v) = do
t <- lookupType v
case t of
Prim pt -> return $ Imp.var v pt
_ -> compilerBugS $ "compileSubExp: SubExp is not a primitive type: " ++ pretty v
compileSubExpOfType :: PrimType -> SubExp -> Imp.Exp
compileSubExpOfType _ (Constant v) = Imp.ValueExp v
compileSubExpOfType t (Var v) = Imp.var v t
compilePrimExp :: PrimExp VName -> Imp.Exp
compilePrimExp = fmap Imp.ScalarVar
varIndex :: VName -> Imp.Exp
varIndex name = LeafExp (Imp.ScalarVar name) int32
constIndex :: Int -> Imp.Exp
constIndex = fromIntegral
addVar :: VName -> VarEntry lore -> ImpM lore op ()
addVar name entry =
modify $ \s -> s { stateVTable = M.insert name entry $ stateVTable s }
getVTable :: ImpM lore op (VTable lore)
getVTable = gets stateVTable
putVTable :: VTable lore -> ImpM lore op ()
putVTable vtable = modify $ \s -> s { stateVTable = vtable }
localVTable :: (VTable lore -> VTable lore) -> ImpM lore op a -> ImpM lore op a
localVTable f m = do
old_vtable <- getVTable
putVTable $ f old_vtable
a <- m
putVTable old_vtable
return a
lookupVar :: VName -> ImpM lore op (VarEntry lore)
lookupVar name = do
res <- gets $ M.lookup name . stateVTable
case res of
Just entry -> return entry
_ -> compilerBugS $ "Unknown variable: " ++ pretty name
lookupArray :: VName -> ImpM lore op ArrayEntry
lookupArray name = do
res <- lookupVar name
case res of
ArrayVar _ entry -> return entry
_ -> compilerBugS $ "ImpGen.lookupArray: not an array: " ++ pretty name
lookupMemory :: VName -> ImpM lore op MemEntry
lookupMemory name = do
res <- lookupVar name
case res of
MemVar _ entry -> return entry
_ -> compilerBugS $ "Unknown memory block: " ++ pretty name
destinationFromPattern :: ExplicitMemorish lore => Pattern lore -> ImpM lore op Destination
destinationFromPattern pat = fmap (Destination (baseTag <$> maybeHead (patternNames pat))) . mapM inspect $
patternElements pat
where ctx_names = patternContextNames pat
inspect patElem = do
let name = patElemName patElem
entry <- lookupVar name
case entry of
ArrayVar _ (ArrayEntry (MemLocation mem shape ixfun) _) ->
return $ ArrayDestination $
if mem `elem` ctx_names
then Nothing
else Just $ MemLocation mem shape ixfun
MemVar{} ->
return $ MemoryDestination name
ScalarVar{} ->
return $ ScalarDestination name
fullyIndexArray :: VName -> [Imp.Exp]
-> ImpM lore op (VName, Imp.Space, Count Bytes)
fullyIndexArray name indices = do
arr <- lookupArray name
fullyIndexArray' (entryArrayLocation arr) indices $ entryArrayElemType arr
fullyIndexArray' :: MemLocation -> [Imp.Exp] -> PrimType
-> ImpM lore op (VName, Imp.Space, Count Bytes)
fullyIndexArray' (MemLocation mem _ ixfun) indices bt = do
space <- entryMemSpace <$> lookupMemory mem
return (mem, space,
bytes $ IxFun.index ixfun indices $ primByteSize bt)
sliceArray :: MemLocation
-> Slice Imp.Exp
-> MemLocation
sliceArray (MemLocation mem shape ixfun) slice =
MemLocation mem (update shape slice) $ IxFun.slice ixfun slice
where update (d:ds) (DimSlice{}:is) = d : update ds is
update (_:ds) (DimFix{}:is) = update ds is
update _ _ = []
offsetArray :: MemLocation
-> Imp.Exp
-> MemLocation
offsetArray (MemLocation mem shape ixfun) offset =
MemLocation mem shape $ IxFun.offsetIndex ixfun offset
strideArray :: MemLocation
-> Imp.Exp
-> MemLocation
strideArray (MemLocation mem shape ixfun) stride =
MemLocation mem shape $ IxFun.strideIndex ixfun stride
subExpNotArray :: SubExp -> ImpM lore op Bool
subExpNotArray se = subExpType se >>= \case
Array {} -> return False
_ -> return True
arrayOuterSize :: ArrayEntry -> Count Elements
arrayOuterSize = arrayDimSize 0
arrayDimSize :: Int -> ArrayEntry -> Count Elements
arrayDimSize i =
product . map Imp.dimSizeToExp . take 1 . drop i . entryArrayShape
copy :: CopyCompiler lore op
copy bt pat src n = do
cc <- asks envCopyCompiler
cc bt pat src n
defaultCopy :: CopyCompiler lore op
defaultCopy bt dest src n
| ixFunMatchesInnerShape
(Shape $ map dimSizeToExp destshape) destIxFun,
ixFunMatchesInnerShape
(Shape $ map dimSizeToExp srcshape) srcIxFun,
Just destoffset <-
IxFun.linearWithOffset destIxFun bt_size,
Just srcoffset <-
IxFun.linearWithOffset srcIxFun bt_size = do
srcspace <- entryMemSpace <$> lookupMemory srcmem
destspace <- entryMemSpace <$> lookupMemory destmem
emit $ Imp.Copy
destmem (bytes destoffset) destspace
srcmem (bytes srcoffset) srcspace $
(n * row_size) `withElemType` bt
| otherwise =
copyElementWise bt dest src n
where bt_size = primByteSize bt
row_size = product $ map Imp.dimSizeToExp $ drop 1 srcshape
MemLocation destmem destshape destIxFun = dest
MemLocation srcmem srcshape srcIxFun = src
copyElementWise :: CopyCompiler lore op
copyElementWise bt (MemLocation destmem _ destIxFun) (MemLocation srcmem srcshape srcIxFun) n = do
is <- replicateM (IxFun.rank destIxFun) (newVName "i")
let ivars = map varIndex is
destidx = IxFun.index destIxFun ivars bt_size
srcidx = IxFun.index srcIxFun ivars bt_size
bounds = map innerExp $ n : drop 1 (map Imp.dimSizeToExp srcshape)
srcspace <- entryMemSpace <$> lookupMemory srcmem
destspace <- entryMemSpace <$> lookupMemory destmem
vol <- asks envVolatility
emit $ foldl (.) id (zipWith (`Imp.For` Int32) is bounds) $
Imp.Write destmem (bytes destidx) bt destspace vol $
Imp.index srcmem (bytes srcidx) bt srcspace vol
where bt_size = primByteSize bt
copyArrayDWIM :: PrimType
-> MemLocation -> [Imp.Exp]
-> MemLocation -> [Imp.Exp]
-> ImpM lore op (Imp.Code op)
copyArrayDWIM bt
destlocation@(MemLocation _ destshape dest_ixfun) destis
srclocation@(MemLocation _ srcshape src_ixfun) srcis
| length srcis == length srcshape, length destis == length destshape = do
(targetmem, destspace, targetoffset) <-
fullyIndexArray' destlocation destis bt
(srcmem, srcspace, srcoffset) <-
fullyIndexArray' srclocation srcis bt
vol <- asks envVolatility
return $ Imp.Write targetmem targetoffset bt destspace vol $
Imp.index srcmem srcoffset bt srcspace vol
| otherwise = do
let destlocation' =
sliceArray destlocation $
fullSliceNum (IxFun.shape dest_ixfun) $ map DimFix destis
srclocation' =
sliceArray srclocation $
fullSliceNum (IxFun.shape src_ixfun) $ map DimFix srcis
if destlocation' == srclocation'
then return mempty
else collect $ copy bt destlocation' srclocation' $
product $ map Imp.dimSizeToExp $
take 1 $ drop (length srcis) srcshape
copyDWIMDest :: ValueDestination -> [Imp.Exp] -> SubExp -> [Imp.Exp]
-> ImpM lore op ()
copyDWIMDest _ _ (Constant v) (_:_) =
compilerBugS $
unwords ["copyDWIMDest: constant source", pretty v, "cannot be indexed."]
copyDWIMDest pat dest_is (Constant v) [] =
case pat of
ScalarDestination name ->
emit $ Imp.SetScalar name $ Imp.ValueExp v
MemoryDestination{} ->
compilerBugS $
unwords ["copyDWIMDest: constant source", pretty v, "cannot be written to memory destination."]
ArrayDestination (Just dest_loc) -> do
(dest_mem, dest_space, dest_i) <-
fullyIndexArray' dest_loc dest_is bt
vol <- asks envVolatility
emit $ Imp.Write dest_mem dest_i bt dest_space vol $ Imp.ValueExp v
ArrayDestination Nothing ->
compilerBugS "copyDWIMDest: ArrayDestination Nothing"
where bt = primValueType v
copyDWIMDest dest dest_is (Var src) src_is = do
src_entry <- lookupVar src
case (dest, src_entry) of
(MemoryDestination mem, MemVar _ (MemEntry _ space)) ->
emit $ Imp.SetMem mem src space
(MemoryDestination{}, _) ->
compilerBugS $
unwords ["copyDWIMDest: cannot write", pretty src, "to memory destination."]
(_, MemVar{}) ->
compilerBugS $
unwords ["copyDWIMDest: source", pretty src, "is a memory block."]
(_, ScalarVar _ (ScalarEntry _)) | not $ null src_is ->
compilerBugS $
unwords ["copyDWIMDest: prim-typed source", pretty src, "with nonzero indices."]
(ScalarDestination name, _) | not $ null dest_is ->
compilerBugS $
unwords ["copyDWIMDest: prim-typed target", pretty name, "with nonzero indices."]
(ScalarDestination name, ScalarVar _ (ScalarEntry pt)) ->
emit $ Imp.SetScalar name $ Imp.var src pt
(ScalarDestination name, ArrayVar _ arr) -> do
let bt = entryArrayElemType arr
(mem, space, i) <-
fullyIndexArray' (entryArrayLocation arr) src_is bt
vol <- asks envVolatility
emit $ Imp.SetScalar name $ Imp.index mem i bt space vol
(ArrayDestination (Just dest_loc), ArrayVar _ src_arr) -> do
let src_loc = entryArrayLocation src_arr
bt = entryArrayElemType src_arr
emit =<< copyArrayDWIM bt dest_loc dest_is src_loc src_is
(ArrayDestination (Just dest_loc), ScalarVar _ (ScalarEntry bt)) -> do
(dest_mem, dest_space, dest_i) <-
fullyIndexArray' dest_loc dest_is bt
vol <- asks envVolatility
emit $ Imp.Write dest_mem dest_i bt dest_space vol (Imp.var src bt)
(ArrayDestination Nothing, _) ->
return ()
copyDWIM :: VName -> [Imp.Exp] -> SubExp -> [Imp.Exp]
-> ImpM lore op ()
copyDWIM dest dest_is src src_is = do
dest_entry <- lookupVar dest
let dest_target =
case dest_entry of
ScalarVar _ _ ->
ScalarDestination dest
ArrayVar _ (ArrayEntry (MemLocation mem shape ixfun) _) ->
ArrayDestination $ Just $ MemLocation mem shape ixfun
MemVar _ _ ->
MemoryDestination dest
copyDWIMDest dest_target dest_is src src_is
compileAlloc :: ExplicitMemorish lore =>
Pattern lore -> SubExp -> Space
-> ImpM lore op ()
compileAlloc (Pattern [] [mem]) e space = do
e' <- compileSubExp e
fake <- asks $ elem space . envFakeMemory
unless fake $ emit $ Imp.Allocate (patElemName mem) (Imp.bytes e') space
compileAlloc pat _ _ =
compilerBugS $ "compileAlloc: Invalid pattern: " ++ pretty pat
dimSizeToSubExp :: Imp.Size -> SubExp
dimSizeToSubExp (Imp.ConstSize n) = constant n
dimSizeToSubExp (Imp.VarSize v) = Var v
dimSizeToExp :: Imp.Size -> Imp.Exp
dimSizeToExp = compilePrimExp . primExpFromSubExp int32 . dimSizeToSubExp
sFor :: VName -> IntType -> Imp.Exp -> ImpM lore op () -> ImpM lore op ()
sFor i it bound body = do
addLoopVar i it
body' <- collect body
emit $ Imp.For i it bound body'
sWhile :: Imp.Exp -> ImpM lore op () -> ImpM lore op ()
sWhile cond body = do
body' <- collect body
emit $ Imp.While cond body'
sComment :: String -> ImpM lore op () -> ImpM lore op ()
sComment s code = do
code' <- collect code
emit $ Imp.Comment s code'
sIf :: Imp.Exp -> ImpM lore op () -> ImpM lore op () -> ImpM lore op ()
sIf cond tbranch fbranch = do
tbranch' <- collect tbranch
fbranch' <- collect fbranch
emit $ Imp.If cond tbranch' fbranch'
sWhen :: Imp.Exp -> ImpM lore op () -> ImpM lore op ()
sWhen cond tbranch = sIf cond tbranch (return ())
sUnless :: Imp.Exp -> ImpM lore op () -> ImpM lore op ()
sUnless cond = sIf cond (return ())
sOp :: op -> ImpM lore op ()
sOp = emit . Imp.Op
sAlloc :: String -> Count Bytes -> Space -> ImpM lore op VName
sAlloc name size space = do
name' <- newVName name
size' <- case Imp.innerExp size of
Imp.LeafExp (Imp.ScalarVar size') _ -> return $ Imp.VarSize size'
Imp.ValueExp (IntValue (Int64Value v)) -> return $ Imp.ConstSize v
_ -> do size_var <- dPrim "local_buf_size" int32
size_var <-- Imp.innerExp size
return $ Imp.VarSize size_var
emit $ Imp.DeclareMem name' space
fake <- asks $ elem space . envFakeMemory
unless fake $ emit $ Imp.Allocate name' size space
addVar name' $ MemVar Nothing $ MemEntry size' space
return name'
sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore op VName
sArray name bt shape membind = do
name' <- newVName name
dArray name' bt shape membind
return name'
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore op VName
sAllocArray name pt shape space = do
let arr_bytes = Imp.bytes $ Imp.LeafExp (Imp.SizeOf pt) int32 *
product (map (compileSubExpOfType int32) (shapeDims shape))
mem <- sAlloc (name ++ "_mem") arr_bytes space
sArray name pt shape $
ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape
sStaticArray :: String -> Space -> PrimType -> [PrimValue] -> ImpM lore op VName
sStaticArray name space pt vs = do
let shape = Shape [constant $ length vs]
size = Imp.ConstSize $ fromIntegral (length vs) * primByteSize pt
mem <- newVName $ name ++ "_mem"
emit $ Imp.DeclareArray mem space pt vs
addVar mem $ MemVar Nothing $ MemEntry size space
sArray name pt shape $
ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape
sWrite :: VName -> [Imp.Exp] -> PrimExp Imp.ExpLeaf -> ImpM lore op ()
sWrite arr is v = do
(mem, space, offset) <- fullyIndexArray arr is
vol <- asks envVolatility
emit $ Imp.Write mem offset (primExpType v) space vol v
(<--) :: VName -> Imp.Exp -> ImpM lore op ()
x <-- e = emit $ Imp.SetScalar x e
infixl 3 <--