1 //===- OpStats.cpp - Prints stats of operations in module -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/IR/Operation.h" 12 #include "mlir/IR/OperationSupport.h" 13 #include "mlir/Transforms/Passes.h" 14 #include "llvm/ADT/DenseMap.h" 15 #include "llvm/Support/Format.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 20 namespace { 21 struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> { 22 explicit PrintOpStatsPass(raw_ostream &os) : os(os) {} 23 24 explicit PrintOpStatsPass(raw_ostream &os, bool printAsJSON) : os(os) { 25 this->printAsJSON = printAsJSON; 26 } 27 28 // Prints the resultant operation statistics post iterating over the module. 29 void runOnOperation() override; 30 31 // Print summary of op stats. 32 void printSummary(); 33 34 // Print symmary of op stats in JSON. 35 void printSummaryInJSON(); 36 37 private: 38 llvm::StringMap<int64_t> opCount; 39 raw_ostream &os; 40 }; 41 } // namespace 42 43 void PrintOpStatsPass::runOnOperation() { 44 opCount.clear(); 45 46 // Compute the operation statistics for the currently visited operation. 47 getOperation()->walk( 48 [&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); 49 if (printAsJSON) { 50 printSummaryInJSON(); 51 } else 52 printSummary(); 53 } 54 55 void PrintOpStatsPass::printSummary() { 56 os << "Operations encountered:\n"; 57 os << "-----------------------\n"; 58 SmallVector<StringRef, 64> sorted(opCount.keys()); 59 llvm::sort(sorted); 60 61 // Split an operation name from its dialect prefix. 62 auto splitOperationName = [](StringRef opName) { 63 auto splitName = opName.split('.'); 64 return splitName.second.empty() ? std::make_pair("", splitName.first) 65 : splitName; 66 }; 67 68 // Compute the largest dialect and operation name. 69 StringRef dialectName, opName; 70 size_t maxLenOpName = 0, maxLenDialect = 0; 71 for (const auto &key : sorted) { 72 std::tie(dialectName, opName) = splitOperationName(key); 73 maxLenDialect = std::max(maxLenDialect, dialectName.size()); 74 maxLenOpName = std::max(maxLenOpName, opName.size()); 75 } 76 77 for (const auto &key : sorted) { 78 std::tie(dialectName, opName) = splitOperationName(key); 79 80 // Left-align the names (aligning on the dialect) and right-align the count 81 // below. The alignment is for readability and does not affect CSV/FileCheck 82 // parsing. 83 if (dialectName.empty()) 84 os.indent(maxLenDialect + 3); 85 else 86 os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.'; 87 88 // Left justify the operation name. 89 os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key] 90 << '\n'; 91 } 92 } 93 94 void PrintOpStatsPass::printSummaryInJSON() { 95 SmallVector<StringRef, 64> sorted(opCount.keys()); 96 llvm::sort(sorted); 97 98 os << "{\n"; 99 100 for (unsigned i = 0, e = sorted.size(); i != e; ++i) { 101 const auto &key = sorted[i]; 102 os << " \"" << key << "\" : " << opCount[key]; 103 if (i != e - 1) 104 os << ",\n"; 105 else 106 os << "\n"; 107 } 108 os << "}\n"; 109 } 110 111 std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os) { 112 return std::make_unique<PrintOpStatsPass>(os); 113 } 114 115 std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os, 116 bool printAsJSON) { 117 return std::make_unique<PrintOpStatsPass>(os, printAsJSON); 118 } 119