1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SCF dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Context
25 //===----------------------------------------------------------------------===//
26 
27 namespace mlir {
28 struct ScfToSPIRVContextImpl {
29   // Map between the spirv region control flow operation (spv.mlir.loop or
30   // spv.mlir.selection) to the VariableOp created to store the region results.
31   // The order of the VariableOp matches the order of the results.
32   DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
33 };
34 } // namespace mlir
35 
36 /// We use ScfToSPIRVContext to store information about the lowering of the scf
37 /// region that need to be used later on. When we lower scf.for/scf.if we create
38 /// VariableOp to store the results. We need to keep track of the VariableOp
39 /// created as we need to insert stores into them when lowering Yield. Those
40 /// StoreOp cannot be created earlier as they may use a different type than
41 /// yield operands.
42 ScfToSPIRVContext::ScfToSPIRVContext() {
43   impl = std::make_unique<ScfToSPIRVContextImpl>();
44 }
45 
46 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
47 
48 //===----------------------------------------------------------------------===//
49 // Pattern Declarations
50 //===----------------------------------------------------------------------===//
51 
52 namespace {
53 /// Common class for all vector to GPU patterns.
54 template <typename OpTy>
55 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
56 public:
57   SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
58                           ScfToSPIRVContextImpl *scfToSPIRVContext)
59       : OpConversionPattern<OpTy>::OpConversionPattern(context),
60         scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
61 
62 protected:
63   ScfToSPIRVContextImpl *scfToSPIRVContext;
64   // FIXME: We explicitly keep a reference of the type converter here instead of
65   // passing it to OpConversionPattern during construction. This effectively
66   // bypasses the conversion framework's automation on type conversion. This is
67   // needed right now because the conversion framework will unconditionally
68   // legalize all types used by SCF ops upon discovering them, for example, the
69   // types of loop carried values. We use SPIR-V variables for those loop
70   // carried values. Depending on the available capabilities, the SPIR-V
71   // variable can be different, for example, cooperative matrix or normal
72   // variable. We'd like to detach the conversion of the loop carried values
73   // from the SCF ops (which is mainly a region). So we need to "mark" types
74   // used by SCF ops as legal, if to use the conversion framework for type
75   // conversion. There isn't a straightforward way to do that yet, as when
76   // converting types, ops aren't taken into consideration. Therefore, we just
77   // bypass the framework's type conversion for now.
78   SPIRVTypeConverter &typeConverter;
79 };
80 
81 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
82 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
83 public:
84   using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
85 
86   LogicalResult
87   matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
88                   ConversionPatternRewriter &rewriter) const override;
89 };
90 
91 /// Pattern to convert a scf::IfOp within kernel functions into
92 /// spirv::SelectionOp.
93 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
94 public:
95   using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
96 
97   LogicalResult
98   matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
99                   ConversionPatternRewriter &rewriter) const override;
100 };
101 
102 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
103 public:
104   using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
105 
106   LogicalResult
107   matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
108                   ConversionPatternRewriter &rewriter) const override;
109 };
110 } // namespace
111 
112 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
113 /// We create VariableOp to handle the results value of the control flow region.
114 /// spv.mlir.loop/spv.mlir.selection currently don't yield value. Right after
115 /// the loop we load the value from the allocation and use it as the SCF op
116 /// result.
117 template <typename ScfOp, typename OpTy>
118 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
119                                   ConversionPatternRewriter &rewriter,
120                                   ScfToSPIRVContextImpl *scfToSPIRVContext,
121                                   ArrayRef<Type> returnTypes) {
122 
123   Location loc = scfOp.getLoc();
124   auto &allocas = scfToSPIRVContext->outputVars[newOp];
125   // Clearing the allocas is necessary in case a dialect conversion path failed
126   // previously, and this is the second attempt of this conversion.
127   allocas.clear();
128   SmallVector<Value, 8> resultValue;
129   for (Type convertedType : returnTypes) {
130     auto pointerType =
131         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
132     rewriter.setInsertionPoint(newOp);
133     auto alloc = rewriter.create<spirv::VariableOp>(
134         loc, pointerType, spirv::StorageClass::Function,
135         /*initializer=*/nullptr);
136     allocas.push_back(alloc);
137     rewriter.setInsertionPointAfter(newOp);
138     Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
139     resultValue.push_back(loadResult);
140   }
141   rewriter.replaceOp(scfOp, resultValue);
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // scf::ForOp
146 //===----------------------------------------------------------------------===//
147 
148 LogicalResult
149 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
150                                  ConversionPatternRewriter &rewriter) const {
151   // scf::ForOp can be lowered to the structured control flow represented by
152   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
153   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
154   // single back edge from the continue to header block, and a single exit from
155   // header to merge.
156   scf::ForOpAdaptor forOperands(operands);
157   auto loc = forOp.getLoc();
158   auto loopControl = rewriter.getI32IntegerAttr(
159       static_cast<uint32_t>(spirv::LoopControl::None));
160   auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
161   loopOp.addEntryAndMergeBlock();
162 
163   OpBuilder::InsertionGuard guard(rewriter);
164   // Create the block for the header.
165   auto *header = new Block();
166   // Insert the header.
167   loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
168 
169   // Create the new induction variable to use.
170   BlockArgument newIndVar =
171       header->addArgument(forOperands.lowerBound().getType());
172   for (Value arg : forOperands.initArgs())
173     header->addArgument(arg.getType());
174   Block *body = forOp.getBody();
175 
176   // Apply signature conversion to the body of the forOp. It has a single block,
177   // with argument which is the induction variable. That has to be replaced with
178   // the new induction variable.
179   TypeConverter::SignatureConversion signatureConverter(
180       body->getNumArguments());
181   signatureConverter.remapInput(0, newIndVar);
182   for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
183     signatureConverter.remapInput(i, header->getArgument(i));
184   body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
185                                            signatureConverter);
186 
187   // Move the blocks from the forOp into the loopOp. This is the body of the
188   // loopOp.
189   rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
190                               std::next(loopOp.body().begin(), 2));
191 
192   SmallVector<Value, 8> args(1, forOperands.lowerBound());
193   args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
194   // Branch into it from the entry.
195   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
196   rewriter.create<spirv::BranchOp>(loc, header, args);
197 
198   // Generate the rest of the loop header.
199   rewriter.setInsertionPointToEnd(header);
200   auto *mergeBlock = loopOp.getMergeBlock();
201   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
202       loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
203 
204   rewriter.create<spirv::BranchConditionalOp>(
205       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
206 
207   // Generate instructions to increment the step of the induction variable and
208   // branch to the header.
209   Block *continueBlock = loopOp.getContinueBlock();
210   rewriter.setInsertionPointToEnd(continueBlock);
211 
212   // Add the step to the induction variable and branch to the header.
213   Value updatedIndVar = rewriter.create<spirv::IAddOp>(
214       loc, newIndVar.getType(), newIndVar, forOperands.step());
215   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
216 
217   // Infer the return types from the init operands. Vector type may get
218   // converted to CooperativeMatrix or to Vector type, to avoid having complex
219   // extra logic to figure out the right type we just infer it from the Init
220   // operands.
221   SmallVector<Type, 8> initTypes;
222   for (auto arg : forOperands.initArgs())
223     initTypes.push_back(arg.getType());
224   replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
225   return success();
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // scf::IfOp
230 //===----------------------------------------------------------------------===//
231 
232 LogicalResult
233 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
234                                 ConversionPatternRewriter &rewriter) const {
235   // When lowering `scf::IfOp` we explicitly create a selection header block
236   // before the control flow diverges and a merge block where control flow
237   // subsequently converges.
238   scf::IfOpAdaptor ifOperands(operands);
239   auto loc = ifOp.getLoc();
240 
241   // Create `spv.mlir.selection` operation, selection header block and merge
242   // block.
243   auto selectionControl = rewriter.getI32IntegerAttr(
244       static_cast<uint32_t>(spirv::SelectionControl::None));
245   auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
246   auto *mergeBlock =
247       rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
248   rewriter.create<spirv::MergeOp>(loc);
249 
250   OpBuilder::InsertionGuard guard(rewriter);
251   auto *selectionHeaderBlock =
252       rewriter.createBlock(&selectionOp.body().front());
253 
254   // Inline `then` region before the merge block and branch to it.
255   auto &thenRegion = ifOp.thenRegion();
256   auto *thenBlock = &thenRegion.front();
257   rewriter.setInsertionPointToEnd(&thenRegion.back());
258   rewriter.create<spirv::BranchOp>(loc, mergeBlock);
259   rewriter.inlineRegionBefore(thenRegion, mergeBlock);
260 
261   auto *elseBlock = mergeBlock;
262   // If `else` region is not empty, inline that region before the merge block
263   // and branch to it.
264   if (!ifOp.elseRegion().empty()) {
265     auto &elseRegion = ifOp.elseRegion();
266     elseBlock = &elseRegion.front();
267     rewriter.setInsertionPointToEnd(&elseRegion.back());
268     rewriter.create<spirv::BranchOp>(loc, mergeBlock);
269     rewriter.inlineRegionBefore(elseRegion, mergeBlock);
270   }
271 
272   // Create a `spv.BranchConditional` operation for selection header block.
273   rewriter.setInsertionPointToEnd(selectionHeaderBlock);
274   rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
275                                               thenBlock, ArrayRef<Value>(),
276                                               elseBlock, ArrayRef<Value>());
277 
278   SmallVector<Type, 8> returnTypes;
279   for (auto result : ifOp.results()) {
280     auto convertedType = typeConverter.convertType(result.getType());
281     returnTypes.push_back(convertedType);
282   }
283   replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
284                         returnTypes);
285   return success();
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // scf::YieldOp
290 //===----------------------------------------------------------------------===//
291 
292 /// Yield is lowered to stores to the VariableOp created during lowering of the
293 /// parent region. For loops we also need to update the branch looping back to
294 /// the header with the loop carried values.
295 LogicalResult TerminatorOpConversion::matchAndRewrite(
296     scf::YieldOp terminatorOp, ArrayRef<Value> operands,
297     ConversionPatternRewriter &rewriter) const {
298   // If the region is return values, store each value into the associated
299   // VariableOp created during lowering of the parent region.
300   if (!operands.empty()) {
301     auto loc = terminatorOp.getLoc();
302     auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
303     assert(allocas.size() == operands.size());
304     for (unsigned i = 0, e = operands.size(); i < e; i++)
305       rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
306     if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
307       // For loops we also need to update the branch jumping back to the header.
308       auto br =
309           cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
310       SmallVector<Value, 8> args(br.getBlockArguments());
311       args.append(operands.begin(), operands.end());
312       rewriter.setInsertionPoint(br);
313       rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
314                                        args);
315       rewriter.eraseOp(br);
316     }
317   }
318   rewriter.eraseOp(terminatorOp);
319   return success();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // Hooks
324 //===----------------------------------------------------------------------===//
325 
326 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
327                                       SPIRVTypeConverter &typeConverter,
328                                       ScfToSPIRVContext &scfToSPIRVContext,
329                                       OwningRewritePatternList &patterns) {
330   patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
331       context, typeConverter, scfToSPIRVContext.getImpl());
332 }
333