1 //===- LoopLikeInterface.cpp - Loop-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/LoopLikeInterface.h" 10 #include "mlir/Interfaces/SideEffectInterfaces.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 #include "llvm/Support/Debug.h" 13 #include <queue> 14 15 using namespace mlir; 16 17 #define DEBUG_TYPE "loop-like" 18 19 //===----------------------------------------------------------------------===// 20 // LoopLike Interfaces 21 //===----------------------------------------------------------------------===// 22 23 /// Include the definitions of the loop-like interfaces. 24 #include "mlir/Interfaces/LoopLikeInterface.cpp.inc" 25 26 //===----------------------------------------------------------------------===// 27 // LoopLike Utilities 28 //===----------------------------------------------------------------------===// 29 30 /// Returns true if the given operation is side-effect free as are all of its 31 /// nested operations. 32 /// 33 /// TODO: There is a duplicate function in ControlFlowSink. Move 34 /// `moveLoopInvariantCode` to TransformUtils and then factor out this function. 35 static bool isSideEffectFree(Operation *op) { 36 if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) { 37 // If the op has side-effects, it cannot be moved. 38 if (!memInterface.hasNoEffect()) 39 return false; 40 // If the op does not have recursive side effects, then it can be moved. 41 if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>()) 42 return true; 43 } else if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>()) { 44 // Otherwise, if the op does not implement the memory effect interface and 45 // it does not have recursive side effects, then it cannot be known that the 46 // op is moveable. 47 return false; 48 } 49 50 // Recurse into the regions and ensure that all nested ops can also be moved. 51 for (Region ®ion : op->getRegions()) 52 for (Operation &op : region.getOps()) 53 if (!isSideEffectFree(&op)) 54 return false; 55 return true; 56 } 57 58 /// Checks whether the given op can be hoisted by checking that 59 /// - the op and none of its contained operations depend on values inside of the 60 /// loop (by means of calling definedOutside). 61 /// - the op has no side-effects. 62 static bool canBeHoisted(Operation *op, 63 function_ref<bool(Value)> definedOutside) { 64 if (!isSideEffectFree(op)) 65 return false; 66 67 // Do not move terminators. 68 if (op->hasTrait<OpTrait::IsTerminator>()) 69 return false; 70 71 // Walk the nested operations and check that all used values are either 72 // defined outside of the loop or in a nested region, but not at the level of 73 // the loop body. 74 auto walkFn = [&](Operation *child) { 75 for (Value operand : child->getOperands()) { 76 // Ignore values defined in a nested region. 77 if (op->isAncestor(operand.getParentRegion()->getParentOp())) 78 continue; 79 if (!definedOutside(operand)) 80 return WalkResult::interrupt(); 81 } 82 return WalkResult::advance(); 83 }; 84 return !op->walk(walkFn).wasInterrupted(); 85 } 86 87 void mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) { 88 Region *loopBody = &looplike.getLoopBody(); 89 90 std::queue<Operation *> worklist; 91 // Add top-level operations in the loop body to the worklist. 92 for (Operation &op : loopBody->getOps()) 93 worklist.push(&op); 94 95 auto definedOutside = [&](Value value) { 96 return looplike.isDefinedOutsideOfLoop(value); 97 }; 98 99 while (!worklist.empty()) { 100 Operation *op = worklist.front(); 101 worklist.pop(); 102 // Skip ops that have already been moved. Check if the op can be hoisted. 103 if (op->getParentRegion() != loopBody || !canBeHoisted(op, definedOutside)) 104 continue; 105 106 looplike.moveOutOfLoop(op); 107 108 // Since the op has been moved, we need to check its users within the 109 // top-level of the loop body. 110 for (Operation *user : op->getUsers()) 111 if (user->getParentRegion() == loopBody) 112 worklist.push(user); 113 } 114 } 115