1 //===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/TopologicalSortUtils.h"
10 #include "mlir/IR/OpDefinition.h"
11 
12 using namespace mlir;
13 
14 bool mlir::sortTopologically(
15     Block *block, llvm::iterator_range<Block::iterator> ops,
16     function_ref<bool(Value, Operation *)> isOperandReady) {
17   if (ops.empty())
18     return true;
19 
20   // The set of operations that have not yet been scheduled.
21   DenseSet<Operation *> unscheduledOps;
22   // Mark all operations as unscheduled.
23   for (Operation &op : ops)
24     unscheduledOps.insert(&op);
25 
26   Block::iterator nextScheduledOp = ops.begin();
27   Block::iterator end = ops.end();
28 
29   // An operation is ready to be scheduled if all its operands are ready. An
30   // operation is ready if:
31   const auto isReady = [&](Value value, Operation *top) {
32     // - the user-provided callback marks it as ready,
33     if (isOperandReady && isOperandReady(value, top))
34       return true;
35     Operation *parent = value.getDefiningOp();
36     // - it is a block argument,
37     if (!parent)
38       return true;
39     Operation *ancestor = block->findAncestorOpInBlock(*parent);
40     // - it is an implicit capture,
41     if (!ancestor)
42       return true;
43     // - it is defined in a nested region, or
44     if (ancestor == top)
45       return true;
46     // - its ancestor in the block is scheduled.
47     return !unscheduledOps.contains(ancestor);
48   };
49 
50   bool allOpsScheduled = true;
51   while (!unscheduledOps.empty()) {
52     bool scheduledAtLeastOnce = false;
53 
54     // Loop over the ops that are not sorted yet, try to find the ones "ready",
55     // i.e. the ones for which there aren't any operand produced by an op in the
56     // set, and "schedule" it (move it before the `nextScheduledOp`).
57     for (Operation &op :
58          llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
59       // An operation is recursively ready to be scheduled of it and its nested
60       // operations are ready.
61       WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
62         return llvm::all_of(
63                    nestedOp->getOperands(),
64                    [&](Value operand) { return isReady(operand, &op); })
65                    ? WalkResult::advance()
66                    : WalkResult::interrupt();
67       });
68       if (readyToSchedule.wasInterrupted())
69         continue;
70 
71       // Schedule the operation by moving it to the start.
72       unscheduledOps.erase(&op);
73       op.moveBefore(block, nextScheduledOp);
74       scheduledAtLeastOnce = true;
75       // Move the iterator forward if we schedule the operation at the front.
76       if (&op == &*nextScheduledOp)
77         ++nextScheduledOp;
78     }
79     // If no operations were scheduled, give up and advance the iterator.
80     if (!scheduledAtLeastOnce) {
81       allOpsScheduled = false;
82       unscheduledOps.erase(&*nextScheduledOp);
83       ++nextScheduledOp;
84     }
85   }
86 
87   return allOpsScheduled;
88 }
89 
90 bool mlir::sortTopologically(
91     Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
92   if (block->empty())
93     return true;
94   if (block->back().hasTrait<OpTrait::IsTerminator>())
95     return sortTopologically(block, block->without_terminator(),
96                              isOperandReady);
97   return sortTopologically(block, *block, isOperandReady);
98 }
99