1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Transforms/ViewOpGraph.h"
19 #include "mlir/IR/Block.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/StandardTypes.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Support/STLExtras.h"
24 #include "llvm/Support/CommandLine.h"
25 
26 static llvm::cl::opt<int> elideIfLarger(
27     "print-op-graph-elide-if-larger",
28     llvm::cl::desc("Upper limit to emit elements attribute rather than elide"),
29     llvm::cl::init(16));
30 
31 namespace llvm {
32 
33 // Specialize GraphTraits to treat Block as a graph of Operations as nodes and
34 // uses as edges.
35 template <> struct GraphTraits<mlir::Block *> {
36   using GraphType = mlir::Block *;
37   using NodeRef = mlir::Operation *;
38 
39   using ChildIteratorType = mlir::UseIterator;
40   static ChildIteratorType child_begin(NodeRef n) {
41     return ChildIteratorType(n);
42   }
43   static ChildIteratorType child_end(NodeRef n) {
44     return ChildIteratorType(n, /*end=*/true);
45   }
46 
47   // Operation's destructor is private so use Operation* instead and use
48   // mapped iterator.
49   static mlir::Operation *AddressOf(mlir::Operation &op) { return &op; }
50   using nodes_iterator =
51       mapped_iterator<mlir::Block::iterator, decltype(&AddressOf)>;
52   static nodes_iterator nodes_begin(mlir::Block *b) {
53     return nodes_iterator(b->begin(), &AddressOf);
54   }
55   static nodes_iterator nodes_end(mlir::Block *b) {
56     return nodes_iterator(b->end(), &AddressOf);
57   }
58 };
59 
60 // Specialize DOTGraphTraits to produce more readable output.
61 template <>
62 struct DOTGraphTraits<mlir::Block *> : public DefaultDOTGraphTraits {
63   using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
64   static std::string getNodeLabel(mlir::Operation *op, mlir::Block *);
65 };
66 
67 std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op,
68                                                         mlir::Block *b) {
69   // Reuse the print output for the node labels.
70   std::string ostr;
71   raw_string_ostream os(ostr);
72   os << op->getName() << "\n";
73 
74   if (!op->getLoc().isa<mlir::UnknownLoc>()) {
75     os << op->getLoc() << "\n";
76   }
77 
78   // Print resultant types
79   mlir::interleaveComma(op->getResultTypes(), os);
80   os << "\n";
81 
82   for (auto attr : op->getAttrs()) {
83     os << '\n' << attr.first << ": ";
84     // Always emit splat attributes.
85     if (attr.second.isa<mlir::SplatElementsAttr>()) {
86       attr.second.print(os);
87       continue;
88     }
89 
90     // Elide "big" elements attributes.
91     auto elements = attr.second.dyn_cast<mlir::ElementsAttr>();
92     if (elements && elements.getNumElements() > elideIfLarger) {
93       os << std::string(elements.getType().getRank(), '[') << "..."
94          << std::string(elements.getType().getRank(), ']') << " : "
95          << elements.getType();
96       continue;
97     }
98 
99     auto array = attr.second.dyn_cast<mlir::ArrayAttr>();
100     if (array && static_cast<int64_t>(array.size()) > elideIfLarger) {
101       os << "[...]";
102       continue;
103     }
104 
105     // Print all other attributes.
106     attr.second.print(os);
107   }
108   return os.str();
109 }
110 
111 } // end namespace llvm
112 
113 namespace {
114 // PrintOpPass is simple pass to write graph per function.
115 // Note: this is a module pass only to avoid interleaving on the same ostream
116 // due to multi-threading over functions.
117 struct PrintOpPass : public mlir::ModulePass<PrintOpPass> {
118   explicit PrintOpPass(llvm::raw_ostream &os = llvm::errs(),
119                        bool short_names = false, const llvm::Twine &title = "")
120       : os(os), title(title.str()), short_names(short_names) {}
121 
122   std::string getOpName(mlir::Operation &op) {
123     auto symbolAttr = op.getAttrOfType<mlir::StringAttr>(
124         mlir::SymbolTable::getSymbolAttrName());
125     if (symbolAttr)
126       return symbolAttr.getValue();
127     ++unnamedOpCtr;
128     return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
129   }
130 
131   // Print all the ops in a module.
132   void processModule(mlir::ModuleOp module) {
133     for (mlir::Operation &op : module) {
134       // Modules may actually be nested, recurse on nesting.
135       if (auto nestedModule = llvm::dyn_cast<mlir::ModuleOp>(op)) {
136         processModule(nestedModule);
137         continue;
138       }
139       auto opName = getOpName(op);
140       for (mlir::Region &region : op.getRegions()) {
141         for (auto indexed_block : llvm::enumerate(region)) {
142           // Suffix block number if there are more than 1 block.
143           auto blockName = region.getBlocks().size() == 1
144                                ? ""
145                                : ("__" + llvm::utostr(indexed_block.index()));
146           llvm::WriteGraph(os, &indexed_block.value(), short_names,
147                            llvm::Twine(title) + opName + blockName);
148         }
149       }
150     }
151   }
152 
153   void runOnModule() override { processModule(getModule()); }
154 
155 private:
156   llvm::raw_ostream &os;
157   std::string title;
158   int unnamedOpCtr = 0;
159   bool short_names;
160 };
161 } // namespace
162 
163 void mlir::viewGraph(mlir::Block &block, const llvm::Twine &name,
164                      bool shortNames, const llvm::Twine &title,
165                      llvm::GraphProgram::Name program) {
166   llvm::ViewGraph(&block, name, shortNames, title, program);
167 }
168 
169 llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, mlir::Block &block,
170                                     bool shortNames, const llvm::Twine &title) {
171   return llvm::WriteGraph(os, &block, shortNames, title);
172 }
173 
174 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
175 mlir::createPrintOpGraphPass(llvm::raw_ostream &os, bool shortNames,
176                              const llvm::Twine &title) {
177   return std::make_unique<PrintOpPass>(os, shortNames, title);
178 }
179 
180 static mlir::PassRegistration<PrintOpPass> pass("print-op-graph",
181                                                 "Print op graph per region");
182