1 //===- TestIntRangeInference.cpp - Create consts from range inference ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // TODO: This pass is needed to test integer range inference until that 9 // functionality has been integrated into SCCP. 10 //===----------------------------------------------------------------------===// 11 12 #include "mlir/Analysis/IntRangeAnalysis.h" 13 #include "mlir/Interfaces/SideEffectInterfaces.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Pass/PassRegistry.h" 16 #include "mlir/Support/TypeID.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 19 using namespace mlir; 20 21 /// Patterned after SCCP 22 static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, 23 OpBuilder &b, OperationFolder &folder, 24 Value value) { 25 Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value); 26 if (!maybeInferredRange) 27 return failure(); 28 const ConstantIntRanges &inferredRange = maybeInferredRange.getValue(); 29 Optional<APInt> maybeConstValue = inferredRange.getConstantValue(); 30 if (!maybeConstValue.hasValue()) 31 return failure(); 32 33 Operation *maybeDefiningOp = value.getDefiningOp(); 34 Dialect *valueDialect = 35 maybeDefiningOp ? maybeDefiningOp->getDialect() 36 : value.getParentRegion()->getParentOp()->getDialect(); 37 Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); 38 Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, 39 value.getType(), value.getLoc()); 40 if (!constant) 41 return failure(); 42 43 value.replaceAllUsesWith(constant); 44 return success(); 45 } 46 47 static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, 48 MutableArrayRef<Region> initialRegions) { 49 SmallVector<Block *> worklist; 50 auto addToWorklist = [&](MutableArrayRef<Region> regions) { 51 for (Region ®ion : regions) 52 for (Block &block : llvm::reverse(region)) 53 worklist.push_back(&block); 54 }; 55 56 OpBuilder builder(context); 57 OperationFolder folder(context); 58 59 addToWorklist(initialRegions); 60 while (!worklist.empty()) { 61 Block *block = worklist.pop_back_val(); 62 63 for (Operation &op : llvm::make_early_inc_range(*block)) { 64 builder.setInsertionPoint(&op); 65 66 // Replace any result with constants. 67 bool replacedAll = op.getNumResults() != 0; 68 for (Value res : op.getResults()) 69 replacedAll &= 70 succeeded(replaceWithConstant(analysis, builder, folder, res)); 71 72 // If all of the results of the operation were replaced, try to erase 73 // the operation completely. 74 if (replacedAll && wouldOpBeTriviallyDead(&op)) { 75 assert(op.use_empty() && "expected all uses to be replaced"); 76 op.erase(); 77 continue; 78 } 79 80 // Add any the regions of this operation to the worklist. 81 addToWorklist(op.getRegions()); 82 } 83 84 // Replace any block arguments with constants. 85 builder.setInsertionPointToStart(block); 86 for (BlockArgument arg : block->getArguments()) 87 (void)replaceWithConstant(analysis, builder, folder, arg); 88 } 89 } 90 91 namespace { 92 struct TestIntRangeInference 93 : PassWrapper<TestIntRangeInference, OperationPass<>> { 94 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference) 95 96 StringRef getArgument() const final { return "test-int-range-inference"; } 97 StringRef getDescription() const final { 98 return "Test integer range inference analysis"; 99 } 100 101 void runOnOperation() override { 102 Operation *op = getOperation(); 103 IntRangeAnalysis analysis(op); 104 rewrite(analysis, op->getContext(), op->getRegions()); 105 } 106 }; 107 } // end anonymous namespace 108 109 namespace mlir { 110 namespace test { 111 void registerTestIntRangeInference() { 112 PassRegistration<TestIntRangeInference>(); 113 } 114 } // end namespace test 115 } // end namespace mlir 116