13f429e82SAnthony Canino //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
23f429e82SAnthony Canino //
33f429e82SAnthony Canino // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43f429e82SAnthony Canino // See https://llvm.org/LICENSE.txt for license information.
53f429e82SAnthony Canino // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63f429e82SAnthony Canino //
73f429e82SAnthony Canino //===----------------------------------------------------------------------===//
83f429e82SAnthony Canino //
93f429e82SAnthony Canino // This file implements loop range folding.
103f429e82SAnthony Canino //
113f429e82SAnthony Canino //===----------------------------------------------------------------------===//
123f429e82SAnthony Canino 
133f429e82SAnthony Canino #include "PassDetail.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
16*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Passes.h"
17*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
193f429e82SAnthony Canino #include "mlir/IR/BlockAndValueMapping.h"
203f429e82SAnthony Canino 
213f429e82SAnthony Canino using namespace mlir;
223f429e82SAnthony Canino using namespace mlir::scf;
233f429e82SAnthony Canino 
243f429e82SAnthony Canino namespace {
253f429e82SAnthony Canino struct ForLoopRangeFolding
263f429e82SAnthony Canino     : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
273f429e82SAnthony Canino   void runOnOperation() override;
283f429e82SAnthony Canino };
293f429e82SAnthony Canino } // namespace
303f429e82SAnthony Canino 
runOnOperation()313f429e82SAnthony Canino void ForLoopRangeFolding::runOnOperation() {
323f429e82SAnthony Canino   getOperation()->walk([&](ForOp op) {
333f429e82SAnthony Canino     Value indVar = op.getInductionVar();
343f429e82SAnthony Canino 
353f429e82SAnthony Canino     auto canBeFolded = [&](Value value) {
363f429e82SAnthony Canino       return op.isDefinedOutsideOfLoop(value) || value == indVar;
373f429e82SAnthony Canino     };
383f429e82SAnthony Canino 
393f429e82SAnthony Canino     // Fold until a fixed point is reached
403f429e82SAnthony Canino     while (true) {
413f429e82SAnthony Canino 
423f429e82SAnthony Canino       // If the induction variable is used more than once, we can't fold its
433f429e82SAnthony Canino       // arith ops into the loop range
443f429e82SAnthony Canino       if (!indVar.hasOneUse())
453f429e82SAnthony Canino         break;
463f429e82SAnthony Canino 
473f429e82SAnthony Canino       Operation *user = *indVar.getUsers().begin();
48a54f4eaeSMogball       if (!isa<arith::AddIOp, arith::MulIOp>(user))
493f429e82SAnthony Canino         break;
503f429e82SAnthony Canino 
513f429e82SAnthony Canino       if (!llvm::all_of(user->getOperands(), canBeFolded))
523f429e82SAnthony Canino         break;
533f429e82SAnthony Canino 
543f429e82SAnthony Canino       OpBuilder b(op);
553f429e82SAnthony Canino       BlockAndValueMapping lbMap;
56c0342a2dSJacques Pienaar       lbMap.map(indVar, op.getLowerBound());
573f429e82SAnthony Canino       BlockAndValueMapping ubMap;
58c0342a2dSJacques Pienaar       ubMap.map(indVar, op.getUpperBound());
593f429e82SAnthony Canino       BlockAndValueMapping stepMap;
60c0342a2dSJacques Pienaar       stepMap.map(indVar, op.getStep());
613f429e82SAnthony Canino 
62a54f4eaeSMogball       if (isa<arith::AddIOp>(user)) {
633f429e82SAnthony Canino         Operation *lbFold = b.clone(*user, lbMap);
643f429e82SAnthony Canino         Operation *ubFold = b.clone(*user, ubMap);
653f429e82SAnthony Canino 
663f429e82SAnthony Canino         op.setLowerBound(lbFold->getResult(0));
673f429e82SAnthony Canino         op.setUpperBound(ubFold->getResult(0));
683f429e82SAnthony Canino 
69a54f4eaeSMogball       } else if (isa<arith::MulIOp>(user)) {
703f429e82SAnthony Canino         Operation *ubFold = b.clone(*user, ubMap);
713f429e82SAnthony Canino         Operation *stepFold = b.clone(*user, stepMap);
723f429e82SAnthony Canino 
733f429e82SAnthony Canino         op.setUpperBound(ubFold->getResult(0));
743f429e82SAnthony Canino         op.setStep(stepFold->getResult(0));
753f429e82SAnthony Canino       }
763f429e82SAnthony Canino 
773f429e82SAnthony Canino       ValueRange wrapIndvar(indVar);
783f429e82SAnthony Canino       user->replaceAllUsesWith(wrapIndvar);
793f429e82SAnthony Canino       user->erase();
803f429e82SAnthony Canino     }
813f429e82SAnthony Canino   });
823f429e82SAnthony Canino }
833f429e82SAnthony Canino 
createForLoopRangeFoldingPass()843f429e82SAnthony Canino std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
853f429e82SAnthony Canino   return std::make_unique<ForLoopRangeFolding>();
863f429e82SAnthony Canino }
87