1 //===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/TopologicalSortUtils.h"
10 #include "mlir/IR/OpDefinition.h"
11
12 using namespace mlir;
13
sortTopologically(Block * block,llvm::iterator_range<Block::iterator> ops,function_ref<bool (Value,Operation *)> isOperandReady)14 bool mlir::sortTopologically(
15 Block *block, llvm::iterator_range<Block::iterator> ops,
16 function_ref<bool(Value, Operation *)> isOperandReady) {
17 if (ops.empty())
18 return true;
19
20 // The set of operations that have not yet been scheduled.
21 DenseSet<Operation *> unscheduledOps;
22 // Mark all operations as unscheduled.
23 for (Operation &op : ops)
24 unscheduledOps.insert(&op);
25
26 Block::iterator nextScheduledOp = ops.begin();
27 Block::iterator end = ops.end();
28
29 // An operation is ready to be scheduled if all its operands are ready. An
30 // operation is ready if:
31 const auto isReady = [&](Value value, Operation *top) {
32 // - the user-provided callback marks it as ready,
33 if (isOperandReady && isOperandReady(value, top))
34 return true;
35 Operation *parent = value.getDefiningOp();
36 // - it is a block argument,
37 if (!parent)
38 return true;
39 Operation *ancestor = block->findAncestorOpInBlock(*parent);
40 // - it is an implicit capture,
41 if (!ancestor)
42 return true;
43 // - it is defined in a nested region, or
44 if (ancestor == top)
45 return true;
46 // - its ancestor in the block is scheduled.
47 return !unscheduledOps.contains(ancestor);
48 };
49
50 bool allOpsScheduled = true;
51 while (!unscheduledOps.empty()) {
52 bool scheduledAtLeastOnce = false;
53
54 // Loop over the ops that are not sorted yet, try to find the ones "ready",
55 // i.e. the ones for which there aren't any operand produced by an op in the
56 // set, and "schedule" it (move it before the `nextScheduledOp`).
57 for (Operation &op :
58 llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
59 // An operation is recursively ready to be scheduled of it and its nested
60 // operations are ready.
61 WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
62 return llvm::all_of(
63 nestedOp->getOperands(),
64 [&](Value operand) { return isReady(operand, &op); })
65 ? WalkResult::advance()
66 : WalkResult::interrupt();
67 });
68 if (readyToSchedule.wasInterrupted())
69 continue;
70
71 // Schedule the operation by moving it to the start.
72 unscheduledOps.erase(&op);
73 op.moveBefore(block, nextScheduledOp);
74 scheduledAtLeastOnce = true;
75 // Move the iterator forward if we schedule the operation at the front.
76 if (&op == &*nextScheduledOp)
77 ++nextScheduledOp;
78 }
79 // If no operations were scheduled, give up and advance the iterator.
80 if (!scheduledAtLeastOnce) {
81 allOpsScheduled = false;
82 unscheduledOps.erase(&*nextScheduledOp);
83 ++nextScheduledOp;
84 }
85 }
86
87 return allOpsScheduled;
88 }
89
sortTopologically(Block * block,function_ref<bool (Value,Operation *)> isOperandReady)90 bool mlir::sortTopologically(
91 Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
92 if (block->empty())
93 return true;
94 if (block->back().hasTrait<OpTrait::IsTerminator>())
95 return sortTopologically(block, block->without_terminator(),
96 isOperandReady);
97 return sortTopologically(block, *block, isOperandReady);
98 }
99