1 //===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
10 #include "../PassDetail.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 /// Removes `unrealized_conversion_cast`s whose results are only used by other
21 /// `unrealized_conversion_cast`s converting back to the original type. This
22 /// pattern is complementary to the folder and can be used to process operations
23 /// starting from the first, i.e. the usual traversal order in dialect
24 /// conversion. The folder, on the other hand, can only apply to the last
25 /// operation in a chain of conversions because it is not expected to walk
26 /// use-def chains. One would need to declare cast ops as dynamically illegal
27 /// with a complex condition in order to eliminate them using the folder alone
28 /// in the dialect conversion infra.
29 struct UnrealizedConversionCastPassthrough
30     : public OpRewritePattern<UnrealizedConversionCastOp> {
31   using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
32 
matchAndRewrite__anon5a9f2e870111::UnrealizedConversionCastPassthrough33   LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
34                                 PatternRewriter &rewriter) const override {
35     // Match the casts that are _only_ used by other casts, with the overall
36     // cast being a trivial noop: A->B->A.
37     auto users = op->getUsers();
38     if (!llvm::all_of(users, [&](Operation *user) {
39           if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
40             return other.getResultTypes() == op.getInputs().getTypes() &&
41                    other.getInputs() == op.getOutputs();
42           return false;
43         })) {
44       return rewriter.notifyMatchFailure(op, "live unrealized conversion cast");
45     }
46 
47     for (Operation *user : users)
48       rewriter.replaceOp(user, op.getInputs());
49 
50     rewriter.eraseOp(op);
51     return success();
52   }
53 };
54 
55 /// Pass to simplify and eliminate unrealized conversion casts.
56 struct ReconcileUnrealizedCasts
57     : public ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
58   ReconcileUnrealizedCasts() = default;
59 
runOnOperation__anon5a9f2e870111::ReconcileUnrealizedCasts60   void runOnOperation() override {
61     RewritePatternSet patterns(&getContext());
62     populateReconcileUnrealizedCastsPatterns(patterns);
63     ConversionTarget target(getContext());
64     target.addIllegalOp<UnrealizedConversionCastOp>();
65     if (failed(applyPartialConversion(getOperation(), target,
66                                       std::move(patterns))))
67       signalPassFailure();
68   }
69 };
70 
71 } // namespace
72 
populateReconcileUnrealizedCastsPatterns(RewritePatternSet & patterns)73 void mlir::populateReconcileUnrealizedCastsPatterns(
74     RewritePatternSet &patterns) {
75   patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
76 }
77 
createReconcileUnrealizedCastsPass()78 std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
79   return std::make_unique<ReconcileUnrealizedCasts>();
80 }
81