1 //===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
10 #include "../PassDetail.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15
16 using namespace mlir;
17
18 namespace {
19
20 /// Removes `unrealized_conversion_cast`s whose results are only used by other
21 /// `unrealized_conversion_cast`s converting back to the original type. This
22 /// pattern is complementary to the folder and can be used to process operations
23 /// starting from the first, i.e. the usual traversal order in dialect
24 /// conversion. The folder, on the other hand, can only apply to the last
25 /// operation in a chain of conversions because it is not expected to walk
26 /// use-def chains. One would need to declare cast ops as dynamically illegal
27 /// with a complex condition in order to eliminate them using the folder alone
28 /// in the dialect conversion infra.
29 struct UnrealizedConversionCastPassthrough
30 : public OpRewritePattern<UnrealizedConversionCastOp> {
31 using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
32
matchAndRewrite__anon5a9f2e870111::UnrealizedConversionCastPassthrough33 LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
34 PatternRewriter &rewriter) const override {
35 // Match the casts that are _only_ used by other casts, with the overall
36 // cast being a trivial noop: A->B->A.
37 auto users = op->getUsers();
38 if (!llvm::all_of(users, [&](Operation *user) {
39 if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
40 return other.getResultTypes() == op.getInputs().getTypes() &&
41 other.getInputs() == op.getOutputs();
42 return false;
43 })) {
44 return rewriter.notifyMatchFailure(op, "live unrealized conversion cast");
45 }
46
47 for (Operation *user : users)
48 rewriter.replaceOp(user, op.getInputs());
49
50 rewriter.eraseOp(op);
51 return success();
52 }
53 };
54
55 /// Pass to simplify and eliminate unrealized conversion casts.
56 struct ReconcileUnrealizedCasts
57 : public ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
58 ReconcileUnrealizedCasts() = default;
59
runOnOperation__anon5a9f2e870111::ReconcileUnrealizedCasts60 void runOnOperation() override {
61 RewritePatternSet patterns(&getContext());
62 populateReconcileUnrealizedCastsPatterns(patterns);
63 ConversionTarget target(getContext());
64 target.addIllegalOp<UnrealizedConversionCastOp>();
65 if (failed(applyPartialConversion(getOperation(), target,
66 std::move(patterns))))
67 signalPassFailure();
68 }
69 };
70
71 } // namespace
72
populateReconcileUnrealizedCastsPatterns(RewritePatternSet & patterns)73 void mlir::populateReconcileUnrealizedCastsPatterns(
74 RewritePatternSet &patterns) {
75 patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
76 }
77
createReconcileUnrealizedCastsPass()78 std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
79 return std::make_unique<ReconcileUnrealizedCasts>();
80 }
81