1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SCF dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Transforms/DialectConversion.h"
20
21 using namespace mlir;
22
23 //===----------------------------------------------------------------------===//
24 // Context
25 //===----------------------------------------------------------------------===//
26
27 namespace mlir {
28 struct ScfToSPIRVContextImpl {
29 // Map between the spirv region control flow operation (spv.mlir.loop or
30 // spv.mlir.selection) to the VariableOp created to store the region results.
31 // The order of the VariableOp matches the order of the results.
32 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
33 };
34 } // namespace mlir
35
36 /// We use ScfToSPIRVContext to store information about the lowering of the scf
37 /// region that need to be used later on. When we lower scf.for/scf.if we create
38 /// VariableOp to store the results. We need to keep track of the VariableOp
39 /// created as we need to insert stores into them when lowering Yield. Those
40 /// StoreOp cannot be created earlier as they may use a different type than
41 /// yield operands.
ScfToSPIRVContext()42 ScfToSPIRVContext::ScfToSPIRVContext() {
43 impl = std::make_unique<ScfToSPIRVContextImpl>();
44 }
45
46 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
47
48 //===----------------------------------------------------------------------===//
49 // Pattern Declarations
50 //===----------------------------------------------------------------------===//
51
52 namespace {
53 /// Common class for all vector to GPU patterns.
54 template <typename OpTy>
55 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
56 public:
SCFToSPIRVPattern(MLIRContext * context,SPIRVTypeConverter & converter,ScfToSPIRVContextImpl * scfToSPIRVContext)57 SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
58 ScfToSPIRVContextImpl *scfToSPIRVContext)
59 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
60 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
61
62 protected:
63 ScfToSPIRVContextImpl *scfToSPIRVContext;
64 // FIXME: We explicitly keep a reference of the type converter here instead of
65 // passing it to OpConversionPattern during construction. This effectively
66 // bypasses the conversion framework's automation on type conversion. This is
67 // needed right now because the conversion framework will unconditionally
68 // legalize all types used by SCF ops upon discovering them, for example, the
69 // types of loop carried values. We use SPIR-V variables for those loop
70 // carried values. Depending on the available capabilities, the SPIR-V
71 // variable can be different, for example, cooperative matrix or normal
72 // variable. We'd like to detach the conversion of the loop carried values
73 // from the SCF ops (which is mainly a region). So we need to "mark" types
74 // used by SCF ops as legal, if to use the conversion framework for type
75 // conversion. There isn't a straightforward way to do that yet, as when
76 // converting types, ops aren't taken into consideration. Therefore, we just
77 // bypass the framework's type conversion for now.
78 SPIRVTypeConverter &typeConverter;
79 };
80
81 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
82 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
83 public:
84 using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
85
86 LogicalResult
87 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override;
89 };
90
91 /// Pattern to convert a scf::IfOp within kernel functions into
92 /// spirv::SelectionOp.
93 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
94 public:
95 using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
96
97 LogicalResult
98 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
99 ConversionPatternRewriter &rewriter) const override;
100 };
101
102 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
103 public:
104 using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
105
106 LogicalResult
107 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override;
109 };
110
111 class WhileOpConversion final : public SCFToSPIRVPattern<scf::WhileOp> {
112 public:
113 using SCFToSPIRVPattern<scf::WhileOp>::SCFToSPIRVPattern;
114
115 LogicalResult
116 matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor,
117 ConversionPatternRewriter &rewriter) const override;
118 };
119 } // namespace
120
121 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
122 /// We create VariableOp to handle the results value of the control flow region.
123 /// spv.mlir.loop/spv.mlir.selection currently don't yield value. Right after
124 /// the loop we load the value from the allocation and use it as the SCF op
125 /// result.
126 template <typename ScfOp, typename OpTy>
replaceSCFOutputValue(ScfOp scfOp,OpTy newOp,ConversionPatternRewriter & rewriter,ScfToSPIRVContextImpl * scfToSPIRVContext,ArrayRef<Type> returnTypes)127 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
128 ConversionPatternRewriter &rewriter,
129 ScfToSPIRVContextImpl *scfToSPIRVContext,
130 ArrayRef<Type> returnTypes) {
131
132 Location loc = scfOp.getLoc();
133 auto &allocas = scfToSPIRVContext->outputVars[newOp];
134 // Clearing the allocas is necessary in case a dialect conversion path failed
135 // previously, and this is the second attempt of this conversion.
136 allocas.clear();
137 SmallVector<Value, 8> resultValue;
138 for (Type convertedType : returnTypes) {
139 auto pointerType =
140 spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
141 rewriter.setInsertionPoint(newOp);
142 auto alloc = rewriter.create<spirv::VariableOp>(
143 loc, pointerType, spirv::StorageClass::Function,
144 /*initializer=*/nullptr);
145 allocas.push_back(alloc);
146 rewriter.setInsertionPointAfter(newOp);
147 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
148 resultValue.push_back(loadResult);
149 }
150 rewriter.replaceOp(scfOp, resultValue);
151 }
152
getBlockIt(Region & region,unsigned index)153 static Region::iterator getBlockIt(Region ®ion, unsigned index) {
154 return std::next(region.begin(), index);
155 }
156
157 //===----------------------------------------------------------------------===//
158 // scf::ForOp
159 //===----------------------------------------------------------------------===//
160
161 LogicalResult
matchAndRewrite(scf::ForOp forOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const162 ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
163 ConversionPatternRewriter &rewriter) const {
164 // scf::ForOp can be lowered to the structured control flow represented by
165 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
166 // latch and the merge block the exit block. The resulting spirv::LoopOp has a
167 // single back edge from the continue to header block, and a single exit from
168 // header to merge.
169 auto loc = forOp.getLoc();
170 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
171 loopOp.addEntryAndMergeBlock();
172
173 OpBuilder::InsertionGuard guard(rewriter);
174 // Create the block for the header.
175 auto *header = new Block();
176 // Insert the header.
177 loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header);
178
179 // Create the new induction variable to use.
180 Value adapLowerBound = adaptor.getLowerBound();
181 BlockArgument newIndVar =
182 header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
183 for (Value arg : adaptor.getInitArgs())
184 header->addArgument(arg.getType(), arg.getLoc());
185 Block *body = forOp.getBody();
186
187 // Apply signature conversion to the body of the forOp. It has a single block,
188 // with argument which is the induction variable. That has to be replaced with
189 // the new induction variable.
190 TypeConverter::SignatureConversion signatureConverter(
191 body->getNumArguments());
192 signatureConverter.remapInput(0, newIndVar);
193 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
194 signatureConverter.remapInput(i, header->getArgument(i));
195 body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
196 signatureConverter);
197
198 // Move the blocks from the forOp into the loopOp. This is the body of the
199 // loopOp.
200 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
201 getBlockIt(loopOp.body(), 2));
202
203 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
204 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
205 // Branch into it from the entry.
206 rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
207 rewriter.create<spirv::BranchOp>(loc, header, args);
208
209 // Generate the rest of the loop header.
210 rewriter.setInsertionPointToEnd(header);
211 auto *mergeBlock = loopOp.getMergeBlock();
212 auto cmpOp = rewriter.create<spirv::SLessThanOp>(
213 loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
214
215 rewriter.create<spirv::BranchConditionalOp>(
216 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
217
218 // Generate instructions to increment the step of the induction variable and
219 // branch to the header.
220 Block *continueBlock = loopOp.getContinueBlock();
221 rewriter.setInsertionPointToEnd(continueBlock);
222
223 // Add the step to the induction variable and branch to the header.
224 Value updatedIndVar = rewriter.create<spirv::IAddOp>(
225 loc, newIndVar.getType(), newIndVar, adaptor.getStep());
226 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
227
228 // Infer the return types from the init operands. Vector type may get
229 // converted to CooperativeMatrix or to Vector type, to avoid having complex
230 // extra logic to figure out the right type we just infer it from the Init
231 // operands.
232 SmallVector<Type, 8> initTypes;
233 for (auto arg : adaptor.getInitArgs())
234 initTypes.push_back(arg.getType());
235 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
236 return success();
237 }
238
239 //===----------------------------------------------------------------------===//
240 // scf::IfOp
241 //===----------------------------------------------------------------------===//
242
243 LogicalResult
matchAndRewrite(scf::IfOp ifOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const244 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const {
246 // When lowering `scf::IfOp` we explicitly create a selection header block
247 // before the control flow diverges and a merge block where control flow
248 // subsequently converges.
249 auto loc = ifOp.getLoc();
250
251 // Create `spv.selection` operation, selection header block and merge block.
252 auto selectionOp =
253 rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
254 auto *mergeBlock =
255 rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
256 rewriter.create<spirv::MergeOp>(loc);
257
258 OpBuilder::InsertionGuard guard(rewriter);
259 auto *selectionHeaderBlock =
260 rewriter.createBlock(&selectionOp.body().front());
261
262 // Inline `then` region before the merge block and branch to it.
263 auto &thenRegion = ifOp.getThenRegion();
264 auto *thenBlock = &thenRegion.front();
265 rewriter.setInsertionPointToEnd(&thenRegion.back());
266 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
267 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
268
269 auto *elseBlock = mergeBlock;
270 // If `else` region is not empty, inline that region before the merge block
271 // and branch to it.
272 if (!ifOp.getElseRegion().empty()) {
273 auto &elseRegion = ifOp.getElseRegion();
274 elseBlock = &elseRegion.front();
275 rewriter.setInsertionPointToEnd(&elseRegion.back());
276 rewriter.create<spirv::BranchOp>(loc, mergeBlock);
277 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
278 }
279
280 // Create a `spv.BranchConditional` operation for selection header block.
281 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
282 rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
283 thenBlock, ArrayRef<Value>(),
284 elseBlock, ArrayRef<Value>());
285
286 SmallVector<Type, 8> returnTypes;
287 for (auto result : ifOp.getResults()) {
288 auto convertedType = typeConverter.convertType(result.getType());
289 returnTypes.push_back(convertedType);
290 }
291 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
292 returnTypes);
293 return success();
294 }
295
296 //===----------------------------------------------------------------------===//
297 // scf::YieldOp
298 //===----------------------------------------------------------------------===//
299
300 /// Yield is lowered to stores to the VariableOp created during lowering of the
301 /// parent region. For loops we also need to update the branch looping back to
302 /// the header with the loop carried values.
matchAndRewrite(scf::YieldOp terminatorOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const303 LogicalResult TerminatorOpConversion::matchAndRewrite(
304 scf::YieldOp terminatorOp, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const {
306 ValueRange operands = adaptor.getOperands();
307
308 // If the region is return values, store each value into the associated
309 // VariableOp created during lowering of the parent region.
310 if (!operands.empty()) {
311 auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
312 if (allocas.size() != operands.size())
313 return failure();
314
315 auto loc = terminatorOp.getLoc();
316 for (unsigned i = 0, e = operands.size(); i < e; i++)
317 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
318 if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
319 // For loops we also need to update the branch jumping back to the header.
320 auto br =
321 cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
322 SmallVector<Value, 8> args(br.getBlockArguments());
323 args.append(operands.begin(), operands.end());
324 rewriter.setInsertionPoint(br);
325 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
326 args);
327 rewriter.eraseOp(br);
328 }
329 }
330 rewriter.eraseOp(terminatorOp);
331 return success();
332 }
333
334 //===----------------------------------------------------------------------===//
335 // scf::WhileOp
336 //===----------------------------------------------------------------------===//
337
338 LogicalResult
matchAndRewrite(scf::WhileOp whileOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const339 WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const {
341 auto loc = whileOp.getLoc();
342 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
343 loopOp.addEntryAndMergeBlock();
344
345 OpBuilder::InsertionGuard guard(rewriter);
346
347 Region &beforeRegion = whileOp.getBefore();
348 Region &afterRegion = whileOp.getAfter();
349
350 Block &entryBlock = *loopOp.getEntryBlock();
351 Block &beforeBlock = beforeRegion.front();
352 Block &afterBlock = afterRegion.front();
353 Block &mergeBlock = *loopOp.getMergeBlock();
354
355 auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
356 SmallVector<Value> condArgs;
357 if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
358 return failure();
359
360 Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
361 if (!conditionVal)
362 return failure();
363
364 auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
365 SmallVector<Value> yieldArgs;
366 if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
367 return failure();
368
369 // Move the while before block as the initial loop header block.
370 rewriter.inlineRegionBefore(beforeRegion, loopOp.body(),
371 getBlockIt(loopOp.body(), 1));
372
373 // Move the while after block as the initial loop body block.
374 rewriter.inlineRegionBefore(afterRegion, loopOp.body(),
375 getBlockIt(loopOp.body(), 2));
376
377 // Jump from the loop entry block to the loop header block.
378 rewriter.setInsertionPointToEnd(&entryBlock);
379 rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
380
381 auto condLoc = cond.getLoc();
382
383 SmallVector<Value> resultValues(condArgs.size());
384
385 // For other SCF ops, the scf.yield op yields the value for the whole SCF op.
386 // So we use the scf.yield op as the anchor to create/load/store SPIR-V local
387 // variables. But for the scf.while op, the scf.yield op yields a value for
388 // the before region, which may not matching the whole op's result. Instead,
389 // the scf.condition op returns values matching the whole op's results. So we
390 // need to create/load/store variables according to that.
391 for (const auto &it : llvm::enumerate(condArgs)) {
392 auto res = it.value();
393 auto i = it.index();
394 auto pointerType =
395 spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
396
397 // Create local variables before the scf.while op.
398 rewriter.setInsertionPoint(loopOp);
399 auto alloc = rewriter.create<spirv::VariableOp>(
400 condLoc, pointerType, spirv::StorageClass::Function,
401 /*initializer=*/nullptr);
402
403 // Load the final result values after the scf.while op.
404 rewriter.setInsertionPointAfter(loopOp);
405 auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
406 resultValues[i] = loadResult;
407
408 // Store the current iteration's result value.
409 rewriter.setInsertionPointToEnd(&beforeBlock);
410 rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
411 }
412
413 rewriter.setInsertionPointToEnd(&beforeBlock);
414 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
415 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, llvm::None);
416
417 // Convert the scf.yield op to a branch back to the header block.
418 rewriter.setInsertionPointToEnd(&afterBlock);
419 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, yieldArgs);
420
421 rewriter.replaceOp(whileOp, resultValues);
422 return success();
423 }
424
425 //===----------------------------------------------------------------------===//
426 // Hooks
427 //===----------------------------------------------------------------------===//
428
populateSCFToSPIRVPatterns(SPIRVTypeConverter & typeConverter,ScfToSPIRVContext & scfToSPIRVContext,RewritePatternSet & patterns)429 void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
430 ScfToSPIRVContext &scfToSPIRVContext,
431 RewritePatternSet &patterns) {
432 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
433 WhileOpConversion>(patterns.getContext(), typeConverter,
434 scfToSPIRVContext.getImpl());
435 }
436