1*fa26c7ffSMogball //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===// 2*fa26c7ffSMogball // 3*fa26c7ffSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*fa26c7ffSMogball // See https://llvm.org/LICENSE.txt for license information. 5*fa26c7ffSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*fa26c7ffSMogball // 7*fa26c7ffSMogball //===----------------------------------------------------------------------===// 8*fa26c7ffSMogball // 9*fa26c7ffSMogball // This file contains the implementation of the core LICM algorithm. 10*fa26c7ffSMogball // 11*fa26c7ffSMogball //===----------------------------------------------------------------------===// 12*fa26c7ffSMogball 13*fa26c7ffSMogball #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 14*fa26c7ffSMogball #include "mlir/IR/Operation.h" 15*fa26c7ffSMogball #include "mlir/Interfaces/LoopLikeInterface.h" 16*fa26c7ffSMogball #include "mlir/Transforms/SideEffectUtils.h" 17*fa26c7ffSMogball #include "llvm/Support/Debug.h" 18*fa26c7ffSMogball #include <queue> 19*fa26c7ffSMogball 20*fa26c7ffSMogball #define DEBUG_TYPE "licm" 21*fa26c7ffSMogball 22*fa26c7ffSMogball using namespace mlir; 23*fa26c7ffSMogball 24*fa26c7ffSMogball /// Checks whether the given op can be hoisted by checking that 25*fa26c7ffSMogball /// - the op and none of its contained operations depend on values inside of the 26*fa26c7ffSMogball /// loop (by means of calling definedOutside). 27*fa26c7ffSMogball /// - the op has no side-effects. 28*fa26c7ffSMogball static bool canBeHoisted(Operation *op, 29*fa26c7ffSMogball function_ref<bool(Value)> definedOutside) { 30*fa26c7ffSMogball // Do not move terminators. 31*fa26c7ffSMogball if (op->hasTrait<OpTrait::IsTerminator>()) 32*fa26c7ffSMogball return false; 33*fa26c7ffSMogball 34*fa26c7ffSMogball // Walk the nested operations and check that all used values are either 35*fa26c7ffSMogball // defined outside of the loop or in a nested region, but not at the level of 36*fa26c7ffSMogball // the loop body. 37*fa26c7ffSMogball auto walkFn = [&](Operation *child) { 38*fa26c7ffSMogball for (Value operand : child->getOperands()) { 39*fa26c7ffSMogball // Ignore values defined in a nested region. 40*fa26c7ffSMogball if (op->isAncestor(operand.getParentRegion()->getParentOp())) 41*fa26c7ffSMogball continue; 42*fa26c7ffSMogball if (!definedOutside(operand)) 43*fa26c7ffSMogball return WalkResult::interrupt(); 44*fa26c7ffSMogball } 45*fa26c7ffSMogball return WalkResult::advance(); 46*fa26c7ffSMogball }; 47*fa26c7ffSMogball return !op->walk(walkFn).wasInterrupted(); 48*fa26c7ffSMogball } 49*fa26c7ffSMogball 50*fa26c7ffSMogball size_t mlir::moveLoopInvariantCode( 51*fa26c7ffSMogball RegionRange regions, 52*fa26c7ffSMogball function_ref<bool(Value, Region *)> isDefinedOutsideRegion, 53*fa26c7ffSMogball function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion, 54*fa26c7ffSMogball function_ref<void(Operation *, Region *)> moveOutOfRegion) { 55*fa26c7ffSMogball size_t numMoved = 0; 56*fa26c7ffSMogball 57*fa26c7ffSMogball for (Region *region : regions) { 58*fa26c7ffSMogball LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" << *region->getParentOp()); 59*fa26c7ffSMogball 60*fa26c7ffSMogball std::queue<Operation *> worklist; 61*fa26c7ffSMogball // Add top-level operations in the loop body to the worklist. 62*fa26c7ffSMogball for (Operation &op : region->getOps()) 63*fa26c7ffSMogball worklist.push(&op); 64*fa26c7ffSMogball 65*fa26c7ffSMogball auto definedOutside = [&](Value value) { 66*fa26c7ffSMogball return isDefinedOutsideRegion(value, region); 67*fa26c7ffSMogball }; 68*fa26c7ffSMogball 69*fa26c7ffSMogball while (!worklist.empty()) { 70*fa26c7ffSMogball Operation *op = worklist.front(); 71*fa26c7ffSMogball worklist.pop(); 72*fa26c7ffSMogball // Skip ops that have already been moved. Check if the op can be hoisted. 73*fa26c7ffSMogball if (op->getParentRegion() != region) 74*fa26c7ffSMogball continue; 75*fa26c7ffSMogball 76*fa26c7ffSMogball LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op); 77*fa26c7ffSMogball if (!shouldMoveOutOfRegion(op, region) || 78*fa26c7ffSMogball !canBeHoisted(op, definedOutside)) 79*fa26c7ffSMogball continue; 80*fa26c7ffSMogball 81*fa26c7ffSMogball LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op); 82*fa26c7ffSMogball moveOutOfRegion(op, region); 83*fa26c7ffSMogball ++numMoved; 84*fa26c7ffSMogball 85*fa26c7ffSMogball // Since the op has been moved, we need to check its users within the 86*fa26c7ffSMogball // top-level of the loop body. 87*fa26c7ffSMogball for (Operation *user : op->getUsers()) 88*fa26c7ffSMogball if (user->getParentRegion() == region) 89*fa26c7ffSMogball worklist.push(user); 90*fa26c7ffSMogball } 91*fa26c7ffSMogball } 92*fa26c7ffSMogball 93*fa26c7ffSMogball return numMoved; 94*fa26c7ffSMogball } 95*fa26c7ffSMogball 96*fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { 97*fa26c7ffSMogball return moveLoopInvariantCode( 98*fa26c7ffSMogball &loopLike.getLoopBody(), 99*fa26c7ffSMogball [&](Value value, Region *) { 100*fa26c7ffSMogball return loopLike.isDefinedOutsideOfLoop(value); 101*fa26c7ffSMogball }, 102*fa26c7ffSMogball [&](Operation *op, Region *) { return isSideEffectFree(op); }, 103*fa26c7ffSMogball [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); 104*fa26c7ffSMogball } 105