1fa26c7ffSMogball //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2fa26c7ffSMogball //
3fa26c7ffSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fa26c7ffSMogball // See https://llvm.org/LICENSE.txt for license information.
5fa26c7ffSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fa26c7ffSMogball //
7fa26c7ffSMogball //===----------------------------------------------------------------------===//
8fa26c7ffSMogball //
9fa26c7ffSMogball // This file contains the implementation of the core LICM algorithm.
10fa26c7ffSMogball //
11fa26c7ffSMogball //===----------------------------------------------------------------------===//
12fa26c7ffSMogball 
13fa26c7ffSMogball #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
14fa26c7ffSMogball #include "mlir/IR/Operation.h"
15fa26c7ffSMogball #include "mlir/Interfaces/LoopLikeInterface.h"
16fa26c7ffSMogball #include "mlir/Transforms/SideEffectUtils.h"
17fa26c7ffSMogball #include "llvm/Support/Debug.h"
18fa26c7ffSMogball #include <queue>
19fa26c7ffSMogball 
20fa26c7ffSMogball #define DEBUG_TYPE "licm"
21fa26c7ffSMogball 
22fa26c7ffSMogball using namespace mlir;
23fa26c7ffSMogball 
24fa26c7ffSMogball /// Checks whether the given op can be hoisted by checking that
25fa26c7ffSMogball /// - the op and none of its contained operations depend on values inside of the
26fa26c7ffSMogball ///   loop (by means of calling definedOutside).
27fa26c7ffSMogball /// - the op has no side-effects.
canBeHoisted(Operation * op,function_ref<bool (Value)> definedOutside)28fa26c7ffSMogball static bool canBeHoisted(Operation *op,
29fa26c7ffSMogball                          function_ref<bool(Value)> definedOutside) {
30fa26c7ffSMogball   // Do not move terminators.
31fa26c7ffSMogball   if (op->hasTrait<OpTrait::IsTerminator>())
32fa26c7ffSMogball     return false;
33fa26c7ffSMogball 
34fa26c7ffSMogball   // Walk the nested operations and check that all used values are either
35fa26c7ffSMogball   // defined outside of the loop or in a nested region, but not at the level of
36fa26c7ffSMogball   // the loop body.
37fa26c7ffSMogball   auto walkFn = [&](Operation *child) {
38fa26c7ffSMogball     for (Value operand : child->getOperands()) {
39fa26c7ffSMogball       // Ignore values defined in a nested region.
40fa26c7ffSMogball       if (op->isAncestor(operand.getParentRegion()->getParentOp()))
41fa26c7ffSMogball         continue;
42fa26c7ffSMogball       if (!definedOutside(operand))
43fa26c7ffSMogball         return WalkResult::interrupt();
44fa26c7ffSMogball     }
45fa26c7ffSMogball     return WalkResult::advance();
46fa26c7ffSMogball   };
47fa26c7ffSMogball   return !op->walk(walkFn).wasInterrupted();
48fa26c7ffSMogball }
49fa26c7ffSMogball 
moveLoopInvariantCode(RegionRange regions,function_ref<bool (Value,Region *)> isDefinedOutsideRegion,function_ref<bool (Operation *,Region *)> shouldMoveOutOfRegion,function_ref<void (Operation *,Region *)> moveOutOfRegion)50fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(
51fa26c7ffSMogball     RegionRange regions,
52fa26c7ffSMogball     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
53fa26c7ffSMogball     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
54fa26c7ffSMogball     function_ref<void(Operation *, Region *)> moveOutOfRegion) {
55fa26c7ffSMogball   size_t numMoved = 0;
56fa26c7ffSMogball 
57fa26c7ffSMogball   for (Region *region : regions) {
58*9bb0f461SCullen Rhodes     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
59*9bb0f461SCullen Rhodes                             << *region->getParentOp() << "\n");
60fa26c7ffSMogball 
61fa26c7ffSMogball     std::queue<Operation *> worklist;
62fa26c7ffSMogball     // Add top-level operations in the loop body to the worklist.
63fa26c7ffSMogball     for (Operation &op : region->getOps())
64fa26c7ffSMogball       worklist.push(&op);
65fa26c7ffSMogball 
66fa26c7ffSMogball     auto definedOutside = [&](Value value) {
67fa26c7ffSMogball       return isDefinedOutsideRegion(value, region);
68fa26c7ffSMogball     };
69fa26c7ffSMogball 
70fa26c7ffSMogball     while (!worklist.empty()) {
71fa26c7ffSMogball       Operation *op = worklist.front();
72fa26c7ffSMogball       worklist.pop();
73fa26c7ffSMogball       // Skip ops that have already been moved. Check if the op can be hoisted.
74fa26c7ffSMogball       if (op->getParentRegion() != region)
75fa26c7ffSMogball         continue;
76fa26c7ffSMogball 
77*9bb0f461SCullen Rhodes       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
78fa26c7ffSMogball       if (!shouldMoveOutOfRegion(op, region) ||
79fa26c7ffSMogball           !canBeHoisted(op, definedOutside))
80fa26c7ffSMogball         continue;
81fa26c7ffSMogball 
82*9bb0f461SCullen Rhodes       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
83fa26c7ffSMogball       moveOutOfRegion(op, region);
84fa26c7ffSMogball       ++numMoved;
85fa26c7ffSMogball 
86fa26c7ffSMogball       // Since the op has been moved, we need to check its users within the
87fa26c7ffSMogball       // top-level of the loop body.
88fa26c7ffSMogball       for (Operation *user : op->getUsers())
89fa26c7ffSMogball         if (user->getParentRegion() == region)
90fa26c7ffSMogball           worklist.push(user);
91fa26c7ffSMogball     }
92fa26c7ffSMogball   }
93fa26c7ffSMogball 
94fa26c7ffSMogball   return numMoved;
95fa26c7ffSMogball }
96fa26c7ffSMogball 
moveLoopInvariantCode(LoopLikeOpInterface loopLike)97fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
98fa26c7ffSMogball   return moveLoopInvariantCode(
99fa26c7ffSMogball       &loopLike.getLoopBody(),
100fa26c7ffSMogball       [&](Value value, Region *) {
101fa26c7ffSMogball         return loopLike.isDefinedOutsideOfLoop(value);
102fa26c7ffSMogball       },
103fa26c7ffSMogball       [&](Operation *op, Region *) { return isSideEffectFree(op); },
104fa26c7ffSMogball       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
105fa26c7ffSMogball }
106