12660623aSJacques Pienaar //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
22660623aSJacques Pienaar //
32660623aSJacques Pienaar // Copyright 2019 The MLIR Authors.
42660623aSJacques Pienaar //
52660623aSJacques Pienaar // Licensed under the Apache License, Version 2.0 (the "License");
62660623aSJacques Pienaar // you may not use this file except in compliance with the License.
72660623aSJacques Pienaar // You may obtain a copy of the License at
82660623aSJacques Pienaar //
92660623aSJacques Pienaar //   http://www.apache.org/licenses/LICENSE-2.0
102660623aSJacques Pienaar //
112660623aSJacques Pienaar // Unless required by applicable law or agreed to in writing, software
122660623aSJacques Pienaar // distributed under the License is distributed on an "AS IS" BASIS,
132660623aSJacques Pienaar // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
142660623aSJacques Pienaar // See the License for the specific language governing permissions and
152660623aSJacques Pienaar // limitations under the License.
162660623aSJacques Pienaar // =============================================================================
172660623aSJacques Pienaar 
182660623aSJacques Pienaar #include "mlir/Transforms/ViewOpGraph.h"
192660623aSJacques Pienaar #include "mlir/IR/Block.h"
202660623aSJacques Pienaar #include "mlir/IR/Operation.h"
21*2b86e27dSJacques Pienaar #include "mlir/IR/StandardTypes.h"
222660623aSJacques Pienaar #include "mlir/Pass/Pass.h"
232660623aSJacques Pienaar #include "llvm/Support/CommandLine.h"
242660623aSJacques Pienaar 
252660623aSJacques Pienaar static llvm::cl::opt<int> elideIfLarger(
262660623aSJacques Pienaar     "print-op-graph-elide-if-larger",
272660623aSJacques Pienaar     llvm::cl::desc("Upper limit to emit elements attribute rather than elide"),
282660623aSJacques Pienaar     llvm::cl::init(16));
292660623aSJacques Pienaar 
302660623aSJacques Pienaar namespace llvm {
312660623aSJacques Pienaar 
322660623aSJacques Pienaar // Specialize GraphTraits to treat Block as a graph of Operations as nodes and
332660623aSJacques Pienaar // uses as edges.
34a23f69a3SJacques Pienaar template <> struct GraphTraits<mlir::Block *> {
352660623aSJacques Pienaar   using GraphType = mlir::Block *;
362660623aSJacques Pienaar   using NodeRef = mlir::Operation *;
372660623aSJacques Pienaar 
382660623aSJacques Pienaar   using ChildIteratorType = mlir::UseIterator;
392660623aSJacques Pienaar   static ChildIteratorType child_begin(NodeRef n) {
402660623aSJacques Pienaar     return ChildIteratorType(n);
412660623aSJacques Pienaar   }
422660623aSJacques Pienaar   static ChildIteratorType child_end(NodeRef n) {
432660623aSJacques Pienaar     return ChildIteratorType(n, /*end=*/true);
442660623aSJacques Pienaar   }
452660623aSJacques Pienaar 
462660623aSJacques Pienaar   // Operation's destructor is private so use Operation* instead and use
472660623aSJacques Pienaar   // mapped iterator.
482660623aSJacques Pienaar   static mlir::Operation *AddressOf(mlir::Operation &op) { return &op; }
492660623aSJacques Pienaar   using nodes_iterator =
502660623aSJacques Pienaar       mapped_iterator<mlir::Block::iterator, decltype(&AddressOf)>;
512660623aSJacques Pienaar   static nodes_iterator nodes_begin(mlir::Block *b) {
522660623aSJacques Pienaar     return nodes_iterator(b->begin(), &AddressOf);
532660623aSJacques Pienaar   }
542660623aSJacques Pienaar   static nodes_iterator nodes_end(mlir::Block *b) {
552660623aSJacques Pienaar     return nodes_iterator(b->end(), &AddressOf);
562660623aSJacques Pienaar   }
572660623aSJacques Pienaar };
582660623aSJacques Pienaar 
592660623aSJacques Pienaar // Specialize DOTGraphTraits to produce more readable output.
602660623aSJacques Pienaar template <>
612660623aSJacques Pienaar struct DOTGraphTraits<mlir::Block *> : public DefaultDOTGraphTraits {
622660623aSJacques Pienaar   using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
632660623aSJacques Pienaar   static std::string getNodeLabel(mlir::Operation *op, mlir::Block *);
642660623aSJacques Pienaar };
652660623aSJacques Pienaar 
662660623aSJacques Pienaar std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op,
672660623aSJacques Pienaar                                                         mlir::Block *b) {
682660623aSJacques Pienaar   // Reuse the print output for the node labels.
692660623aSJacques Pienaar   std::string ostr;
702660623aSJacques Pienaar   raw_string_ostream os(ostr);
712660623aSJacques Pienaar   os << op->getName() << "\n";
722660623aSJacques Pienaar   for (auto attr : op->getAttrs()) {
732660623aSJacques Pienaar     os << '\n' << attr.first << ": ";
742660623aSJacques Pienaar     // Always emit splat attributes.
752660623aSJacques Pienaar     if (attr.second.isa<mlir::SplatElementsAttr>()) {
762660623aSJacques Pienaar       attr.second.print(os);
772660623aSJacques Pienaar       continue;
782660623aSJacques Pienaar     }
792660623aSJacques Pienaar 
802660623aSJacques Pienaar     // Elide "big" elements attributes.
812660623aSJacques Pienaar     auto elements = attr.second.dyn_cast<mlir::ElementsAttr>();
822660623aSJacques Pienaar     if (elements && elements.getNumElements() > elideIfLarger) {
83*2b86e27dSJacques Pienaar       os << std::string(elements.getType().getRank(), '[') << "..."
84*2b86e27dSJacques Pienaar          << std::string(elements.getType().getRank(), ']') << " : "
85*2b86e27dSJacques Pienaar          << elements.getType();
862660623aSJacques Pienaar       continue;
872660623aSJacques Pienaar     }
882660623aSJacques Pienaar 
892660623aSJacques Pienaar     // Print all other attributes.
902660623aSJacques Pienaar     attr.second.print(os);
912660623aSJacques Pienaar   }
922660623aSJacques Pienaar   return os.str();
932660623aSJacques Pienaar }
942660623aSJacques Pienaar 
952660623aSJacques Pienaar } // end namespace llvm
962660623aSJacques Pienaar 
972660623aSJacques Pienaar namespace {
982660623aSJacques Pienaar // PrintOpPass is simple pass to write graph per function.
992660623aSJacques Pienaar // Note: this is a module pass only to avoid interleaving on the same ostream
1002660623aSJacques Pienaar // due to multi-threading over functions.
1012660623aSJacques Pienaar struct PrintOpPass : public mlir::ModulePass<PrintOpPass> {
1022660623aSJacques Pienaar   explicit PrintOpPass(llvm::raw_ostream &os = llvm::errs(),
1032660623aSJacques Pienaar                        bool short_names = false, const llvm::Twine &title = "")
1042660623aSJacques Pienaar       : os(os), title(title.str()), short_names(short_names) {}
1052660623aSJacques Pienaar 
1062660623aSJacques Pienaar   std::string getOpName(mlir::Operation &op) {
1072660623aSJacques Pienaar     auto symbolAttr = op.getAttrOfType<mlir::StringAttr>(
1082660623aSJacques Pienaar         mlir::SymbolTable::getSymbolAttrName());
1092660623aSJacques Pienaar     if (symbolAttr)
1102660623aSJacques Pienaar       return symbolAttr.getValue();
1112660623aSJacques Pienaar     ++unnamedOpCtr;
1122660623aSJacques Pienaar     return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
1132660623aSJacques Pienaar   }
1142660623aSJacques Pienaar 
1152660623aSJacques Pienaar   // Print all the ops in a module.
1162660623aSJacques Pienaar   void processModule(mlir::ModuleOp module) {
1172660623aSJacques Pienaar     for (mlir::Operation &op : module) {
1182660623aSJacques Pienaar       // Modules may actually be nested, recurse on nesting.
1192660623aSJacques Pienaar       if (auto nestedModule = llvm::dyn_cast<mlir::ModuleOp>(op)) {
1202660623aSJacques Pienaar         processModule(nestedModule);
1212660623aSJacques Pienaar         continue;
1222660623aSJacques Pienaar       }
1232660623aSJacques Pienaar       auto opName = getOpName(op);
1242660623aSJacques Pienaar       for (mlir::Region &region : op.getRegions()) {
1252660623aSJacques Pienaar         for (auto indexed_block : llvm::enumerate(region)) {
1262660623aSJacques Pienaar           // Suffix block number if there are more than 1 block.
1272660623aSJacques Pienaar           auto blockName = region.getBlocks().size() == 1
1282660623aSJacques Pienaar                                ? ""
1292660623aSJacques Pienaar                                : ("__" + llvm::utostr(indexed_block.index()));
1302660623aSJacques Pienaar           llvm::WriteGraph(os, &indexed_block.value(), short_names,
1312660623aSJacques Pienaar                            llvm::Twine(title) + opName + blockName);
1322660623aSJacques Pienaar         }
1332660623aSJacques Pienaar       }
1342660623aSJacques Pienaar     }
1352660623aSJacques Pienaar   }
1362660623aSJacques Pienaar 
1372660623aSJacques Pienaar   void runOnModule() override { processModule(getModule()); }
1382660623aSJacques Pienaar 
1392660623aSJacques Pienaar private:
1402660623aSJacques Pienaar   llvm::raw_ostream &os;
1412660623aSJacques Pienaar   std::string title;
1422660623aSJacques Pienaar   int unnamedOpCtr = 0;
1432660623aSJacques Pienaar   bool short_names;
1442660623aSJacques Pienaar };
1452660623aSJacques Pienaar } // namespace
1462660623aSJacques Pienaar 
1472660623aSJacques Pienaar void mlir::viewGraph(mlir::Block &block, const llvm::Twine &name,
1482660623aSJacques Pienaar                      bool shortNames, const llvm::Twine &title,
1492660623aSJacques Pienaar                      llvm::GraphProgram::Name program) {
1502660623aSJacques Pienaar   llvm::ViewGraph(&block, name, shortNames, title, program);
1512660623aSJacques Pienaar }
1522660623aSJacques Pienaar 
1532660623aSJacques Pienaar llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, mlir::Block &block,
1542660623aSJacques Pienaar                                     bool shortNames, const llvm::Twine &title) {
1552660623aSJacques Pienaar   return llvm::WriteGraph(os, &block, shortNames, title);
1562660623aSJacques Pienaar }
1572660623aSJacques Pienaar 
158f1b100c7SRiver Riddle std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
1592660623aSJacques Pienaar mlir::createPrintOpGraphPass(llvm::raw_ostream &os, bool shortNames,
1602660623aSJacques Pienaar                              const llvm::Twine &title) {
1612660623aSJacques Pienaar   return std::make_unique<PrintOpPass>(os, shortNames, title);
1622660623aSJacques Pienaar }
1632660623aSJacques Pienaar 
1642660623aSJacques Pienaar static mlir::PassRegistration<PrintOpPass> pass("print-op-graph",
1652660623aSJacques Pienaar                                                 "Print op graph per region");
166