1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/SCF/IR/SCF.h"
11 #include "mlir/Dialect/SCF/Transforms/Passes.h"
12 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
13 #include "mlir/Transforms/DialectConversion.h"
14 
15 using namespace mlir;
16 using namespace mlir::scf;
17 
18 namespace {
19 class ConvertForOpTypes : public OpConversionPattern<ForOp> {
20 public:
21   using OpConversionPattern::OpConversionPattern;
22   LogicalResult
matchAndRewrite(ForOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const23   matchAndRewrite(ForOp op, OpAdaptor adaptor,
24                   ConversionPatternRewriter &rewriter) const override {
25     SmallVector<Type, 6> newResultTypes;
26     for (auto type : op.getResultTypes()) {
27       Type newType = typeConverter->convertType(type);
28       if (!newType)
29         return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
30       newResultTypes.push_back(newType);
31     }
32 
33     // Clone the op without the regions and inline the regions from the old op.
34     //
35     // This is a little bit tricky. We have two concerns here:
36     //
37     // 1. We cannot update the op in place because the dialect conversion
38     // framework does not track type changes for ops updated in place, so it
39     // won't insert appropriate materializations on the changed result types.
40     // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
41     // clone the op.
42     //
43     // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
44     // inefficient to recursively clone the regions, there is a correctness
45     // issue: if we clone with the regions, then the dialect conversion
46     // framework thinks that we just inserted all the cloned child ops. But what
47     // we want is to "take" the child regions and let the dialect conversion
48     // framework continue recursively into ops inside those regions (which are
49     // already in its worklist; inlining them into the new op's regions doesn't
50     // remove the child ops from the worklist).
51     ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
52     // Take the region from the old op and put it in the new op.
53     rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
54                                 newOp.getLoopBody().end());
55 
56     // Now, update all the types.
57 
58     // Convert the type of the entry block of the ForOp's body.
59     if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
60                                            *getTypeConverter()))) {
61       return rewriter.notifyMatchFailure(op, "could not convert body types");
62     }
63     // Change the clone to use the updated operands. We could have cloned with
64     // a BlockAndValueMapping, but this seems a bit more direct.
65     newOp->setOperands(adaptor.getOperands());
66     // Update the result types to the new converted types.
67     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
68       std::get<0>(t).setType(std::get<1>(t));
69 
70     rewriter.replaceOp(op, newOp.getResults());
71     return success();
72   }
73 };
74 } // namespace
75 
76 namespace {
77 class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
78 public:
79   using OpConversionPattern::OpConversionPattern;
80   LogicalResult
matchAndRewrite(IfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const81   matchAndRewrite(IfOp op, OpAdaptor adaptor,
82                   ConversionPatternRewriter &rewriter) const override {
83     // TODO: Generalize this to any type conversion, not just 1:1.
84     //
85     // We need to implement something more sophisticated here that tracks which
86     // types convert to which other types and does the appropriate
87     // materialization logic.
88     // For example, it's possible that one result type converts to 0 types and
89     // another to 2 types, so newResultTypes would at least be the right size to
90     // not crash in the llvm::zip call below, but then we would set the the
91     // wrong type on the SSA values! These edge cases are also why we cannot
92     // safely use the TypeConverter::convertTypes helper here.
93     SmallVector<Type, 6> newResultTypes;
94     for (auto type : op.getResultTypes()) {
95       Type newType = typeConverter->convertType(type);
96       if (!newType)
97         return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
98       newResultTypes.push_back(newType);
99     }
100 
101     // See comments in the ForOp pattern for why we clone without regions and
102     // then inline.
103     IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
104     rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
105                                 newOp.getThenRegion().end());
106     rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
107                                 newOp.getElseRegion().end());
108 
109     // Update the operands and types.
110     newOp->setOperands(adaptor.getOperands());
111     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
112       std::get<0>(t).setType(std::get<1>(t));
113     rewriter.replaceOp(op, newOp.getResults());
114     return success();
115   }
116 };
117 } // namespace
118 
119 namespace {
120 // When the result types of a ForOp/IfOp get changed, the operand types of the
121 // corresponding yield op need to be changed. In order to trigger the
122 // appropriate type conversions / materializations, we need a dummy pattern.
123 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
124 public:
125   using OpConversionPattern::OpConversionPattern;
126   LogicalResult
matchAndRewrite(scf::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const127   matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
128                   ConversionPatternRewriter &rewriter) const override {
129     rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
130     return success();
131   }
132 };
133 } // namespace
134 
135 namespace {
136 class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
137 public:
138   using OpConversionPattern<WhileOp>::OpConversionPattern;
139 
140   LogicalResult
matchAndRewrite(WhileOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const141   matchAndRewrite(WhileOp op, OpAdaptor adaptor,
142                   ConversionPatternRewriter &rewriter) const override {
143     auto *converter = getTypeConverter();
144     assert(converter);
145     SmallVector<Type> newResultTypes;
146     if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
147       return failure();
148 
149     auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
150                                           adaptor.getOperands());
151     for (auto i : {0u, 1u}) {
152       auto &dstRegion = newOp.getRegion(i);
153       rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
154       if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
155         return rewriter.notifyMatchFailure(op, "could not convert body types");
156     }
157     rewriter.replaceOp(op, newOp.getResults());
158     return success();
159   }
160 };
161 } // namespace
162 
163 namespace {
164 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
165 public:
166   using OpConversionPattern<ConditionOp>::OpConversionPattern;
167   LogicalResult
matchAndRewrite(ConditionOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const168   matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
169                   ConversionPatternRewriter &rewriter) const override {
170     rewriter.updateRootInPlace(
171         op, [&]() { op->setOperands(adaptor.getOperands()); });
172     return success();
173   }
174 };
175 } // namespace
176 
populateSCFStructuralTypeConversionsAndLegality(TypeConverter & typeConverter,RewritePatternSet & patterns,ConversionTarget & target)177 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
178     TypeConverter &typeConverter, RewritePatternSet &patterns,
179     ConversionTarget &target) {
180   patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
181                ConvertWhileOpTypes, ConvertConditionOpTypes>(
182       typeConverter, patterns.getContext());
183   target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
184     return typeConverter.isLegal(op->getResultTypes());
185   });
186   target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
187     // We only have conversions for a subset of ops that use scf.yield
188     // terminators.
189     if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
190       return true;
191     return typeConverter.isLegal(op.getOperandTypes());
192   });
193   target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
194       [&](Operation *op) { return typeConverter.isLegal(op); });
195 }
196