1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the core LICM algorithm.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Interfaces/LoopLikeInterface.h"
16 #include "mlir/Transforms/SideEffectUtils.h"
17 #include "llvm/Support/Debug.h"
18 #include <queue>
19 
20 #define DEBUG_TYPE "licm"
21 
22 using namespace mlir;
23 
24 /// Checks whether the given op can be hoisted by checking that
25 /// - the op and none of its contained operations depend on values inside of the
26 ///   loop (by means of calling definedOutside).
27 /// - the op has no side-effects.
28 static bool canBeHoisted(Operation *op,
29                          function_ref<bool(Value)> definedOutside) {
30   // Do not move terminators.
31   if (op->hasTrait<OpTrait::IsTerminator>())
32     return false;
33 
34   // Walk the nested operations and check that all used values are either
35   // defined outside of the loop or in a nested region, but not at the level of
36   // the loop body.
37   auto walkFn = [&](Operation *child) {
38     for (Value operand : child->getOperands()) {
39       // Ignore values defined in a nested region.
40       if (op->isAncestor(operand.getParentRegion()->getParentOp()))
41         continue;
42       if (!definedOutside(operand))
43         return WalkResult::interrupt();
44     }
45     return WalkResult::advance();
46   };
47   return !op->walk(walkFn).wasInterrupted();
48 }
49 
50 size_t mlir::moveLoopInvariantCode(
51     RegionRange regions,
52     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
53     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
54     function_ref<void(Operation *, Region *)> moveOutOfRegion) {
55   size_t numMoved = 0;
56 
57   for (Region *region : regions) {
58     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" << *region->getParentOp());
59 
60     std::queue<Operation *> worklist;
61     // Add top-level operations in the loop body to the worklist.
62     for (Operation &op : region->getOps())
63       worklist.push(&op);
64 
65     auto definedOutside = [&](Value value) {
66       return isDefinedOutsideRegion(value, region);
67     };
68 
69     while (!worklist.empty()) {
70       Operation *op = worklist.front();
71       worklist.pop();
72       // Skip ops that have already been moved. Check if the op can be hoisted.
73       if (op->getParentRegion() != region)
74         continue;
75 
76       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op);
77       if (!shouldMoveOutOfRegion(op, region) ||
78           !canBeHoisted(op, definedOutside))
79         continue;
80 
81       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op);
82       moveOutOfRegion(op, region);
83       ++numMoved;
84 
85       // Since the op has been moved, we need to check its users within the
86       // top-level of the loop body.
87       for (Operation *user : op->getUsers())
88         if (user->getParentRegion() == region)
89           worklist.push(user);
90     }
91   }
92 
93   return numMoved;
94 }
95 
96 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
97   return moveLoopInvariantCode(
98       {&loopLike.getLoopBody()},
99       [&](Value value, Region *) {
100         return loopLike.isDefinedOutsideOfLoop(value);
101       },
102       [&](Operation *op, Region *) { return isSideEffectFree(op); },
103       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
104 }
105