1 //===- TestIntRangeInference.cpp - Create consts from range inference ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // TODO: This pass is needed to test integer range inference until that 9 // functionality has been integrated into SCCP. 10 //===----------------------------------------------------------------------===// 11 12 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 13 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 14 #include "mlir/Interfaces/SideEffectInterfaces.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassRegistry.h" 17 #include "mlir/Support/TypeID.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 20 using namespace mlir; 21 using namespace mlir::dataflow; 22 23 /// Patterned after SCCP 24 static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b, 25 OperationFolder &folder, Value value) { 26 auto *maybeInferredRange = 27 solver.lookupState<IntegerValueRangeLattice>(value); 28 if (!maybeInferredRange || maybeInferredRange->isUninitialized()) 29 return failure(); 30 const ConstantIntRanges &inferredRange = 31 maybeInferredRange->getValue().getValue(); 32 Optional<APInt> maybeConstValue = inferredRange.getConstantValue(); 33 if (!maybeConstValue.hasValue()) 34 return failure(); 35 36 Operation *maybeDefiningOp = value.getDefiningOp(); 37 Dialect *valueDialect = 38 maybeDefiningOp ? maybeDefiningOp->getDialect() 39 : value.getParentRegion()->getParentOp()->getDialect(); 40 Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); 41 Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, 42 value.getType(), value.getLoc()); 43 if (!constant) 44 return failure(); 45 46 value.replaceAllUsesWith(constant); 47 return success(); 48 } 49 50 static void rewrite(DataFlowSolver &solver, MLIRContext *context, 51 MutableArrayRef<Region> initialRegions) { 52 SmallVector<Block *> worklist; 53 auto addToWorklist = [&](MutableArrayRef<Region> regions) { 54 for (Region ®ion : regions) 55 for (Block &block : llvm::reverse(region)) 56 worklist.push_back(&block); 57 }; 58 59 OpBuilder builder(context); 60 OperationFolder folder(context); 61 62 addToWorklist(initialRegions); 63 while (!worklist.empty()) { 64 Block *block = worklist.pop_back_val(); 65 66 for (Operation &op : llvm::make_early_inc_range(*block)) { 67 builder.setInsertionPoint(&op); 68 69 // Replace any result with constants. 70 bool replacedAll = op.getNumResults() != 0; 71 for (Value res : op.getResults()) 72 replacedAll &= 73 succeeded(replaceWithConstant(solver, builder, folder, res)); 74 75 // If all of the results of the operation were replaced, try to erase 76 // the operation completely. 77 if (replacedAll && wouldOpBeTriviallyDead(&op)) { 78 assert(op.use_empty() && "expected all uses to be replaced"); 79 op.erase(); 80 continue; 81 } 82 83 // Add any the regions of this operation to the worklist. 84 addToWorklist(op.getRegions()); 85 } 86 87 // Replace any block arguments with constants. 88 builder.setInsertionPointToStart(block); 89 for (BlockArgument arg : block->getArguments()) 90 (void)replaceWithConstant(solver, builder, folder, arg); 91 } 92 } 93 94 namespace { 95 struct TestIntRangeInference 96 : PassWrapper<TestIntRangeInference, OperationPass<>> { 97 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference) 98 99 StringRef getArgument() const final { return "test-int-range-inference"; } 100 StringRef getDescription() const final { 101 return "Test integer range inference analysis"; 102 } 103 104 void runOnOperation() override { 105 Operation *op = getOperation(); 106 DataFlowSolver solver; 107 solver.load<DeadCodeAnalysis>(); 108 solver.load<IntegerRangeAnalysis>(); 109 if (failed(solver.initializeAndRun(op))) 110 return signalPassFailure(); 111 rewrite(solver, op->getContext(), op->getRegions()); 112 } 113 }; 114 } // end anonymous namespace 115 116 namespace mlir { 117 namespace test { 118 void registerTestIntRangeInference() { 119 PassRegistration<TestIntRangeInference>(); 120 } 121 } // end namespace test 122 } // end namespace mlir 123