195aff23eSKrzysztof Drewniak //===- TestIntRangeInference.cpp - Create consts from range inference ---===//
295aff23eSKrzysztof Drewniak //
395aff23eSKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
495aff23eSKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
595aff23eSKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
695aff23eSKrzysztof Drewniak //
795aff23eSKrzysztof Drewniak //===----------------------------------------------------------------------===//
895aff23eSKrzysztof Drewniak // TODO: This pass is needed to test integer range inference until that
995aff23eSKrzysztof Drewniak // functionality has been integrated into SCCP.
1095aff23eSKrzysztof Drewniak //===----------------------------------------------------------------------===//
1195aff23eSKrzysztof Drewniak 
1295aff23eSKrzysztof Drewniak #include "mlir/Analysis/IntRangeAnalysis.h"
1395aff23eSKrzysztof Drewniak #include "mlir/Interfaces/SideEffectInterfaces.h"
1495aff23eSKrzysztof Drewniak #include "mlir/Pass/Pass.h"
1595aff23eSKrzysztof Drewniak #include "mlir/Pass/PassRegistry.h"
1695aff23eSKrzysztof Drewniak #include "mlir/Support/TypeID.h"
1795aff23eSKrzysztof Drewniak #include "mlir/Transforms/FoldUtils.h"
1895aff23eSKrzysztof Drewniak 
1995aff23eSKrzysztof Drewniak using namespace mlir;
2095aff23eSKrzysztof Drewniak 
2195aff23eSKrzysztof Drewniak /// Patterned after SCCP
2295aff23eSKrzysztof Drewniak static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
2395aff23eSKrzysztof Drewniak                                          OpBuilder &b, OperationFolder &folder,
2495aff23eSKrzysztof Drewniak                                          Value value) {
2595aff23eSKrzysztof Drewniak   Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
2695aff23eSKrzysztof Drewniak   if (!maybeInferredRange)
2795aff23eSKrzysztof Drewniak     return failure();
28*3b7c3a65SKazu Hirata   const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
2995aff23eSKrzysztof Drewniak   Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
30*3b7c3a65SKazu Hirata   if (!maybeConstValue.hasValue())
3195aff23eSKrzysztof Drewniak     return failure();
3295aff23eSKrzysztof Drewniak 
3395aff23eSKrzysztof Drewniak   Operation *maybeDefiningOp = value.getDefiningOp();
3495aff23eSKrzysztof Drewniak   Dialect *valueDialect =
3595aff23eSKrzysztof Drewniak       maybeDefiningOp ? maybeDefiningOp->getDialect()
3695aff23eSKrzysztof Drewniak                       : value.getParentRegion()->getParentOp()->getDialect();
3795aff23eSKrzysztof Drewniak   Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
3895aff23eSKrzysztof Drewniak   Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
3995aff23eSKrzysztof Drewniak                                               value.getType(), value.getLoc());
4095aff23eSKrzysztof Drewniak   if (!constant)
4195aff23eSKrzysztof Drewniak     return failure();
4295aff23eSKrzysztof Drewniak 
4395aff23eSKrzysztof Drewniak   value.replaceAllUsesWith(constant);
4495aff23eSKrzysztof Drewniak   return success();
4595aff23eSKrzysztof Drewniak }
4695aff23eSKrzysztof Drewniak 
4795aff23eSKrzysztof Drewniak static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
4895aff23eSKrzysztof Drewniak                     MutableArrayRef<Region> initialRegions) {
4995aff23eSKrzysztof Drewniak   SmallVector<Block *> worklist;
5095aff23eSKrzysztof Drewniak   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
5195aff23eSKrzysztof Drewniak     for (Region &region : regions)
5295aff23eSKrzysztof Drewniak       for (Block &block : llvm::reverse(region))
5395aff23eSKrzysztof Drewniak         worklist.push_back(&block);
5495aff23eSKrzysztof Drewniak   };
5595aff23eSKrzysztof Drewniak 
5695aff23eSKrzysztof Drewniak   OpBuilder builder(context);
5795aff23eSKrzysztof Drewniak   OperationFolder folder(context);
5895aff23eSKrzysztof Drewniak 
5995aff23eSKrzysztof Drewniak   addToWorklist(initialRegions);
6095aff23eSKrzysztof Drewniak   while (!worklist.empty()) {
6195aff23eSKrzysztof Drewniak     Block *block = worklist.pop_back_val();
6295aff23eSKrzysztof Drewniak 
6395aff23eSKrzysztof Drewniak     for (Operation &op : llvm::make_early_inc_range(*block)) {
6495aff23eSKrzysztof Drewniak       builder.setInsertionPoint(&op);
6595aff23eSKrzysztof Drewniak 
6695aff23eSKrzysztof Drewniak       // Replace any result with constants.
6795aff23eSKrzysztof Drewniak       bool replacedAll = op.getNumResults() != 0;
6895aff23eSKrzysztof Drewniak       for (Value res : op.getResults())
6995aff23eSKrzysztof Drewniak         replacedAll &=
7095aff23eSKrzysztof Drewniak             succeeded(replaceWithConstant(analysis, builder, folder, res));
7195aff23eSKrzysztof Drewniak 
7295aff23eSKrzysztof Drewniak       // If all of the results of the operation were replaced, try to erase
7395aff23eSKrzysztof Drewniak       // the operation completely.
7495aff23eSKrzysztof Drewniak       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
7595aff23eSKrzysztof Drewniak         assert(op.use_empty() && "expected all uses to be replaced");
7695aff23eSKrzysztof Drewniak         op.erase();
7795aff23eSKrzysztof Drewniak         continue;
7895aff23eSKrzysztof Drewniak       }
7995aff23eSKrzysztof Drewniak 
8095aff23eSKrzysztof Drewniak       // Add any the regions of this operation to the worklist.
8195aff23eSKrzysztof Drewniak       addToWorklist(op.getRegions());
8295aff23eSKrzysztof Drewniak     }
8395aff23eSKrzysztof Drewniak 
8495aff23eSKrzysztof Drewniak     // Replace any block arguments with constants.
8595aff23eSKrzysztof Drewniak     builder.setInsertionPointToStart(block);
8695aff23eSKrzysztof Drewniak     for (BlockArgument arg : block->getArguments())
8795aff23eSKrzysztof Drewniak       (void)replaceWithConstant(analysis, builder, folder, arg);
8895aff23eSKrzysztof Drewniak   }
8995aff23eSKrzysztof Drewniak }
9095aff23eSKrzysztof Drewniak 
9195aff23eSKrzysztof Drewniak namespace {
9295aff23eSKrzysztof Drewniak struct TestIntRangeInference
9395aff23eSKrzysztof Drewniak     : PassWrapper<TestIntRangeInference, OperationPass<>> {
9495aff23eSKrzysztof Drewniak   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
9595aff23eSKrzysztof Drewniak 
9695aff23eSKrzysztof Drewniak   StringRef getArgument() const final { return "test-int-range-inference"; }
9795aff23eSKrzysztof Drewniak   StringRef getDescription() const final {
9895aff23eSKrzysztof Drewniak     return "Test integer range inference analysis";
9995aff23eSKrzysztof Drewniak   }
10095aff23eSKrzysztof Drewniak 
10195aff23eSKrzysztof Drewniak   void runOnOperation() override {
10295aff23eSKrzysztof Drewniak     Operation *op = getOperation();
10395aff23eSKrzysztof Drewniak     IntRangeAnalysis analysis(op);
10495aff23eSKrzysztof Drewniak     rewrite(analysis, op->getContext(), op->getRegions());
10595aff23eSKrzysztof Drewniak   }
10695aff23eSKrzysztof Drewniak };
10795aff23eSKrzysztof Drewniak } // end anonymous namespace
10895aff23eSKrzysztof Drewniak 
10995aff23eSKrzysztof Drewniak namespace mlir {
11095aff23eSKrzysztof Drewniak namespace test {
11195aff23eSKrzysztof Drewniak void registerTestIntRangeInference() {
11295aff23eSKrzysztof Drewniak   PassRegistration<TestIntRangeInference>();
11395aff23eSKrzysztof Drewniak }
11495aff23eSKrzysztof Drewniak } // end namespace test
11595aff23eSKrzysztof Drewniak } // end namespace mlir
116