1*fa26c7ffSMogball //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2*fa26c7ffSMogball //
3*fa26c7ffSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*fa26c7ffSMogball // See https://llvm.org/LICENSE.txt for license information.
5*fa26c7ffSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*fa26c7ffSMogball //
7*fa26c7ffSMogball //===----------------------------------------------------------------------===//
8*fa26c7ffSMogball //
9*fa26c7ffSMogball // This file contains the implementation of the core LICM algorithm.
10*fa26c7ffSMogball //
11*fa26c7ffSMogball //===----------------------------------------------------------------------===//
12*fa26c7ffSMogball 
13*fa26c7ffSMogball #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
14*fa26c7ffSMogball #include "mlir/IR/Operation.h"
15*fa26c7ffSMogball #include "mlir/Interfaces/LoopLikeInterface.h"
16*fa26c7ffSMogball #include "mlir/Transforms/SideEffectUtils.h"
17*fa26c7ffSMogball #include "llvm/Support/Debug.h"
18*fa26c7ffSMogball #include <queue>
19*fa26c7ffSMogball 
20*fa26c7ffSMogball #define DEBUG_TYPE "licm"
21*fa26c7ffSMogball 
22*fa26c7ffSMogball using namespace mlir;
23*fa26c7ffSMogball 
24*fa26c7ffSMogball /// Checks whether the given op can be hoisted by checking that
25*fa26c7ffSMogball /// - the op and none of its contained operations depend on values inside of the
26*fa26c7ffSMogball ///   loop (by means of calling definedOutside).
27*fa26c7ffSMogball /// - the op has no side-effects.
28*fa26c7ffSMogball static bool canBeHoisted(Operation *op,
29*fa26c7ffSMogball                          function_ref<bool(Value)> definedOutside) {
30*fa26c7ffSMogball   // Do not move terminators.
31*fa26c7ffSMogball   if (op->hasTrait<OpTrait::IsTerminator>())
32*fa26c7ffSMogball     return false;
33*fa26c7ffSMogball 
34*fa26c7ffSMogball   // Walk the nested operations and check that all used values are either
35*fa26c7ffSMogball   // defined outside of the loop or in a nested region, but not at the level of
36*fa26c7ffSMogball   // the loop body.
37*fa26c7ffSMogball   auto walkFn = [&](Operation *child) {
38*fa26c7ffSMogball     for (Value operand : child->getOperands()) {
39*fa26c7ffSMogball       // Ignore values defined in a nested region.
40*fa26c7ffSMogball       if (op->isAncestor(operand.getParentRegion()->getParentOp()))
41*fa26c7ffSMogball         continue;
42*fa26c7ffSMogball       if (!definedOutside(operand))
43*fa26c7ffSMogball         return WalkResult::interrupt();
44*fa26c7ffSMogball     }
45*fa26c7ffSMogball     return WalkResult::advance();
46*fa26c7ffSMogball   };
47*fa26c7ffSMogball   return !op->walk(walkFn).wasInterrupted();
48*fa26c7ffSMogball }
49*fa26c7ffSMogball 
50*fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(
51*fa26c7ffSMogball     RegionRange regions,
52*fa26c7ffSMogball     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
53*fa26c7ffSMogball     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
54*fa26c7ffSMogball     function_ref<void(Operation *, Region *)> moveOutOfRegion) {
55*fa26c7ffSMogball   size_t numMoved = 0;
56*fa26c7ffSMogball 
57*fa26c7ffSMogball   for (Region *region : regions) {
58*fa26c7ffSMogball     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" << *region->getParentOp());
59*fa26c7ffSMogball 
60*fa26c7ffSMogball     std::queue<Operation *> worklist;
61*fa26c7ffSMogball     // Add top-level operations in the loop body to the worklist.
62*fa26c7ffSMogball     for (Operation &op : region->getOps())
63*fa26c7ffSMogball       worklist.push(&op);
64*fa26c7ffSMogball 
65*fa26c7ffSMogball     auto definedOutside = [&](Value value) {
66*fa26c7ffSMogball       return isDefinedOutsideRegion(value, region);
67*fa26c7ffSMogball     };
68*fa26c7ffSMogball 
69*fa26c7ffSMogball     while (!worklist.empty()) {
70*fa26c7ffSMogball       Operation *op = worklist.front();
71*fa26c7ffSMogball       worklist.pop();
72*fa26c7ffSMogball       // Skip ops that have already been moved. Check if the op can be hoisted.
73*fa26c7ffSMogball       if (op->getParentRegion() != region)
74*fa26c7ffSMogball         continue;
75*fa26c7ffSMogball 
76*fa26c7ffSMogball       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op);
77*fa26c7ffSMogball       if (!shouldMoveOutOfRegion(op, region) ||
78*fa26c7ffSMogball           !canBeHoisted(op, definedOutside))
79*fa26c7ffSMogball         continue;
80*fa26c7ffSMogball 
81*fa26c7ffSMogball       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op);
82*fa26c7ffSMogball       moveOutOfRegion(op, region);
83*fa26c7ffSMogball       ++numMoved;
84*fa26c7ffSMogball 
85*fa26c7ffSMogball       // Since the op has been moved, we need to check its users within the
86*fa26c7ffSMogball       // top-level of the loop body.
87*fa26c7ffSMogball       for (Operation *user : op->getUsers())
88*fa26c7ffSMogball         if (user->getParentRegion() == region)
89*fa26c7ffSMogball           worklist.push(user);
90*fa26c7ffSMogball     }
91*fa26c7ffSMogball   }
92*fa26c7ffSMogball 
93*fa26c7ffSMogball   return numMoved;
94*fa26c7ffSMogball }
95*fa26c7ffSMogball 
96*fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
97*fa26c7ffSMogball   return moveLoopInvariantCode(
98*fa26c7ffSMogball       &loopLike.getLoopBody(),
99*fa26c7ffSMogball       [&](Value value, Region *) {
100*fa26c7ffSMogball         return loopLike.isDefinedOutsideOfLoop(value);
101*fa26c7ffSMogball       },
102*fa26c7ffSMogball       [&](Operation *op, Region *) { return isSideEffectFree(op); },
103*fa26c7ffSMogball       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
104*fa26c7ffSMogball }
105