1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Transforms/ViewOpGraph.h"
19 #include "mlir/IR/Block.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Pass/Pass.h"
22 #include "llvm/Support/CommandLine.h"
23 
24 // NOLINTNEXTLINE
25 static llvm::cl::opt<int> elideIfLarger(
26     "print-op-graph-elide-if-larger",
27     llvm::cl::desc("Upper limit to emit elements attribute rather than elide"),
28     llvm::cl::init(16));
29 
30 namespace llvm {
31 
32 // Specialize GraphTraits to treat Block as a graph of Operations as nodes and
33 // uses as edges.
34 template <> struct GraphTraits<mlir::Block *> {
35   using GraphType = mlir::Block *;
36   using NodeRef = mlir::Operation *;
37 
38   using ChildIteratorType = mlir::UseIterator;
39   static ChildIteratorType child_begin(NodeRef n) {
40     return ChildIteratorType(n);
41   }
42   static ChildIteratorType child_end(NodeRef n) {
43     return ChildIteratorType(n, /*end=*/true);
44   }
45 
46   // Operation's destructor is private so use Operation* instead and use
47   // mapped iterator.
48   static mlir::Operation *AddressOf(mlir::Operation &op) { return &op; }
49   using nodes_iterator =
50       mapped_iterator<mlir::Block::iterator, decltype(&AddressOf)>;
51   static nodes_iterator nodes_begin(mlir::Block *b) {
52     return nodes_iterator(b->begin(), &AddressOf);
53   }
54   static nodes_iterator nodes_end(mlir::Block *b) {
55     return nodes_iterator(b->end(), &AddressOf);
56   }
57 };
58 
59 // Specialize DOTGraphTraits to produce more readable output.
60 template <>
61 struct DOTGraphTraits<mlir::Block *> : public DefaultDOTGraphTraits {
62   using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
63   static std::string getNodeLabel(mlir::Operation *op, mlir::Block *);
64 };
65 
66 std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op,
67                                                         mlir::Block *b) {
68   // Reuse the print output for the node labels.
69   std::string ostr;
70   raw_string_ostream os(ostr);
71   os << op->getName() << "\n";
72   for (auto attr : op->getAttrs()) {
73     os << '\n' << attr.first << ": ";
74     // Always emit splat attributes.
75     if (attr.second.isa<mlir::SplatElementsAttr>()) {
76       attr.second.print(os);
77       continue;
78     }
79 
80     // Elide "big" elements attributes.
81     auto elements = attr.second.dyn_cast<mlir::ElementsAttr>();
82     if (elements && elements.getNumElements() > elideIfLarger) {
83       os << "...";
84       continue;
85     }
86 
87     // Print all other attributes.
88     attr.second.print(os);
89   }
90   return os.str();
91 }
92 
93 } // end namespace llvm
94 
95 namespace {
96 // PrintOpPass is simple pass to write graph per function.
97 // Note: this is a module pass only to avoid interleaving on the same ostream
98 // due to multi-threading over functions.
99 struct PrintOpPass : public mlir::ModulePass<PrintOpPass> {
100   explicit PrintOpPass(llvm::raw_ostream &os = llvm::errs(),
101                        bool short_names = false, const llvm::Twine &title = "")
102       : os(os), title(title.str()), short_names(short_names) {}
103 
104   std::string getOpName(mlir::Operation &op) {
105     auto symbolAttr = op.getAttrOfType<mlir::StringAttr>(
106         mlir::SymbolTable::getSymbolAttrName());
107     if (symbolAttr)
108       return symbolAttr.getValue();
109     ++unnamedOpCtr;
110     return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
111   }
112 
113   // Print all the ops in a module.
114   void processModule(mlir::ModuleOp module) {
115     for (mlir::Operation &op : module) {
116       // Modules may actually be nested, recurse on nesting.
117       if (auto nestedModule = llvm::dyn_cast<mlir::ModuleOp>(op)) {
118         processModule(nestedModule);
119         continue;
120       }
121       auto opName = getOpName(op);
122       for (mlir::Region &region : op.getRegions()) {
123         for (auto indexed_block : llvm::enumerate(region)) {
124           // Suffix block number if there are more than 1 block.
125           auto blockName = region.getBlocks().size() == 1
126                                ? ""
127                                : ("__" + llvm::utostr(indexed_block.index()));
128           llvm::WriteGraph(os, &indexed_block.value(), short_names,
129                            llvm::Twine(title) + opName + blockName);
130         }
131       }
132     }
133   }
134 
135   void runOnModule() override { processModule(getModule()); }
136 
137 private:
138   llvm::raw_ostream &os;
139   std::string title;
140   int unnamedOpCtr = 0;
141   bool short_names;
142 };
143 } // namespace
144 
145 void mlir::viewGraph(mlir::Block &block, const llvm::Twine &name,
146                      bool shortNames, const llvm::Twine &title,
147                      llvm::GraphProgram::Name program) {
148   llvm::ViewGraph(&block, name, shortNames, title, program);
149 }
150 
151 llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, mlir::Block &block,
152                                     bool shortNames, const llvm::Twine &title) {
153   return llvm::WriteGraph(os, &block, shortNames, title);
154 }
155 
156 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
157 mlir::createPrintOpGraphPass(llvm::raw_ostream &os, bool shortNames,
158                              const llvm::Twine &title) {
159   return std::make_unique<PrintOpPass>(os, shortNames, title);
160 }
161 
162 static mlir::PassRegistration<PrintOpPass> pass("print-op-graph",
163                                                 "Print op graph per region");
164