1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/SCF/IR/SCF.h"
11 #include "mlir/Dialect/SCF/Transforms/Passes.h"
12 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
13 #include "mlir/Transforms/DialectConversion.h"
14
15 using namespace mlir;
16 using namespace mlir::scf;
17
18 namespace {
19 class ConvertForOpTypes : public OpConversionPattern<ForOp> {
20 public:
21 using OpConversionPattern::OpConversionPattern;
22 LogicalResult
matchAndRewrite(ForOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const23 matchAndRewrite(ForOp op, OpAdaptor adaptor,
24 ConversionPatternRewriter &rewriter) const override {
25 SmallVector<Type, 6> newResultTypes;
26 for (auto type : op.getResultTypes()) {
27 Type newType = typeConverter->convertType(type);
28 if (!newType)
29 return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
30 newResultTypes.push_back(newType);
31 }
32
33 // Clone the op without the regions and inline the regions from the old op.
34 //
35 // This is a little bit tricky. We have two concerns here:
36 //
37 // 1. We cannot update the op in place because the dialect conversion
38 // framework does not track type changes for ops updated in place, so it
39 // won't insert appropriate materializations on the changed result types.
40 // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
41 // clone the op.
42 //
43 // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
44 // inefficient to recursively clone the regions, there is a correctness
45 // issue: if we clone with the regions, then the dialect conversion
46 // framework thinks that we just inserted all the cloned child ops. But what
47 // we want is to "take" the child regions and let the dialect conversion
48 // framework continue recursively into ops inside those regions (which are
49 // already in its worklist; inlining them into the new op's regions doesn't
50 // remove the child ops from the worklist).
51 ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
52 // Take the region from the old op and put it in the new op.
53 rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
54 newOp.getLoopBody().end());
55
56 // Now, update all the types.
57
58 // Convert the type of the entry block of the ForOp's body.
59 if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
60 *getTypeConverter()))) {
61 return rewriter.notifyMatchFailure(op, "could not convert body types");
62 }
63 // Change the clone to use the updated operands. We could have cloned with
64 // a BlockAndValueMapping, but this seems a bit more direct.
65 newOp->setOperands(adaptor.getOperands());
66 // Update the result types to the new converted types.
67 for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
68 std::get<0>(t).setType(std::get<1>(t));
69
70 rewriter.replaceOp(op, newOp.getResults());
71 return success();
72 }
73 };
74 } // namespace
75
76 namespace {
77 class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
78 public:
79 using OpConversionPattern::OpConversionPattern;
80 LogicalResult
matchAndRewrite(IfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const81 matchAndRewrite(IfOp op, OpAdaptor adaptor,
82 ConversionPatternRewriter &rewriter) const override {
83 // TODO: Generalize this to any type conversion, not just 1:1.
84 //
85 // We need to implement something more sophisticated here that tracks which
86 // types convert to which other types and does the appropriate
87 // materialization logic.
88 // For example, it's possible that one result type converts to 0 types and
89 // another to 2 types, so newResultTypes would at least be the right size to
90 // not crash in the llvm::zip call below, but then we would set the the
91 // wrong type on the SSA values! These edge cases are also why we cannot
92 // safely use the TypeConverter::convertTypes helper here.
93 SmallVector<Type, 6> newResultTypes;
94 for (auto type : op.getResultTypes()) {
95 Type newType = typeConverter->convertType(type);
96 if (!newType)
97 return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
98 newResultTypes.push_back(newType);
99 }
100
101 // See comments in the ForOp pattern for why we clone without regions and
102 // then inline.
103 IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
104 rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
105 newOp.getThenRegion().end());
106 rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
107 newOp.getElseRegion().end());
108
109 // Update the operands and types.
110 newOp->setOperands(adaptor.getOperands());
111 for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
112 std::get<0>(t).setType(std::get<1>(t));
113 rewriter.replaceOp(op, newOp.getResults());
114 return success();
115 }
116 };
117 } // namespace
118
119 namespace {
120 // When the result types of a ForOp/IfOp get changed, the operand types of the
121 // corresponding yield op need to be changed. In order to trigger the
122 // appropriate type conversions / materializations, we need a dummy pattern.
123 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
124 public:
125 using OpConversionPattern::OpConversionPattern;
126 LogicalResult
matchAndRewrite(scf::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const127 matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter) const override {
129 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
130 return success();
131 }
132 };
133 } // namespace
134
135 namespace {
136 class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
137 public:
138 using OpConversionPattern<WhileOp>::OpConversionPattern;
139
140 LogicalResult
matchAndRewrite(WhileOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const141 matchAndRewrite(WhileOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter) const override {
143 auto *converter = getTypeConverter();
144 assert(converter);
145 SmallVector<Type> newResultTypes;
146 if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
147 return failure();
148
149 auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
150 adaptor.getOperands());
151 for (auto i : {0u, 1u}) {
152 auto &dstRegion = newOp.getRegion(i);
153 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
154 if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
155 return rewriter.notifyMatchFailure(op, "could not convert body types");
156 }
157 rewriter.replaceOp(op, newOp.getResults());
158 return success();
159 }
160 };
161 } // namespace
162
163 namespace {
164 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
165 public:
166 using OpConversionPattern<ConditionOp>::OpConversionPattern;
167 LogicalResult
matchAndRewrite(ConditionOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const168 matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
169 ConversionPatternRewriter &rewriter) const override {
170 rewriter.updateRootInPlace(
171 op, [&]() { op->setOperands(adaptor.getOperands()); });
172 return success();
173 }
174 };
175 } // namespace
176
populateSCFStructuralTypeConversionsAndLegality(TypeConverter & typeConverter,RewritePatternSet & patterns,ConversionTarget & target)177 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
178 TypeConverter &typeConverter, RewritePatternSet &patterns,
179 ConversionTarget &target) {
180 patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
181 ConvertWhileOpTypes, ConvertConditionOpTypes>(
182 typeConverter, patterns.getContext());
183 target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
184 return typeConverter.isLegal(op->getResultTypes());
185 });
186 target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
187 // We only have conversions for a subset of ops that use scf.yield
188 // terminators.
189 if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
190 return true;
191 return typeConverter.isLegal(op.getOperandTypes());
192 });
193 target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
194 [&](Operation *op) { return typeConverter.isLegal(op); });
195 }
196