1 //===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/TopologicalSortUtils.h" 10 #include "mlir/IR/OpDefinition.h" 11 12 using namespace mlir; 13 14 bool mlir::sortTopologically( 15 Block *block, llvm::iterator_range<Block::iterator> ops, 16 function_ref<bool(Value, Operation *)> isOperandReady) { 17 if (ops.empty()) 18 return true; 19 20 // The set of operations that have not yet been scheduled. 21 DenseSet<Operation *> unscheduledOps; 22 // Mark all operations as unscheduled. 23 for (Operation &op : ops) 24 unscheduledOps.insert(&op); 25 26 Block::iterator nextScheduledOp = ops.begin(); 27 Block::iterator end = ops.end(); 28 29 // An operation is ready to be scheduled if all its operands are ready. An 30 // operation is ready if: 31 const auto isReady = [&](Value value, Operation *top) { 32 // - the user-provided callback marks it as ready, 33 if (isOperandReady && isOperandReady(value, top)) 34 return true; 35 Operation *parent = value.getDefiningOp(); 36 // - it is a block argument, 37 if (!parent) 38 return true; 39 Operation *ancestor = block->findAncestorOpInBlock(*parent); 40 // - it is an implicit capture, 41 if (!ancestor) 42 return true; 43 // - it is defined in a nested region, or 44 if (ancestor == top) 45 return true; 46 // - its ancestor in the block is scheduled. 47 return !unscheduledOps.contains(ancestor); 48 }; 49 50 bool allOpsScheduled = true; 51 while (!unscheduledOps.empty()) { 52 bool scheduledAtLeastOnce = false; 53 54 // Loop over the ops that are not sorted yet, try to find the ones "ready", 55 // i.e. the ones for which there aren't any operand produced by an op in the 56 // set, and "schedule" it (move it before the `nextScheduledOp`). 57 for (Operation &op : 58 llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) { 59 // An operation is recursively ready to be scheduled of it and its nested 60 // operations are ready. 61 WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) { 62 return llvm::all_of( 63 nestedOp->getOperands(), 64 [&](Value operand) { return isReady(operand, &op); }) 65 ? WalkResult::advance() 66 : WalkResult::interrupt(); 67 }); 68 if (readyToSchedule.wasInterrupted()) 69 continue; 70 71 // Schedule the operation by moving it to the start. 72 unscheduledOps.erase(&op); 73 op.moveBefore(block, nextScheduledOp); 74 scheduledAtLeastOnce = true; 75 // Move the iterator forward if we schedule the operation at the front. 76 if (&op == &*nextScheduledOp) 77 ++nextScheduledOp; 78 } 79 // If no operations were scheduled, give up and advance the iterator. 80 if (!scheduledAtLeastOnce) { 81 allOpsScheduled = false; 82 unscheduledOps.erase(&*nextScheduledOp); 83 ++nextScheduledOp; 84 } 85 } 86 87 return allOpsScheduled; 88 } 89 90 bool mlir::sortTopologically( 91 Block *block, function_ref<bool(Value, Operation *)> isOperandReady) { 92 if (block->empty()) 93 return true; 94 if (block->back().hasTrait<OpTrait::IsTerminator>()) 95 return sortTopologically(block, block->without_terminator(), 96 isOperandReady); 97 return sortTopologically(block, *block, isOperandReady); 98 } 99