17905da65SAmit Sabne //===- LoopInvariantCodeMotion.cpp - Code to perform loop fusion-----------===// 27905da65SAmit Sabne // 37905da65SAmit Sabne // Copyright 2019 The MLIR Authors. 47905da65SAmit Sabne // 57905da65SAmit Sabne // Licensed under the Apache License, Version 2.0 (the "License"); 67905da65SAmit Sabne // you may not use this file except in compliance with the License. 77905da65SAmit Sabne // You may obtain a copy of the License at 87905da65SAmit Sabne // 97905da65SAmit Sabne // http://www.apache.org/licenses/LICENSE-2.0 107905da65SAmit Sabne // 117905da65SAmit Sabne // Unless required by applicable law or agreed to in writing, software 127905da65SAmit Sabne // distributed under the License is distributed on an "AS IS" BASIS, 137905da65SAmit Sabne // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 147905da65SAmit Sabne // See the License for the specific language governing permissions and 157905da65SAmit Sabne // limitations under the License. 167905da65SAmit Sabne // ============================================================================= 177905da65SAmit Sabne // 187905da65SAmit Sabne // This file implements loop invariant code motion. 197905da65SAmit Sabne // 207905da65SAmit Sabne //===----------------------------------------------------------------------===// 217905da65SAmit Sabne 227905da65SAmit Sabne #include "mlir/AffineOps/AffineOps.h" 237905da65SAmit Sabne #include "mlir/Analysis/AffineAnalysis.h" 247905da65SAmit Sabne #include "mlir/Analysis/AffineStructures.h" 257905da65SAmit Sabne #include "mlir/Analysis/LoopAnalysis.h" 267905da65SAmit Sabne #include "mlir/Analysis/SliceAnalysis.h" 277905da65SAmit Sabne #include "mlir/Analysis/Utils.h" 287905da65SAmit Sabne #include "mlir/IR/AffineExpr.h" 297905da65SAmit Sabne #include "mlir/IR/AffineMap.h" 307905da65SAmit Sabne #include "mlir/IR/Builders.h" 317905da65SAmit Sabne #include "mlir/Pass/Pass.h" 327905da65SAmit Sabne #include "mlir/StandardOps/Ops.h" 337905da65SAmit Sabne #include "mlir/Transforms/LoopUtils.h" 347905da65SAmit Sabne #include "mlir/Transforms/Passes.h" 357905da65SAmit Sabne #include "mlir/Transforms/Utils.h" 367905da65SAmit Sabne #include "llvm/ADT/DenseMap.h" 377905da65SAmit Sabne #include "llvm/ADT/DenseSet.h" 38*7a43da60SAmit Sabne #include "llvm/ADT/SmallPtrSet.h" 397905da65SAmit Sabne #include "llvm/Support/CommandLine.h" 407905da65SAmit Sabne #include "llvm/Support/Debug.h" 417905da65SAmit Sabne #include "llvm/Support/raw_ostream.h" 427905da65SAmit Sabne 437905da65SAmit Sabne #define DEBUG_TYPE "licm" 447905da65SAmit Sabne 457905da65SAmit Sabne using namespace mlir; 467905da65SAmit Sabne 477905da65SAmit Sabne namespace { 487905da65SAmit Sabne 497905da65SAmit Sabne /// Loop invariant code motion (LICM) pass. 507905da65SAmit Sabne /// TODO(asabne) : The pass is missing zero-trip tests. 517905da65SAmit Sabne /// TODO(asabne) : Check for the presence of side effects before hoisting. 527905da65SAmit Sabne struct LoopInvariantCodeMotion : public FunctionPass<LoopInvariantCodeMotion> { 537905da65SAmit Sabne void runOnFunction() override; 547905da65SAmit Sabne void runOnAffineForOp(AffineForOp forOp); 557905da65SAmit Sabne }; 567905da65SAmit Sabne } // end anonymous namespace 577905da65SAmit Sabne 58*7a43da60SAmit Sabne static bool 59*7a43da60SAmit Sabne checkInvarianceOfNestedIfOps(Operation *op, Value *indVar, 60*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &definedOps, 61*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &opsToHoist); 62*7a43da60SAmit Sabne static bool isOpLoopInvariant(Operation &op, Value *indVar, 63*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &definedOps, 64*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &opsToHoist); 65*7a43da60SAmit Sabne 66*7a43da60SAmit Sabne static bool 67*7a43da60SAmit Sabne areAllOpsInTheBlockListInvariant(Region &blockList, Value *indVar, 68*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &definedOps, 69*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &opsToHoist); 70*7a43da60SAmit Sabne 71*7a43da60SAmit Sabne static bool isMemRefDereferencingOp(Operation &op) { 72*7a43da60SAmit Sabne // TODO(asabne): Support DMA Ops. 73*7a43da60SAmit Sabne if (isa<LoadOp>(op) || isa<StoreOp>(op)) { 74*7a43da60SAmit Sabne return true; 75*7a43da60SAmit Sabne } 76*7a43da60SAmit Sabne return false; 77*7a43da60SAmit Sabne } 78*7a43da60SAmit Sabne 797905da65SAmit Sabne FunctionPassBase *mlir::createLoopInvariantCodeMotionPass() { 807905da65SAmit Sabne return new LoopInvariantCodeMotion(); 817905da65SAmit Sabne } 827905da65SAmit Sabne 83*7a43da60SAmit Sabne // Returns true if the individual op is loop invariant. 84*7a43da60SAmit Sabne bool isOpLoopInvariant(Operation &op, Value *indVar, 85*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &definedOps, 86*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &opsToHoist) { 87*7a43da60SAmit Sabne LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); 88*7a43da60SAmit Sabne 89*7a43da60SAmit Sabne if (isa<AffineIfOp>(op)) { 90*7a43da60SAmit Sabne if (!checkInvarianceOfNestedIfOps(&op, indVar, definedOps, opsToHoist)) { 91*7a43da60SAmit Sabne return false; 92*7a43da60SAmit Sabne } 93*7a43da60SAmit Sabne } else if (isa<AffineForOp>(op)) { 94*7a43da60SAmit Sabne // If the body of a predicated region has a for loop, we don't hoist the 95*7a43da60SAmit Sabne // 'affine.if'. 96*7a43da60SAmit Sabne return false; 97*7a43da60SAmit Sabne } else if (isa<DmaStartOp>(op) || isa<DmaWaitOp>(op)) { 98*7a43da60SAmit Sabne // TODO(asabne): Support DMA ops. 99*7a43da60SAmit Sabne return false; 100*7a43da60SAmit Sabne } else if (!isa<ConstantOp>(op)) { 101*7a43da60SAmit Sabne if (isMemRefDereferencingOp(op)) { 102*7a43da60SAmit Sabne Value *memref = isa<LoadOp>(op) ? cast<LoadOp>(op).getMemRef() 103*7a43da60SAmit Sabne : cast<StoreOp>(op).getMemRef(); 104*7a43da60SAmit Sabne for (auto *user : memref->getUsers()) { 105*7a43da60SAmit Sabne // If this memref has a user that is a DMA, give up because these 106*7a43da60SAmit Sabne // operations write to this memref. 107*7a43da60SAmit Sabne if (isa<DmaStartOp>(op) || isa<DmaWaitOp>(op)) { 108*7a43da60SAmit Sabne return false; 109*7a43da60SAmit Sabne } 110*7a43da60SAmit Sabne // If the memref used by the load/store is used in a store elsewhere in 111*7a43da60SAmit Sabne // the loop nest, we do not hoist. Similarly, if the memref used in a 112*7a43da60SAmit Sabne // load is also being stored too, we do not hoist the load. 113*7a43da60SAmit Sabne if (isa<StoreOp>(user) || (isa<LoadOp>(user) && isa<StoreOp>(op))) { 114*7a43da60SAmit Sabne if (&op != user) { 115*7a43da60SAmit Sabne SmallVector<AffineForOp, 8> userIVs; 116*7a43da60SAmit Sabne getLoopIVs(*user, &userIVs); 117*7a43da60SAmit Sabne // Check that userIVs don't contain the for loop around the op. 118*7a43da60SAmit Sabne if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar))) { 119*7a43da60SAmit Sabne return false; 120*7a43da60SAmit Sabne } 121*7a43da60SAmit Sabne } 122*7a43da60SAmit Sabne } 123*7a43da60SAmit Sabne } 124*7a43da60SAmit Sabne } 125*7a43da60SAmit Sabne 126*7a43da60SAmit Sabne // Insert this op in the defined ops list. 127*7a43da60SAmit Sabne definedOps.insert(&op); 128*7a43da60SAmit Sabne 129*7a43da60SAmit Sabne if (op.getNumOperands() == 0 && !isa<AffineTerminatorOp>(op)) { 130*7a43da60SAmit Sabne LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n"); 131*7a43da60SAmit Sabne return false; 132*7a43da60SAmit Sabne } 133*7a43da60SAmit Sabne for (unsigned int i = 0; i < op.getNumOperands(); ++i) { 134*7a43da60SAmit Sabne auto *operandSrc = op.getOperand(i)->getDefiningOp(); 135*7a43da60SAmit Sabne 136*7a43da60SAmit Sabne LLVM_DEBUG( 137*7a43da60SAmit Sabne op.getOperand(i)->print(llvm::dbgs() << "\nIterating on operand\n")); 138*7a43da60SAmit Sabne 139*7a43da60SAmit Sabne // If the loop IV is the operand, this op isn't loop invariant. 140*7a43da60SAmit Sabne if (indVar == op.getOperand(i)) { 141*7a43da60SAmit Sabne LLVM_DEBUG(llvm::dbgs() << "\nLoop IV is the operand\n"); 142*7a43da60SAmit Sabne return false; 143*7a43da60SAmit Sabne } 144*7a43da60SAmit Sabne 145*7a43da60SAmit Sabne if (operandSrc != nullptr) { 146*7a43da60SAmit Sabne LLVM_DEBUG(llvm::dbgs() 147*7a43da60SAmit Sabne << *operandSrc << "\nIterating on operand src\n"); 148*7a43da60SAmit Sabne 149*7a43da60SAmit Sabne // If the value was defined in the loop (outside of the 150*7a43da60SAmit Sabne // if/else region), and that operation itself wasn't meant to 151*7a43da60SAmit Sabne // be hoisted, then mark this operation loop dependent. 152*7a43da60SAmit Sabne if (definedOps.count(operandSrc) && opsToHoist.count(operandSrc) == 0) { 153*7a43da60SAmit Sabne return false; 154*7a43da60SAmit Sabne } 155*7a43da60SAmit Sabne } 156*7a43da60SAmit Sabne } 157*7a43da60SAmit Sabne } 158*7a43da60SAmit Sabne 159*7a43da60SAmit Sabne // If no operand was loop variant, mark this op for motion. 160*7a43da60SAmit Sabne opsToHoist.insert(&op); 161*7a43da60SAmit Sabne return true; 162*7a43da60SAmit Sabne } 163*7a43da60SAmit Sabne 164*7a43da60SAmit Sabne // Checks if all ops in a region (i.e. list of blocks) are loop invariant. 165*7a43da60SAmit Sabne bool areAllOpsInTheBlockListInvariant( 166*7a43da60SAmit Sabne Region &blockList, Value *indVar, SmallPtrSetImpl<Operation *> &definedOps, 167*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &opsToHoist) { 168*7a43da60SAmit Sabne 169*7a43da60SAmit Sabne for (auto &b : blockList) { 170*7a43da60SAmit Sabne for (auto &op : b) { 171*7a43da60SAmit Sabne if (!isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { 172*7a43da60SAmit Sabne return false; 173*7a43da60SAmit Sabne } 174*7a43da60SAmit Sabne } 175*7a43da60SAmit Sabne } 176*7a43da60SAmit Sabne 177*7a43da60SAmit Sabne return true; 178*7a43da60SAmit Sabne } 179*7a43da60SAmit Sabne 180*7a43da60SAmit Sabne // Returns true if the affine.if op can be hoisted. 181*7a43da60SAmit Sabne bool checkInvarianceOfNestedIfOps(Operation *op, Value *indVar, 182*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &definedOps, 183*7a43da60SAmit Sabne SmallPtrSetImpl<Operation *> &opsToHoist) { 184*7a43da60SAmit Sabne assert(isa<AffineIfOp>(op)); 185*7a43da60SAmit Sabne auto ifOp = cast<AffineIfOp>(op); 186*7a43da60SAmit Sabne 187*7a43da60SAmit Sabne if (!areAllOpsInTheBlockListInvariant(ifOp.getThenBlocks(), indVar, 188*7a43da60SAmit Sabne definedOps, opsToHoist)) { 189*7a43da60SAmit Sabne return false; 190*7a43da60SAmit Sabne } 191*7a43da60SAmit Sabne 192*7a43da60SAmit Sabne if (!areAllOpsInTheBlockListInvariant(ifOp.getElseBlocks(), indVar, 193*7a43da60SAmit Sabne definedOps, opsToHoist)) { 194*7a43da60SAmit Sabne return false; 195*7a43da60SAmit Sabne } 196*7a43da60SAmit Sabne 197*7a43da60SAmit Sabne return true; 198*7a43da60SAmit Sabne } 199*7a43da60SAmit Sabne 2007905da65SAmit Sabne void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { 2017905da65SAmit Sabne auto *loopBody = forOp.getBody(); 202*7a43da60SAmit Sabne auto *indVar = forOp.getInductionVar(); 2037905da65SAmit Sabne 204*7a43da60SAmit Sabne SmallPtrSet<Operation *, 8> definedOps; 2057905da65SAmit Sabne // This is the place where hoisted instructions would reside. 2067905da65SAmit Sabne FuncBuilder b(forOp.getOperation()); 2077905da65SAmit Sabne 208*7a43da60SAmit Sabne SmallPtrSet<Operation *, 8> opsToHoist; 2097905da65SAmit Sabne SmallVector<Operation *, 8> opsToMove; 2107905da65SAmit Sabne 2117905da65SAmit Sabne for (auto &op : *loopBody) { 212*7a43da60SAmit Sabne // We don't hoist for loops. 213*7a43da60SAmit Sabne if (!isa<AffineForOp>(op)) { 214*7a43da60SAmit Sabne if (!isa<AffineTerminatorOp>(op)) { 215*7a43da60SAmit Sabne if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { 2167905da65SAmit Sabne opsToMove.push_back(&op); 2177905da65SAmit Sabne } 2187905da65SAmit Sabne } 219*7a43da60SAmit Sabne } 220*7a43da60SAmit Sabne } 2217905da65SAmit Sabne 222*7a43da60SAmit Sabne // For all instructions that we found to be invariant, place sequentially 2237905da65SAmit Sabne // right before the for loop. 2247905da65SAmit Sabne for (auto *op : opsToMove) { 2257905da65SAmit Sabne op->moveBefore(forOp); 2267905da65SAmit Sabne } 2277905da65SAmit Sabne 228*7a43da60SAmit Sabne LLVM_DEBUG(forOp.getOperation()->print(llvm::dbgs() << "Modified loop\n")); 2297905da65SAmit Sabne 2307905da65SAmit Sabne // If the for loop body has a single operation (the terminator), erase it. 2317905da65SAmit Sabne if (forOp.getBody()->getOperations().size() == 1) { 232d5b60ee8SRiver Riddle assert(isa<AffineTerminatorOp>(forOp.getBody()->front())); 2337905da65SAmit Sabne forOp.erase(); 2347905da65SAmit Sabne } 2357905da65SAmit Sabne } 2367905da65SAmit Sabne 2377905da65SAmit Sabne void LoopInvariantCodeMotion::runOnFunction() { 2380134b5dfSChris Lattner // Walk through all loops in a function in innermost-loop-first order. This 2390134b5dfSChris Lattner // way, we first LICM from the inner loop, and place the ops in 2400134b5dfSChris Lattner // the outer loop, which in turn can be further LICM'ed. 2410134b5dfSChris Lattner getFunction().walk<AffineForOp>([&](AffineForOp op) { 2424aa9235aSAmit Sabne LLVM_DEBUG(op.getOperation()->print(llvm::dbgs() << "\nOriginal loop\n")); 2434aa9235aSAmit Sabne runOnAffineForOp(op); 2440134b5dfSChris Lattner }); 2457905da65SAmit Sabne } 2467905da65SAmit Sabne 2477905da65SAmit Sabne static PassRegistration<LoopInvariantCodeMotion> 248258e8d9cSNicolas Vasilache pass("affine-loop-invariant-code-motion", 2497905da65SAmit Sabne "Hoist loop invariant instructions outside of the loop"); 250