1 //===- OpStats.cpp - Prints stats of operations in module -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/OperationSupport.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> {
22   explicit PrintOpStatsPass(raw_ostream &os) : os(os) {}
23 
24   // Prints the resultant operation statistics post iterating over the module.
25   void runOnOperation() override;
26 
27   // Print summary of op stats.
28   void printSummary();
29 
30   // Print symmary of op stats in JSON.
31   void printSummaryInJSON();
32 
33 private:
34   llvm::StringMap<int64_t> opCount;
35   raw_ostream &os;
36 };
37 } // namespace
38 
39 void PrintOpStatsPass::runOnOperation() {
40   opCount.clear();
41 
42   // Compute the operation statistics for the currently visited operation.
43   getOperation()->walk(
44       [&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
45   if (printAsJSON) {
46     printSummaryInJSON();
47   } else
48     printSummary();
49 }
50 
51 void PrintOpStatsPass::printSummary() {
52   os << "Operations encountered:\n";
53   os << "-----------------------\n";
54   SmallVector<StringRef, 64> sorted(opCount.keys());
55   llvm::sort(sorted);
56 
57   // Split an operation name from its dialect prefix.
58   auto splitOperationName = [](StringRef opName) {
59     auto splitName = opName.split('.');
60     return splitName.second.empty() ? std::make_pair("", splitName.first)
61                                     : splitName;
62   };
63 
64   // Compute the largest dialect and operation name.
65   StringRef dialectName, opName;
66   size_t maxLenOpName = 0, maxLenDialect = 0;
67   for (const auto &key : sorted) {
68     std::tie(dialectName, opName) = splitOperationName(key);
69     maxLenDialect = std::max(maxLenDialect, dialectName.size());
70     maxLenOpName = std::max(maxLenOpName, opName.size());
71   }
72 
73   for (const auto &key : sorted) {
74     std::tie(dialectName, opName) = splitOperationName(key);
75 
76     // Left-align the names (aligning on the dialect) and right-align the count
77     // below. The alignment is for readability and does not affect CSV/FileCheck
78     // parsing.
79     if (dialectName.empty())
80       os.indent(maxLenDialect + 3);
81     else
82       os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.';
83 
84     // Left justify the operation name.
85     os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key]
86        << '\n';
87   }
88 }
89 
90 void PrintOpStatsPass::printSummaryInJSON() {
91   SmallVector<StringRef, 64> sorted(opCount.keys());
92   llvm::sort(sorted);
93 
94   os << "{\n";
95 
96   for (unsigned i = 0, e = sorted.size(); i != e; ++i) {
97     const auto &key = sorted[i];
98     os << "  \"" << key << "\" : " << opCount[key];
99     if (i != e - 1)
100       os << ",\n";
101     else
102       os << "\n";
103   }
104   os << "}\n";
105 }
106 
107 std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os) {
108   return std::make_unique<PrintOpStatsPass>(os);
109 }
110