1930c74f1SLei Zhang //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2fbce9855SThomas Raoux //
3fbce9855SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fbce9855SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5fbce9855SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fbce9855SThomas Raoux //
7fbce9855SThomas Raoux //===----------------------------------------------------------------------===//
8fbce9855SThomas Raoux //
9930c74f1SLei Zhang // This file implements patterns to convert SCF dialect to SPIR-V dialect.
10fbce9855SThomas Raoux //
11fbce9855SThomas Raoux //===----------------------------------------------------------------------===//
12930c74f1SLei Zhang 
13fbce9855SThomas Raoux #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1501178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1701178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
1865fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
197c3ae48fSLei Zhang #include "mlir/Transforms/DialectConversion.h"
20fbce9855SThomas Raoux 
21fbce9855SThomas Raoux using namespace mlir;
22fbce9855SThomas Raoux 
237c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
247c3ae48fSLei Zhang // Context
257c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
267c3ae48fSLei Zhang 
270670f855SThomas Raoux namespace mlir {
280670f855SThomas Raoux struct ScfToSPIRVContextImpl {
2929812a61SKareemErgawy-TomTom   // Map between the spirv region control flow operation (spv.mlir.loop or
303fb384d5SKareemErgawy-TomTom   // spv.mlir.selection) to the VariableOp created to store the region results.
313fb384d5SKareemErgawy-TomTom   // The order of the VariableOp matches the order of the results.
320670f855SThomas Raoux   DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
330670f855SThomas Raoux };
340670f855SThomas Raoux } // namespace mlir
350670f855SThomas Raoux 
360670f855SThomas Raoux /// We use ScfToSPIRVContext to store information about the lowering of the scf
370670f855SThomas Raoux /// region that need to be used later on. When we lower scf.for/scf.if we create
380670f855SThomas Raoux /// VariableOp to store the results. We need to keep track of the VariableOp
390670f855SThomas Raoux /// created as we need to insert stores into them when lowering Yield. Those
400670f855SThomas Raoux /// StoreOp cannot be created earlier as they may use a different type than
410670f855SThomas Raoux /// yield operands.
ScfToSPIRVContext()420670f855SThomas Raoux ScfToSPIRVContext::ScfToSPIRVContext() {
430670f855SThomas Raoux   impl = std::make_unique<ScfToSPIRVContextImpl>();
440670f855SThomas Raoux }
457c3ae48fSLei Zhang 
460670f855SThomas Raoux ScfToSPIRVContext::~ScfToSPIRVContext() = default;
470670f855SThomas Raoux 
487c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
497c3ae48fSLei Zhang // Pattern Declarations
507c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
517c3ae48fSLei Zhang 
52fbce9855SThomas Raoux namespace {
530670f855SThomas Raoux /// Common class for all vector to GPU patterns.
540670f855SThomas Raoux template <typename OpTy>
557c3ae48fSLei Zhang class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
560670f855SThomas Raoux public:
SCFToSPIRVPattern(MLIRContext * context,SPIRVTypeConverter & converter,ScfToSPIRVContextImpl * scfToSPIRVContext)570670f855SThomas Raoux   SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
580670f855SThomas Raoux                           ScfToSPIRVContextImpl *scfToSPIRVContext)
59015192c6SRiver Riddle       : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
607c3ae48fSLei Zhang         scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
610670f855SThomas Raoux 
620670f855SThomas Raoux protected:
630670f855SThomas Raoux   ScfToSPIRVContextImpl *scfToSPIRVContext;
647c3ae48fSLei Zhang   // FIXME: We explicitly keep a reference of the type converter here instead of
657c3ae48fSLei Zhang   // passing it to OpConversionPattern during construction. This effectively
667c3ae48fSLei Zhang   // bypasses the conversion framework's automation on type conversion. This is
677c3ae48fSLei Zhang   // needed right now because the conversion framework will unconditionally
687c3ae48fSLei Zhang   // legalize all types used by SCF ops upon discovering them, for example, the
697c3ae48fSLei Zhang   // types of loop carried values. We use SPIR-V variables for those loop
707c3ae48fSLei Zhang   // carried values. Depending on the available capabilities, the SPIR-V
717c3ae48fSLei Zhang   // variable can be different, for example, cooperative matrix or normal
727c3ae48fSLei Zhang   // variable. We'd like to detach the conversion of the loop carried values
737c3ae48fSLei Zhang   // from the SCF ops (which is mainly a region). So we need to "mark" types
747c3ae48fSLei Zhang   // used by SCF ops as legal, if to use the conversion framework for type
757c3ae48fSLei Zhang   // conversion. There isn't a straightforward way to do that yet, as when
767c3ae48fSLei Zhang   // converting types, ops aren't taken into consideration. Therefore, we just
777c3ae48fSLei Zhang   // bypass the framework's type conversion for now.
787c3ae48fSLei Zhang   SPIRVTypeConverter &typeConverter;
790670f855SThomas Raoux };
80fbce9855SThomas Raoux 
81fbce9855SThomas Raoux /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
820670f855SThomas Raoux class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
83fbce9855SThomas Raoux public:
840670f855SThomas Raoux   using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
85fbce9855SThomas Raoux 
86fbce9855SThomas Raoux   LogicalResult
87b54c724bSRiver Riddle   matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
88fbce9855SThomas Raoux                   ConversionPatternRewriter &rewriter) const override;
89fbce9855SThomas Raoux };
90fbce9855SThomas Raoux 
91fbce9855SThomas Raoux /// Pattern to convert a scf::IfOp within kernel functions into
92fbce9855SThomas Raoux /// spirv::SelectionOp.
930670f855SThomas Raoux class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
94fbce9855SThomas Raoux public:
950670f855SThomas Raoux   using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
96fbce9855SThomas Raoux 
97fbce9855SThomas Raoux   LogicalResult
98b54c724bSRiver Riddle   matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
99fbce9855SThomas Raoux                   ConversionPatternRewriter &rewriter) const override;
100fbce9855SThomas Raoux };
101fbce9855SThomas Raoux 
1020670f855SThomas Raoux class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
103fbce9855SThomas Raoux public:
1040670f855SThomas Raoux   using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
105fbce9855SThomas Raoux 
106fbce9855SThomas Raoux   LogicalResult
107b54c724bSRiver Riddle   matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
1080670f855SThomas Raoux                   ConversionPatternRewriter &rewriter) const override;
109fbce9855SThomas Raoux };
110526b71e4SButygin 
111526b71e4SButygin class WhileOpConversion final : public SCFToSPIRVPattern<scf::WhileOp> {
112526b71e4SButygin public:
113526b71e4SButygin   using SCFToSPIRVPattern<scf::WhileOp>::SCFToSPIRVPattern;
114526b71e4SButygin 
115526b71e4SButygin   LogicalResult
116526b71e4SButygin   matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor,
117526b71e4SButygin                   ConversionPatternRewriter &rewriter) const override;
118526b71e4SButygin };
119fbce9855SThomas Raoux } // namespace
120fbce9855SThomas Raoux 
1210670f855SThomas Raoux /// Helper function to replaces SCF op outputs with SPIR-V variable loads.
1220670f855SThomas Raoux /// We create VariableOp to handle the results value of the control flow region.
1233fb384d5SKareemErgawy-TomTom /// spv.mlir.loop/spv.mlir.selection currently don't yield value. Right after
1243fb384d5SKareemErgawy-TomTom /// the loop we load the value from the allocation and use it as the SCF op
1253fb384d5SKareemErgawy-TomTom /// result.
1260670f855SThomas Raoux template <typename ScfOp, typename OpTy>
replaceSCFOutputValue(ScfOp scfOp,OpTy newOp,ConversionPatternRewriter & rewriter,ScfToSPIRVContextImpl * scfToSPIRVContext,ArrayRef<Type> returnTypes)1270670f855SThomas Raoux static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
1280670f855SThomas Raoux                                   ConversionPatternRewriter &rewriter,
1296a3c69e9SThomas Raoux                                   ScfToSPIRVContextImpl *scfToSPIRVContext,
1306a3c69e9SThomas Raoux                                   ArrayRef<Type> returnTypes) {
1310670f855SThomas Raoux 
1320670f855SThomas Raoux   Location loc = scfOp.getLoc();
1330670f855SThomas Raoux   auto &allocas = scfToSPIRVContext->outputVars[newOp];
13474170a3aSTres Popp   // Clearing the allocas is necessary in case a dialect conversion path failed
13574170a3aSTres Popp   // previously, and this is the second attempt of this conversion.
13674170a3aSTres Popp   allocas.clear();
1370670f855SThomas Raoux   SmallVector<Value, 8> resultValue;
1386a3c69e9SThomas Raoux   for (Type convertedType : returnTypes) {
1390670f855SThomas Raoux     auto pointerType =
1400670f855SThomas Raoux         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
1410670f855SThomas Raoux     rewriter.setInsertionPoint(newOp);
1420670f855SThomas Raoux     auto alloc = rewriter.create<spirv::VariableOp>(
1430670f855SThomas Raoux         loc, pointerType, spirv::StorageClass::Function,
1440670f855SThomas Raoux         /*initializer=*/nullptr);
1450670f855SThomas Raoux     allocas.push_back(alloc);
1460670f855SThomas Raoux     rewriter.setInsertionPointAfter(newOp);
1470670f855SThomas Raoux     Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
1480670f855SThomas Raoux     resultValue.push_back(loadResult);
1490670f855SThomas Raoux   }
1500670f855SThomas Raoux   rewriter.replaceOp(scfOp, resultValue);
1510670f855SThomas Raoux }
1520670f855SThomas Raoux 
getBlockIt(Region & region,unsigned index)153526b71e4SButygin static Region::iterator getBlockIt(Region &region, unsigned index) {
154526b71e4SButygin   return std::next(region.begin(), index);
155526b71e4SButygin }
156526b71e4SButygin 
157fbce9855SThomas Raoux //===----------------------------------------------------------------------===//
1587c3ae48fSLei Zhang // scf::ForOp
159fbce9855SThomas Raoux //===----------------------------------------------------------------------===//
160fbce9855SThomas Raoux 
161fbce9855SThomas Raoux LogicalResult
matchAndRewrite(scf::ForOp forOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const162b54c724bSRiver Riddle ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
163fbce9855SThomas Raoux                                  ConversionPatternRewriter &rewriter) const {
164fbce9855SThomas Raoux   // scf::ForOp can be lowered to the structured control flow represented by
165fbce9855SThomas Raoux   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
166fbce9855SThomas Raoux   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
167fbce9855SThomas Raoux   // single back edge from the continue to header block, and a single exit from
168fbce9855SThomas Raoux   // header to merge.
169fbce9855SThomas Raoux   auto loc = forOp.getLoc();
170fee90542SVladislav Vinogradov   auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
171fbce9855SThomas Raoux   loopOp.addEntryAndMergeBlock();
172fbce9855SThomas Raoux 
173fbce9855SThomas Raoux   OpBuilder::InsertionGuard guard(rewriter);
174fbce9855SThomas Raoux   // Create the block for the header.
175fbce9855SThomas Raoux   auto *header = new Block();
176fbce9855SThomas Raoux   // Insert the header.
177526b71e4SButygin   loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header);
178fbce9855SThomas Raoux 
179fbce9855SThomas Raoux   // Create the new induction variable to use.
180e084679fSRiver Riddle   Value adapLowerBound = adaptor.getLowerBound();
181c0342a2dSJacques Pienaar   BlockArgument newIndVar =
182e084679fSRiver Riddle       header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
183c0342a2dSJacques Pienaar   for (Value arg : adaptor.getInitArgs())
184e084679fSRiver Riddle     header->addArgument(arg.getType(), arg.getLoc());
185fbce9855SThomas Raoux   Block *body = forOp.getBody();
186fbce9855SThomas Raoux 
187fbce9855SThomas Raoux   // Apply signature conversion to the body of the forOp. It has a single block,
188fbce9855SThomas Raoux   // with argument which is the induction variable. That has to be replaced with
189fbce9855SThomas Raoux   // the new induction variable.
190fbce9855SThomas Raoux   TypeConverter::SignatureConversion signatureConverter(
191fbce9855SThomas Raoux       body->getNumArguments());
192fbce9855SThomas Raoux   signatureConverter.remapInput(0, newIndVar);
1930670f855SThomas Raoux   for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
1940670f855SThomas Raoux     signatureConverter.remapInput(i, header->getArgument(i));
1950670f855SThomas Raoux   body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
1960670f855SThomas Raoux                                            signatureConverter);
197fbce9855SThomas Raoux 
198fbce9855SThomas Raoux   // Move the blocks from the forOp into the loopOp. This is the body of the
199fbce9855SThomas Raoux   // loopOp.
200c4a04059SChristian Sigg   rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
201526b71e4SButygin                               getBlockIt(loopOp.body(), 2));
202fbce9855SThomas Raoux 
203c0342a2dSJacques Pienaar   SmallVector<Value, 8> args(1, adaptor.getLowerBound());
204c0342a2dSJacques Pienaar   args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
205fbce9855SThomas Raoux   // Branch into it from the entry.
206fbce9855SThomas Raoux   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
2070670f855SThomas Raoux   rewriter.create<spirv::BranchOp>(loc, header, args);
208fbce9855SThomas Raoux 
209fbce9855SThomas Raoux   // Generate the rest of the loop header.
210fbce9855SThomas Raoux   rewriter.setInsertionPointToEnd(header);
211fbce9855SThomas Raoux   auto *mergeBlock = loopOp.getMergeBlock();
212fbce9855SThomas Raoux   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
213c0342a2dSJacques Pienaar       loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
2140670f855SThomas Raoux 
215fbce9855SThomas Raoux   rewriter.create<spirv::BranchConditionalOp>(
216fbce9855SThomas Raoux       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
217fbce9855SThomas Raoux 
218fbce9855SThomas Raoux   // Generate instructions to increment the step of the induction variable and
219fbce9855SThomas Raoux   // branch to the header.
220fbce9855SThomas Raoux   Block *continueBlock = loopOp.getContinueBlock();
221fbce9855SThomas Raoux   rewriter.setInsertionPointToEnd(continueBlock);
222fbce9855SThomas Raoux 
223fbce9855SThomas Raoux   // Add the step to the induction variable and branch to the header.
224fbce9855SThomas Raoux   Value updatedIndVar = rewriter.create<spirv::IAddOp>(
225c0342a2dSJacques Pienaar       loc, newIndVar.getType(), newIndVar, adaptor.getStep());
226fbce9855SThomas Raoux   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
227fbce9855SThomas Raoux 
2286a3c69e9SThomas Raoux   // Infer the return types from the init operands. Vector type may get
2296a3c69e9SThomas Raoux   // converted to CooperativeMatrix or to Vector type, to avoid having complex
2306a3c69e9SThomas Raoux   // extra logic to figure out the right type we just infer it from the Init
2316a3c69e9SThomas Raoux   // operands.
2326a3c69e9SThomas Raoux   SmallVector<Type, 8> initTypes;
233c0342a2dSJacques Pienaar   for (auto arg : adaptor.getInitArgs())
2346a3c69e9SThomas Raoux     initTypes.push_back(arg.getType());
2357c3ae48fSLei Zhang   replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
236fbce9855SThomas Raoux   return success();
237fbce9855SThomas Raoux }
238fbce9855SThomas Raoux 
239fbce9855SThomas Raoux //===----------------------------------------------------------------------===//
2407c3ae48fSLei Zhang // scf::IfOp
241fbce9855SThomas Raoux //===----------------------------------------------------------------------===//
242fbce9855SThomas Raoux 
243fbce9855SThomas Raoux LogicalResult
matchAndRewrite(scf::IfOp ifOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const244b54c724bSRiver Riddle IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
245fbce9855SThomas Raoux                                 ConversionPatternRewriter &rewriter) const {
246fbce9855SThomas Raoux   // When lowering `scf::IfOp` we explicitly create a selection header block
247fbce9855SThomas Raoux   // before the control flow diverges and a merge block where control flow
248fbce9855SThomas Raoux   // subsequently converges.
249fbce9855SThomas Raoux   auto loc = ifOp.getLoc();
250fbce9855SThomas Raoux 
251fee90542SVladislav Vinogradov   // Create `spv.selection` operation, selection header block and merge block.
252fee90542SVladislav Vinogradov   auto selectionOp =
253fee90542SVladislav Vinogradov       rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
25474170a3aSTres Popp   auto *mergeBlock =
25574170a3aSTres Popp       rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
25674170a3aSTres Popp   rewriter.create<spirv::MergeOp>(loc);
257fbce9855SThomas Raoux 
258fbce9855SThomas Raoux   OpBuilder::InsertionGuard guard(rewriter);
25974170a3aSTres Popp   auto *selectionHeaderBlock =
26074170a3aSTres Popp       rewriter.createBlock(&selectionOp.body().front());
261fbce9855SThomas Raoux 
262fbce9855SThomas Raoux   // Inline `then` region before the merge block and branch to it.
263c0342a2dSJacques Pienaar   auto &thenRegion = ifOp.getThenRegion();
264fbce9855SThomas Raoux   auto *thenBlock = &thenRegion.front();
265fbce9855SThomas Raoux   rewriter.setInsertionPointToEnd(&thenRegion.back());
266fbce9855SThomas Raoux   rewriter.create<spirv::BranchOp>(loc, mergeBlock);
267fbce9855SThomas Raoux   rewriter.inlineRegionBefore(thenRegion, mergeBlock);
268fbce9855SThomas Raoux 
269fbce9855SThomas Raoux   auto *elseBlock = mergeBlock;
270fbce9855SThomas Raoux   // If `else` region is not empty, inline that region before the merge block
271fbce9855SThomas Raoux   // and branch to it.
272c0342a2dSJacques Pienaar   if (!ifOp.getElseRegion().empty()) {
273c0342a2dSJacques Pienaar     auto &elseRegion = ifOp.getElseRegion();
274fbce9855SThomas Raoux     elseBlock = &elseRegion.front();
275fbce9855SThomas Raoux     rewriter.setInsertionPointToEnd(&elseRegion.back());
276fbce9855SThomas Raoux     rewriter.create<spirv::BranchOp>(loc, mergeBlock);
277fbce9855SThomas Raoux     rewriter.inlineRegionBefore(elseRegion, mergeBlock);
278fbce9855SThomas Raoux   }
279fbce9855SThomas Raoux 
280fbce9855SThomas Raoux   // Create a `spv.BranchConditional` operation for selection header block.
281fbce9855SThomas Raoux   rewriter.setInsertionPointToEnd(selectionHeaderBlock);
282c0342a2dSJacques Pienaar   rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
283fbce9855SThomas Raoux                                               thenBlock, ArrayRef<Value>(),
284fbce9855SThomas Raoux                                               elseBlock, ArrayRef<Value>());
285fbce9855SThomas Raoux 
2866a3c69e9SThomas Raoux   SmallVector<Type, 8> returnTypes;
287c0342a2dSJacques Pienaar   for (auto result : ifOp.getResults()) {
2886a3c69e9SThomas Raoux     auto convertedType = typeConverter.convertType(result.getType());
2896a3c69e9SThomas Raoux     returnTypes.push_back(convertedType);
2906a3c69e9SThomas Raoux   }
2917c3ae48fSLei Zhang   replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
2927c3ae48fSLei Zhang                         returnTypes);
2930670f855SThomas Raoux   return success();
2940670f855SThomas Raoux }
2950670f855SThomas Raoux 
2967c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
2977c3ae48fSLei Zhang // scf::YieldOp
2987c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
2997c3ae48fSLei Zhang 
3000670f855SThomas Raoux /// Yield is lowered to stores to the VariableOp created during lowering of the
3010670f855SThomas Raoux /// parent region. For loops we also need to update the branch looping back to
3020670f855SThomas Raoux /// the header with the loop carried values.
matchAndRewrite(scf::YieldOp terminatorOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3030670f855SThomas Raoux LogicalResult TerminatorOpConversion::matchAndRewrite(
304b54c724bSRiver Riddle     scf::YieldOp terminatorOp, OpAdaptor adaptor,
3050670f855SThomas Raoux     ConversionPatternRewriter &rewriter) const {
306b54c724bSRiver Riddle   ValueRange operands = adaptor.getOperands();
307b54c724bSRiver Riddle 
3080670f855SThomas Raoux   // If the region is return values, store each value into the associated
3090670f855SThomas Raoux   // VariableOp created during lowering of the parent region.
3100670f855SThomas Raoux   if (!operands.empty()) {
3110bf4a82aSChristian Sigg     auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
312526b71e4SButygin     if (allocas.size() != operands.size())
313526b71e4SButygin       return failure();
314526b71e4SButygin 
315526b71e4SButygin     auto loc = terminatorOp.getLoc();
3160670f855SThomas Raoux     for (unsigned i = 0, e = operands.size(); i < e; i++)
3170670f855SThomas Raoux       rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
3180bf4a82aSChristian Sigg     if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
3190670f855SThomas Raoux       // For loops we also need to update the branch jumping back to the header.
3200670f855SThomas Raoux       auto br =
3210670f855SThomas Raoux           cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
3220670f855SThomas Raoux       SmallVector<Value, 8> args(br.getBlockArguments());
3230670f855SThomas Raoux       args.append(operands.begin(), operands.end());
3240670f855SThomas Raoux       rewriter.setInsertionPoint(br);
3250670f855SThomas Raoux       rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
3260670f855SThomas Raoux                                        args);
3270670f855SThomas Raoux       rewriter.eraseOp(br);
3280670f855SThomas Raoux     }
3290670f855SThomas Raoux   }
3300670f855SThomas Raoux   rewriter.eraseOp(terminatorOp);
331fbce9855SThomas Raoux   return success();
332fbce9855SThomas Raoux }
333fbce9855SThomas Raoux 
3347c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
335526b71e4SButygin // scf::WhileOp
336526b71e4SButygin //===----------------------------------------------------------------------===//
337526b71e4SButygin 
338526b71e4SButygin LogicalResult
matchAndRewrite(scf::WhileOp whileOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const339526b71e4SButygin WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
340526b71e4SButygin                                    ConversionPatternRewriter &rewriter) const {
341526b71e4SButygin   auto loc = whileOp.getLoc();
342526b71e4SButygin   auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
343526b71e4SButygin   loopOp.addEntryAndMergeBlock();
344526b71e4SButygin 
345526b71e4SButygin   OpBuilder::InsertionGuard guard(rewriter);
346526b71e4SButygin 
347c0342a2dSJacques Pienaar   Region &beforeRegion = whileOp.getBefore();
348c0342a2dSJacques Pienaar   Region &afterRegion = whileOp.getAfter();
349526b71e4SButygin 
350526b71e4SButygin   Block &entryBlock = *loopOp.getEntryBlock();
351526b71e4SButygin   Block &beforeBlock = beforeRegion.front();
352526b71e4SButygin   Block &afterBlock = afterRegion.front();
353526b71e4SButygin   Block &mergeBlock = *loopOp.getMergeBlock();
354526b71e4SButygin 
355526b71e4SButygin   auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
356526b71e4SButygin   SmallVector<Value> condArgs;
357c0342a2dSJacques Pienaar   if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
358526b71e4SButygin     return failure();
359526b71e4SButygin 
360c0342a2dSJacques Pienaar   Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
361526b71e4SButygin   if (!conditionVal)
362526b71e4SButygin     return failure();
363526b71e4SButygin 
364526b71e4SButygin   auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
365526b71e4SButygin   SmallVector<Value> yieldArgs;
366c0342a2dSJacques Pienaar   if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
367526b71e4SButygin     return failure();
368526b71e4SButygin 
369526b71e4SButygin   // Move the while before block as the initial loop header block.
370526b71e4SButygin   rewriter.inlineRegionBefore(beforeRegion, loopOp.body(),
371526b71e4SButygin                               getBlockIt(loopOp.body(), 1));
372526b71e4SButygin 
373526b71e4SButygin   // Move the while after block as the initial loop body block.
374526b71e4SButygin   rewriter.inlineRegionBefore(afterRegion, loopOp.body(),
375526b71e4SButygin                               getBlockIt(loopOp.body(), 2));
376526b71e4SButygin 
377526b71e4SButygin   // Jump from the loop entry block to the loop header block.
378526b71e4SButygin   rewriter.setInsertionPointToEnd(&entryBlock);
379c0342a2dSJacques Pienaar   rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
380526b71e4SButygin 
381526b71e4SButygin   auto condLoc = cond.getLoc();
382526b71e4SButygin 
383526b71e4SButygin   SmallVector<Value> resultValues(condArgs.size());
384526b71e4SButygin 
385526b71e4SButygin   // For other SCF ops, the scf.yield op yields the value for the whole SCF op.
386526b71e4SButygin   // So we use the scf.yield op as the anchor to create/load/store SPIR-V local
387526b71e4SButygin   // variables. But for the scf.while op, the scf.yield op yields a value for
388526b71e4SButygin   // the before region, which may not matching the whole op's result. Instead,
389526b71e4SButygin   // the scf.condition op returns values matching the whole op's results. So we
390526b71e4SButygin   // need to create/load/store variables according to that.
391e4853be2SMehdi Amini   for (const auto &it : llvm::enumerate(condArgs)) {
392526b71e4SButygin     auto res = it.value();
393526b71e4SButygin     auto i = it.index();
394526b71e4SButygin     auto pointerType =
395526b71e4SButygin         spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
396526b71e4SButygin 
397526b71e4SButygin     // Create local variables before the scf.while op.
398526b71e4SButygin     rewriter.setInsertionPoint(loopOp);
399526b71e4SButygin     auto alloc = rewriter.create<spirv::VariableOp>(
400526b71e4SButygin         condLoc, pointerType, spirv::StorageClass::Function,
401526b71e4SButygin         /*initializer=*/nullptr);
402526b71e4SButygin 
403526b71e4SButygin     // Load the final result values after the scf.while op.
404526b71e4SButygin     rewriter.setInsertionPointAfter(loopOp);
405526b71e4SButygin     auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
406526b71e4SButygin     resultValues[i] = loadResult;
407526b71e4SButygin 
408526b71e4SButygin     // Store the current iteration's result value.
409526b71e4SButygin     rewriter.setInsertionPointToEnd(&beforeBlock);
410526b71e4SButygin     rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
411526b71e4SButygin   }
412526b71e4SButygin 
413526b71e4SButygin   rewriter.setInsertionPointToEnd(&beforeBlock);
414526b71e4SButygin   rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
415526b71e4SButygin       cond, conditionVal, &afterBlock, condArgs, &mergeBlock, llvm::None);
416526b71e4SButygin 
417526b71e4SButygin   // Convert the scf.yield op to a branch back to the header block.
418526b71e4SButygin   rewriter.setInsertionPointToEnd(&afterBlock);
419526b71e4SButygin   rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, yieldArgs);
420526b71e4SButygin 
421526b71e4SButygin   rewriter.replaceOp(whileOp, resultValues);
422526b71e4SButygin   return success();
423526b71e4SButygin }
424526b71e4SButygin 
425526b71e4SButygin //===----------------------------------------------------------------------===//
4267c3ae48fSLei Zhang // Hooks
4277c3ae48fSLei Zhang //===----------------------------------------------------------------------===//
4287c3ae48fSLei Zhang 
populateSCFToSPIRVPatterns(SPIRVTypeConverter & typeConverter,ScfToSPIRVContext & scfToSPIRVContext,RewritePatternSet & patterns)4293a506b31SChris Lattner void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
4300670f855SThomas Raoux                                       ScfToSPIRVContext &scfToSPIRVContext,
431dc4e913bSChris Lattner                                       RewritePatternSet &patterns) {
432526b71e4SButygin   patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
433526b71e4SButygin                WhileOpConversion>(patterns.getContext(), typeConverter,
434526b71e4SButygin                                   scfToSPIRVContext.getImpl());
435fbce9855SThomas Raoux }
436