1*c8457eb5SMogball //===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
2*c8457eb5SMogball //
3*c8457eb5SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c8457eb5SMogball // See https://llvm.org/LICENSE.txt for license information.
5*c8457eb5SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c8457eb5SMogball //
7*c8457eb5SMogball //===----------------------------------------------------------------------===//
8*c8457eb5SMogball
9*c8457eb5SMogball #include "mlir/Transforms/TopologicalSortUtils.h"
10*c8457eb5SMogball #include "mlir/IR/OpDefinition.h"
11*c8457eb5SMogball
12*c8457eb5SMogball using namespace mlir;
13*c8457eb5SMogball
sortTopologically(Block * block,llvm::iterator_range<Block::iterator> ops,function_ref<bool (Value,Operation *)> isOperandReady)14*c8457eb5SMogball bool mlir::sortTopologically(
15*c8457eb5SMogball Block *block, llvm::iterator_range<Block::iterator> ops,
16*c8457eb5SMogball function_ref<bool(Value, Operation *)> isOperandReady) {
17*c8457eb5SMogball if (ops.empty())
18*c8457eb5SMogball return true;
19*c8457eb5SMogball
20*c8457eb5SMogball // The set of operations that have not yet been scheduled.
21*c8457eb5SMogball DenseSet<Operation *> unscheduledOps;
22*c8457eb5SMogball // Mark all operations as unscheduled.
23*c8457eb5SMogball for (Operation &op : ops)
24*c8457eb5SMogball unscheduledOps.insert(&op);
25*c8457eb5SMogball
26*c8457eb5SMogball Block::iterator nextScheduledOp = ops.begin();
27*c8457eb5SMogball Block::iterator end = ops.end();
28*c8457eb5SMogball
29*c8457eb5SMogball // An operation is ready to be scheduled if all its operands are ready. An
30*c8457eb5SMogball // operation is ready if:
31*c8457eb5SMogball const auto isReady = [&](Value value, Operation *top) {
32*c8457eb5SMogball // - the user-provided callback marks it as ready,
33*c8457eb5SMogball if (isOperandReady && isOperandReady(value, top))
34*c8457eb5SMogball return true;
35*c8457eb5SMogball Operation *parent = value.getDefiningOp();
36*c8457eb5SMogball // - it is a block argument,
37*c8457eb5SMogball if (!parent)
38*c8457eb5SMogball return true;
39*c8457eb5SMogball Operation *ancestor = block->findAncestorOpInBlock(*parent);
40*c8457eb5SMogball // - it is an implicit capture,
41*c8457eb5SMogball if (!ancestor)
42*c8457eb5SMogball return true;
43*c8457eb5SMogball // - it is defined in a nested region, or
44*c8457eb5SMogball if (ancestor == top)
45*c8457eb5SMogball return true;
46*c8457eb5SMogball // - its ancestor in the block is scheduled.
47*c8457eb5SMogball return !unscheduledOps.contains(ancestor);
48*c8457eb5SMogball };
49*c8457eb5SMogball
50*c8457eb5SMogball bool allOpsScheduled = true;
51*c8457eb5SMogball while (!unscheduledOps.empty()) {
52*c8457eb5SMogball bool scheduledAtLeastOnce = false;
53*c8457eb5SMogball
54*c8457eb5SMogball // Loop over the ops that are not sorted yet, try to find the ones "ready",
55*c8457eb5SMogball // i.e. the ones for which there aren't any operand produced by an op in the
56*c8457eb5SMogball // set, and "schedule" it (move it before the `nextScheduledOp`).
57*c8457eb5SMogball for (Operation &op :
58*c8457eb5SMogball llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
59*c8457eb5SMogball // An operation is recursively ready to be scheduled of it and its nested
60*c8457eb5SMogball // operations are ready.
61*c8457eb5SMogball WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
62*c8457eb5SMogball return llvm::all_of(
63*c8457eb5SMogball nestedOp->getOperands(),
64*c8457eb5SMogball [&](Value operand) { return isReady(operand, &op); })
65*c8457eb5SMogball ? WalkResult::advance()
66*c8457eb5SMogball : WalkResult::interrupt();
67*c8457eb5SMogball });
68*c8457eb5SMogball if (readyToSchedule.wasInterrupted())
69*c8457eb5SMogball continue;
70*c8457eb5SMogball
71*c8457eb5SMogball // Schedule the operation by moving it to the start.
72*c8457eb5SMogball unscheduledOps.erase(&op);
73*c8457eb5SMogball op.moveBefore(block, nextScheduledOp);
74*c8457eb5SMogball scheduledAtLeastOnce = true;
75*c8457eb5SMogball // Move the iterator forward if we schedule the operation at the front.
76*c8457eb5SMogball if (&op == &*nextScheduledOp)
77*c8457eb5SMogball ++nextScheduledOp;
78*c8457eb5SMogball }
79*c8457eb5SMogball // If no operations were scheduled, give up and advance the iterator.
80*c8457eb5SMogball if (!scheduledAtLeastOnce) {
81*c8457eb5SMogball allOpsScheduled = false;
82*c8457eb5SMogball unscheduledOps.erase(&*nextScheduledOp);
83*c8457eb5SMogball ++nextScheduledOp;
84*c8457eb5SMogball }
85*c8457eb5SMogball }
86*c8457eb5SMogball
87*c8457eb5SMogball return allOpsScheduled;
88*c8457eb5SMogball }
89*c8457eb5SMogball
sortTopologically(Block * block,function_ref<bool (Value,Operation *)> isOperandReady)90*c8457eb5SMogball bool mlir::sortTopologically(
91*c8457eb5SMogball Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
92*c8457eb5SMogball if (block->empty())
93*c8457eb5SMogball return true;
94*c8457eb5SMogball if (block->back().hasTrait<OpTrait::IsTerminator>())
95*c8457eb5SMogball return sortTopologically(block, block->without_terminator(),
96*c8457eb5SMogball isOperandReady);
97*c8457eb5SMogball return sortTopologically(block, *block, isOperandReady);
98*c8457eb5SMogball }
99