1 //===- TestIntRangeInference.cpp - Create consts from range inference ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // TODO: This pass is needed to test integer range inference until that
9 // functionality has been integrated into SCCP.
10 //===----------------------------------------------------------------------===//
11
12 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
13 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassRegistry.h"
17 #include "mlir/Support/TypeID.h"
18 #include "mlir/Transforms/FoldUtils.h"
19
20 using namespace mlir;
21 using namespace mlir::dataflow;
22
23 /// Patterned after SCCP
replaceWithConstant(DataFlowSolver & solver,OpBuilder & b,OperationFolder & folder,Value value)24 static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
25 OperationFolder &folder, Value value) {
26 auto *maybeInferredRange =
27 solver.lookupState<IntegerValueRangeLattice>(value);
28 if (!maybeInferredRange || maybeInferredRange->isUninitialized())
29 return failure();
30 const ConstantIntRanges &inferredRange =
31 maybeInferredRange->getValue().getValue();
32 Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
33 if (!maybeConstValue.has_value())
34 return failure();
35
36 Operation *maybeDefiningOp = value.getDefiningOp();
37 Dialect *valueDialect =
38 maybeDefiningOp ? maybeDefiningOp->getDialect()
39 : value.getParentRegion()->getParentOp()->getDialect();
40 Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
41 Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
42 value.getType(), value.getLoc());
43 if (!constant)
44 return failure();
45
46 value.replaceAllUsesWith(constant);
47 return success();
48 }
49
rewrite(DataFlowSolver & solver,MLIRContext * context,MutableArrayRef<Region> initialRegions)50 static void rewrite(DataFlowSolver &solver, MLIRContext *context,
51 MutableArrayRef<Region> initialRegions) {
52 SmallVector<Block *> worklist;
53 auto addToWorklist = [&](MutableArrayRef<Region> regions) {
54 for (Region ®ion : regions)
55 for (Block &block : llvm::reverse(region))
56 worklist.push_back(&block);
57 };
58
59 OpBuilder builder(context);
60 OperationFolder folder(context);
61
62 addToWorklist(initialRegions);
63 while (!worklist.empty()) {
64 Block *block = worklist.pop_back_val();
65
66 for (Operation &op : llvm::make_early_inc_range(*block)) {
67 builder.setInsertionPoint(&op);
68
69 // Replace any result with constants.
70 bool replacedAll = op.getNumResults() != 0;
71 for (Value res : op.getResults())
72 replacedAll &=
73 succeeded(replaceWithConstant(solver, builder, folder, res));
74
75 // If all of the results of the operation were replaced, try to erase
76 // the operation completely.
77 if (replacedAll && wouldOpBeTriviallyDead(&op)) {
78 assert(op.use_empty() && "expected all uses to be replaced");
79 op.erase();
80 continue;
81 }
82
83 // Add any the regions of this operation to the worklist.
84 addToWorklist(op.getRegions());
85 }
86
87 // Replace any block arguments with constants.
88 builder.setInsertionPointToStart(block);
89 for (BlockArgument arg : block->getArguments())
90 (void)replaceWithConstant(solver, builder, folder, arg);
91 }
92 }
93
94 namespace {
95 struct TestIntRangeInference
96 : PassWrapper<TestIntRangeInference, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonee8f13ed0211::TestIntRangeInference97 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
98
99 StringRef getArgument() const final { return "test-int-range-inference"; }
getDescription__anonee8f13ed0211::TestIntRangeInference100 StringRef getDescription() const final {
101 return "Test integer range inference analysis";
102 }
103
runOnOperation__anonee8f13ed0211::TestIntRangeInference104 void runOnOperation() override {
105 Operation *op = getOperation();
106 DataFlowSolver solver;
107 solver.load<DeadCodeAnalysis>();
108 solver.load<IntegerRangeAnalysis>();
109 if (failed(solver.initializeAndRun(op)))
110 return signalPassFailure();
111 rewrite(solver, op->getContext(), op->getRegions());
112 }
113 };
114 } // end anonymous namespace
115
116 namespace mlir {
117 namespace test {
registerTestIntRangeInference()118 void registerTestIntRangeInference() {
119 PassRegistration<TestIntRangeInference>();
120 }
121 } // end namespace test
122 } // end namespace mlir
123