1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Transforms/ViewOpGraph.h"
19 #include "mlir/IR/Block.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/StandardTypes.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Support/STLExtras.h"
24 #include "llvm/Support/CommandLine.h"
25 
26 static llvm::cl::opt<int> elideIfLarger(
27     "print-op-graph-elide-if-larger",
28     llvm::cl::desc("Upper limit to emit elements attribute rather than elide"),
29     llvm::cl::init(16));
30 
31 namespace llvm {
32 
33 // Specialize GraphTraits to treat Block as a graph of Operations as nodes and
34 // uses as edges.
35 template <> struct GraphTraits<mlir::Block *> {
36   using GraphType = mlir::Block *;
37   using NodeRef = mlir::Operation *;
38 
39   using ChildIteratorType = mlir::UseIterator;
40   static ChildIteratorType child_begin(NodeRef n) {
41     return ChildIteratorType(n);
42   }
43   static ChildIteratorType child_end(NodeRef n) {
44     return ChildIteratorType(n, /*end=*/true);
45   }
46 
47   // Operation's destructor is private so use Operation* instead and use
48   // mapped iterator.
49   static mlir::Operation *AddressOf(mlir::Operation &op) { return &op; }
50   using nodes_iterator =
51       mapped_iterator<mlir::Block::iterator, decltype(&AddressOf)>;
52   static nodes_iterator nodes_begin(mlir::Block *b) {
53     return nodes_iterator(b->begin(), &AddressOf);
54   }
55   static nodes_iterator nodes_end(mlir::Block *b) {
56     return nodes_iterator(b->end(), &AddressOf);
57   }
58 };
59 
60 // Specialize DOTGraphTraits to produce more readable output.
61 template <>
62 struct DOTGraphTraits<mlir::Block *> : public DefaultDOTGraphTraits {
63   using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
64   static std::string getNodeLabel(mlir::Operation *op, mlir::Block *);
65 };
66 
67 std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op,
68                                                         mlir::Block *b) {
69   // Reuse the print output for the node labels.
70   std::string ostr;
71   raw_string_ostream os(ostr);
72   os << op->getName() << "\n";
73 
74   // Print resultant types
75   mlir::interleaveComma(op->getResultTypes(), os);
76   os << "\n";
77 
78   for (auto attr : op->getAttrs()) {
79     os << '\n' << attr.first << ": ";
80     // Always emit splat attributes.
81     if (attr.second.isa<mlir::SplatElementsAttr>()) {
82       attr.second.print(os);
83       continue;
84     }
85 
86     // Elide "big" elements attributes.
87     auto elements = attr.second.dyn_cast<mlir::ElementsAttr>();
88     if (elements && elements.getNumElements() > elideIfLarger) {
89       os << std::string(elements.getType().getRank(), '[') << "..."
90          << std::string(elements.getType().getRank(), ']') << " : "
91          << elements.getType();
92       continue;
93     }
94 
95     // Print all other attributes.
96     attr.second.print(os);
97   }
98   return os.str();
99 }
100 
101 } // end namespace llvm
102 
103 namespace {
104 // PrintOpPass is simple pass to write graph per function.
105 // Note: this is a module pass only to avoid interleaving on the same ostream
106 // due to multi-threading over functions.
107 struct PrintOpPass : public mlir::ModulePass<PrintOpPass> {
108   explicit PrintOpPass(llvm::raw_ostream &os = llvm::errs(),
109                        bool short_names = false, const llvm::Twine &title = "")
110       : os(os), title(title.str()), short_names(short_names) {}
111 
112   std::string getOpName(mlir::Operation &op) {
113     auto symbolAttr = op.getAttrOfType<mlir::StringAttr>(
114         mlir::SymbolTable::getSymbolAttrName());
115     if (symbolAttr)
116       return symbolAttr.getValue();
117     ++unnamedOpCtr;
118     return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
119   }
120 
121   // Print all the ops in a module.
122   void processModule(mlir::ModuleOp module) {
123     for (mlir::Operation &op : module) {
124       // Modules may actually be nested, recurse on nesting.
125       if (auto nestedModule = llvm::dyn_cast<mlir::ModuleOp>(op)) {
126         processModule(nestedModule);
127         continue;
128       }
129       auto opName = getOpName(op);
130       for (mlir::Region &region : op.getRegions()) {
131         for (auto indexed_block : llvm::enumerate(region)) {
132           // Suffix block number if there are more than 1 block.
133           auto blockName = region.getBlocks().size() == 1
134                                ? ""
135                                : ("__" + llvm::utostr(indexed_block.index()));
136           llvm::WriteGraph(os, &indexed_block.value(), short_names,
137                            llvm::Twine(title) + opName + blockName);
138         }
139       }
140     }
141   }
142 
143   void runOnModule() override { processModule(getModule()); }
144 
145 private:
146   llvm::raw_ostream &os;
147   std::string title;
148   int unnamedOpCtr = 0;
149   bool short_names;
150 };
151 } // namespace
152 
153 void mlir::viewGraph(mlir::Block &block, const llvm::Twine &name,
154                      bool shortNames, const llvm::Twine &title,
155                      llvm::GraphProgram::Name program) {
156   llvm::ViewGraph(&block, name, shortNames, title, program);
157 }
158 
159 llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, mlir::Block &block,
160                                     bool shortNames, const llvm::Twine &title) {
161   return llvm::WriteGraph(os, &block, shortNames, title);
162 }
163 
164 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
165 mlir::createPrintOpGraphPass(llvm::raw_ostream &os, bool shortNames,
166                              const llvm::Twine &title) {
167   return std::make_unique<PrintOpPass>(os, shortNames, title);
168 }
169 
170 static mlir::PassRegistration<PrintOpPass> pass("print-op-graph",
171                                                 "Print op graph per region");
172