1*ace01605SRiver Riddle //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2*ace01605SRiver Riddle //
3*ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*ace01605SRiver Riddle //
7*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8*ace01605SRiver Riddle //
9*ace01605SRiver Riddle // This file implements patterns to convert standard dialect to SPIR-V dialect.
10*ace01605SRiver Riddle //
11*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
12*ace01605SRiver Riddle 
13*ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
14*ace01605SRiver Riddle #include "../SPIRVCommon/Pattern.h"
15*ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16*ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17*ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18*ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19*ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20*ace01605SRiver Riddle #include "mlir/IR/AffineMap.h"
21*ace01605SRiver Riddle #include "mlir/Support/LogicalResult.h"
22*ace01605SRiver Riddle #include "llvm/ADT/SetVector.h"
23*ace01605SRiver Riddle #include "llvm/Support/Debug.h"
24*ace01605SRiver Riddle 
25*ace01605SRiver Riddle #define DEBUG_TYPE "cf-to-spirv-pattern"
26*ace01605SRiver Riddle 
27*ace01605SRiver Riddle using namespace mlir;
28*ace01605SRiver Riddle 
29*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
30*ace01605SRiver Riddle // Operation conversion
31*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
32*ace01605SRiver Riddle 
33*ace01605SRiver Riddle namespace {
34*ace01605SRiver Riddle 
35*ace01605SRiver Riddle /// Converts cf.br to spv.Branch.
36*ace01605SRiver Riddle struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
37*ace01605SRiver Riddle   using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
38*ace01605SRiver Riddle 
39*ace01605SRiver Riddle   LogicalResult
matchAndRewrite__anon110ecafd0111::BranchOpPattern40*ace01605SRiver Riddle   matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
41*ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
42*ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
43*ace01605SRiver Riddle                                                  adaptor.getDestOperands());
44*ace01605SRiver Riddle     return success();
45*ace01605SRiver Riddle   }
46*ace01605SRiver Riddle };
47*ace01605SRiver Riddle 
48*ace01605SRiver Riddle /// Converts cf.cond_br to spv.BranchConditional.
49*ace01605SRiver Riddle struct CondBranchOpPattern final
50*ace01605SRiver Riddle     : public OpConversionPattern<cf::CondBranchOp> {
51*ace01605SRiver Riddle   using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
52*ace01605SRiver Riddle 
53*ace01605SRiver Riddle   LogicalResult
matchAndRewrite__anon110ecafd0111::CondBranchOpPattern54*ace01605SRiver Riddle   matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
55*ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
56*ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
57*ace01605SRiver Riddle         op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
58*ace01605SRiver Riddle         op.getFalseDest(), adaptor.getFalseDestOperands());
59*ace01605SRiver Riddle     return success();
60*ace01605SRiver Riddle   }
61*ace01605SRiver Riddle };
62*ace01605SRiver Riddle } // namespace
63*ace01605SRiver Riddle 
64*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
65*ace01605SRiver Riddle // Pattern population
66*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
67*ace01605SRiver Riddle 
populateControlFlowToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)68*ace01605SRiver Riddle void mlir::cf::populateControlFlowToSPIRVPatterns(
69*ace01605SRiver Riddle     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
70*ace01605SRiver Riddle   MLIRContext *context = patterns.getContext();
71*ace01605SRiver Riddle 
72*ace01605SRiver Riddle   patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
73*ace01605SRiver Riddle }
74