1 //===- TestIntRangeInference.cpp - Create consts from range inference ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // TODO: This pass is needed to test integer range inference until that
9 // functionality has been integrated into SCCP.
10 //===----------------------------------------------------------------------===//
11 
12 #include "mlir/Analysis/IntRangeAnalysis.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassRegistry.h"
16 #include "mlir/Support/TypeID.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 
19 using namespace mlir;
20 
21 /// Patterned after SCCP
22 static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
23                                          OpBuilder &b, OperationFolder &folder,
24                                          Value value) {
25   Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
26   if (!maybeInferredRange)
27     return failure();
28   const ConstantIntRanges &inferredRange = maybeInferredRange.value();
29   Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
30   if (!maybeConstValue.has_value())
31     return failure();
32 
33   Operation *maybeDefiningOp = value.getDefiningOp();
34   Dialect *valueDialect =
35       maybeDefiningOp ? maybeDefiningOp->getDialect()
36                       : value.getParentRegion()->getParentOp()->getDialect();
37   Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
38   Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
39                                               value.getType(), value.getLoc());
40   if (!constant)
41     return failure();
42 
43   value.replaceAllUsesWith(constant);
44   return success();
45 }
46 
47 static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
48                     MutableArrayRef<Region> initialRegions) {
49   SmallVector<Block *> worklist;
50   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
51     for (Region &region : regions)
52       for (Block &block : llvm::reverse(region))
53         worklist.push_back(&block);
54   };
55 
56   OpBuilder builder(context);
57   OperationFolder folder(context);
58 
59   addToWorklist(initialRegions);
60   while (!worklist.empty()) {
61     Block *block = worklist.pop_back_val();
62 
63     for (Operation &op : llvm::make_early_inc_range(*block)) {
64       builder.setInsertionPoint(&op);
65 
66       // Replace any result with constants.
67       bool replacedAll = op.getNumResults() != 0;
68       for (Value res : op.getResults())
69         replacedAll &=
70             succeeded(replaceWithConstant(analysis, builder, folder, res));
71 
72       // If all of the results of the operation were replaced, try to erase
73       // the operation completely.
74       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
75         assert(op.use_empty() && "expected all uses to be replaced");
76         op.erase();
77         continue;
78       }
79 
80       // Add any the regions of this operation to the worklist.
81       addToWorklist(op.getRegions());
82     }
83 
84     // Replace any block arguments with constants.
85     builder.setInsertionPointToStart(block);
86     for (BlockArgument arg : block->getArguments())
87       (void)replaceWithConstant(analysis, builder, folder, arg);
88   }
89 }
90 
91 namespace {
92 struct TestIntRangeInference
93     : PassWrapper<TestIntRangeInference, OperationPass<>> {
94   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
95 
96   StringRef getArgument() const final { return "test-int-range-inference"; }
97   StringRef getDescription() const final {
98     return "Test integer range inference analysis";
99   }
100 
101   void runOnOperation() override {
102     Operation *op = getOperation();
103     IntRangeAnalysis analysis(op);
104     rewrite(analysis, op->getContext(), op->getRegions());
105   }
106 };
107 } // end anonymous namespace
108 
109 namespace mlir {
110 namespace test {
111 void registerTestIntRangeInference() {
112   PassRegistration<TestIntRangeInference>();
113 }
114 } // end namespace test
115 } // end namespace mlir
116