1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Transforms/ViewOpGraph.h"
19 #include "mlir/IR/Block.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/StandardTypes.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Support/STLExtras.h"
24 #include "llvm/Support/CommandLine.h"
25 
26 static llvm::cl::opt<int> elideIfLarger(
27     "print-op-graph-elide-if-larger",
28     llvm::cl::desc("Upper limit to emit elements attribute rather than elide"),
29     llvm::cl::init(16));
30 
31 using namespace mlir;
32 
33 namespace llvm {
34 
35 // Specialize GraphTraits to treat Block as a graph of Operations as nodes and
36 // uses as edges.
37 template <> struct GraphTraits<Block *> {
38   using GraphType = Block *;
39   using NodeRef = Operation *;
40 
41   using ChildIteratorType = UseIterator;
42   static ChildIteratorType child_begin(NodeRef n) {
43     return ChildIteratorType(n);
44   }
45   static ChildIteratorType child_end(NodeRef n) {
46     return ChildIteratorType(n, /*end=*/true);
47   }
48 
49   // Operation's destructor is private so use Operation* instead and use
50   // mapped iterator.
51   static Operation *AddressOf(Operation &op) { return &op; }
52   using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>;
53   static nodes_iterator nodes_begin(Block *b) {
54     return nodes_iterator(b->begin(), &AddressOf);
55   }
56   static nodes_iterator nodes_end(Block *b) {
57     return nodes_iterator(b->end(), &AddressOf);
58   }
59 };
60 
61 // Specialize DOTGraphTraits to produce more readable output.
62 template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits {
63   using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
64   static std::string getNodeLabel(Operation *op, Block *);
65 };
66 
67 std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) {
68   // Reuse the print output for the node labels.
69   std::string ostr;
70   raw_string_ostream os(ostr);
71   os << op->getName() << "\n";
72 
73   if (!op->getLoc().isa<UnknownLoc>()) {
74     os << op->getLoc() << "\n";
75   }
76 
77   // Print resultant types
78   interleaveComma(op->getResultTypes(), os);
79   os << "\n";
80 
81   for (auto attr : op->getAttrs()) {
82     os << '\n' << attr.first << ": ";
83     // Always emit splat attributes.
84     if (attr.second.isa<SplatElementsAttr>()) {
85       attr.second.print(os);
86       continue;
87     }
88 
89     // Elide "big" elements attributes.
90     auto elements = attr.second.dyn_cast<ElementsAttr>();
91     if (elements && elements.getNumElements() > elideIfLarger) {
92       os << std::string(elements.getType().getRank(), '[') << "..."
93          << std::string(elements.getType().getRank(), ']') << " : "
94          << elements.getType();
95       continue;
96     }
97 
98     auto array = attr.second.dyn_cast<ArrayAttr>();
99     if (array && static_cast<int64_t>(array.size()) > elideIfLarger) {
100       os << "[...]";
101       continue;
102     }
103 
104     // Print all other attributes.
105     attr.second.print(os);
106   }
107   return os.str();
108 }
109 
110 } // end namespace llvm
111 
112 namespace {
113 // PrintOpPass is simple pass to write graph per function.
114 // Note: this is a module pass only to avoid interleaving on the same ostream
115 // due to multi-threading over functions.
116 struct PrintOpPass : public ModulePass<PrintOpPass> {
117   explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false,
118                        const Twine &title = "")
119       : os(os), title(title.str()), short_names(short_names) {}
120 
121   std::string getOpName(Operation &op) {
122     auto symbolAttr =
123         op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
124     if (symbolAttr)
125       return symbolAttr.getValue();
126     ++unnamedOpCtr;
127     return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
128   }
129 
130   // Print all the ops in a module.
131   void processModule(ModuleOp module) {
132     for (Operation &op : module) {
133       // Modules may actually be nested, recurse on nesting.
134       if (auto nestedModule = dyn_cast<ModuleOp>(op)) {
135         processModule(nestedModule);
136         continue;
137       }
138       auto opName = getOpName(op);
139       for (Region &region : op.getRegions()) {
140         for (auto indexed_block : llvm::enumerate(region)) {
141           // Suffix block number if there are more than 1 block.
142           auto blockName = region.getBlocks().size() == 1
143                                ? ""
144                                : ("__" + llvm::utostr(indexed_block.index()));
145           llvm::WriteGraph(os, &indexed_block.value(), short_names,
146                            Twine(title) + opName + blockName);
147         }
148       }
149     }
150   }
151 
152   void runOnModule() override { processModule(getModule()); }
153 
154 private:
155   raw_ostream &os;
156   std::string title;
157   int unnamedOpCtr = 0;
158   bool short_names;
159 };
160 } // namespace
161 
162 void mlir::viewGraph(Block &block, const Twine &name, bool shortNames,
163                      const Twine &title, llvm::GraphProgram::Name program) {
164   llvm::ViewGraph(&block, name, shortNames, title, program);
165 }
166 
167 raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames,
168                               const Twine &title) {
169   return llvm::WriteGraph(os, &block, shortNames, title);
170 }
171 
172 std::unique_ptr<OpPassBase<ModuleOp>>
173 mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames,
174                              const Twine &title) {
175   return std::make_unique<PrintOpPass>(os, shortNames, title);
176 }
177 
178 static PassRegistration<PrintOpPass> pass("print-op-graph",
179                                           "Print op graph per region");
180