1*95aff23eSKrzysztof Drewniak //===- TestIntRangeInference.cpp - Create consts from range inference ---===//
2*95aff23eSKrzysztof Drewniak //
3*95aff23eSKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*95aff23eSKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
5*95aff23eSKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*95aff23eSKrzysztof Drewniak //
7*95aff23eSKrzysztof Drewniak //===----------------------------------------------------------------------===//
8*95aff23eSKrzysztof Drewniak // TODO: This pass is needed to test integer range inference until that
9*95aff23eSKrzysztof Drewniak // functionality has been integrated into SCCP.
10*95aff23eSKrzysztof Drewniak //===----------------------------------------------------------------------===//
11*95aff23eSKrzysztof Drewniak 
12*95aff23eSKrzysztof Drewniak #include "mlir/Analysis/IntRangeAnalysis.h"
13*95aff23eSKrzysztof Drewniak #include "mlir/Interfaces/SideEffectInterfaces.h"
14*95aff23eSKrzysztof Drewniak #include "mlir/Pass/Pass.h"
15*95aff23eSKrzysztof Drewniak #include "mlir/Pass/PassRegistry.h"
16*95aff23eSKrzysztof Drewniak #include "mlir/Support/TypeID.h"
17*95aff23eSKrzysztof Drewniak #include "mlir/Transforms/FoldUtils.h"
18*95aff23eSKrzysztof Drewniak 
19*95aff23eSKrzysztof Drewniak using namespace mlir;
20*95aff23eSKrzysztof Drewniak 
21*95aff23eSKrzysztof Drewniak /// Patterned after SCCP
22*95aff23eSKrzysztof Drewniak static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
23*95aff23eSKrzysztof Drewniak                                          OpBuilder &b, OperationFolder &folder,
24*95aff23eSKrzysztof Drewniak                                          Value value) {
25*95aff23eSKrzysztof Drewniak   Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
26*95aff23eSKrzysztof Drewniak   if (!maybeInferredRange)
27*95aff23eSKrzysztof Drewniak     return failure();
28*95aff23eSKrzysztof Drewniak   const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
29*95aff23eSKrzysztof Drewniak   Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
30*95aff23eSKrzysztof Drewniak   if (!maybeConstValue.hasValue())
31*95aff23eSKrzysztof Drewniak     return failure();
32*95aff23eSKrzysztof Drewniak 
33*95aff23eSKrzysztof Drewniak   Operation *maybeDefiningOp = value.getDefiningOp();
34*95aff23eSKrzysztof Drewniak   Dialect *valueDialect =
35*95aff23eSKrzysztof Drewniak       maybeDefiningOp ? maybeDefiningOp->getDialect()
36*95aff23eSKrzysztof Drewniak                       : value.getParentRegion()->getParentOp()->getDialect();
37*95aff23eSKrzysztof Drewniak   Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
38*95aff23eSKrzysztof Drewniak   Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
39*95aff23eSKrzysztof Drewniak                                               value.getType(), value.getLoc());
40*95aff23eSKrzysztof Drewniak   if (!constant)
41*95aff23eSKrzysztof Drewniak     return failure();
42*95aff23eSKrzysztof Drewniak 
43*95aff23eSKrzysztof Drewniak   value.replaceAllUsesWith(constant);
44*95aff23eSKrzysztof Drewniak   return success();
45*95aff23eSKrzysztof Drewniak }
46*95aff23eSKrzysztof Drewniak 
47*95aff23eSKrzysztof Drewniak static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
48*95aff23eSKrzysztof Drewniak                     MutableArrayRef<Region> initialRegions) {
49*95aff23eSKrzysztof Drewniak   SmallVector<Block *> worklist;
50*95aff23eSKrzysztof Drewniak   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
51*95aff23eSKrzysztof Drewniak     for (Region &region : regions)
52*95aff23eSKrzysztof Drewniak       for (Block &block : llvm::reverse(region))
53*95aff23eSKrzysztof Drewniak         worklist.push_back(&block);
54*95aff23eSKrzysztof Drewniak   };
55*95aff23eSKrzysztof Drewniak 
56*95aff23eSKrzysztof Drewniak   OpBuilder builder(context);
57*95aff23eSKrzysztof Drewniak   OperationFolder folder(context);
58*95aff23eSKrzysztof Drewniak 
59*95aff23eSKrzysztof Drewniak   addToWorklist(initialRegions);
60*95aff23eSKrzysztof Drewniak   while (!worklist.empty()) {
61*95aff23eSKrzysztof Drewniak     Block *block = worklist.pop_back_val();
62*95aff23eSKrzysztof Drewniak 
63*95aff23eSKrzysztof Drewniak     for (Operation &op : llvm::make_early_inc_range(*block)) {
64*95aff23eSKrzysztof Drewniak       builder.setInsertionPoint(&op);
65*95aff23eSKrzysztof Drewniak 
66*95aff23eSKrzysztof Drewniak       // Replace any result with constants.
67*95aff23eSKrzysztof Drewniak       bool replacedAll = op.getNumResults() != 0;
68*95aff23eSKrzysztof Drewniak       for (Value res : op.getResults())
69*95aff23eSKrzysztof Drewniak         replacedAll &=
70*95aff23eSKrzysztof Drewniak             succeeded(replaceWithConstant(analysis, builder, folder, res));
71*95aff23eSKrzysztof Drewniak 
72*95aff23eSKrzysztof Drewniak       // If all of the results of the operation were replaced, try to erase
73*95aff23eSKrzysztof Drewniak       // the operation completely.
74*95aff23eSKrzysztof Drewniak       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
75*95aff23eSKrzysztof Drewniak         assert(op.use_empty() && "expected all uses to be replaced");
76*95aff23eSKrzysztof Drewniak         op.erase();
77*95aff23eSKrzysztof Drewniak         continue;
78*95aff23eSKrzysztof Drewniak       }
79*95aff23eSKrzysztof Drewniak 
80*95aff23eSKrzysztof Drewniak       // Add any the regions of this operation to the worklist.
81*95aff23eSKrzysztof Drewniak       addToWorklist(op.getRegions());
82*95aff23eSKrzysztof Drewniak     }
83*95aff23eSKrzysztof Drewniak 
84*95aff23eSKrzysztof Drewniak     // Replace any block arguments with constants.
85*95aff23eSKrzysztof Drewniak     builder.setInsertionPointToStart(block);
86*95aff23eSKrzysztof Drewniak     for (BlockArgument arg : block->getArguments())
87*95aff23eSKrzysztof Drewniak       (void)replaceWithConstant(analysis, builder, folder, arg);
88*95aff23eSKrzysztof Drewniak   }
89*95aff23eSKrzysztof Drewniak }
90*95aff23eSKrzysztof Drewniak 
91*95aff23eSKrzysztof Drewniak namespace {
92*95aff23eSKrzysztof Drewniak struct TestIntRangeInference
93*95aff23eSKrzysztof Drewniak     : PassWrapper<TestIntRangeInference, OperationPass<>> {
94*95aff23eSKrzysztof Drewniak   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
95*95aff23eSKrzysztof Drewniak 
96*95aff23eSKrzysztof Drewniak   StringRef getArgument() const final { return "test-int-range-inference"; }
97*95aff23eSKrzysztof Drewniak   StringRef getDescription() const final {
98*95aff23eSKrzysztof Drewniak     return "Test integer range inference analysis";
99*95aff23eSKrzysztof Drewniak   }
100*95aff23eSKrzysztof Drewniak 
101*95aff23eSKrzysztof Drewniak   void runOnOperation() override {
102*95aff23eSKrzysztof Drewniak     Operation *op = getOperation();
103*95aff23eSKrzysztof Drewniak     IntRangeAnalysis analysis(op);
104*95aff23eSKrzysztof Drewniak     rewrite(analysis, op->getContext(), op->getRegions());
105*95aff23eSKrzysztof Drewniak   }
106*95aff23eSKrzysztof Drewniak };
107*95aff23eSKrzysztof Drewniak } // end anonymous namespace
108*95aff23eSKrzysztof Drewniak 
109*95aff23eSKrzysztof Drewniak namespace mlir {
110*95aff23eSKrzysztof Drewniak namespace test {
111*95aff23eSKrzysztof Drewniak void registerTestIntRangeInference() {
112*95aff23eSKrzysztof Drewniak   PassRegistration<TestIntRangeInference>();
113*95aff23eSKrzysztof Drewniak }
114*95aff23eSKrzysztof Drewniak } // end namespace test
115*95aff23eSKrzysztof Drewniak } // end namespace mlir
116