1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SCF dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinOps.h"
19 
20 using namespace mlir;
21 
22 namespace mlir {
23 struct ScfToSPIRVContextImpl {
24   // Map between the spirv region control flow operation (spv.loop or
25   // spv.selection) to the VariableOp created to store the region results. The
26   // order of the VariableOp matches the order of the results.
27   DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
28 };
29 } // namespace mlir
30 
31 /// We use ScfToSPIRVContext to store information about the lowering of the scf
32 /// region that need to be used later on. When we lower scf.for/scf.if we create
33 /// VariableOp to store the results. We need to keep track of the VariableOp
34 /// created as we need to insert stores into them when lowering Yield. Those
35 /// StoreOp cannot be created earlier as they may use a different type than
36 /// yield operands.
37 ScfToSPIRVContext::ScfToSPIRVContext() {
38   impl = std::make_unique<ScfToSPIRVContextImpl>();
39 }
40 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
41 
42 namespace {
43 /// Common class for all vector to GPU patterns.
44 template <typename OpTy>
45 class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
46 public:
47   SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
48                           ScfToSPIRVContextImpl *scfToSPIRVContext)
49       : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
50         scfToSPIRVContext(scfToSPIRVContext) {}
51 
52 protected:
53   ScfToSPIRVContextImpl *scfToSPIRVContext;
54 };
55 
56 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
57 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
58 public:
59   using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
60 
61   LogicalResult
62   matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
63                   ConversionPatternRewriter &rewriter) const override;
64 };
65 
66 /// Pattern to convert a scf::IfOp within kernel functions into
67 /// spirv::SelectionOp.
68 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
69 public:
70   using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
71 
72   LogicalResult
73   matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
74                   ConversionPatternRewriter &rewriter) const override;
75 };
76 
77 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
78 public:
79   using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
80 
81   LogicalResult
82   matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
83                   ConversionPatternRewriter &rewriter) const override;
84 };
85 } // namespace
86 
87 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
88 /// We create VariableOp to handle the results value of the control flow region.
89 /// spv.loop/spv.selection currently don't yield value. Right after the loop
90 /// we load the value from the allocation and use it as the SCF op result.
91 template <typename ScfOp, typename OpTy>
92 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
93                                   SPIRVTypeConverter &typeConverter,
94                                   ConversionPatternRewriter &rewriter,
95                                   ScfToSPIRVContextImpl *scfToSPIRVContext,
96                                   ArrayRef<Type> returnTypes) {
97 
98   Location loc = scfOp.getLoc();
99   auto &allocas = scfToSPIRVContext->outputVars[newOp];
100   // Clearing the allocas is necessary in case a dialect conversion path failed
101   // previously, and this is the second attempt of this conversion.
102   allocas.clear();
103   SmallVector<Value, 8> resultValue;
104   for (Type convertedType : returnTypes) {
105     auto pointerType =
106         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
107     rewriter.setInsertionPoint(newOp);
108     auto alloc = rewriter.create<spirv::VariableOp>(
109         loc, pointerType, spirv::StorageClass::Function,
110         /*initializer=*/nullptr);
111     allocas.push_back(alloc);
112     rewriter.setInsertionPointAfter(newOp);
113     Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
114     resultValue.push_back(loadResult);
115   }
116   rewriter.replaceOp(scfOp, resultValue);
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // scf::ForOp.
121 //===----------------------------------------------------------------------===//
122 
123 LogicalResult
124 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
125                                  ConversionPatternRewriter &rewriter) const {
126   // scf::ForOp can be lowered to the structured control flow represented by
127   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
128   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
129   // single back edge from the continue to header block, and a single exit from
130   // header to merge.
131   scf::ForOpAdaptor forOperands(operands);
132   auto loc = forOp.getLoc();
133   auto loopControl = rewriter.getI32IntegerAttr(
134       static_cast<uint32_t>(spirv::LoopControl::None));
135   auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
136   loopOp.addEntryAndMergeBlock();
137 
138   OpBuilder::InsertionGuard guard(rewriter);
139   // Create the block for the header.
140   auto *header = new Block();
141   // Insert the header.
142   loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
143 
144   // Create the new induction variable to use.
145   BlockArgument newIndVar =
146       header->addArgument(forOperands.lowerBound().getType());
147   for (Value arg : forOperands.initArgs())
148     header->addArgument(arg.getType());
149   Block *body = forOp.getBody();
150 
151   // Apply signature conversion to the body of the forOp. It has a single block,
152   // with argument which is the induction variable. That has to be replaced with
153   // the new induction variable.
154   TypeConverter::SignatureConversion signatureConverter(
155       body->getNumArguments());
156   signatureConverter.remapInput(0, newIndVar);
157   for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
158     signatureConverter.remapInput(i, header->getArgument(i));
159   body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
160                                            signatureConverter);
161 
162   // Move the blocks from the forOp into the loopOp. This is the body of the
163   // loopOp.
164   rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
165                               std::next(loopOp.body().begin(), 2));
166 
167   SmallVector<Value, 8> args(1, forOperands.lowerBound());
168   args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
169   // Branch into it from the entry.
170   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
171   rewriter.create<spirv::BranchOp>(loc, header, args);
172 
173   // Generate the rest of the loop header.
174   rewriter.setInsertionPointToEnd(header);
175   auto *mergeBlock = loopOp.getMergeBlock();
176   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
177       loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
178 
179   rewriter.create<spirv::BranchConditionalOp>(
180       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
181 
182   // Generate instructions to increment the step of the induction variable and
183   // branch to the header.
184   Block *continueBlock = loopOp.getContinueBlock();
185   rewriter.setInsertionPointToEnd(continueBlock);
186 
187   // Add the step to the induction variable and branch to the header.
188   Value updatedIndVar = rewriter.create<spirv::IAddOp>(
189       loc, newIndVar.getType(), newIndVar, forOperands.step());
190   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
191 
192   // Infer the return types from the init operands. Vector type may get
193   // converted to CooperativeMatrix or to Vector type, to avoid having complex
194   // extra logic to figure out the right type we just infer it from the Init
195   // operands.
196   SmallVector<Type, 8> initTypes;
197   for (auto arg : forOperands.initArgs())
198     initTypes.push_back(arg.getType());
199   replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
200                         scfToSPIRVContext, initTypes);
201   return success();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // scf::IfOp.
206 //===----------------------------------------------------------------------===//
207 
208 LogicalResult
209 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
210                                 ConversionPatternRewriter &rewriter) const {
211   // When lowering `scf::IfOp` we explicitly create a selection header block
212   // before the control flow diverges and a merge block where control flow
213   // subsequently converges.
214   scf::IfOpAdaptor ifOperands(operands);
215   auto loc = ifOp.getLoc();
216 
217   // Create `spv.selection` operation, selection header block and merge block.
218   auto selectionControl = rewriter.getI32IntegerAttr(
219       static_cast<uint32_t>(spirv::SelectionControl::None));
220   auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
221   auto *mergeBlock =
222       rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
223   rewriter.create<spirv::MergeOp>(loc);
224 
225   OpBuilder::InsertionGuard guard(rewriter);
226   auto *selectionHeaderBlock =
227       rewriter.createBlock(&selectionOp.body().front());
228 
229   // Inline `then` region before the merge block and branch to it.
230   auto &thenRegion = ifOp.thenRegion();
231   auto *thenBlock = &thenRegion.front();
232   rewriter.setInsertionPointToEnd(&thenRegion.back());
233   rewriter.create<spirv::BranchOp>(loc, mergeBlock);
234   rewriter.inlineRegionBefore(thenRegion, mergeBlock);
235 
236   auto *elseBlock = mergeBlock;
237   // If `else` region is not empty, inline that region before the merge block
238   // and branch to it.
239   if (!ifOp.elseRegion().empty()) {
240     auto &elseRegion = ifOp.elseRegion();
241     elseBlock = &elseRegion.front();
242     rewriter.setInsertionPointToEnd(&elseRegion.back());
243     rewriter.create<spirv::BranchOp>(loc, mergeBlock);
244     rewriter.inlineRegionBefore(elseRegion, mergeBlock);
245   }
246 
247   // Create a `spv.BranchConditional` operation for selection header block.
248   rewriter.setInsertionPointToEnd(selectionHeaderBlock);
249   rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
250                                               thenBlock, ArrayRef<Value>(),
251                                               elseBlock, ArrayRef<Value>());
252 
253   SmallVector<Type, 8> returnTypes;
254   for (auto result : ifOp.results()) {
255     auto convertedType = typeConverter.convertType(result.getType());
256     returnTypes.push_back(convertedType);
257   }
258   replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
259                         scfToSPIRVContext, returnTypes);
260   return success();
261 }
262 
263 /// Yield is lowered to stores to the VariableOp created during lowering of the
264 /// parent region. For loops we also need to update the branch looping back to
265 /// the header with the loop carried values.
266 LogicalResult TerminatorOpConversion::matchAndRewrite(
267     scf::YieldOp terminatorOp, ArrayRef<Value> operands,
268     ConversionPatternRewriter &rewriter) const {
269   // If the region is return values, store each value into the associated
270   // VariableOp created during lowering of the parent region.
271   if (!operands.empty()) {
272     auto loc = terminatorOp.getLoc();
273     auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
274     assert(allocas.size() == operands.size());
275     for (unsigned i = 0, e = operands.size(); i < e; i++)
276       rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
277     if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
278       // For loops we also need to update the branch jumping back to the header.
279       auto br =
280           cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
281       SmallVector<Value, 8> args(br.getBlockArguments());
282       args.append(operands.begin(), operands.end());
283       rewriter.setInsertionPoint(br);
284       rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
285                                        args);
286       rewriter.eraseOp(br);
287     }
288   }
289   rewriter.eraseOp(terminatorOp);
290   return success();
291 }
292 
293 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
294                                       SPIRVTypeConverter &typeConverter,
295                                       ScfToSPIRVContext &scfToSPIRVContext,
296                                       OwningRewritePatternList &patterns) {
297   patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
298       context, typeConverter, scfToSPIRVContext.getImpl());
299 }
300