1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/SubElementInterfaces.h"
28 #include "mlir/IR/Verifier.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/MapVector.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/ScopedHashTable.h"
35 #include "llvm/ADT/SetVector.h"
36 #include "llvm/ADT/SmallString.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringSet.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/Endian.h"
43 #include "llvm/Support/Regex.h"
44 #include "llvm/Support/SaveAndRestore.h"
45 #include "llvm/Support/Threading.h"
46 
47 #include <tuple>
48 
49 using namespace mlir;
50 using namespace mlir::detail;
51 
52 #define DEBUG_TYPE "mlir-asm-printer"
53 
print(raw_ostream & os) const54 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
55 
dump() const56 void OperationName::dump() const { print(llvm::errs()); }
57 
58 //===--------------------------------------------------------------------===//
59 // AsmParser
60 //===--------------------------------------------------------------------===//
61 
62 AsmParser::~AsmParser() = default;
63 DialectAsmParser::~DialectAsmParser() = default;
64 OpAsmParser::~OpAsmParser() = default;
65 
getContext() const66 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
67 
68 //===----------------------------------------------------------------------===//
69 // DialectAsmPrinter
70 //===----------------------------------------------------------------------===//
71 
72 DialectAsmPrinter::~DialectAsmPrinter() = default;
73 
74 //===----------------------------------------------------------------------===//
75 // OpAsmPrinter
76 //===----------------------------------------------------------------------===//
77 
78 OpAsmPrinter::~OpAsmPrinter() = default;
79 
printFunctionalType(Operation * op)80 void OpAsmPrinter::printFunctionalType(Operation *op) {
81   auto &os = getStream();
82   os << '(';
83   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
84     // Print the types of null values as <<NULL TYPE>>.
85     *this << (operand ? operand.getType() : Type());
86   });
87   os << ") -> ";
88 
89   // Print the result list.  We don't parenthesize single result types unless
90   // it is a function (avoiding a grammar ambiguity).
91   bool wrapped = op->getNumResults() != 1;
92   if (!wrapped && op->getResult(0).getType() &&
93       op->getResult(0).getType().isa<FunctionType>())
94     wrapped = true;
95 
96   if (wrapped)
97     os << '(';
98 
99   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
100     // Print the types of null values as <<NULL TYPE>>.
101     *this << (result ? result.getType() : Type());
102   });
103 
104   if (wrapped)
105     os << ')';
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // Operation OpAsm interface.
110 //===----------------------------------------------------------------------===//
111 
112 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
113 #include "mlir/IR/OpAsmInterface.cpp.inc"
114 
115 LogicalResult
parseResource(AsmParsedResourceEntry & entry) const116 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
117   return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
118                            << "' for dialect '" << getDialect()->getNamespace()
119                            << "'";
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // OpPrintingFlags
124 //===----------------------------------------------------------------------===//
125 
126 namespace {
127 /// This struct contains command line options that can be used to initialize
128 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
129 /// for global command line options.
130 struct AsmPrinterOptions {
131   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
132       "mlir-print-elementsattrs-with-hex-if-larger",
133       llvm::cl::desc(
134           "Print DenseElementsAttrs with a hex string that have "
135           "more elements than the given upper limit (use -1 to disable)")};
136 
137   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
138       "mlir-elide-elementsattrs-if-larger",
139       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
140                      "more elements than the given upper limit")};
141 
142   llvm::cl::opt<bool> printDebugInfoOpt{
143       "mlir-print-debuginfo", llvm::cl::init(false),
144       llvm::cl::desc("Print debug info in MLIR output")};
145 
146   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
147       "mlir-pretty-debuginfo", llvm::cl::init(false),
148       llvm::cl::desc("Print pretty debug info in MLIR output")};
149 
150   // Use the generic op output form in the operation printer even if the custom
151   // form is defined.
152   llvm::cl::opt<bool> printGenericOpFormOpt{
153       "mlir-print-op-generic", llvm::cl::init(false),
154       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
155 
156   llvm::cl::opt<bool> assumeVerifiedOpt{
157       "mlir-print-assume-verified", llvm::cl::init(false),
158       llvm::cl::desc("Skip op verification when using custom printers"),
159       llvm::cl::Hidden};
160 
161   llvm::cl::opt<bool> printLocalScopeOpt{
162       "mlir-print-local-scope", llvm::cl::init(false),
163       llvm::cl::desc("Print with local scope and inline information (eliding "
164                      "aliases for attributes, types, and locations")};
165 
166   llvm::cl::opt<bool> printValueUsers{
167       "mlir-print-value-users", llvm::cl::init(false),
168       llvm::cl::desc(
169           "Print users of operation results and block arguments as a comment")};
170 };
171 } // namespace
172 
173 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
174 
175 /// Register a set of useful command-line options that can be used to configure
176 /// various flags within the AsmPrinter.
registerAsmPrinterCLOptions()177 void mlir::registerAsmPrinterCLOptions() {
178   // Make sure that the options struct has been initialized.
179   *clOptions;
180 }
181 
182 /// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags()183 OpPrintingFlags::OpPrintingFlags()
184     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
185       printGenericOpFormFlag(false), assumeVerifiedFlag(false),
186       printLocalScope(false), printValueUsersFlag(false) {
187   // Initialize based upon command line options, if they are available.
188   if (!clOptions.isConstructed())
189     return;
190   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
191     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
192   printDebugInfoFlag = clOptions->printDebugInfoOpt;
193   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
194   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
195   assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
196   printLocalScope = clOptions->printLocalScopeOpt;
197   printValueUsersFlag = clOptions->printValueUsers;
198 }
199 
200 /// Enable the elision of large elements attributes, by printing a '...'
201 /// instead of the element data, when the number of elements is greater than
202 /// `largeElementLimit`. Note: The IR generated with this option is not
203 /// parsable.
204 OpPrintingFlags &
elideLargeElementsAttrs(int64_t largeElementLimit)205 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
206   elementsAttrElementLimit = largeElementLimit;
207   return *this;
208 }
209 
210 /// Enable printing of debug information. If 'prettyForm' is set to true,
211 /// debug information is printed in a more readable 'pretty' form.
enableDebugInfo(bool prettyForm)212 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
213   printDebugInfoFlag = true;
214   printDebugInfoPrettyFormFlag = prettyForm;
215   return *this;
216 }
217 
218 /// Always print operations in the generic form.
printGenericOpForm()219 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
220   printGenericOpFormFlag = true;
221   return *this;
222 }
223 
224 /// Do not verify the operation when using custom operation printers.
assumeVerified()225 OpPrintingFlags &OpPrintingFlags::assumeVerified() {
226   assumeVerifiedFlag = true;
227   return *this;
228 }
229 
230 /// Use local scope when printing the operation. This allows for using the
231 /// printer in a more localized and thread-safe setting, but may not necessarily
232 /// be identical of what the IR will look like when dumping the full module.
useLocalScope()233 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
234   printLocalScope = true;
235   return *this;
236 }
237 
238 /// Print users of values as comments.
printValueUsers()239 OpPrintingFlags &OpPrintingFlags::printValueUsers() {
240   printValueUsersFlag = true;
241   return *this;
242 }
243 
244 /// Return if the given ElementsAttr should be elided.
shouldElideElementsAttr(ElementsAttr attr) const245 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
246   return elementsAttrElementLimit &&
247          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
248          !attr.isa<SplatElementsAttr>();
249 }
250 
251 /// Return the size limit for printing large ElementsAttr.
getLargeElementsAttrLimit() const252 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
253   return elementsAttrElementLimit;
254 }
255 
256 /// Return if debug information should be printed.
shouldPrintDebugInfo() const257 bool OpPrintingFlags::shouldPrintDebugInfo() const {
258   return printDebugInfoFlag;
259 }
260 
261 /// Return if debug information should be printed in the pretty form.
shouldPrintDebugInfoPrettyForm() const262 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
263   return printDebugInfoPrettyFormFlag;
264 }
265 
266 /// Return if operations should be printed in the generic form.
shouldPrintGenericOpForm() const267 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
268   return printGenericOpFormFlag;
269 }
270 
271 /// Return if operation verification should be skipped.
shouldAssumeVerified() const272 bool OpPrintingFlags::shouldAssumeVerified() const {
273   return assumeVerifiedFlag;
274 }
275 
276 /// Return if the printer should use local scope when dumping the IR.
shouldUseLocalScope() const277 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
278 
279 /// Return if the printer should print users of values.
shouldPrintValueUsers() const280 bool OpPrintingFlags::shouldPrintValueUsers() const {
281   return printValueUsersFlag;
282 }
283 
284 /// Returns true if an ElementsAttr with the given number of elements should be
285 /// printed with hex.
shouldPrintElementsAttrWithHex(int64_t numElements)286 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
287   // Check to see if a command line option was provided for the limit.
288   if (clOptions.isConstructed()) {
289     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
290       // -1 is used to disable hex printing.
291       if (clOptions->printElementsAttrWithHexIfLarger == -1)
292         return false;
293       return numElements > clOptions->printElementsAttrWithHexIfLarger;
294     }
295   }
296 
297   // Otherwise, default to printing with hex if the number of elements is >100.
298   return numElements > 100;
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // NewLineCounter
303 //===----------------------------------------------------------------------===//
304 
305 namespace {
306 /// This class is a simple formatter that emits a new line when inputted into a
307 /// stream, that enables counting the number of newlines emitted. This class
308 /// should be used whenever emitting newlines in the printer.
309 struct NewLineCounter {
310   unsigned curLine = 1;
311 };
312 
operator <<(raw_ostream & os,NewLineCounter & newLine)313 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
314   ++newLine.curLine;
315   return os << '\n';
316 }
317 } // namespace
318 
319 //===----------------------------------------------------------------------===//
320 // AliasInitializer
321 //===----------------------------------------------------------------------===//
322 
323 namespace {
324 /// This class represents a specific instance of a symbol Alias.
325 class SymbolAlias {
326 public:
SymbolAlias(StringRef name,bool isDeferrable)327   SymbolAlias(StringRef name, bool isDeferrable)
328       : name(name), suffixIndex(0), hasSuffixIndex(false),
329         isDeferrable(isDeferrable) {}
SymbolAlias(StringRef name,uint32_t suffixIndex,bool isDeferrable)330   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
331       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
332         isDeferrable(isDeferrable) {}
333 
334   /// Print this alias to the given stream.
print(raw_ostream & os) const335   void print(raw_ostream &os) const {
336     os << name;
337     if (hasSuffixIndex)
338       os << suffixIndex;
339   }
340 
341   /// Returns true if this alias supports deferred resolution when parsing.
canBeDeferred() const342   bool canBeDeferred() const { return isDeferrable; }
343 
344 private:
345   /// The main name of the alias.
346   StringRef name;
347   /// The optional suffix index of the alias, if multiple aliases had the same
348   /// name.
349   uint32_t suffixIndex : 30;
350   /// A flag indicating whether this alias has a suffix or not.
351   bool hasSuffixIndex : 1;
352   /// A flag indicating whether this alias may be deferred or not.
353   bool isDeferrable : 1;
354 };
355 
356 /// This class represents a utility that initializes the set of attribute and
357 /// type aliases, without the need to store the extra information within the
358 /// main AliasState class or pass it around via function arguments.
359 class AliasInitializer {
360 public:
AliasInitializer(DialectInterfaceCollection<OpAsmDialectInterface> & interfaces,llvm::BumpPtrAllocator & aliasAllocator)361   AliasInitializer(
362       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
363       llvm::BumpPtrAllocator &aliasAllocator)
364       : interfaces(interfaces), aliasAllocator(aliasAllocator),
365         aliasOS(aliasBuffer) {}
366 
367   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
368                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
369                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
370 
371   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
372   /// set to true if the originator of this attribute can resolve the alias
373   /// after parsing has completed (e.g. in the case of operation locations).
374   void visit(Attribute attr, bool canBeDeferred = false);
375 
376   /// Visit the given type to see if it has an alias.
377   void visit(Type type);
378 
379 private:
380   /// Try to generate an alias for the provided symbol. If an alias is
381   /// generated, the provided alias mapping and reverse mapping are updated.
382   /// Returns success if an alias was generated, failure otherwise.
383   template <typename T>
384   LogicalResult
385   generateAlias(T symbol,
386                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
387 
388   /// The set of asm interfaces within the context.
389   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
390 
391   /// Mapping between an alias and the set of symbols mapped to it.
392   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
393   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
394 
395   /// An allocator used for alias names.
396   llvm::BumpPtrAllocator &aliasAllocator;
397 
398   /// The set of visited attributes.
399   DenseSet<Attribute> visitedAttributes;
400 
401   /// The set of attributes that have aliases *and* can be deferred.
402   DenseSet<Attribute> deferrableAttributes;
403 
404   /// The set of visited types.
405   DenseSet<Type> visitedTypes;
406 
407   /// Storage and stream used when generating an alias.
408   SmallString<32> aliasBuffer;
409   llvm::raw_svector_ostream aliasOS;
410 };
411 
412 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
413 /// and merely collects the attributes and types that *would* be printed in a
414 /// normal print invocation so that we can generate proper aliases. This allows
415 /// for us to generate aliases only for the attributes and types that would be
416 /// in the output, and trims down unnecessary output.
417 class DummyAliasOperationPrinter : private OpAsmPrinter {
418 public:
DummyAliasOperationPrinter(const OpPrintingFlags & printerFlags,AliasInitializer & initializer)419   explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
420                                       AliasInitializer &initializer)
421       : printerFlags(printerFlags), initializer(initializer) {}
422 
423   /// Print the given operation.
print(Operation * op)424   void print(Operation *op) {
425     // Visit the operation location.
426     if (printerFlags.shouldPrintDebugInfo())
427       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
428 
429     // If requested, always print the generic form.
430     if (!printerFlags.shouldPrintGenericOpForm()) {
431       // Check to see if this is a known operation.  If so, use the registered
432       // custom printer hook.
433       if (auto opInfo = op->getRegisteredInfo()) {
434         opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
435         return;
436       }
437     }
438 
439     // Otherwise print with the generic assembly form.
440     printGenericOp(op);
441   }
442 
443 private:
444   /// Print the given operation in the generic form.
printGenericOp(Operation * op,bool printOpName=true)445   void printGenericOp(Operation *op, bool printOpName = true) override {
446     // Consider nested operations for aliases.
447     if (op->getNumRegions() != 0) {
448       for (Region &region : op->getRegions())
449         printRegion(region, /*printEntryBlockArgs=*/true,
450                     /*printBlockTerminators=*/true);
451     }
452 
453     // Visit all the types used in the operation.
454     for (Type type : op->getOperandTypes())
455       printType(type);
456     for (Type type : op->getResultTypes())
457       printType(type);
458 
459     // Consider the attributes of the operation for aliases.
460     for (const NamedAttribute &attr : op->getAttrs())
461       printAttribute(attr.getValue());
462   }
463 
464   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
465   /// block are not printed. If 'printBlockTerminator' is false, the terminator
466   /// operation of the block is not printed.
print(Block * block,bool printBlockArgs=true,bool printBlockTerminator=true)467   void print(Block *block, bool printBlockArgs = true,
468              bool printBlockTerminator = true) {
469     // Consider the types of the block arguments for aliases if 'printBlockArgs'
470     // is set to true.
471     if (printBlockArgs) {
472       for (BlockArgument arg : block->getArguments()) {
473         printType(arg.getType());
474 
475         // Visit the argument location.
476         if (printerFlags.shouldPrintDebugInfo())
477           // TODO: Allow deferring argument locations.
478           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
479       }
480     }
481 
482     // Consider the operations within this block, ignoring the terminator if
483     // requested.
484     bool hasTerminator =
485         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
486     auto range = llvm::make_range(
487         block->begin(),
488         std::prev(block->end(),
489                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
490     for (Operation &op : range)
491       print(&op);
492   }
493 
494   /// Print the given region.
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock=false)495   void printRegion(Region &region, bool printEntryBlockArgs,
496                    bool printBlockTerminators,
497                    bool printEmptyBlock = false) override {
498     if (region.empty())
499       return;
500 
501     auto *entryBlock = &region.front();
502     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
503     for (Block &b : llvm::drop_begin(region, 1))
504       print(&b);
505   }
506 
printRegionArgument(BlockArgument arg,ArrayRef<NamedAttribute> argAttrs,bool omitType)507   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
508                            bool omitType) override {
509     printType(arg.getType());
510     // Visit the argument location.
511     if (printerFlags.shouldPrintDebugInfo())
512       // TODO: Allow deferring argument locations.
513       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
514   }
515 
516   /// Consider the given type to be printed for an alias.
printType(Type type)517   void printType(Type type) override { initializer.visit(type); }
518 
519   /// Consider the given attribute to be printed for an alias.
printAttribute(Attribute attr)520   void printAttribute(Attribute attr) override { initializer.visit(attr); }
printAttributeWithoutType(Attribute attr)521   void printAttributeWithoutType(Attribute attr) override {
522     printAttribute(attr);
523   }
printAlias(Attribute attr)524   LogicalResult printAlias(Attribute attr) override {
525     initializer.visit(attr);
526     return success();
527   }
printAlias(Type type)528   LogicalResult printAlias(Type type) override {
529     initializer.visit(type);
530     return success();
531   }
532 
533   /// Print the given set of attributes with names not included within
534   /// 'elidedAttrs'.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})535   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
536                              ArrayRef<StringRef> elidedAttrs = {}) override {
537     if (attrs.empty())
538       return;
539     if (elidedAttrs.empty()) {
540       for (const NamedAttribute &attr : attrs)
541         printAttribute(attr.getValue());
542       return;
543     }
544     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
545                                                   elidedAttrs.end());
546     for (const NamedAttribute &attr : attrs)
547       if (!elidedAttrsSet.contains(attr.getName().strref()))
548         printAttribute(attr.getValue());
549   }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})550   void printOptionalAttrDictWithKeyword(
551       ArrayRef<NamedAttribute> attrs,
552       ArrayRef<StringRef> elidedAttrs = {}) override {
553     printOptionalAttrDict(attrs, elidedAttrs);
554   }
555 
556   /// Return a null stream as the output stream, this will ignore any data fed
557   /// to it.
getStream() const558   raw_ostream &getStream() const override { return os; }
559 
560   /// The following are hooks of `OpAsmPrinter` that are not necessary for
561   /// determining potential aliases.
printFloat(const APFloat & value)562   void printFloat(const APFloat &value) override {}
printAffineMapOfSSAIds(AffineMapAttr,ValueRange)563   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
printAffineExprOfSSAIds(AffineExpr,ValueRange,ValueRange)564   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
printNewline()565   void printNewline() override {}
printOperand(Value)566   void printOperand(Value) override {}
printOperand(Value,raw_ostream & os)567   void printOperand(Value, raw_ostream &os) override {
568     // Users expect the output string to have at least the prefixed % to signal
569     // a value name. To maintain this invariant, emit a name even if it is
570     // guaranteed to go unused.
571     os << "%";
572   }
printKeywordOrString(StringRef)573   void printKeywordOrString(StringRef) override {}
printSymbolName(StringRef)574   void printSymbolName(StringRef) override {}
printSuccessor(Block *)575   void printSuccessor(Block *) override {}
printSuccessorAndUseList(Block *,ValueRange)576   void printSuccessorAndUseList(Block *, ValueRange) override {}
shadowRegionArgs(Region &,ValueRange)577   void shadowRegionArgs(Region &, ValueRange) override {}
578 
579   /// The printer flags to use when determining potential aliases.
580   const OpPrintingFlags &printerFlags;
581 
582   /// The initializer to use when identifying aliases.
583   AliasInitializer &initializer;
584 
585   /// A dummy output stream.
586   mutable llvm::raw_null_ostream os;
587 };
588 } // namespace
589 
590 /// Sanitize the given name such that it can be used as a valid identifier. If
591 /// the string needs to be modified in any way, the provided buffer is used to
592 /// store the new copy,
sanitizeIdentifier(StringRef name,SmallString<16> & buffer,StringRef allowedPunctChars="$._-",bool allowTrailingDigit=true)593 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
594                                     StringRef allowedPunctChars = "$._-",
595                                     bool allowTrailingDigit = true) {
596   assert(!name.empty() && "Shouldn't have an empty name here");
597 
598   auto copyNameToBuffer = [&] {
599     for (char ch : name) {
600       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
601         buffer.push_back(ch);
602       else if (ch == ' ')
603         buffer.push_back('_');
604       else
605         buffer.append(llvm::utohexstr((unsigned char)ch));
606     }
607   };
608 
609   // Check to see if this name is valid. If it starts with a digit, then it
610   // could conflict with the autogenerated numeric ID's, so add an underscore
611   // prefix to avoid problems.
612   if (isdigit(name[0])) {
613     buffer.push_back('_');
614     copyNameToBuffer();
615     return buffer;
616   }
617 
618   // If the name ends with a trailing digit, add a '_' to avoid potential
619   // conflicts with autogenerated ID's.
620   if (!allowTrailingDigit && isdigit(name.back())) {
621     copyNameToBuffer();
622     buffer.push_back('_');
623     return buffer;
624   }
625 
626   // Check to see that the name consists of only valid identifier characters.
627   for (char ch : name) {
628     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
629       copyNameToBuffer();
630       return buffer;
631     }
632   }
633 
634   // If there are no invalid characters, return the original name.
635   return name;
636 }
637 
638 /// Given a collection of aliases and symbols, initialize a mapping from a
639 /// symbol to a given alias.
640 template <typename T>
641 static void
initializeAliases(llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol,llvm::MapVector<T,SymbolAlias> & symbolToAlias,DenseSet<T> * deferrableAliases=nullptr)642 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
643                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
644                   DenseSet<T> *deferrableAliases = nullptr) {
645   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
646       aliasToSymbol.takeVector();
647   llvm::array_pod_sort(aliases.begin(), aliases.end(),
648                        [](const auto *lhs, const auto *rhs) {
649                          return lhs->first.compare(rhs->first);
650                        });
651 
652   for (auto &it : aliases) {
653     // If there is only one instance for this alias, use the name directly.
654     if (it.second.size() == 1) {
655       T symbol = it.second.front();
656       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
657       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
658       continue;
659     }
660     // Otherwise, add the index to the name.
661     for (int i = 0, e = it.second.size(); i < e; ++i) {
662       T symbol = it.second[i];
663       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
664       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
665     }
666   }
667 }
668 
initialize(Operation * op,const OpPrintingFlags & printerFlags,llvm::MapVector<Attribute,SymbolAlias> & attrToAlias,llvm::MapVector<Type,SymbolAlias> & typeToAlias)669 void AliasInitializer::initialize(
670     Operation *op, const OpPrintingFlags &printerFlags,
671     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
672     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
673   // Use a dummy printer when walking the IR so that we can collect the
674   // attributes/types that will actually be used during printing when
675   // considering aliases.
676   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
677   aliasPrinter.print(op);
678 
679   // Initialize the aliases sorted by name.
680   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
681   initializeAliases(aliasToType, typeToAlias);
682 }
683 
visit(Attribute attr,bool canBeDeferred)684 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
685   if (!visitedAttributes.insert(attr).second) {
686     // If this attribute already has an alias and this instance can't be
687     // deferred, make sure that the alias isn't deferred.
688     if (!canBeDeferred)
689       deferrableAttributes.erase(attr);
690     return;
691   }
692 
693   // Try to generate an alias for this attribute.
694   if (succeeded(generateAlias(attr, aliasToAttr))) {
695     if (canBeDeferred)
696       deferrableAttributes.insert(attr);
697     return;
698   }
699 
700   // Check for any sub elements.
701   if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
702     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
703                                         [&](Type type) { visit(type); });
704   }
705 }
706 
visit(Type type)707 void AliasInitializer::visit(Type type) {
708   if (!visitedTypes.insert(type).second)
709     return;
710 
711   // Try to generate an alias for this type.
712   if (succeeded(generateAlias(type, aliasToType)))
713     return;
714 
715   // Check for any sub elements.
716   if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
717     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
718                                         [&](Type type) { visit(type); });
719   }
720 }
721 
722 template <typename T>
generateAlias(T symbol,llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol)723 LogicalResult AliasInitializer::generateAlias(
724     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
725   SmallString<32> nameBuffer;
726   for (const auto &interface : interfaces) {
727     OpAsmDialectInterface::AliasResult result =
728         interface.getAlias(symbol, aliasOS);
729     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
730       continue;
731     nameBuffer = std::move(aliasBuffer);
732     assert(!nameBuffer.empty() && "expected valid alias name");
733     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
734       break;
735   }
736 
737   if (nameBuffer.empty())
738     return failure();
739 
740   SmallString<16> tempBuffer;
741   StringRef name =
742       sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
743                          /*allowTrailingDigit=*/false);
744   name = name.copy(aliasAllocator);
745   aliasToSymbol[name].push_back(symbol);
746   return success();
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // AliasState
751 //===----------------------------------------------------------------------===//
752 
753 namespace {
754 /// This class manages the state for type and attribute aliases.
755 class AliasState {
756 public:
757   // Initialize the internal aliases.
758   void
759   initialize(Operation *op, const OpPrintingFlags &printerFlags,
760              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
761 
762   /// Get an alias for the given attribute if it has one and print it in `os`.
763   /// Returns success if an alias was printed, failure otherwise.
764   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
765 
766   /// Get an alias for the given type if it has one and print it in `os`.
767   /// Returns success if an alias was printed, failure otherwise.
768   LogicalResult getAlias(Type ty, raw_ostream &os) const;
769 
770   /// Print all of the referenced aliases that can not be resolved in a deferred
771   /// manner.
printNonDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const772   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
773     printAliases(os, newLine, /*isDeferred=*/false);
774   }
775 
776   /// Print all of the referenced aliases that support deferred resolution.
printDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const777   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
778     printAliases(os, newLine, /*isDeferred=*/true);
779   }
780 
781 private:
782   /// Print all of the referenced aliases that support the provided resolution
783   /// behavior.
784   void printAliases(raw_ostream &os, NewLineCounter &newLine,
785                     bool isDeferred) const;
786 
787   /// Mapping between attribute and alias.
788   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
789   /// Mapping between type and alias.
790   llvm::MapVector<Type, SymbolAlias> typeToAlias;
791 
792   /// An allocator used for alias names.
793   llvm::BumpPtrAllocator aliasAllocator;
794 };
795 } // namespace
796 
initialize(Operation * op,const OpPrintingFlags & printerFlags,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)797 void AliasState::initialize(
798     Operation *op, const OpPrintingFlags &printerFlags,
799     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
800   AliasInitializer initializer(interfaces, aliasAllocator);
801   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
802 }
803 
getAlias(Attribute attr,raw_ostream & os) const804 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
805   auto it = attrToAlias.find(attr);
806   if (it == attrToAlias.end())
807     return failure();
808   it->second.print(os << '#');
809   return success();
810 }
811 
getAlias(Type ty,raw_ostream & os) const812 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
813   auto it = typeToAlias.find(ty);
814   if (it == typeToAlias.end())
815     return failure();
816 
817   it->second.print(os << '!');
818   return success();
819 }
820 
printAliases(raw_ostream & os,NewLineCounter & newLine,bool isDeferred) const821 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
822                               bool isDeferred) const {
823   auto filterFn = [=](const auto &aliasIt) {
824     return aliasIt.second.canBeDeferred() == isDeferred;
825   };
826   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
827     it.second.print(os << '#');
828     os << " = " << it.first << newLine;
829   }
830   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
831     it.second.print(os << '!');
832     os << " = " << it.first << newLine;
833   }
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // SSANameState
838 //===----------------------------------------------------------------------===//
839 
840 namespace {
841 /// Info about block printing: a number which is its position in the visitation
842 /// order, and a name that is used to print reference to it, e.g. ^bb42.
843 struct BlockInfo {
844   int ordering;
845   StringRef name;
846 };
847 
848 /// This class manages the state of SSA value names.
849 class SSANameState {
850 public:
851   /// A sentinel value used for values with names set.
852   enum : unsigned { NameSentinel = ~0U };
853 
854   SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
855 
856   /// Print the SSA identifier for the given value to 'stream'. If
857   /// 'printResultNo' is true, it also presents the result number ('#' number)
858   /// of this value.
859   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
860 
861   /// Print the operation identifier.
862   void printOperationID(Operation *op, raw_ostream &stream) const;
863 
864   /// Return the result indices for each of the result groups registered by this
865   /// operation, or empty if none exist.
866   ArrayRef<int> getOpResultGroups(Operation *op);
867 
868   /// Get the info for the given block.
869   BlockInfo getBlockInfo(Block *block);
870 
871   /// Renumber the arguments for the specified region to the same names as the
872   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
873   /// details.
874   void shadowRegionArgs(Region &region, ValueRange namesToUse);
875 
876 private:
877   /// Number the SSA values within the given IR unit.
878   void numberValuesInRegion(Region &region);
879   void numberValuesInBlock(Block &block);
880   void numberValuesInOp(Operation &op);
881 
882   /// Given a result of an operation 'result', find the result group head
883   /// 'lookupValue' and the result of 'result' within that group in
884   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
885   /// has more than 1 result.
886   void getResultIDAndNumber(OpResult result, Value &lookupValue,
887                             Optional<int> &lookupResultNo) const;
888 
889   /// Set a special value name for the given value.
890   void setValueName(Value value, StringRef name);
891 
892   /// Uniques the given value name within the printer. If the given name
893   /// conflicts, it is automatically renamed.
894   StringRef uniqueValueName(StringRef name);
895 
896   /// This is the value ID for each SSA value. If this returns NameSentinel,
897   /// then the valueID has an entry in valueNames.
898   DenseMap<Value, unsigned> valueIDs;
899   DenseMap<Value, StringRef> valueNames;
900 
901   /// When printing users of values, an operation without a result might
902   /// be the user. This map holds ids for such operations.
903   DenseMap<Operation *, unsigned> operationIDs;
904 
905   /// This is a map of operations that contain multiple named result groups,
906   /// i.e. there may be multiple names for the results of the operation. The
907   /// value of this map are the result numbers that start a result group.
908   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
909 
910   /// This maps blocks to there visitation number in the current region as well
911   /// as the string representing their name.
912   DenseMap<Block *, BlockInfo> blockNames;
913 
914   /// This keeps track of all of the non-numeric names that are in flight,
915   /// allowing us to check for duplicates.
916   /// Note: the value of the map is unused.
917   llvm::ScopedHashTable<StringRef, char> usedNames;
918   llvm::BumpPtrAllocator usedNameAllocator;
919 
920   /// This is the next value ID to assign in numbering.
921   unsigned nextValueID = 0;
922   /// This is the next ID to assign to a region entry block argument.
923   unsigned nextArgumentID = 0;
924   /// This is the next ID to assign when a name conflict is detected.
925   unsigned nextConflictID = 0;
926 
927   /// These are the printing flags.  They control, eg., whether to print in
928   /// generic form.
929   OpPrintingFlags printerFlags;
930 };
931 } // namespace
932 
SSANameState(Operation * op,const OpPrintingFlags & printerFlags)933 SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
934     : printerFlags(printerFlags) {
935   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
936   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
937   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
938 
939   // The naming context includes `nextValueID`, `nextArgumentID`,
940   // `nextConflictID` and `usedNames` scoped HashTable. This information is
941   // carried from the parent region.
942   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
943   using NamingContext =
944       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
945 
946   // Allocator for UsedNamesScopeTy
947   llvm::BumpPtrAllocator allocator;
948 
949   // Add a scope for the top level operation.
950   auto *topLevelNamesScope =
951       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
952 
953   SmallVector<NamingContext, 8> nameContext;
954   for (Region &region : op->getRegions())
955     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
956                                           nextConflictID, topLevelNamesScope));
957 
958   numberValuesInOp(*op);
959 
960   while (!nameContext.empty()) {
961     Region *region;
962     UsedNamesScopeTy *parentScope;
963     std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
964         nameContext.pop_back_val();
965 
966     // When we switch from one subtree to another, pop the scopes(needless)
967     // until the parent scope.
968     while (usedNames.getCurScope() != parentScope) {
969       usedNames.getCurScope()->~UsedNamesScopeTy();
970       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
971              "top level parentScope must be a nullptr");
972     }
973 
974     // Add a scope for the current region.
975     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
976         UsedNamesScopeTy(usedNames);
977 
978     numberValuesInRegion(*region);
979 
980     for (Operation &op : region->getOps())
981       for (Region &region : op.getRegions())
982         nameContext.push_back(std::make_tuple(&region, nextValueID,
983                                               nextArgumentID, nextConflictID,
984                                               curNamesScope));
985   }
986 
987   // Manually remove all the scopes.
988   while (usedNames.getCurScope() != nullptr)
989     usedNames.getCurScope()->~UsedNamesScopeTy();
990 }
991 
printValueID(Value value,bool printResultNo,raw_ostream & stream) const992 void SSANameState::printValueID(Value value, bool printResultNo,
993                                 raw_ostream &stream) const {
994   if (!value) {
995     stream << "<<NULL VALUE>>";
996     return;
997   }
998 
999   Optional<int> resultNo;
1000   auto lookupValue = value;
1001 
1002   // If this is an operation result, collect the head lookup value of the result
1003   // group and the result number of 'result' within that group.
1004   if (OpResult result = value.dyn_cast<OpResult>())
1005     getResultIDAndNumber(result, lookupValue, resultNo);
1006 
1007   auto it = valueIDs.find(lookupValue);
1008   if (it == valueIDs.end()) {
1009     stream << "<<UNKNOWN SSA VALUE>>";
1010     return;
1011   }
1012 
1013   stream << '%';
1014   if (it->second != NameSentinel) {
1015     stream << it->second;
1016   } else {
1017     auto nameIt = valueNames.find(lookupValue);
1018     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
1019     stream << nameIt->second;
1020   }
1021 
1022   if (resultNo && printResultNo)
1023     stream << '#' << resultNo;
1024 }
1025 
printOperationID(Operation * op,raw_ostream & stream) const1026 void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
1027   auto it = operationIDs.find(op);
1028   if (it == operationIDs.end()) {
1029     stream << "<<UNKOWN OPERATION>>";
1030   } else {
1031     stream << '%' << it->second;
1032   }
1033 }
1034 
getOpResultGroups(Operation * op)1035 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
1036   auto it = opResultGroups.find(op);
1037   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
1038 }
1039 
getBlockInfo(Block * block)1040 BlockInfo SSANameState::getBlockInfo(Block *block) {
1041   auto it = blockNames.find(block);
1042   BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
1043   return it != blockNames.end() ? it->second : invalidBlock;
1044 }
1045 
shadowRegionArgs(Region & region,ValueRange namesToUse)1046 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
1047   assert(!region.empty() && "cannot shadow arguments of an empty region");
1048   assert(region.getNumArguments() == namesToUse.size() &&
1049          "incorrect number of names passed in");
1050   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1051          "only KnownIsolatedFromAbove ops can shadow names");
1052 
1053   SmallVector<char, 16> nameStr;
1054   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1055     auto nameToUse = namesToUse[i];
1056     if (nameToUse == nullptr)
1057       continue;
1058     auto nameToReplace = region.getArgument(i);
1059 
1060     nameStr.clear();
1061     llvm::raw_svector_ostream nameStream(nameStr);
1062     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
1063 
1064     // Entry block arguments should already have a pretty "arg" name.
1065     assert(valueIDs[nameToReplace] == NameSentinel);
1066 
1067     // Use the name without the leading %.
1068     auto name = StringRef(nameStream.str()).drop_front();
1069 
1070     // Overwrite the name.
1071     valueNames[nameToReplace] = name.copy(usedNameAllocator);
1072   }
1073 }
1074 
numberValuesInRegion(Region & region)1075 void SSANameState::numberValuesInRegion(Region &region) {
1076   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1077     assert(!valueIDs.count(arg) && "arg numbered multiple times");
1078     assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
1079            "arg not defined in current region");
1080     setValueName(arg, name);
1081   };
1082 
1083   if (!printerFlags.shouldPrintGenericOpForm()) {
1084     if (Operation *op = region.getParentOp()) {
1085       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1086         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1087     }
1088   }
1089 
1090   // Number the values within this region in a breadth-first order.
1091   unsigned nextBlockID = 0;
1092   for (auto &block : region) {
1093     // Each block gets a unique ID, and all of the operations within it get
1094     // numbered as well.
1095     auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1096     if (blockInfoIt.second) {
1097       // This block hasn't been named through `getAsmBlockArgumentNames`, use
1098       // default `^bbNNN` format.
1099       std::string name;
1100       llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1101       blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1102     }
1103     blockInfoIt.first->second.ordering = nextBlockID++;
1104 
1105     numberValuesInBlock(block);
1106   }
1107 }
1108 
numberValuesInBlock(Block & block)1109 void SSANameState::numberValuesInBlock(Block &block) {
1110   // Number the block arguments. We give entry block arguments a special name
1111   // 'arg'.
1112   bool isEntryBlock = block.isEntryBlock();
1113   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1114   llvm::raw_svector_ostream specialName(specialNameBuffer);
1115   for (auto arg : block.getArguments()) {
1116     if (valueIDs.count(arg))
1117       continue;
1118     if (isEntryBlock) {
1119       specialNameBuffer.resize(strlen("arg"));
1120       specialName << nextArgumentID++;
1121     }
1122     setValueName(arg, specialName.str());
1123   }
1124 
1125   // Number the operations in this block.
1126   for (auto &op : block)
1127     numberValuesInOp(op);
1128 }
1129 
numberValuesInOp(Operation & op)1130 void SSANameState::numberValuesInOp(Operation &op) {
1131   // Function used to set the special result names for the operation.
1132   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1133   auto setResultNameFn = [&](Value result, StringRef name) {
1134     assert(!valueIDs.count(result) && "result numbered multiple times");
1135     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1136     setValueName(result, name);
1137 
1138     // Record the result number for groups not anchored at 0.
1139     if (int resultNo = result.cast<OpResult>().getResultNumber())
1140       resultGroups.push_back(resultNo);
1141   };
1142   // Operations can customize the printing of block names in OpAsmOpInterface.
1143   auto setBlockNameFn = [&](Block *block, StringRef name) {
1144     assert(block->getParentOp() == &op &&
1145            "getAsmBlockArgumentNames callback invoked on a block not directly "
1146            "nested under the current operation");
1147     assert(!blockNames.count(block) && "block numbered multiple times");
1148     SmallString<16> tmpBuffer{"^"};
1149     name = sanitizeIdentifier(name, tmpBuffer);
1150     if (name.data() != tmpBuffer.data()) {
1151       tmpBuffer.append(name);
1152       name = tmpBuffer.str();
1153     }
1154     name = name.copy(usedNameAllocator);
1155     blockNames[block] = {-1, name};
1156   };
1157 
1158   if (!printerFlags.shouldPrintGenericOpForm()) {
1159     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1160       asmInterface.getAsmBlockNames(setBlockNameFn);
1161       asmInterface.getAsmResultNames(setResultNameFn);
1162     }
1163   }
1164 
1165   unsigned numResults = op.getNumResults();
1166   if (numResults == 0) {
1167     // If value users should be printed, operations with no result need an id.
1168     if (printerFlags.shouldPrintValueUsers()) {
1169       if (operationIDs.try_emplace(&op, nextValueID).second)
1170         ++nextValueID;
1171     }
1172     return;
1173   }
1174   Value resultBegin = op.getResult(0);
1175 
1176   // If the first result wasn't numbered, give it a default number.
1177   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1178     ++nextValueID;
1179 
1180   // If this operation has multiple result groups, mark it.
1181   if (resultGroups.size() != 1) {
1182     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1183     opResultGroups.try_emplace(&op, std::move(resultGroups));
1184   }
1185 }
1186 
getResultIDAndNumber(OpResult result,Value & lookupValue,Optional<int> & lookupResultNo) const1187 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1188                                         Optional<int> &lookupResultNo) const {
1189   Operation *owner = result.getOwner();
1190   if (owner->getNumResults() == 1)
1191     return;
1192   int resultNo = result.getResultNumber();
1193 
1194   // If this operation has multiple result groups, we will need to find the
1195   // one corresponding to this result.
1196   auto resultGroupIt = opResultGroups.find(owner);
1197   if (resultGroupIt == opResultGroups.end()) {
1198     // If not, just use the first result.
1199     lookupResultNo = resultNo;
1200     lookupValue = owner->getResult(0);
1201     return;
1202   }
1203 
1204   // Find the correct index using a binary search, as the groups are ordered.
1205   ArrayRef<int> resultGroups = resultGroupIt->second;
1206   const auto *it = llvm::upper_bound(resultGroups, resultNo);
1207   int groupResultNo = 0, groupSize = 0;
1208 
1209   // If there are no smaller elements, the last result group is the lookup.
1210   if (it == resultGroups.end()) {
1211     groupResultNo = resultGroups.back();
1212     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1213   } else {
1214     // Otherwise, the previous element is the lookup.
1215     groupResultNo = *std::prev(it);
1216     groupSize = *it - groupResultNo;
1217   }
1218 
1219   // We only record the result number for a group of size greater than 1.
1220   if (groupSize != 1)
1221     lookupResultNo = resultNo - groupResultNo;
1222   lookupValue = owner->getResult(groupResultNo);
1223 }
1224 
setValueName(Value value,StringRef name)1225 void SSANameState::setValueName(Value value, StringRef name) {
1226   // If the name is empty, the value uses the default numbering.
1227   if (name.empty()) {
1228     valueIDs[value] = nextValueID++;
1229     return;
1230   }
1231 
1232   valueIDs[value] = NameSentinel;
1233   valueNames[value] = uniqueValueName(name);
1234 }
1235 
uniqueValueName(StringRef name)1236 StringRef SSANameState::uniqueValueName(StringRef name) {
1237   SmallString<16> tmpBuffer;
1238   name = sanitizeIdentifier(name, tmpBuffer);
1239 
1240   // Check to see if this name is already unique.
1241   if (!usedNames.count(name)) {
1242     name = name.copy(usedNameAllocator);
1243   } else {
1244     // Otherwise, we had a conflict - probe until we find a unique name. This
1245     // is guaranteed to terminate (and usually in a single iteration) because it
1246     // generates new names by incrementing nextConflictID.
1247     SmallString<64> probeName(name);
1248     probeName.push_back('_');
1249     while (true) {
1250       probeName += llvm::utostr(nextConflictID++);
1251       if (!usedNames.count(probeName)) {
1252         name = probeName.str().copy(usedNameAllocator);
1253         break;
1254       }
1255       probeName.resize(name.size() + 1);
1256     }
1257   }
1258 
1259   usedNames.insert(name, char());
1260   return name;
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // Resources
1265 //===----------------------------------------------------------------------===//
1266 
1267 AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
1268 AsmResourceBuilder::~AsmResourceBuilder() = default;
1269 AsmResourceParser::~AsmResourceParser() = default;
1270 AsmResourcePrinter::~AsmResourcePrinter() = default;
1271 
1272 //===----------------------------------------------------------------------===//
1273 // AsmState
1274 //===----------------------------------------------------------------------===//
1275 
1276 namespace mlir {
1277 namespace detail {
1278 class AsmStateImpl {
1279 public:
AsmStateImpl(Operation * op,const OpPrintingFlags & printerFlags,AsmState::LocationMap * locationMap)1280   explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1281                         AsmState::LocationMap *locationMap)
1282       : interfaces(op->getContext()), nameState(op, printerFlags),
1283         printerFlags(printerFlags), locationMap(locationMap) {}
1284 
1285   /// Initialize the alias state to enable the printing of aliases.
initializeAliases(Operation * op)1286   void initializeAliases(Operation *op) {
1287     aliasState.initialize(op, printerFlags, interfaces);
1288   }
1289 
1290   /// Get the state used for aliases.
getAliasState()1291   AliasState &getAliasState() { return aliasState; }
1292 
1293   /// Get the state used for SSA names.
getSSANameState()1294   SSANameState &getSSANameState() { return nameState; }
1295 
1296   /// Return the dialects within the context that implement
1297   /// OpAsmDialectInterface.
getDialectInterfaces()1298   DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
1299     return interfaces;
1300   }
1301 
1302   /// Return the non-dialect resource printers.
getResourcePrinters()1303   auto getResourcePrinters() {
1304     return llvm::make_pointee_range(externalResourcePrinters);
1305   }
1306 
1307   /// Get the printer flags.
getPrinterFlags() const1308   const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1309 
1310   /// Register the location, line and column, within the buffer that the given
1311   /// operation was printed at.
registerOperationLocation(Operation * op,unsigned line,unsigned col)1312   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1313     if (locationMap)
1314       (*locationMap)[op] = std::make_pair(line, col);
1315   }
1316 
1317 private:
1318   /// Collection of OpAsm interfaces implemented in the context.
1319   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1320 
1321   /// A collection of non-dialect resource printers.
1322   SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
1323 
1324   /// The state used for attribute and type aliases.
1325   AliasState aliasState;
1326 
1327   /// The state used for SSA value names.
1328   SSANameState nameState;
1329 
1330   /// Flags that control op output.
1331   OpPrintingFlags printerFlags;
1332 
1333   /// An optional location map to be populated.
1334   AsmState::LocationMap *locationMap;
1335 
1336   // Allow direct access to the impl fields.
1337   friend AsmState;
1338 };
1339 } // namespace detail
1340 } // namespace mlir
1341 
1342 /// Verifies the operation and switches to generic op printing if verification
1343 /// fails. We need to do this because custom print functions may fail for
1344 /// invalid ops.
verifyOpAndAdjustFlags(Operation * op,OpPrintingFlags printerFlags)1345 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
1346                                               OpPrintingFlags printerFlags) {
1347   if (printerFlags.shouldPrintGenericOpForm() ||
1348       printerFlags.shouldAssumeVerified())
1349     return printerFlags;
1350 
1351   LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: "
1352                           << op->getName() << "\n");
1353 
1354   // Ignore errors emitted by the verifier. We check the thread id to avoid
1355   // consuming other threads' errors.
1356   auto parentThreadId = llvm::get_threadid();
1357   ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
1358     if (parentThreadId == llvm::get_threadid()) {
1359       LLVM_DEBUG({
1360         diag.print(llvm::dbgs());
1361         llvm::dbgs() << "\n";
1362       });
1363       return success();
1364     }
1365     return failure();
1366   });
1367   if (failed(verify(op))) {
1368     LLVM_DEBUG(llvm::dbgs()
1369                << DEBUG_TYPE << ": '" << op->getName()
1370                << "' failed to verify and will be printed in generic form\n");
1371     printerFlags.printGenericOpForm();
1372   }
1373 
1374   return printerFlags;
1375 }
1376 
AsmState(Operation * op,const OpPrintingFlags & printerFlags,LocationMap * locationMap)1377 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1378                    LocationMap *locationMap)
1379     : impl(std::make_unique<AsmStateImpl>(
1380           op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
1381 AsmState::~AsmState() = default;
1382 
getPrinterFlags() const1383 const OpPrintingFlags &AsmState::getPrinterFlags() const {
1384   return impl->getPrinterFlags();
1385 }
1386 
attachResourcePrinter(std::unique_ptr<AsmResourcePrinter> printer)1387 void AsmState::attachResourcePrinter(
1388     std::unique_ptr<AsmResourcePrinter> printer) {
1389   impl->externalResourcePrinters.emplace_back(std::move(printer));
1390 }
1391 
1392 //===----------------------------------------------------------------------===//
1393 // AsmPrinter::Impl
1394 //===----------------------------------------------------------------------===//
1395 
1396 namespace mlir {
1397 class AsmPrinter::Impl {
1398 public:
Impl(raw_ostream & os,OpPrintingFlags flags=llvm::None,AsmStateImpl * state=nullptr)1399   Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1400        AsmStateImpl *state = nullptr)
1401       : os(os), printerFlags(flags), state(state) {}
Impl(Impl & other)1402   explicit Impl(Impl &other)
1403       : Impl(other.os, other.printerFlags, other.state) {}
1404 
1405   /// Returns the output stream of the printer.
getStream()1406   raw_ostream &getStream() { return os; }
1407 
1408   template <typename Container, typename UnaryFunctor>
interleaveComma(const Container & c,UnaryFunctor eachFn) const1409   inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
1410     llvm::interleaveComma(c, os, eachFn);
1411   }
1412 
1413   /// This enum describes the different kinds of elision for the type of an
1414   /// attribute when printing it.
1415   enum class AttrTypeElision {
1416     /// The type must not be elided,
1417     Never,
1418     /// The type may be elided when it matches the default used in the parser
1419     /// (for example i64 is the default for integer attributes).
1420     May,
1421     /// The type must be elided.
1422     Must
1423   };
1424 
1425   /// Print the given attribute.
1426   void printAttribute(Attribute attr,
1427                       AttrTypeElision typeElision = AttrTypeElision::Never);
1428 
1429   /// Print the alias for the given attribute, return failure if no alias could
1430   /// be printed.
1431   LogicalResult printAlias(Attribute attr);
1432 
1433   void printType(Type type);
1434 
1435   /// Print the alias for the given type, return failure if no alias could
1436   /// be printed.
1437   LogicalResult printAlias(Type type);
1438 
1439   /// Print the given location to the stream. If `allowAlias` is true, this
1440   /// allows for the internal location to use an attribute alias.
1441   void printLocation(LocationAttr loc, bool allowAlias = false);
1442 
1443   /// Print a reference to the given resource that is owned by the given
1444   /// dialect.
printResourceHandle(const AsmDialectResourceHandle & resource)1445   void printResourceHandle(const AsmDialectResourceHandle &resource) {
1446     auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
1447     os << interface->getResourceKey(resource);
1448     dialectResources[resource.getDialect()].insert(resource);
1449   }
1450 
1451   void printAffineMap(AffineMap map);
1452   void
1453   printAffineExpr(AffineExpr expr,
1454                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1455   void printAffineConstraint(AffineExpr expr, bool isEq);
1456   void printIntegerSet(IntegerSet set);
1457 
1458 protected:
1459   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1460                              ArrayRef<StringRef> elidedAttrs = {},
1461                              bool withKeyword = false);
1462   void printNamedAttribute(NamedAttribute attr);
1463   void printTrailingLocation(Location loc, bool allowAlias = true);
1464   void printLocationInternal(LocationAttr loc, bool pretty = false);
1465 
1466   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1467   /// used instead of individual elements when the elements attr is large.
1468   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1469 
1470   /// Print a dense string elements attribute.
1471   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1472 
1473   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1474   /// used instead of individual elements when the elements attr is large.
1475   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1476                                      bool allowHex);
1477 
1478   void printDialectAttribute(Attribute attr);
1479   void printDialectType(Type type);
1480 
1481   /// Print an escaped string, wrapped with "".
1482   void printEscapedString(StringRef str);
1483 
1484   /// Print a hex string, wrapped with "".
1485   void printHexString(StringRef str);
1486   void printHexString(ArrayRef<char> data);
1487 
1488   /// This enum is used to represent the binding strength of the enclosing
1489   /// context that an AffineExprStorage is being printed in, so we can
1490   /// intelligently produce parens.
1491   enum class BindingStrength {
1492     Weak,   // + and -
1493     Strong, // All other binary operators.
1494   };
1495   void printAffineExprInternal(
1496       AffineExpr expr, BindingStrength enclosingTightness,
1497       function_ref<void(unsigned, bool)> printValueName = nullptr);
1498 
1499   /// The output stream for the printer.
1500   raw_ostream &os;
1501 
1502   /// A set of flags to control the printer's behavior.
1503   OpPrintingFlags printerFlags;
1504 
1505   /// An optional printer state for the module.
1506   AsmStateImpl *state;
1507 
1508   /// A tracker for the number of new lines emitted during printing.
1509   NewLineCounter newLine;
1510 
1511   /// A set of dialect resources that were referenced during printing.
1512   DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
1513 };
1514 } // namespace mlir
1515 
printTrailingLocation(Location loc,bool allowAlias)1516 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1517   // Check to see if we are printing debug information.
1518   if (!printerFlags.shouldPrintDebugInfo())
1519     return;
1520 
1521   os << " ";
1522   printLocation(loc, /*allowAlias=*/allowAlias);
1523 }
1524 
printLocationInternal(LocationAttr loc,bool pretty)1525 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
1526   TypeSwitch<LocationAttr>(loc)
1527       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1528         printLocationInternal(loc.getFallbackLocation(), pretty);
1529       })
1530       .Case<UnknownLoc>([&](UnknownLoc loc) {
1531         if (pretty)
1532           os << "[unknown]";
1533         else
1534           os << "unknown";
1535       })
1536       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1537         if (pretty)
1538           os << loc.getFilename().getValue();
1539         else
1540           printEscapedString(loc.getFilename());
1541         os << ':' << loc.getLine() << ':' << loc.getColumn();
1542       })
1543       .Case<NameLoc>([&](NameLoc loc) {
1544         printEscapedString(loc.getName());
1545 
1546         // Print the child if it isn't unknown.
1547         auto childLoc = loc.getChildLoc();
1548         if (!childLoc.isa<UnknownLoc>()) {
1549           os << '(';
1550           printLocationInternal(childLoc, pretty);
1551           os << ')';
1552         }
1553       })
1554       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1555         Location caller = loc.getCaller();
1556         Location callee = loc.getCallee();
1557         if (!pretty)
1558           os << "callsite(";
1559         printLocationInternal(callee, pretty);
1560         if (pretty) {
1561           if (callee.isa<NameLoc>()) {
1562             if (caller.isa<FileLineColLoc>()) {
1563               os << " at ";
1564             } else {
1565               os << newLine << " at ";
1566             }
1567           } else {
1568             os << newLine << " at ";
1569           }
1570         } else {
1571           os << " at ";
1572         }
1573         printLocationInternal(caller, pretty);
1574         if (!pretty)
1575           os << ")";
1576       })
1577       .Case<FusedLoc>([&](FusedLoc loc) {
1578         if (!pretty)
1579           os << "fused";
1580         if (Attribute metadata = loc.getMetadata())
1581           os << '<' << metadata << '>';
1582         os << '[';
1583         interleave(
1584             loc.getLocations(),
1585             [&](Location loc) { printLocationInternal(loc, pretty); },
1586             [&]() { os << ", "; });
1587         os << ']';
1588       });
1589 }
1590 
1591 /// Print a floating point value in a way that the parser will be able to
1592 /// round-trip losslessly.
printFloatValue(const APFloat & apValue,raw_ostream & os)1593 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1594   // We would like to output the FP constant value in exponential notation,
1595   // but we cannot do this if doing so will lose precision.  Check here to
1596   // make sure that we only output it in exponential format if we can parse
1597   // the value back and get the same value.
1598   bool isInf = apValue.isInfinity();
1599   bool isNaN = apValue.isNaN();
1600   if (!isInf && !isNaN) {
1601     SmallString<128> strValue;
1602     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1603                      /*TruncateZero=*/false);
1604 
1605     // Check to make sure that the stringized number is not some string like
1606     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1607     // that the string matches the "[-+]?[0-9]" regex.
1608     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1609             ((strValue[0] == '-' || strValue[0] == '+') &&
1610              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1611            "[-+]?[0-9] regex does not match!");
1612 
1613     // Parse back the stringized version and check that the value is equal
1614     // (i.e., there is no precision loss).
1615     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1616       os << strValue;
1617       return;
1618     }
1619 
1620     // If it is not, use the default format of APFloat instead of the
1621     // exponential notation.
1622     strValue.clear();
1623     apValue.toString(strValue);
1624 
1625     // Make sure that we can parse the default form as a float.
1626     if (strValue.str().contains('.')) {
1627       os << strValue;
1628       return;
1629     }
1630   }
1631 
1632   // Print special values in hexadecimal format. The sign bit should be included
1633   // in the literal.
1634   SmallVector<char, 16> str;
1635   APInt apInt = apValue.bitcastToAPInt();
1636   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1637                  /*formatAsCLiteral=*/true);
1638   os << str;
1639 }
1640 
printLocation(LocationAttr loc,bool allowAlias)1641 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
1642   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1643     return printLocationInternal(loc, /*pretty=*/true);
1644 
1645   os << "loc(";
1646   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1647     printLocationInternal(loc);
1648   os << ')';
1649 }
1650 
1651 /// Returns true if the given dialect symbol data is simple enough to print in
1652 /// the pretty form. This is essentially when the symbol takes the form:
1653 ///   identifier (`<` body `>`)?
isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)1654 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1655   // The name must start with an identifier.
1656   if (symName.empty() || !isalpha(symName.front()))
1657     return false;
1658 
1659   // Ignore all the characters that are valid in an identifier in the symbol
1660   // name.
1661   symName = symName.drop_while(
1662       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1663   if (symName.empty())
1664     return true;
1665 
1666   // If we got to an unexpected character, then it must be a <>. Check that the
1667   // rest of the symbol is wrapped within <>.
1668   return symName.front() == '<' && symName.back() == '>';
1669 }
1670 
1671 /// Print the given dialect symbol to the stream.
printDialectSymbol(raw_ostream & os,StringRef symPrefix,StringRef dialectName,StringRef symString)1672 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1673                                StringRef dialectName, StringRef symString) {
1674   os << symPrefix << dialectName;
1675 
1676   // If this symbol name is simple enough, print it directly in pretty form,
1677   // otherwise, we print it as an escaped string.
1678   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1679     os << '.' << symString;
1680     return;
1681   }
1682 
1683   os << '<' << symString << '>';
1684 }
1685 
1686 /// Returns true if the given string can be represented as a bare identifier.
isBareIdentifier(StringRef name)1687 static bool isBareIdentifier(StringRef name) {
1688   // By making this unsigned, the value passed in to isalnum will always be
1689   // in the range 0-255. This is important when building with MSVC because
1690   // its implementation will assert. This situation can arise when dealing
1691   // with UTF-8 multibyte characters.
1692   if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
1693     return false;
1694   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1695     return isalnum(c) || c == '_' || c == '$' || c == '.';
1696   });
1697 }
1698 
1699 /// Print the given string as a keyword, or a quoted and escaped string if it
1700 /// has any special or non-printable characters in it.
printKeywordOrString(StringRef keyword,raw_ostream & os)1701 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
1702   // If it can be represented as a bare identifier, write it directly.
1703   if (isBareIdentifier(keyword)) {
1704     os << keyword;
1705     return;
1706   }
1707 
1708   // Otherwise, output the keyword wrapped in quotes with proper escaping.
1709   os << "\"";
1710   printEscapedString(keyword, os);
1711   os << '"';
1712 }
1713 
1714 /// Print the given string as a symbol reference. A symbol reference is
1715 /// represented as a string prefixed with '@'. The reference is surrounded with
1716 /// ""'s and escaped if it has any special or non-printable characters in it.
printSymbolReference(StringRef symbolRef,raw_ostream & os)1717 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1718   assert(!symbolRef.empty() && "expected valid symbol reference");
1719   os << '@';
1720   printKeywordOrString(symbolRef, os);
1721 }
1722 
1723 // Print out a valid ElementsAttr that is succinct and can represent any
1724 // potential shape/type, for use when eliding a large ElementsAttr.
1725 //
1726 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1727 // hopefully alert readers to the fact that this has been elided.
1728 //
1729 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1730 // accept the string "elided". The first string must be a registered dialect
1731 // name and the latter must be a hex constant.
printElidedElementsAttr(raw_ostream & os)1732 static void printElidedElementsAttr(raw_ostream &os) {
1733   os << R"(opaque<"elided_large_const", "0xDEADBEEF">)";
1734 }
1735 
printAlias(Attribute attr)1736 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
1737   return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
1738 }
1739 
printAlias(Type type)1740 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
1741   return success(state && succeeded(state->getAliasState().getAlias(type, os)));
1742 }
1743 
printAttribute(Attribute attr,AttrTypeElision typeElision)1744 void AsmPrinter::Impl::printAttribute(Attribute attr,
1745                                       AttrTypeElision typeElision) {
1746   if (!attr) {
1747     os << "<<NULL ATTRIBUTE>>";
1748     return;
1749   }
1750 
1751   // Try to print an alias for this attribute.
1752   if (succeeded(printAlias(attr)))
1753     return;
1754 
1755   auto attrType = attr.getType();
1756   if (!isa<BuiltinDialect>(attr.getDialect())) {
1757     printDialectAttribute(attr);
1758   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1759     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1760                        opaqueAttr.getAttrData());
1761   } else if (attr.isa<UnitAttr>()) {
1762     os << "unit";
1763     return;
1764   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1765     os << '{';
1766     interleaveComma(dictAttr.getValue(),
1767                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1768     os << '}';
1769 
1770   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1771     if (attrType.isSignlessInteger(1)) {
1772       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1773 
1774       // Boolean integer attributes always elides the type.
1775       return;
1776     }
1777 
1778     // Only print attributes as unsigned if they are explicitly unsigned or are
1779     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1780     // values print as signed.
1781     bool isUnsigned =
1782         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1783     intAttr.getValue().print(os, !isUnsigned);
1784 
1785     // IntegerAttr elides the type if I64.
1786     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1787       return;
1788 
1789   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1790     printFloatValue(floatAttr.getValue(), os);
1791 
1792     // FloatAttr elides the type if F64.
1793     if (typeElision == AttrTypeElision::May && attrType.isF64())
1794       return;
1795 
1796   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1797     printEscapedString(strAttr.getValue());
1798 
1799   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1800     os << '[';
1801     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1802       printAttribute(attr, AttrTypeElision::May);
1803     });
1804     os << ']';
1805 
1806   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1807     os << "affine_map<";
1808     affineMapAttr.getValue().print(os);
1809     os << '>';
1810 
1811     // AffineMap always elides the type.
1812     return;
1813 
1814   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1815     os << "affine_set<";
1816     integerSetAttr.getValue().print(os);
1817     os << '>';
1818 
1819     // IntegerSet always elides the type.
1820     return;
1821 
1822   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1823     printType(typeAttr.getValue());
1824 
1825   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1826     printSymbolReference(refAttr.getRootReference().getValue(), os);
1827     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1828       os << "::";
1829       printSymbolReference(nestedRef.getValue(), os);
1830     }
1831 
1832   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1833     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1834       printElidedElementsAttr(os);
1835     } else {
1836       os << "opaque<" << opaqueAttr.getDialect() << ", ";
1837       printHexString(opaqueAttr.getValue());
1838       os << ">";
1839     }
1840 
1841   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1842     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1843       printElidedElementsAttr(os);
1844     } else {
1845       os << "dense<";
1846       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1847       os << '>';
1848     }
1849 
1850   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1851     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1852       printElidedElementsAttr(os);
1853     } else {
1854       os << "dense<";
1855       printDenseStringElementsAttr(strEltAttr);
1856       os << '>';
1857     }
1858 
1859   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1860     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1861         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1862       printElidedElementsAttr(os);
1863     } else {
1864       os << "sparse<";
1865       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1866       if (indices.getNumElements() != 0) {
1867         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1868         os << ", ";
1869         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1870       }
1871       os << '>';
1872     }
1873   } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
1874     typeElision = AttrTypeElision::Must;
1875     switch (denseArrayAttr.getElementType()) {
1876     case DenseArrayBaseAttr::EltType::I8:
1877       os << "[:i8";
1878       break;
1879     case DenseArrayBaseAttr::EltType::I16:
1880       os << "[:i16";
1881       break;
1882     case DenseArrayBaseAttr::EltType::I32:
1883       os << "[:i32";
1884       break;
1885     case DenseArrayBaseAttr::EltType::I64:
1886       os << "[:i64";
1887       break;
1888     case DenseArrayBaseAttr::EltType::F32:
1889       os << "[:f32";
1890       break;
1891     case DenseArrayBaseAttr::EltType::F64:
1892       os << "[:f64";
1893       break;
1894     }
1895     if (denseArrayAttr.getType().cast<ShapedType>().getRank())
1896       os << " ";
1897     denseArrayAttr.printWithoutBraces(os);
1898     os << "]";
1899   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1900     printLocation(locAttr);
1901   } else {
1902     llvm::report_fatal_error("Unknown builtin attribute");
1903   }
1904   // Don't print the type if we must elide it, or if it is a None type.
1905   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1906     os << " : ";
1907     printType(attrType);
1908   }
1909 }
1910 
1911 /// Print the integer element of a DenseElementsAttr.
printDenseIntElement(const APInt & value,raw_ostream & os,bool isSigned)1912 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1913                                  bool isSigned) {
1914   if (value.getBitWidth() == 1)
1915     os << (value.getBoolValue() ? "true" : "false");
1916   else
1917     value.print(os, isSigned);
1918 }
1919 
1920 static void
printDenseElementsAttrImpl(bool isSplat,ShapedType type,raw_ostream & os,function_ref<void (unsigned)> printEltFn)1921 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1922                            function_ref<void(unsigned)> printEltFn) {
1923   // Special case for 0-d and splat tensors.
1924   if (isSplat)
1925     return printEltFn(0);
1926 
1927   // Special case for degenerate tensors.
1928   auto numElements = type.getNumElements();
1929   if (numElements == 0)
1930     return;
1931 
1932   // We use a mixed-radix counter to iterate through the shape. When we bump a
1933   // non-least-significant digit, we emit a close bracket. When we next emit an
1934   // element we re-open all closed brackets.
1935 
1936   // The mixed-radix counter, with radices in 'shape'.
1937   int64_t rank = type.getRank();
1938   SmallVector<unsigned, 4> counter(rank, 0);
1939   // The number of brackets that have been opened and not closed.
1940   unsigned openBrackets = 0;
1941 
1942   auto shape = type.getShape();
1943   auto bumpCounter = [&] {
1944     // Bump the least significant digit.
1945     ++counter[rank - 1];
1946     // Iterate backwards bubbling back the increment.
1947     for (unsigned i = rank - 1; i > 0; --i)
1948       if (counter[i] >= shape[i]) {
1949         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1950         counter[i] = 0;
1951         ++counter[i - 1];
1952         --openBrackets;
1953         os << ']';
1954       }
1955   };
1956 
1957   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1958     if (idx != 0)
1959       os << ", ";
1960     while (openBrackets++ < rank)
1961       os << '[';
1962     openBrackets = rank;
1963     printEltFn(idx);
1964     bumpCounter();
1965   }
1966   while (openBrackets-- > 0)
1967     os << ']';
1968 }
1969 
printDenseElementsAttr(DenseElementsAttr attr,bool allowHex)1970 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
1971                                               bool allowHex) {
1972   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1973     return printDenseStringElementsAttr(stringAttr);
1974 
1975   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1976                                 allowHex);
1977 }
1978 
printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,bool allowHex)1979 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
1980     DenseIntOrFPElementsAttr attr, bool allowHex) {
1981   auto type = attr.getType();
1982   auto elementType = type.getElementType();
1983 
1984   // Check to see if we should format this attribute as a hex string.
1985   auto numElements = type.getNumElements();
1986   if (!attr.isSplat() && allowHex &&
1987       shouldPrintElementsAttrWithHex(numElements)) {
1988     ArrayRef<char> rawData = attr.getRawData();
1989     if (llvm::support::endian::system_endianness() ==
1990         llvm::support::endianness::big) {
1991       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1992       // machines. It is converted here to print in LE format.
1993       SmallVector<char, 64> outDataVec(rawData.size());
1994       MutableArrayRef<char> convRawData(outDataVec);
1995       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1996           rawData, convRawData, type);
1997       printHexString(convRawData);
1998     } else {
1999       printHexString(rawData);
2000     }
2001 
2002     return;
2003   }
2004 
2005   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
2006     Type complexElementType = complexTy.getElementType();
2007     // Note: The if and else below had a common lambda function which invoked
2008     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
2009     // and hence was replaced.
2010     if (complexElementType.isa<IntegerType>()) {
2011       bool isSigned = !complexElementType.isUnsignedInteger();
2012       auto valueIt = attr.value_begin<std::complex<APInt>>();
2013       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2014         auto complexValue = *(valueIt + index);
2015         os << "(";
2016         printDenseIntElement(complexValue.real(), os, isSigned);
2017         os << ",";
2018         printDenseIntElement(complexValue.imag(), os, isSigned);
2019         os << ")";
2020       });
2021     } else {
2022       auto valueIt = attr.value_begin<std::complex<APFloat>>();
2023       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2024         auto complexValue = *(valueIt + index);
2025         os << "(";
2026         printFloatValue(complexValue.real(), os);
2027         os << ",";
2028         printFloatValue(complexValue.imag(), os);
2029         os << ")";
2030       });
2031     }
2032   } else if (elementType.isIntOrIndex()) {
2033     bool isSigned = !elementType.isUnsignedInteger();
2034     auto valueIt = attr.value_begin<APInt>();
2035     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2036       printDenseIntElement(*(valueIt + index), os, isSigned);
2037     });
2038   } else {
2039     assert(elementType.isa<FloatType>() && "unexpected element type");
2040     auto valueIt = attr.value_begin<APFloat>();
2041     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2042       printFloatValue(*(valueIt + index), os);
2043     });
2044   }
2045 }
2046 
printDenseStringElementsAttr(DenseStringElementsAttr attr)2047 void AsmPrinter::Impl::printDenseStringElementsAttr(
2048     DenseStringElementsAttr attr) {
2049   ArrayRef<StringRef> data = attr.getRawStringData();
2050   auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
2051   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
2052 }
2053 
printType(Type type)2054 void AsmPrinter::Impl::printType(Type type) {
2055   if (!type) {
2056     os << "<<NULL TYPE>>";
2057     return;
2058   }
2059 
2060   // Try to print an alias for this type.
2061   if (state && succeeded(state->getAliasState().getAlias(type, os)))
2062     return;
2063 
2064   TypeSwitch<Type>(type)
2065       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
2066         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
2067                            opaqueTy.getTypeData());
2068       })
2069       .Case<IndexType>([&](Type) { os << "index"; })
2070       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2071       .Case<Float16Type>([&](Type) { os << "f16"; })
2072       .Case<Float32Type>([&](Type) { os << "f32"; })
2073       .Case<Float64Type>([&](Type) { os << "f64"; })
2074       .Case<Float80Type>([&](Type) { os << "f80"; })
2075       .Case<Float128Type>([&](Type) { os << "f128"; })
2076       .Case<IntegerType>([&](IntegerType integerTy) {
2077         if (integerTy.isSigned())
2078           os << 's';
2079         else if (integerTy.isUnsigned())
2080           os << 'u';
2081         os << 'i' << integerTy.getWidth();
2082       })
2083       .Case<FunctionType>([&](FunctionType funcTy) {
2084         os << '(';
2085         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2086         os << ") -> ";
2087         ArrayRef<Type> results = funcTy.getResults();
2088         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
2089           printType(results[0]);
2090         } else {
2091           os << '(';
2092           interleaveComma(results, [&](Type ty) { printType(ty); });
2093           os << ')';
2094         }
2095       })
2096       .Case<VectorType>([&](VectorType vectorTy) {
2097         os << "vector<";
2098         auto vShape = vectorTy.getShape();
2099         unsigned lastDim = vShape.size();
2100         unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
2101         unsigned dimIdx = 0;
2102         for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
2103           os << vShape[dimIdx] << 'x';
2104         if (vectorTy.isScalable()) {
2105           os << '[';
2106           unsigned secondToLastDim = lastDim - 1;
2107           for (; dimIdx < secondToLastDim; dimIdx++)
2108             os << vShape[dimIdx] << 'x';
2109           os << vShape[dimIdx] << "]x";
2110         }
2111         printType(vectorTy.getElementType());
2112         os << '>';
2113       })
2114       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2115         os << "tensor<";
2116         for (int64_t dim : tensorTy.getShape()) {
2117           if (ShapedType::isDynamic(dim))
2118             os << '?';
2119           else
2120             os << dim;
2121           os << 'x';
2122         }
2123         printType(tensorTy.getElementType());
2124         // Only print the encoding attribute value if set.
2125         if (tensorTy.getEncoding()) {
2126           os << ", ";
2127           printAttribute(tensorTy.getEncoding());
2128         }
2129         os << '>';
2130       })
2131       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2132         os << "tensor<*x";
2133         printType(tensorTy.getElementType());
2134         os << '>';
2135       })
2136       .Case<MemRefType>([&](MemRefType memrefTy) {
2137         os << "memref<";
2138         for (int64_t dim : memrefTy.getShape()) {
2139           if (ShapedType::isDynamic(dim))
2140             os << '?';
2141           else
2142             os << dim;
2143           os << 'x';
2144         }
2145         printType(memrefTy.getElementType());
2146         if (!memrefTy.getLayout().isIdentity()) {
2147           os << ", ";
2148           printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2149         }
2150         // Only print the memory space if it is the non-default one.
2151         if (memrefTy.getMemorySpace()) {
2152           os << ", ";
2153           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2154         }
2155         os << '>';
2156       })
2157       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2158         os << "memref<*x";
2159         printType(memrefTy.getElementType());
2160         // Only print the memory space if it is the non-default one.
2161         if (memrefTy.getMemorySpace()) {
2162           os << ", ";
2163           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2164         }
2165         os << '>';
2166       })
2167       .Case<ComplexType>([&](ComplexType complexTy) {
2168         os << "complex<";
2169         printType(complexTy.getElementType());
2170         os << '>';
2171       })
2172       .Case<TupleType>([&](TupleType tupleTy) {
2173         os << "tuple<";
2174         interleaveComma(tupleTy.getTypes(),
2175                         [&](Type type) { printType(type); });
2176         os << '>';
2177       })
2178       .Case<NoneType>([&](Type) { os << "none"; })
2179       .Default([&](Type type) { return printDialectType(type); });
2180 }
2181 
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs,bool withKeyword)2182 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2183                                              ArrayRef<StringRef> elidedAttrs,
2184                                              bool withKeyword) {
2185   // If there are no attributes, then there is nothing to be done.
2186   if (attrs.empty())
2187     return;
2188 
2189   // Functor used to print a filtered attribute list.
2190   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2191     // Print the 'attributes' keyword if necessary.
2192     if (withKeyword)
2193       os << " attributes";
2194 
2195     // Otherwise, print them all out in braces.
2196     os << " {";
2197     interleaveComma(filteredAttrs,
2198                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
2199     os << '}';
2200   };
2201 
2202   // If no attributes are elided, we can directly print with no filtering.
2203   if (elidedAttrs.empty())
2204     return printFilteredAttributesFn(attrs);
2205 
2206   // Otherwise, filter out any attributes that shouldn't be included.
2207   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2208                                                 elidedAttrs.end());
2209   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2210     return !elidedAttrsSet.contains(attr.getName().strref());
2211   });
2212   if (!filteredAttrs.empty())
2213     printFilteredAttributesFn(filteredAttrs);
2214 }
2215 
printNamedAttribute(NamedAttribute attr)2216 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2217   // Print the name without quotes if possible.
2218   ::printKeywordOrString(attr.getName().strref(), os);
2219 
2220   // Pretty printing elides the attribute value for unit attributes.
2221   if (attr.getValue().isa<UnitAttr>())
2222     return;
2223 
2224   os << " = ";
2225   printAttribute(attr.getValue());
2226 }
2227 
printDialectAttribute(Attribute attr)2228 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2229   auto &dialect = attr.getDialect();
2230 
2231   // Ask the dialect to serialize the attribute to a string.
2232   std::string attrName;
2233   {
2234     llvm::raw_string_ostream attrNameStr(attrName);
2235     Impl subPrinter(attrNameStr, printerFlags, state);
2236     DialectAsmPrinter printer(subPrinter);
2237     dialect.printAttribute(attr, printer);
2238 
2239     // FIXME: Delete this when we no longer require a nested printer.
2240     for (auto &it : subPrinter.dialectResources)
2241       for (const auto &resource : it.second)
2242         dialectResources[it.first].insert(resource);
2243   }
2244   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2245 }
2246 
printDialectType(Type type)2247 void AsmPrinter::Impl::printDialectType(Type type) {
2248   auto &dialect = type.getDialect();
2249 
2250   // Ask the dialect to serialize the type to a string.
2251   std::string typeName;
2252   {
2253     llvm::raw_string_ostream typeNameStr(typeName);
2254     Impl subPrinter(typeNameStr, printerFlags, state);
2255     DialectAsmPrinter printer(subPrinter);
2256     dialect.printType(type, printer);
2257 
2258     // FIXME: Delete this when we no longer require a nested printer.
2259     for (auto &it : subPrinter.dialectResources)
2260       for (const auto &resource : it.second)
2261         dialectResources[it.first].insert(resource);
2262   }
2263   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2264 }
2265 
printEscapedString(StringRef str)2266 void AsmPrinter::Impl::printEscapedString(StringRef str) {
2267   os << "\"";
2268   llvm::printEscapedString(str, os);
2269   os << "\"";
2270 }
2271 
printHexString(StringRef str)2272 void AsmPrinter::Impl::printHexString(StringRef str) {
2273   os << "\"0x" << llvm::toHex(str) << "\"";
2274 }
printHexString(ArrayRef<char> data)2275 void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
2276   printHexString(StringRef(data.data(), data.size()));
2277 }
2278 
2279 //===--------------------------------------------------------------------===//
2280 // AsmPrinter
2281 //===--------------------------------------------------------------------===//
2282 
2283 AsmPrinter::~AsmPrinter() = default;
2284 
getStream() const2285 raw_ostream &AsmPrinter::getStream() const {
2286   assert(impl && "expected AsmPrinter::getStream to be overriden");
2287   return impl->getStream();
2288 }
2289 
2290 /// Print the given floating point value in a stablized form.
printFloat(const APFloat & value)2291 void AsmPrinter::printFloat(const APFloat &value) {
2292   assert(impl && "expected AsmPrinter::printFloat to be overriden");
2293   printFloatValue(value, impl->getStream());
2294 }
2295 
printType(Type type)2296 void AsmPrinter::printType(Type type) {
2297   assert(impl && "expected AsmPrinter::printType to be overriden");
2298   impl->printType(type);
2299 }
2300 
printAttribute(Attribute attr)2301 void AsmPrinter::printAttribute(Attribute attr) {
2302   assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2303   impl->printAttribute(attr);
2304 }
2305 
printAlias(Attribute attr)2306 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2307   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2308   return impl->printAlias(attr);
2309 }
2310 
printAlias(Type type)2311 LogicalResult AsmPrinter::printAlias(Type type) {
2312   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2313   return impl->printAlias(type);
2314 }
2315 
printAttributeWithoutType(Attribute attr)2316 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2317   assert(impl &&
2318          "expected AsmPrinter::printAttributeWithoutType to be overriden");
2319   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2320 }
2321 
printKeywordOrString(StringRef keyword)2322 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2323   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2324   ::printKeywordOrString(keyword, impl->getStream());
2325 }
2326 
printSymbolName(StringRef symbolRef)2327 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2328   assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2329   ::printSymbolReference(symbolRef, impl->getStream());
2330 }
2331 
printResourceHandle(const AsmDialectResourceHandle & resource)2332 void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
2333   assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
2334   impl->printResourceHandle(resource);
2335 }
2336 
2337 //===----------------------------------------------------------------------===//
2338 // Affine expressions and maps
2339 //===----------------------------------------------------------------------===//
2340 
printAffineExpr(AffineExpr expr,function_ref<void (unsigned,bool)> printValueName)2341 void AsmPrinter::Impl::printAffineExpr(
2342     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2343   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2344 }
2345 
printAffineExprInternal(AffineExpr expr,BindingStrength enclosingTightness,function_ref<void (unsigned,bool)> printValueName)2346 void AsmPrinter::Impl::printAffineExprInternal(
2347     AffineExpr expr, BindingStrength enclosingTightness,
2348     function_ref<void(unsigned, bool)> printValueName) {
2349   const char *binopSpelling = nullptr;
2350   switch (expr.getKind()) {
2351   case AffineExprKind::SymbolId: {
2352     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2353     if (printValueName)
2354       printValueName(pos, /*isSymbol=*/true);
2355     else
2356       os << 's' << pos;
2357     return;
2358   }
2359   case AffineExprKind::DimId: {
2360     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2361     if (printValueName)
2362       printValueName(pos, /*isSymbol=*/false);
2363     else
2364       os << 'd' << pos;
2365     return;
2366   }
2367   case AffineExprKind::Constant:
2368     os << expr.cast<AffineConstantExpr>().getValue();
2369     return;
2370   case AffineExprKind::Add:
2371     binopSpelling = " + ";
2372     break;
2373   case AffineExprKind::Mul:
2374     binopSpelling = " * ";
2375     break;
2376   case AffineExprKind::FloorDiv:
2377     binopSpelling = " floordiv ";
2378     break;
2379   case AffineExprKind::CeilDiv:
2380     binopSpelling = " ceildiv ";
2381     break;
2382   case AffineExprKind::Mod:
2383     binopSpelling = " mod ";
2384     break;
2385   }
2386 
2387   auto binOp = expr.cast<AffineBinaryOpExpr>();
2388   AffineExpr lhsExpr = binOp.getLHS();
2389   AffineExpr rhsExpr = binOp.getRHS();
2390 
2391   // Handle tightly binding binary operators.
2392   if (binOp.getKind() != AffineExprKind::Add) {
2393     if (enclosingTightness == BindingStrength::Strong)
2394       os << '(';
2395 
2396     // Pretty print multiplication with -1.
2397     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2398     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2399         rhsConst.getValue() == -1) {
2400       os << "-";
2401       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2402       if (enclosingTightness == BindingStrength::Strong)
2403         os << ')';
2404       return;
2405     }
2406 
2407     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2408 
2409     os << binopSpelling;
2410     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2411 
2412     if (enclosingTightness == BindingStrength::Strong)
2413       os << ')';
2414     return;
2415   }
2416 
2417   // Print out special "pretty" forms for add.
2418   if (enclosingTightness == BindingStrength::Strong)
2419     os << '(';
2420 
2421   // Pretty print addition to a product that has a negative operand as a
2422   // subtraction.
2423   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2424     if (rhs.getKind() == AffineExprKind::Mul) {
2425       AffineExpr rrhsExpr = rhs.getRHS();
2426       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2427         if (rrhs.getValue() == -1) {
2428           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2429                                   printValueName);
2430           os << " - ";
2431           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2432             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2433                                     printValueName);
2434           } else {
2435             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2436                                     printValueName);
2437           }
2438 
2439           if (enclosingTightness == BindingStrength::Strong)
2440             os << ')';
2441           return;
2442         }
2443 
2444         if (rrhs.getValue() < -1) {
2445           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2446                                   printValueName);
2447           os << " - ";
2448           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2449                                   printValueName);
2450           os << " * " << -rrhs.getValue();
2451           if (enclosingTightness == BindingStrength::Strong)
2452             os << ')';
2453           return;
2454         }
2455       }
2456     }
2457   }
2458 
2459   // Pretty print addition to a negative number as a subtraction.
2460   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2461     if (rhsConst.getValue() < 0) {
2462       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2463       os << " - " << -rhsConst.getValue();
2464       if (enclosingTightness == BindingStrength::Strong)
2465         os << ')';
2466       return;
2467     }
2468   }
2469 
2470   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2471 
2472   os << " + ";
2473   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2474 
2475   if (enclosingTightness == BindingStrength::Strong)
2476     os << ')';
2477 }
2478 
printAffineConstraint(AffineExpr expr,bool isEq)2479 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2480   printAffineExprInternal(expr, BindingStrength::Weak);
2481   isEq ? os << " == 0" : os << " >= 0";
2482 }
2483 
printAffineMap(AffineMap map)2484 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2485   // Dimension identifiers.
2486   os << '(';
2487   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2488     os << 'd' << i << ", ";
2489   if (map.getNumDims() >= 1)
2490     os << 'd' << map.getNumDims() - 1;
2491   os << ')';
2492 
2493   // Symbolic identifiers.
2494   if (map.getNumSymbols() != 0) {
2495     os << '[';
2496     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2497       os << 's' << i << ", ";
2498     if (map.getNumSymbols() >= 1)
2499       os << 's' << map.getNumSymbols() - 1;
2500     os << ']';
2501   }
2502 
2503   // Result affine expressions.
2504   os << " -> (";
2505   interleaveComma(map.getResults(),
2506                   [&](AffineExpr expr) { printAffineExpr(expr); });
2507   os << ')';
2508 }
2509 
printIntegerSet(IntegerSet set)2510 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
2511   // Dimension identifiers.
2512   os << '(';
2513   for (unsigned i = 1; i < set.getNumDims(); ++i)
2514     os << 'd' << i - 1 << ", ";
2515   if (set.getNumDims() >= 1)
2516     os << 'd' << set.getNumDims() - 1;
2517   os << ')';
2518 
2519   // Symbolic identifiers.
2520   if (set.getNumSymbols() != 0) {
2521     os << '[';
2522     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2523       os << 's' << i << ", ";
2524     if (set.getNumSymbols() >= 1)
2525       os << 's' << set.getNumSymbols() - 1;
2526     os << ']';
2527   }
2528 
2529   // Print constraints.
2530   os << " : (";
2531   int numConstraints = set.getNumConstraints();
2532   for (int i = 1; i < numConstraints; ++i) {
2533     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2534     os << ", ";
2535   }
2536   if (numConstraints >= 1)
2537     printAffineConstraint(set.getConstraint(numConstraints - 1),
2538                           set.isEq(numConstraints - 1));
2539   os << ')';
2540 }
2541 
2542 //===----------------------------------------------------------------------===//
2543 // OperationPrinter
2544 //===----------------------------------------------------------------------===//
2545 
2546 namespace {
2547 /// This class contains the logic for printing operations, regions, and blocks.
2548 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
2549 public:
2550   using Impl = AsmPrinter::Impl;
2551   using Impl::printType;
2552 
OperationPrinter(raw_ostream & os,AsmStateImpl & state)2553   explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
2554       : Impl(os, state.getPrinterFlags(), &state),
2555         OpAsmPrinter(static_cast<Impl &>(*this)) {}
2556 
2557   /// Print the given top-level operation.
2558   void printTopLevelOperation(Operation *op);
2559 
2560   /// Print the given operation with its indent and location.
2561   void print(Operation *op);
2562   /// Print the bare location, not including indentation/location/etc.
2563   void printOperation(Operation *op);
2564   /// Print the given operation in the generic form.
2565   void printGenericOp(Operation *op, bool printOpName) override;
2566 
2567   /// Print the name of the given block.
2568   void printBlockName(Block *block);
2569 
2570   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2571   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2572   /// operation of the block is not printed.
2573   void print(Block *block, bool printBlockArgs = true,
2574              bool printBlockTerminator = true);
2575 
2576   /// Print the ID of the given value, optionally with its result number.
2577   void printValueID(Value value, bool printResultNo = true,
2578                     raw_ostream *streamOverride = nullptr) const;
2579 
2580   /// Print the ID of the given operation.
2581   void printOperationID(Operation *op,
2582                         raw_ostream *streamOverride = nullptr) const;
2583 
2584   //===--------------------------------------------------------------------===//
2585   // OpAsmPrinter methods
2586   //===--------------------------------------------------------------------===//
2587 
2588   /// Print a newline and indent the printer to the start of the current
2589   /// operation.
printNewline()2590   void printNewline() override {
2591     os << newLine;
2592     os.indent(currentIndent);
2593   }
2594 
2595   /// Print a block argument in the usual format of:
2596   ///   %ssaName : type {attr1=42} loc("here")
2597   /// where location printing is controlled by the standard internal option.
2598   /// You may pass omitType=true to not print a type, and pass an empty
2599   /// attribute list if you don't care for attributes.
2600   void printRegionArgument(BlockArgument arg,
2601                            ArrayRef<NamedAttribute> argAttrs = {},
2602                            bool omitType = false) override;
2603 
2604   /// Print the ID for the given value.
printOperand(Value value)2605   void printOperand(Value value) override { printValueID(value); }
printOperand(Value value,raw_ostream & os)2606   void printOperand(Value value, raw_ostream &os) override {
2607     printValueID(value, /*printResultNo=*/true, &os);
2608   }
2609 
2610   /// Print an optional attribute dictionary with a given set of elided values.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2611   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2612                              ArrayRef<StringRef> elidedAttrs = {}) override {
2613     Impl::printOptionalAttrDict(attrs, elidedAttrs);
2614   }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2615   void printOptionalAttrDictWithKeyword(
2616       ArrayRef<NamedAttribute> attrs,
2617       ArrayRef<StringRef> elidedAttrs = {}) override {
2618     Impl::printOptionalAttrDict(attrs, elidedAttrs,
2619                                 /*withKeyword=*/true);
2620   }
2621 
2622   /// Print the given successor.
2623   void printSuccessor(Block *successor) override;
2624 
2625   /// Print an operation successor with the operands used for the block
2626   /// arguments.
2627   void printSuccessorAndUseList(Block *successor,
2628                                 ValueRange succOperands) override;
2629 
2630   /// Print the given region.
2631   void printRegion(Region &region, bool printEntryBlockArgs,
2632                    bool printBlockTerminators, bool printEmptyBlock) override;
2633 
2634   /// Renumber the arguments for the specified region to the same names as the
2635   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2636   /// operations. If any entry in namesToUse is null, the corresponding
2637   /// argument name is left alone.
shadowRegionArgs(Region & region,ValueRange namesToUse)2638   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2639     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2640   }
2641 
2642   /// Print the given affine map with the symbol and dimension operands printed
2643   /// inline with the map.
2644   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2645                               ValueRange operands) override;
2646 
2647   /// Print the given affine expression with the symbol and dimension operands
2648   /// printed inline with the expression.
2649   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2650                                ValueRange symOperands) override;
2651 
2652   /// Print users of this operation or id of this operation if it has no result.
2653   void printUsersComment(Operation *op);
2654 
2655   /// Print users of this block arg.
2656   void printUsersComment(BlockArgument arg);
2657 
2658   /// Print the users of a value.
2659   void printValueUsers(Value value);
2660 
2661   /// Print either the ids of the result values or the id of the operation if
2662   /// the operation has no results.
2663   void printUserIDs(Operation *user, bool prefixComma = false);
2664 
2665 private:
2666   /// This class represents a resource builder implementation for the MLIR
2667   /// textual assembly format.
2668   class ResourceBuilder : public AsmResourceBuilder {
2669   public:
2670     using ValueFn = function_ref<void(raw_ostream &)>;
2671     using PrintFn = function_ref<void(StringRef, ValueFn)>;
2672 
ResourceBuilder(OperationPrinter & p,PrintFn printFn)2673     ResourceBuilder(OperationPrinter &p, PrintFn printFn)
2674         : p(p), printFn(printFn) {}
2675     ~ResourceBuilder() override = default;
2676 
buildBool(StringRef key,bool data)2677     void buildBool(StringRef key, bool data) final {
2678       printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); });
2679     }
2680 
buildString(StringRef key,StringRef data)2681     void buildString(StringRef key, StringRef data) final {
2682       printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); });
2683     }
2684 
buildBlob(StringRef key,ArrayRef<char> data,uint32_t dataAlignment)2685     void buildBlob(StringRef key, ArrayRef<char> data,
2686                    uint32_t dataAlignment) final {
2687       printFn(key, [&](raw_ostream &os) {
2688         // Store the blob in a hex string containing the alignment and the data.
2689         llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
2690         os << "\"0x"
2691            << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
2692                                     sizeof(dataAlignment)))
2693            << llvm::toHex(StringRef(data.data(), data.size())) << "\"";
2694       });
2695     }
2696 
2697   private:
2698     OperationPrinter &p;
2699     PrintFn printFn;
2700   };
2701 
2702   /// Print the metadata dictionary for the file, eliding it if it is empty.
2703   void printFileMetadataDictionary(Operation *op);
2704 
2705   /// Print the resource sections for the file metadata dictionary.
2706   /// `checkAddMetadataDict` is used to indicate that metadata is going to be
2707   /// added, and the file metadata dictionary should be started if it hasn't
2708   /// yet.
2709   void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
2710                                  Operation *op);
2711 
2712   // Contains the stack of default dialects to use when printing regions.
2713   // A new dialect is pushed to the stack before parsing regions nested under an
2714   // operation implementing `OpAsmOpInterface`, and popped when done. At the
2715   // top-level we start with "builtin" as the default, so that the top-level
2716   // `module` operation prints as-is.
2717   SmallVector<StringRef> defaultDialectStack{"builtin"};
2718 
2719   /// The number of spaces used for indenting nested operations.
2720   const static unsigned indentWidth = 2;
2721 
2722   // This is the current indentation level for nested structures.
2723   unsigned currentIndent = 0;
2724 };
2725 } // namespace
2726 
printTopLevelOperation(Operation * op)2727 void OperationPrinter::printTopLevelOperation(Operation *op) {
2728   // Output the aliases at the top level that can't be deferred.
2729   state->getAliasState().printNonDeferredAliases(os, newLine);
2730 
2731   // Print the module.
2732   print(op);
2733   os << newLine;
2734 
2735   // Output the aliases at the top level that can be deferred.
2736   state->getAliasState().printDeferredAliases(os, newLine);
2737 
2738   // Output any file level metadata.
2739   printFileMetadataDictionary(op);
2740 }
2741 
printFileMetadataDictionary(Operation * op)2742 void OperationPrinter::printFileMetadataDictionary(Operation *op) {
2743   bool sawMetadataEntry = false;
2744   auto checkAddMetadataDict = [&] {
2745     if (!std::exchange(sawMetadataEntry, true))
2746       os << newLine << "{-#" << newLine;
2747   };
2748 
2749   // Add the various types of metadata.
2750   printResourceFileMetadata(checkAddMetadataDict, op);
2751 
2752   // If the file dictionary exists, close it.
2753   if (sawMetadataEntry)
2754     os << newLine << "#-}" << newLine;
2755 }
2756 
printResourceFileMetadata(function_ref<void ()> checkAddMetadataDict,Operation * op)2757 void OperationPrinter::printResourceFileMetadata(
2758     function_ref<void()> checkAddMetadataDict, Operation *op) {
2759   // Functor used to add data entries to the file metadata dictionary.
2760   bool hadResource = false;
2761   auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
2762                              auto &&...providerArgs) {
2763     bool hadEntry = false;
2764     auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
2765       checkAddMetadataDict();
2766 
2767       // Emit the top-level resource entry if we haven't yet.
2768       if (!std::exchange(hadResource, true))
2769         os << "  " << dictName << "_resources: {" << newLine;
2770       // Emit the parent resource entry if we haven't yet.
2771       if (!std::exchange(hadEntry, true))
2772         os << "    " << name << ": {" << newLine;
2773       else
2774         os << "," << newLine;
2775 
2776       os << "      " << key << ": ";
2777       valueFn(os);
2778     };
2779     ResourceBuilder entryBuilder(*this, printFn);
2780     provider.buildResources(op, providerArgs..., entryBuilder);
2781 
2782     if (hadEntry)
2783       os << newLine << "    }";
2784   };
2785 
2786   // Print the `dialect_resources` section if we have any dialects with
2787   // resources.
2788   for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) {
2789     StringRef name = interface.getDialect()->getNamespace();
2790     auto it = dialectResources.find(interface.getDialect());
2791     if (it != dialectResources.end())
2792       processProvider("dialect", name, interface, it->second);
2793     else
2794       processProvider("dialect", name, interface,
2795                       SetVector<AsmDialectResourceHandle>());
2796   }
2797   if (hadResource)
2798     os << newLine << "  }";
2799 
2800   // Print the `external_resources` section if we have any external clients with
2801   // resources.
2802   hadResource = false;
2803   for (const auto &printer : state->getResourcePrinters())
2804     processProvider("external", printer.getName(), printer);
2805   if (hadResource)
2806     os << newLine << "  }";
2807 }
2808 
2809 /// Print a block argument in the usual format of:
2810 ///   %ssaName : type {attr1=42} loc("here")
2811 /// where location printing is controlled by the standard internal option.
2812 /// You may pass omitType=true to not print a type, and pass an empty
2813 /// attribute list if you don't care for attributes.
printRegionArgument(BlockArgument arg,ArrayRef<NamedAttribute> argAttrs,bool omitType)2814 void OperationPrinter::printRegionArgument(BlockArgument arg,
2815                                            ArrayRef<NamedAttribute> argAttrs,
2816                                            bool omitType) {
2817   printOperand(arg);
2818   if (!omitType) {
2819     os << ": ";
2820     printType(arg.getType());
2821   }
2822   printOptionalAttrDict(argAttrs);
2823   // TODO: We should allow location aliases on block arguments.
2824   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2825 }
2826 
print(Operation * op)2827 void OperationPrinter::print(Operation *op) {
2828   // Track the location of this operation.
2829   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2830 
2831   os.indent(currentIndent);
2832   printOperation(op);
2833   printTrailingLocation(op->getLoc());
2834   if (printerFlags.shouldPrintValueUsers())
2835     printUsersComment(op);
2836 }
2837 
printOperation(Operation * op)2838 void OperationPrinter::printOperation(Operation *op) {
2839   if (size_t numResults = op->getNumResults()) {
2840     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2841       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2842       if (resultCount > 1)
2843         os << ':' << resultCount;
2844     };
2845 
2846     // Check to see if this operation has multiple result groups.
2847     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2848     if (!resultGroups.empty()) {
2849       // Interleave the groups excluding the last one, this one will be handled
2850       // separately.
2851       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2852         printResultGroup(resultGroups[i],
2853                          resultGroups[i + 1] - resultGroups[i]);
2854       });
2855       os << ", ";
2856       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2857 
2858     } else {
2859       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2860     }
2861 
2862     os << " = ";
2863   }
2864 
2865   // If requested, always print the generic form.
2866   if (!printerFlags.shouldPrintGenericOpForm()) {
2867     // Check to see if this is a known operation. If so, use the registered
2868     // custom printer hook.
2869     if (auto opInfo = op->getRegisteredInfo()) {
2870       opInfo->printAssembly(op, *this, defaultDialectStack.back());
2871       return;
2872     }
2873     // Otherwise try to dispatch to the dialect, if available.
2874     if (Dialect *dialect = op->getDialect()) {
2875       if (auto opPrinter = dialect->getOperationPrinter(op)) {
2876         // Print the op name first.
2877         StringRef name = op->getName().getStringRef();
2878         // Only drop the default dialect prefix when it cannot lead to
2879         // ambiguities.
2880         if (name.count('.') == 1)
2881           name.consume_front((defaultDialectStack.back() + ".").str());
2882         os << name;
2883 
2884         // Print the rest of the op now.
2885         opPrinter(op, *this);
2886         return;
2887       }
2888     }
2889   }
2890 
2891   // Otherwise print with the generic assembly form.
2892   printGenericOp(op, /*printOpName=*/true);
2893 }
2894 
printUsersComment(Operation * op)2895 void OperationPrinter::printUsersComment(Operation *op) {
2896   unsigned numResults = op->getNumResults();
2897   if (!numResults && op->getNumOperands()) {
2898     os << " // id: ";
2899     printOperationID(op);
2900   } else if (numResults && op->use_empty()) {
2901     os << " // unused";
2902   } else if (numResults && !op->use_empty()) {
2903     // Print "user" if the operation has one result used to compute one other
2904     // result, or is used in one operation with no result.
2905     unsigned usedInNResults = 0;
2906     unsigned usedInNOperations = 0;
2907     SmallPtrSet<Operation *, 1> userSet;
2908     for (Operation *user : op->getUsers()) {
2909       if (userSet.insert(user).second) {
2910         ++usedInNOperations;
2911         usedInNResults += user->getNumResults();
2912       }
2913     }
2914 
2915     // We already know that users is not empty.
2916     bool exactlyOneUniqueUse =
2917         usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
2918     os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
2919     bool shouldPrintBrackets = numResults > 1;
2920     auto printOpResult = [&](OpResult opResult) {
2921       if (shouldPrintBrackets)
2922         os << "(";
2923       printValueUsers(opResult);
2924       if (shouldPrintBrackets)
2925         os << ")";
2926     };
2927 
2928     interleaveComma(op->getResults(), printOpResult);
2929   }
2930 }
2931 
printUsersComment(BlockArgument arg)2932 void OperationPrinter::printUsersComment(BlockArgument arg) {
2933   os << "// ";
2934   printValueID(arg);
2935   if (arg.use_empty()) {
2936     os << " is unused";
2937   } else {
2938     os << " is used by ";
2939     printValueUsers(arg);
2940   }
2941   os << newLine;
2942 }
2943 
printValueUsers(Value value)2944 void OperationPrinter::printValueUsers(Value value) {
2945   if (value.use_empty())
2946     os << "unused";
2947 
2948   // One value might be used as the operand of an operation more than once.
2949   // Only print the operations results once in that case.
2950   SmallPtrSet<Operation *, 1> userSet;
2951   for (auto &indexedUser : enumerate(value.getUsers())) {
2952     if (userSet.insert(indexedUser.value()).second)
2953       printUserIDs(indexedUser.value(), indexedUser.index());
2954   }
2955 }
2956 
printUserIDs(Operation * user,bool prefixComma)2957 void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
2958   if (prefixComma)
2959     os << ", ";
2960 
2961   if (!user->getNumResults()) {
2962     printOperationID(user);
2963   } else {
2964     interleaveComma(user->getResults(),
2965                     [this](Value result) { printValueID(result); });
2966   }
2967 }
2968 
printGenericOp(Operation * op,bool printOpName)2969 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
2970   if (printOpName)
2971     printEscapedString(op->getName().getStringRef());
2972   os << '(';
2973   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2974   os << ')';
2975 
2976   // For terminators, print the list of successors and their operands.
2977   if (op->getNumSuccessors() != 0) {
2978     os << '[';
2979     interleaveComma(op->getSuccessors(),
2980                     [&](Block *successor) { printBlockName(successor); });
2981     os << ']';
2982   }
2983 
2984   // Print regions.
2985   if (op->getNumRegions() != 0) {
2986     os << " (";
2987     interleaveComma(op->getRegions(), [&](Region &region) {
2988       printRegion(region, /*printEntryBlockArgs=*/true,
2989                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2990     });
2991     os << ')';
2992   }
2993 
2994   auto attrs = op->getAttrs();
2995   printOptionalAttrDict(attrs);
2996 
2997   // Print the type signature of the operation.
2998   os << " : ";
2999   printFunctionalType(op);
3000 }
3001 
printBlockName(Block * block)3002 void OperationPrinter::printBlockName(Block *block) {
3003   os << state->getSSANameState().getBlockInfo(block).name;
3004 }
3005 
print(Block * block,bool printBlockArgs,bool printBlockTerminator)3006 void OperationPrinter::print(Block *block, bool printBlockArgs,
3007                              bool printBlockTerminator) {
3008   // Print the block label and argument list if requested.
3009   if (printBlockArgs) {
3010     os.indent(currentIndent);
3011     printBlockName(block);
3012 
3013     // Print the argument list if non-empty.
3014     if (!block->args_empty()) {
3015       os << '(';
3016       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
3017         printValueID(arg);
3018         os << ": ";
3019         printType(arg.getType());
3020         // TODO: We should allow location aliases on block arguments.
3021         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3022       });
3023       os << ')';
3024     }
3025     os << ':';
3026 
3027     // Print out some context information about the predecessors of this block.
3028     if (!block->getParent()) {
3029       os << "  // block is not in a region!";
3030     } else if (block->hasNoPredecessors()) {
3031       if (!block->isEntryBlock())
3032         os << "  // no predecessors";
3033     } else if (auto *pred = block->getSinglePredecessor()) {
3034       os << "  // pred: ";
3035       printBlockName(pred);
3036     } else {
3037       // We want to print the predecessors in a stable order, not in
3038       // whatever order the use-list is in, so gather and sort them.
3039       SmallVector<BlockInfo, 4> predIDs;
3040       for (auto *pred : block->getPredecessors())
3041         predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
3042       llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
3043         return lhs.ordering < rhs.ordering;
3044       });
3045 
3046       os << "  // " << predIDs.size() << " preds: ";
3047 
3048       interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
3049     }
3050     os << newLine;
3051   }
3052 
3053   currentIndent += indentWidth;
3054 
3055   if (printerFlags.shouldPrintValueUsers()) {
3056     for (BlockArgument arg : block->getArguments()) {
3057       os.indent(currentIndent);
3058       printUsersComment(arg);
3059     }
3060   }
3061 
3062   bool hasTerminator =
3063       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
3064   auto range = llvm::make_range(
3065       block->begin(),
3066       std::prev(block->end(),
3067                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
3068   for (auto &op : range) {
3069     print(&op);
3070     os << newLine;
3071   }
3072   currentIndent -= indentWidth;
3073 }
3074 
printValueID(Value value,bool printResultNo,raw_ostream * streamOverride) const3075 void OperationPrinter::printValueID(Value value, bool printResultNo,
3076                                     raw_ostream *streamOverride) const {
3077   state->getSSANameState().printValueID(value, printResultNo,
3078                                         streamOverride ? *streamOverride : os);
3079 }
3080 
printOperationID(Operation * op,raw_ostream * streamOverride) const3081 void OperationPrinter::printOperationID(Operation *op,
3082                                         raw_ostream *streamOverride) const {
3083   state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
3084                                                                : os);
3085 }
3086 
printSuccessor(Block * successor)3087 void OperationPrinter::printSuccessor(Block *successor) {
3088   printBlockName(successor);
3089 }
3090 
printSuccessorAndUseList(Block * successor,ValueRange succOperands)3091 void OperationPrinter::printSuccessorAndUseList(Block *successor,
3092                                                 ValueRange succOperands) {
3093   printBlockName(successor);
3094   if (succOperands.empty())
3095     return;
3096 
3097   os << '(';
3098   interleaveComma(succOperands,
3099                   [this](Value operand) { printValueID(operand); });
3100   os << " : ";
3101   interleaveComma(succOperands,
3102                   [this](Value operand) { printType(operand.getType()); });
3103   os << ')';
3104 }
3105 
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock)3106 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
3107                                    bool printBlockTerminators,
3108                                    bool printEmptyBlock) {
3109   os << "{" << newLine;
3110   if (!region.empty()) {
3111     auto restoreDefaultDialect =
3112         llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
3113     if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
3114       defaultDialectStack.push_back(iface.getDefaultDialect());
3115     else
3116       defaultDialectStack.push_back("");
3117 
3118     auto *entryBlock = &region.front();
3119     // Force printing the block header if printEmptyBlock is set and the block
3120     // is empty or if printEntryBlockArgs is set and there are arguments to
3121     // print.
3122     bool shouldAlwaysPrintBlockHeader =
3123         (printEmptyBlock && entryBlock->empty()) ||
3124         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
3125     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
3126     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
3127       print(&b);
3128   }
3129   os.indent(currentIndent) << "}";
3130 }
3131 
printAffineMapOfSSAIds(AffineMapAttr mapAttr,ValueRange operands)3132 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3133                                               ValueRange operands) {
3134   AffineMap map = mapAttr.getValue();
3135   unsigned numDims = map.getNumDims();
3136   auto printValueName = [&](unsigned pos, bool isSymbol) {
3137     unsigned index = isSymbol ? numDims + pos : pos;
3138     assert(index < operands.size());
3139     if (isSymbol)
3140       os << "symbol(";
3141     printValueID(operands[index]);
3142     if (isSymbol)
3143       os << ')';
3144   };
3145 
3146   interleaveComma(map.getResults(), [&](AffineExpr expr) {
3147     printAffineExpr(expr, printValueName);
3148   });
3149 }
3150 
printAffineExprOfSSAIds(AffineExpr expr,ValueRange dimOperands,ValueRange symOperands)3151 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
3152                                                ValueRange dimOperands,
3153                                                ValueRange symOperands) {
3154   auto printValueName = [&](unsigned pos, bool isSymbol) {
3155     if (!isSymbol)
3156       return printValueID(dimOperands[pos]);
3157     os << "symbol(";
3158     printValueID(symOperands[pos]);
3159     os << ')';
3160   };
3161   printAffineExpr(expr, printValueName);
3162 }
3163 
3164 //===----------------------------------------------------------------------===//
3165 // print and dump methods
3166 //===----------------------------------------------------------------------===//
3167 
print(raw_ostream & os) const3168 void Attribute::print(raw_ostream &os) const {
3169   AsmPrinter::Impl(os).printAttribute(*this);
3170 }
3171 
dump() const3172 void Attribute::dump() const {
3173   print(llvm::errs());
3174   llvm::errs() << "\n";
3175 }
3176 
print(raw_ostream & os) const3177 void Type::print(raw_ostream &os) const {
3178   AsmPrinter::Impl(os).printType(*this);
3179 }
3180 
dump() const3181 void Type::dump() const { print(llvm::errs()); }
3182 
dump() const3183 void AffineMap::dump() const {
3184   print(llvm::errs());
3185   llvm::errs() << "\n";
3186 }
3187 
dump() const3188 void IntegerSet::dump() const {
3189   print(llvm::errs());
3190   llvm::errs() << "\n";
3191 }
3192 
print(raw_ostream & os) const3193 void AffineExpr::print(raw_ostream &os) const {
3194   if (!expr) {
3195     os << "<<NULL AFFINE EXPR>>";
3196     return;
3197   }
3198   AsmPrinter::Impl(os).printAffineExpr(*this);
3199 }
3200 
dump() const3201 void AffineExpr::dump() const {
3202   print(llvm::errs());
3203   llvm::errs() << "\n";
3204 }
3205 
print(raw_ostream & os) const3206 void AffineMap::print(raw_ostream &os) const {
3207   if (!map) {
3208     os << "<<NULL AFFINE MAP>>";
3209     return;
3210   }
3211   AsmPrinter::Impl(os).printAffineMap(*this);
3212 }
3213 
print(raw_ostream & os) const3214 void IntegerSet::print(raw_ostream &os) const {
3215   AsmPrinter::Impl(os).printIntegerSet(*this);
3216 }
3217 
print(raw_ostream & os)3218 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
print(raw_ostream & os,const OpPrintingFlags & flags)3219 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
3220   if (!impl) {
3221     os << "<<NULL VALUE>>";
3222     return;
3223   }
3224 
3225   if (auto *op = getDefiningOp())
3226     return op->print(os, flags);
3227   // TODO: Improve BlockArgument print'ing.
3228   BlockArgument arg = this->cast<BlockArgument>();
3229   os << "<block argument> of type '" << arg.getType()
3230      << "' at index: " << arg.getArgNumber();
3231 }
print(raw_ostream & os,AsmState & state)3232 void Value::print(raw_ostream &os, AsmState &state) {
3233   if (!impl) {
3234     os << "<<NULL VALUE>>";
3235     return;
3236   }
3237 
3238   if (auto *op = getDefiningOp())
3239     return op->print(os, state);
3240 
3241   // TODO: Improve BlockArgument print'ing.
3242   BlockArgument arg = this->cast<BlockArgument>();
3243   os << "<block argument> of type '" << arg.getType()
3244      << "' at index: " << arg.getArgNumber();
3245 }
3246 
dump()3247 void Value::dump() {
3248   print(llvm::errs());
3249   llvm::errs() << "\n";
3250 }
3251 
printAsOperand(raw_ostream & os,AsmState & state)3252 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
3253   // TODO: This doesn't necessarily capture all potential cases.
3254   // Currently, region arguments can be shadowed when printing the main
3255   // operation. If the IR hasn't been printed, this will produce the old SSA
3256   // name and not the shadowed name.
3257   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
3258                                                  os);
3259 }
3260 
print(raw_ostream & os,const OpPrintingFlags & printerFlags)3261 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
3262   // If this is a top level operation, we also print aliases.
3263   if (!getParent() && !printerFlags.shouldUseLocalScope()) {
3264     AsmState state(this, printerFlags);
3265     state.getImpl().initializeAliases(this);
3266     print(os, state);
3267     return;
3268   }
3269 
3270   // Find the operation to number from based upon the provided flags.
3271   Operation *op = this;
3272   bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
3273   do {
3274     // If we are printing local scope, stop at the first operation that is
3275     // isolated from above.
3276     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
3277       break;
3278 
3279     // Otherwise, traverse up to the next parent.
3280     Operation *parentOp = op->getParentOp();
3281     if (!parentOp)
3282       break;
3283     op = parentOp;
3284   } while (true);
3285 
3286   AsmState state(op, printerFlags);
3287   print(os, state);
3288 }
print(raw_ostream & os,AsmState & state)3289 void Operation::print(raw_ostream &os, AsmState &state) {
3290   OperationPrinter printer(os, state.getImpl());
3291   if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
3292     printer.printTopLevelOperation(this);
3293   else
3294     printer.print(this);
3295 }
3296 
dump()3297 void Operation::dump() {
3298   print(llvm::errs(), OpPrintingFlags().useLocalScope());
3299   llvm::errs() << "\n";
3300 }
3301 
print(raw_ostream & os)3302 void Block::print(raw_ostream &os) {
3303   Operation *parentOp = getParentOp();
3304   if (!parentOp) {
3305     os << "<<UNLINKED BLOCK>>\n";
3306     return;
3307   }
3308   // Get the top-level op.
3309   while (auto *nextOp = parentOp->getParentOp())
3310     parentOp = nextOp;
3311 
3312   AsmState state(parentOp);
3313   print(os, state);
3314 }
print(raw_ostream & os,AsmState & state)3315 void Block::print(raw_ostream &os, AsmState &state) {
3316   OperationPrinter(os, state.getImpl()).print(this);
3317 }
3318 
dump()3319 void Block::dump() { print(llvm::errs()); }
3320 
3321 /// Print out the name of the block without printing its body.
printAsOperand(raw_ostream & os,bool printType)3322 void Block::printAsOperand(raw_ostream &os, bool printType) {
3323   Operation *parentOp = getParentOp();
3324   if (!parentOp) {
3325     os << "<<UNLINKED BLOCK>>\n";
3326     return;
3327   }
3328   AsmState state(parentOp);
3329   printAsOperand(os, state);
3330 }
printAsOperand(raw_ostream & os,AsmState & state)3331 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
3332   OperationPrinter printer(os, state.getImpl());
3333   printer.printBlockName(this);
3334 }
3335