1 //===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the conversion patterns from SCF ops to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/IR/Module.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 struct ScfToSPIRVContextImpl {
23   // Map between the spirv region control flow operation (spv.loop or
24   // spv.selection) to the VariableOp created to store the region results. The
25   // order of the VariableOp matches the order of the results.
26   DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
27 };
28 } // namespace mlir
29 
30 /// We use ScfToSPIRVContext to store information about the lowering of the scf
31 /// region that need to be used later on. When we lower scf.for/scf.if we create
32 /// VariableOp to store the results. We need to keep track of the VariableOp
33 /// created as we need to insert stores into them when lowering Yield. Those
34 /// StoreOp cannot be created earlier as they may use a different type than
35 /// yield operands.
36 ScfToSPIRVContext::ScfToSPIRVContext() {
37   impl = std::make_unique<ScfToSPIRVContextImpl>();
38 }
39 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
40 
41 namespace {
42 /// Common class for all vector to GPU patterns.
43 template <typename OpTy>
44 class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
45 public:
46   SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
47                           ScfToSPIRVContextImpl *scfToSPIRVContext)
48       : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
49         scfToSPIRVContext(scfToSPIRVContext) {}
50 
51 protected:
52   ScfToSPIRVContextImpl *scfToSPIRVContext;
53 };
54 
55 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
56 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
57 public:
58   using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
59 
60   LogicalResult
61   matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override;
63 };
64 
65 /// Pattern to convert a scf::IfOp within kernel functions into
66 /// spirv::SelectionOp.
67 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
68 public:
69   using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
70 
71   LogicalResult
72   matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
73                   ConversionPatternRewriter &rewriter) const override;
74 };
75 
76 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
77 public:
78   using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
79 
80   LogicalResult
81   matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
82                   ConversionPatternRewriter &rewriter) const override;
83 };
84 } // namespace
85 
86 /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
87 /// We create VariableOp to handle the results value of the control flow region.
88 /// spv.loop/spv.selection currently don't yield value. Right after the loop
89 /// we load the value from the allocation and use it as the SCF op result.
90 template <typename ScfOp, typename OpTy>
91 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
92                                   SPIRVTypeConverter &typeConverter,
93                                   ConversionPatternRewriter &rewriter,
94                                   ScfToSPIRVContextImpl *scfToSPIRVContext,
95                                   ArrayRef<Type> returnTypes) {
96 
97   Location loc = scfOp.getLoc();
98   auto &allocas = scfToSPIRVContext->outputVars[newOp];
99   SmallVector<Value, 8> resultValue;
100   for (Type convertedType : returnTypes) {
101     auto pointerType =
102         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
103     rewriter.setInsertionPoint(newOp);
104     auto alloc = rewriter.create<spirv::VariableOp>(
105         loc, pointerType, spirv::StorageClass::Function,
106         /*initializer=*/nullptr);
107     allocas.push_back(alloc);
108     rewriter.setInsertionPointAfter(newOp);
109     Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
110     resultValue.push_back(loadResult);
111   }
112   rewriter.replaceOp(scfOp, resultValue);
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // scf::ForOp.
117 //===----------------------------------------------------------------------===//
118 
119 LogicalResult
120 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
121                                  ConversionPatternRewriter &rewriter) const {
122   // scf::ForOp can be lowered to the structured control flow represented by
123   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
124   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
125   // single back edge from the continue to header block, and a single exit from
126   // header to merge.
127   scf::ForOpAdaptor forOperands(operands);
128   auto loc = forOp.getLoc();
129   auto loopControl = rewriter.getI32IntegerAttr(
130       static_cast<uint32_t>(spirv::LoopControl::None));
131   auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
132   loopOp.addEntryAndMergeBlock();
133 
134   OpBuilder::InsertionGuard guard(rewriter);
135   // Create the block for the header.
136   auto *header = new Block();
137   // Insert the header.
138   loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
139 
140   // Create the new induction variable to use.
141   BlockArgument newIndVar =
142       header->addArgument(forOperands.lowerBound().getType());
143   for (Value arg : forOperands.initArgs())
144     header->addArgument(arg.getType());
145   Block *body = forOp.getBody();
146 
147   // Apply signature conversion to the body of the forOp. It has a single block,
148   // with argument which is the induction variable. That has to be replaced with
149   // the new induction variable.
150   TypeConverter::SignatureConversion signatureConverter(
151       body->getNumArguments());
152   signatureConverter.remapInput(0, newIndVar);
153   for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
154     signatureConverter.remapInput(i, header->getArgument(i));
155   body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
156                                            signatureConverter);
157 
158   // Move the blocks from the forOp into the loopOp. This is the body of the
159   // loopOp.
160   rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
161                               std::next(loopOp.body().begin(), 2));
162 
163   SmallVector<Value, 8> args(1, forOperands.lowerBound());
164   args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
165   // Branch into it from the entry.
166   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
167   rewriter.create<spirv::BranchOp>(loc, header, args);
168 
169   // Generate the rest of the loop header.
170   rewriter.setInsertionPointToEnd(header);
171   auto *mergeBlock = loopOp.getMergeBlock();
172   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
173       loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
174 
175   rewriter.create<spirv::BranchConditionalOp>(
176       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
177 
178   // Generate instructions to increment the step of the induction variable and
179   // branch to the header.
180   Block *continueBlock = loopOp.getContinueBlock();
181   rewriter.setInsertionPointToEnd(continueBlock);
182 
183   // Add the step to the induction variable and branch to the header.
184   Value updatedIndVar = rewriter.create<spirv::IAddOp>(
185       loc, newIndVar.getType(), newIndVar, forOperands.step());
186   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
187 
188   // Infer the return types from the init operands. Vector type may get
189   // converted to CooperativeMatrix or to Vector type, to avoid having complex
190   // extra logic to figure out the right type we just infer it from the Init
191   // operands.
192   SmallVector<Type, 8> initTypes;
193   for (auto arg : forOperands.initArgs())
194     initTypes.push_back(arg.getType());
195   replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
196                         scfToSPIRVContext, initTypes);
197   return success();
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // scf::IfOp.
202 //===----------------------------------------------------------------------===//
203 
204 LogicalResult
205 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
206                                 ConversionPatternRewriter &rewriter) const {
207   // When lowering `scf::IfOp` we explicitly create a selection header block
208   // before the control flow diverges and a merge block where control flow
209   // subsequently converges.
210   scf::IfOpAdaptor ifOperands(operands);
211   auto loc = ifOp.getLoc();
212 
213   // Create `spv.selection` operation, selection header block and merge block.
214   auto selectionControl = rewriter.getI32IntegerAttr(
215       static_cast<uint32_t>(spirv::SelectionControl::None));
216   auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
217   selectionOp.addMergeBlock();
218   auto *mergeBlock = selectionOp.getMergeBlock();
219 
220   OpBuilder::InsertionGuard guard(rewriter);
221   auto *selectionHeaderBlock = new Block();
222   selectionOp.body().getBlocks().push_front(selectionHeaderBlock);
223 
224   // Inline `then` region before the merge block and branch to it.
225   auto &thenRegion = ifOp.thenRegion();
226   auto *thenBlock = &thenRegion.front();
227   rewriter.setInsertionPointToEnd(&thenRegion.back());
228   rewriter.create<spirv::BranchOp>(loc, mergeBlock);
229   rewriter.inlineRegionBefore(thenRegion, mergeBlock);
230 
231   auto *elseBlock = mergeBlock;
232   // If `else` region is not empty, inline that region before the merge block
233   // and branch to it.
234   if (!ifOp.elseRegion().empty()) {
235     auto &elseRegion = ifOp.elseRegion();
236     elseBlock = &elseRegion.front();
237     rewriter.setInsertionPointToEnd(&elseRegion.back());
238     rewriter.create<spirv::BranchOp>(loc, mergeBlock);
239     rewriter.inlineRegionBefore(elseRegion, mergeBlock);
240   }
241 
242   // Create a `spv.BranchConditional` operation for selection header block.
243   rewriter.setInsertionPointToEnd(selectionHeaderBlock);
244   rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
245                                               thenBlock, ArrayRef<Value>(),
246                                               elseBlock, ArrayRef<Value>());
247 
248   SmallVector<Type, 8> returnTypes;
249   for (auto result : ifOp.results()) {
250     auto convertedType = typeConverter.convertType(result.getType());
251     returnTypes.push_back(convertedType);
252   }
253   replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
254                         scfToSPIRVContext, returnTypes);
255   return success();
256 }
257 
258 /// Yield is lowered to stores to the VariableOp created during lowering of the
259 /// parent region. For loops we also need to update the branch looping back to
260 /// the header with the loop carried values.
261 LogicalResult TerminatorOpConversion::matchAndRewrite(
262     scf::YieldOp terminatorOp, ArrayRef<Value> operands,
263     ConversionPatternRewriter &rewriter) const {
264   // If the region is return values, store each value into the associated
265   // VariableOp created during lowering of the parent region.
266   if (!operands.empty()) {
267     auto loc = terminatorOp.getLoc();
268     auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()];
269     assert(allocas.size() == operands.size());
270     for (unsigned i = 0, e = operands.size(); i < e; i++)
271       rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
272     if (isa<spirv::LoopOp>(terminatorOp.getParentOp())) {
273       // For loops we also need to update the branch jumping back to the header.
274       auto br =
275           cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
276       SmallVector<Value, 8> args(br.getBlockArguments());
277       args.append(operands.begin(), operands.end());
278       rewriter.setInsertionPoint(br);
279       rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
280                                        args);
281       rewriter.eraseOp(br);
282     }
283   }
284   rewriter.eraseOp(terminatorOp);
285   return success();
286 }
287 
288 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
289                                       SPIRVTypeConverter &typeConverter,
290                                       ScfToSPIRVContext &scfToSPIRVContext,
291                                       OwningRewritePatternList &patterns) {
292   patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
293       context, typeConverter, scfToSPIRVContext.getImpl());
294 }
295