1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SCF dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Context
25 //===----------------------------------------------------------------------===//
26 
27 namespace mlir {
28 struct ScfToSPIRVContextImpl {
29   // Map between the spirv region control flow operation (spv.mlir.loop or
30   // spv.mlir.selection) to the VariableOp created to store the region results.
31   // The order of the VariableOp matches the order of the results.
32   DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
33 };
34 } // namespace mlir
35 
36 /// We use ScfToSPIRVContext to store information about the lowering of the scf
37 /// region that need to be used later on. When we lower scf.for/scf.if we create
38 /// VariableOp to store the results. We need to keep track of the VariableOp
39 /// created as we need to insert stores into them when lowering Yield. Those
40 /// StoreOp cannot be created earlier as they may use a different type than
41 /// yield operands.
42 ScfToSPIRVContext::ScfToSPIRVContext() {
43   impl = std::make_unique<ScfToSPIRVContextImpl>();
44 }
45 
46 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
47 
48 //===----------------------------------------------------------------------===//
49 // Pattern Declarations
50 //===----------------------------------------------------------------------===//
51 
52 namespace {
53 /// Common class for all vector to GPU patterns.
54 template <typename OpTy>
55 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
56 public:
57   SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
58                           ScfToSPIRVContextImpl *scfToSPIRVContext)
59       : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
60         scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
61 
62 protected:
63   ScfToSPIRVContextImpl *scfToSPIRVContext;
64   // FIXME: We explicitly keep a reference of the type converter here instead of
65   // passing it to OpConversionPattern during construction. This effectively
66   // bypasses the conversion framework's automation on type conversion. This is
67   // needed right now because the conversion framework will unconditionally
68   // legalize all types used by SCF ops upon discovering them, for example, the
69   // types of loop carried values. We use SPIR-V variables for those loop
70   // carried values. Depending on the available capabilities, the SPIR-V
71   // variable can be different, for example, cooperative matrix or normal
72   // variable. We'd like to detach the conversion of the loop carried values
73   // from the SCF ops (which is mainly a region). So we need to "mark" types
74   // used by SCF ops as legal, if to use the conversion framework for type
75   // conversion. There isn't a straightforward way to do that yet, as when
76   // converting types, ops aren't taken into consideration. Therefore, we just
77   // bypass the framework's type conversion for now.
78   SPIRVTypeConverter &typeConverter;
79 };
80 
81 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
82 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
83 public:
84   using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
85 
86   LogicalResult
87   matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override;
89 };
90 
91 /// Pattern to convert a scf::IfOp within kernel functions into
92 /// spirv::SelectionOp.
93 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
94 public:
95   using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
96 
97   LogicalResult
98   matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override;
100 };
101 
102 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
103 public:
104   using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
105 
106   LogicalResult
107   matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override;
109 };
110 
111 class WhileOpConversion final : public SCFToSPIRVPattern<scf::WhileOp> {
112 public:
113   using SCFToSPIRVPattern<scf::WhileOp>::SCFToSPIRVPattern;
114 
115   LogicalResult
116   matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor,
117                   ConversionPatternRewriter &rewriter) const override;
118 };
119 } // namespace
120 
121 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
122 /// We create VariableOp to handle the results value of the control flow region.
123 /// spv.mlir.loop/spv.mlir.selection currently don't yield value. Right after
124 /// the loop we load the value from the allocation and use it as the SCF op
125 /// result.
126 template <typename ScfOp, typename OpTy>
127 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
128                                   ConversionPatternRewriter &rewriter,
129                                   ScfToSPIRVContextImpl *scfToSPIRVContext,
130                                   ArrayRef<Type> returnTypes) {
131 
132   Location loc = scfOp.getLoc();
133   auto &allocas = scfToSPIRVContext->outputVars[newOp];
134   // Clearing the allocas is necessary in case a dialect conversion path failed
135   // previously, and this is the second attempt of this conversion.
136   allocas.clear();
137   SmallVector<Value, 8> resultValue;
138   for (Type convertedType : returnTypes) {
139     auto pointerType =
140         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
141     rewriter.setInsertionPoint(newOp);
142     auto alloc = rewriter.create<spirv::VariableOp>(
143         loc, pointerType, spirv::StorageClass::Function,
144         /*initializer=*/nullptr);
145     allocas.push_back(alloc);
146     rewriter.setInsertionPointAfter(newOp);
147     Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
148     resultValue.push_back(loadResult);
149   }
150   rewriter.replaceOp(scfOp, resultValue);
151 }
152 
153 static Region::iterator getBlockIt(Region &region, unsigned index) {
154   return std::next(region.begin(), index);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // scf::ForOp
159 //===----------------------------------------------------------------------===//
160 
161 LogicalResult
162 ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
163                                  ConversionPatternRewriter &rewriter) const {
164   // scf::ForOp can be lowered to the structured control flow represented by
165   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
166   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
167   // single back edge from the continue to header block, and a single exit from
168   // header to merge.
169   auto loc = forOp.getLoc();
170   auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
171   loopOp.addEntryAndMergeBlock();
172 
173   OpBuilder::InsertionGuard guard(rewriter);
174   // Create the block for the header.
175   auto *header = new Block();
176   // Insert the header.
177   loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header);
178 
179   // Create the new induction variable to use.
180   BlockArgument newIndVar =
181       header->addArgument(adaptor.getLowerBound().getType());
182   for (Value arg : adaptor.getInitArgs())
183     header->addArgument(arg.getType());
184   Block *body = forOp.getBody();
185 
186   // Apply signature conversion to the body of the forOp. It has a single block,
187   // with argument which is the induction variable. That has to be replaced with
188   // the new induction variable.
189   TypeConverter::SignatureConversion signatureConverter(
190       body->getNumArguments());
191   signatureConverter.remapInput(0, newIndVar);
192   for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
193     signatureConverter.remapInput(i, header->getArgument(i));
194   body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
195                                            signatureConverter);
196 
197   // Move the blocks from the forOp into the loopOp. This is the body of the
198   // loopOp.
199   rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
200                               getBlockIt(loopOp.body(), 2));
201 
202   SmallVector<Value, 8> args(1, adaptor.getLowerBound());
203   args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
204   // Branch into it from the entry.
205   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
206   rewriter.create<spirv::BranchOp>(loc, header, args);
207 
208   // Generate the rest of the loop header.
209   rewriter.setInsertionPointToEnd(header);
210   auto *mergeBlock = loopOp.getMergeBlock();
211   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
212       loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
213 
214   rewriter.create<spirv::BranchConditionalOp>(
215       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
216 
217   // Generate instructions to increment the step of the induction variable and
218   // branch to the header.
219   Block *continueBlock = loopOp.getContinueBlock();
220   rewriter.setInsertionPointToEnd(continueBlock);
221 
222   // Add the step to the induction variable and branch to the header.
223   Value updatedIndVar = rewriter.create<spirv::IAddOp>(
224       loc, newIndVar.getType(), newIndVar, adaptor.getStep());
225   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
226 
227   // Infer the return types from the init operands. Vector type may get
228   // converted to CooperativeMatrix or to Vector type, to avoid having complex
229   // extra logic to figure out the right type we just infer it from the Init
230   // operands.
231   SmallVector<Type, 8> initTypes;
232   for (auto arg : adaptor.getInitArgs())
233     initTypes.push_back(arg.getType());
234   replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
235   return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // scf::IfOp
240 //===----------------------------------------------------------------------===//
241 
242 LogicalResult
243 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
244                                 ConversionPatternRewriter &rewriter) const {
245   // When lowering `scf::IfOp` we explicitly create a selection header block
246   // before the control flow diverges and a merge block where control flow
247   // subsequently converges.
248   auto loc = ifOp.getLoc();
249 
250   // Create `spv.selection` operation, selection header block and merge block.
251   auto selectionOp =
252       rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
253   auto *mergeBlock =
254       rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
255   rewriter.create<spirv::MergeOp>(loc);
256 
257   OpBuilder::InsertionGuard guard(rewriter);
258   auto *selectionHeaderBlock =
259       rewriter.createBlock(&selectionOp.body().front());
260 
261   // Inline `then` region before the merge block and branch to it.
262   auto &thenRegion = ifOp.getThenRegion();
263   auto *thenBlock = &thenRegion.front();
264   rewriter.setInsertionPointToEnd(&thenRegion.back());
265   rewriter.create<spirv::BranchOp>(loc, mergeBlock);
266   rewriter.inlineRegionBefore(thenRegion, mergeBlock);
267 
268   auto *elseBlock = mergeBlock;
269   // If `else` region is not empty, inline that region before the merge block
270   // and branch to it.
271   if (!ifOp.getElseRegion().empty()) {
272     auto &elseRegion = ifOp.getElseRegion();
273     elseBlock = &elseRegion.front();
274     rewriter.setInsertionPointToEnd(&elseRegion.back());
275     rewriter.create<spirv::BranchOp>(loc, mergeBlock);
276     rewriter.inlineRegionBefore(elseRegion, mergeBlock);
277   }
278 
279   // Create a `spv.BranchConditional` operation for selection header block.
280   rewriter.setInsertionPointToEnd(selectionHeaderBlock);
281   rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
282                                               thenBlock, ArrayRef<Value>(),
283                                               elseBlock, ArrayRef<Value>());
284 
285   SmallVector<Type, 8> returnTypes;
286   for (auto result : ifOp.getResults()) {
287     auto convertedType = typeConverter.convertType(result.getType());
288     returnTypes.push_back(convertedType);
289   }
290   replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
291                         returnTypes);
292   return success();
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // scf::YieldOp
297 //===----------------------------------------------------------------------===//
298 
299 /// Yield is lowered to stores to the VariableOp created during lowering of the
300 /// parent region. For loops we also need to update the branch looping back to
301 /// the header with the loop carried values.
302 LogicalResult TerminatorOpConversion::matchAndRewrite(
303     scf::YieldOp terminatorOp, OpAdaptor adaptor,
304     ConversionPatternRewriter &rewriter) const {
305   ValueRange operands = adaptor.getOperands();
306 
307   // If the region is return values, store each value into the associated
308   // VariableOp created during lowering of the parent region.
309   if (!operands.empty()) {
310     auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
311     if (allocas.size() != operands.size())
312       return failure();
313 
314     auto loc = terminatorOp.getLoc();
315     for (unsigned i = 0, e = operands.size(); i < e; i++)
316       rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
317     if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
318       // For loops we also need to update the branch jumping back to the header.
319       auto br =
320           cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
321       SmallVector<Value, 8> args(br.getBlockArguments());
322       args.append(operands.begin(), operands.end());
323       rewriter.setInsertionPoint(br);
324       rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
325                                        args);
326       rewriter.eraseOp(br);
327     }
328   }
329   rewriter.eraseOp(terminatorOp);
330   return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // scf::WhileOp
335 //===----------------------------------------------------------------------===//
336 
337 LogicalResult
338 WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
339                                    ConversionPatternRewriter &rewriter) const {
340   auto loc = whileOp.getLoc();
341   auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
342   loopOp.addEntryAndMergeBlock();
343 
344   OpBuilder::InsertionGuard guard(rewriter);
345 
346   Region &beforeRegion = whileOp.getBefore();
347   Region &afterRegion = whileOp.getAfter();
348 
349   Block &entryBlock = *loopOp.getEntryBlock();
350   Block &beforeBlock = beforeRegion.front();
351   Block &afterBlock = afterRegion.front();
352   Block &mergeBlock = *loopOp.getMergeBlock();
353 
354   auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
355   SmallVector<Value> condArgs;
356   if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
357     return failure();
358 
359   Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
360   if (!conditionVal)
361     return failure();
362 
363   auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
364   SmallVector<Value> yieldArgs;
365   if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
366     return failure();
367 
368   // Move the while before block as the initial loop header block.
369   rewriter.inlineRegionBefore(beforeRegion, loopOp.body(),
370                               getBlockIt(loopOp.body(), 1));
371 
372   // Move the while after block as the initial loop body block.
373   rewriter.inlineRegionBefore(afterRegion, loopOp.body(),
374                               getBlockIt(loopOp.body(), 2));
375 
376   // Jump from the loop entry block to the loop header block.
377   rewriter.setInsertionPointToEnd(&entryBlock);
378   rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
379 
380   auto condLoc = cond.getLoc();
381 
382   SmallVector<Value> resultValues(condArgs.size());
383 
384   // For other SCF ops, the scf.yield op yields the value for the whole SCF op.
385   // So we use the scf.yield op as the anchor to create/load/store SPIR-V local
386   // variables. But for the scf.while op, the scf.yield op yields a value for
387   // the before region, which may not matching the whole op's result. Instead,
388   // the scf.condition op returns values matching the whole op's results. So we
389   // need to create/load/store variables according to that.
390   for (const auto &it : llvm::enumerate(condArgs)) {
391     auto res = it.value();
392     auto i = it.index();
393     auto pointerType =
394         spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
395 
396     // Create local variables before the scf.while op.
397     rewriter.setInsertionPoint(loopOp);
398     auto alloc = rewriter.create<spirv::VariableOp>(
399         condLoc, pointerType, spirv::StorageClass::Function,
400         /*initializer=*/nullptr);
401 
402     // Load the final result values after the scf.while op.
403     rewriter.setInsertionPointAfter(loopOp);
404     auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
405     resultValues[i] = loadResult;
406 
407     // Store the current iteration's result value.
408     rewriter.setInsertionPointToEnd(&beforeBlock);
409     rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
410   }
411 
412   rewriter.setInsertionPointToEnd(&beforeBlock);
413   rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
414       cond, conditionVal, &afterBlock, condArgs, &mergeBlock, llvm::None);
415 
416   // Convert the scf.yield op to a branch back to the header block.
417   rewriter.setInsertionPointToEnd(&afterBlock);
418   rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, yieldArgs);
419 
420   rewriter.replaceOp(whileOp, resultValues);
421   return success();
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // Hooks
426 //===----------------------------------------------------------------------===//
427 
428 void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
429                                       ScfToSPIRVContext &scfToSPIRVContext,
430                                       RewritePatternSet &patterns) {
431   patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
432                WhileOpConversion>(patterns.getContext(), typeConverter,
433                                   scfToSPIRVContext.getImpl());
434 }
435