1*c8457eb5SMogball //===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
2*c8457eb5SMogball //
3*c8457eb5SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c8457eb5SMogball // See https://llvm.org/LICENSE.txt for license information.
5*c8457eb5SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c8457eb5SMogball //
7*c8457eb5SMogball //===----------------------------------------------------------------------===//
8*c8457eb5SMogball 
9*c8457eb5SMogball #include "mlir/Transforms/TopologicalSortUtils.h"
10*c8457eb5SMogball #include "mlir/IR/OpDefinition.h"
11*c8457eb5SMogball 
12*c8457eb5SMogball using namespace mlir;
13*c8457eb5SMogball 
sortTopologically(Block * block,llvm::iterator_range<Block::iterator> ops,function_ref<bool (Value,Operation *)> isOperandReady)14*c8457eb5SMogball bool mlir::sortTopologically(
15*c8457eb5SMogball     Block *block, llvm::iterator_range<Block::iterator> ops,
16*c8457eb5SMogball     function_ref<bool(Value, Operation *)> isOperandReady) {
17*c8457eb5SMogball   if (ops.empty())
18*c8457eb5SMogball     return true;
19*c8457eb5SMogball 
20*c8457eb5SMogball   // The set of operations that have not yet been scheduled.
21*c8457eb5SMogball   DenseSet<Operation *> unscheduledOps;
22*c8457eb5SMogball   // Mark all operations as unscheduled.
23*c8457eb5SMogball   for (Operation &op : ops)
24*c8457eb5SMogball     unscheduledOps.insert(&op);
25*c8457eb5SMogball 
26*c8457eb5SMogball   Block::iterator nextScheduledOp = ops.begin();
27*c8457eb5SMogball   Block::iterator end = ops.end();
28*c8457eb5SMogball 
29*c8457eb5SMogball   // An operation is ready to be scheduled if all its operands are ready. An
30*c8457eb5SMogball   // operation is ready if:
31*c8457eb5SMogball   const auto isReady = [&](Value value, Operation *top) {
32*c8457eb5SMogball     // - the user-provided callback marks it as ready,
33*c8457eb5SMogball     if (isOperandReady && isOperandReady(value, top))
34*c8457eb5SMogball       return true;
35*c8457eb5SMogball     Operation *parent = value.getDefiningOp();
36*c8457eb5SMogball     // - it is a block argument,
37*c8457eb5SMogball     if (!parent)
38*c8457eb5SMogball       return true;
39*c8457eb5SMogball     Operation *ancestor = block->findAncestorOpInBlock(*parent);
40*c8457eb5SMogball     // - it is an implicit capture,
41*c8457eb5SMogball     if (!ancestor)
42*c8457eb5SMogball       return true;
43*c8457eb5SMogball     // - it is defined in a nested region, or
44*c8457eb5SMogball     if (ancestor == top)
45*c8457eb5SMogball       return true;
46*c8457eb5SMogball     // - its ancestor in the block is scheduled.
47*c8457eb5SMogball     return !unscheduledOps.contains(ancestor);
48*c8457eb5SMogball   };
49*c8457eb5SMogball 
50*c8457eb5SMogball   bool allOpsScheduled = true;
51*c8457eb5SMogball   while (!unscheduledOps.empty()) {
52*c8457eb5SMogball     bool scheduledAtLeastOnce = false;
53*c8457eb5SMogball 
54*c8457eb5SMogball     // Loop over the ops that are not sorted yet, try to find the ones "ready",
55*c8457eb5SMogball     // i.e. the ones for which there aren't any operand produced by an op in the
56*c8457eb5SMogball     // set, and "schedule" it (move it before the `nextScheduledOp`).
57*c8457eb5SMogball     for (Operation &op :
58*c8457eb5SMogball          llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
59*c8457eb5SMogball       // An operation is recursively ready to be scheduled of it and its nested
60*c8457eb5SMogball       // operations are ready.
61*c8457eb5SMogball       WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
62*c8457eb5SMogball         return llvm::all_of(
63*c8457eb5SMogball                    nestedOp->getOperands(),
64*c8457eb5SMogball                    [&](Value operand) { return isReady(operand, &op); })
65*c8457eb5SMogball                    ? WalkResult::advance()
66*c8457eb5SMogball                    : WalkResult::interrupt();
67*c8457eb5SMogball       });
68*c8457eb5SMogball       if (readyToSchedule.wasInterrupted())
69*c8457eb5SMogball         continue;
70*c8457eb5SMogball 
71*c8457eb5SMogball       // Schedule the operation by moving it to the start.
72*c8457eb5SMogball       unscheduledOps.erase(&op);
73*c8457eb5SMogball       op.moveBefore(block, nextScheduledOp);
74*c8457eb5SMogball       scheduledAtLeastOnce = true;
75*c8457eb5SMogball       // Move the iterator forward if we schedule the operation at the front.
76*c8457eb5SMogball       if (&op == &*nextScheduledOp)
77*c8457eb5SMogball         ++nextScheduledOp;
78*c8457eb5SMogball     }
79*c8457eb5SMogball     // If no operations were scheduled, give up and advance the iterator.
80*c8457eb5SMogball     if (!scheduledAtLeastOnce) {
81*c8457eb5SMogball       allOpsScheduled = false;
82*c8457eb5SMogball       unscheduledOps.erase(&*nextScheduledOp);
83*c8457eb5SMogball       ++nextScheduledOp;
84*c8457eb5SMogball     }
85*c8457eb5SMogball   }
86*c8457eb5SMogball 
87*c8457eb5SMogball   return allOpsScheduled;
88*c8457eb5SMogball }
89*c8457eb5SMogball 
sortTopologically(Block * block,function_ref<bool (Value,Operation *)> isOperandReady)90*c8457eb5SMogball bool mlir::sortTopologically(
91*c8457eb5SMogball     Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
92*c8457eb5SMogball   if (block->empty())
93*c8457eb5SMogball     return true;
94*c8457eb5SMogball   if (block->back().hasTrait<OpTrait::IsTerminator>())
95*c8457eb5SMogball     return sortTopologically(block, block->without_terminator(),
96*c8457eb5SMogball                              isOperandReady);
97*c8457eb5SMogball   return sortTopologically(block, *block, isOperandReady);
98*c8457eb5SMogball }
99