1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/SubElementInterfaces.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/ScopedHashTable.h"
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/ADT/StringSet.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Endian.h"
38 #include "llvm/Support/Regex.h"
39 #include "llvm/Support/SaveAndRestore.h"
40 
41 #include <tuple>
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 void Identifier::print(raw_ostream &os) const { os << str(); }
47 
48 void Identifier::dump() const { print(llvm::errs()); }
49 
50 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
51 
52 void OperationName::dump() const { print(llvm::errs()); }
53 
54 DialectAsmPrinter::~DialectAsmPrinter() {}
55 
56 //===--------------------------------------------------------------------===//
57 // OpAsmPrinter
58 //===--------------------------------------------------------------------===//
59 
60 OpAsmPrinter::~OpAsmPrinter() {}
61 
62 void OpAsmPrinter::printFunctionalType(Operation *op) {
63   auto &os = getStream();
64   os << '(';
65   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
66     // Print the types of null values as <<NULL TYPE>>.
67     *this << (operand ? operand.getType() : Type());
68   });
69   os << ") -> ";
70 
71   // Print the result list.  We don't parenthesize single result types unless
72   // it is a function (avoiding a grammar ambiguity).
73   bool wrapped = op->getNumResults() != 1;
74   if (!wrapped && op->getResult(0).getType() &&
75       op->getResult(0).getType().isa<FunctionType>())
76     wrapped = true;
77 
78   if (wrapped)
79     os << '(';
80 
81   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
82     // Print the types of null values as <<NULL TYPE>>.
83     *this << (result ? result.getType() : Type());
84   });
85 
86   if (wrapped)
87     os << ')';
88 }
89 
90 //===--------------------------------------------------------------------===//
91 // Operation OpAsm interface.
92 //===--------------------------------------------------------------------===//
93 
94 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
95 #include "mlir/IR/OpAsmInterface.cpp.inc"
96 
97 //===----------------------------------------------------------------------===//
98 // OpPrintingFlags
99 //===----------------------------------------------------------------------===//
100 
101 namespace {
102 /// This struct contains command line options that can be used to initialize
103 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
104 /// for global command line options.
105 struct AsmPrinterOptions {
106   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
107       "mlir-print-elementsattrs-with-hex-if-larger",
108       llvm::cl::desc(
109           "Print DenseElementsAttrs with a hex string that have "
110           "more elements than the given upper limit (use -1 to disable)")};
111 
112   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
113       "mlir-elide-elementsattrs-if-larger",
114       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
115                      "more elements than the given upper limit")};
116 
117   llvm::cl::opt<bool> printDebugInfoOpt{
118       "mlir-print-debuginfo", llvm::cl::init(false),
119       llvm::cl::desc("Print debug info in MLIR output")};
120 
121   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
122       "mlir-pretty-debuginfo", llvm::cl::init(false),
123       llvm::cl::desc("Print pretty debug info in MLIR output")};
124 
125   // Use the generic op output form in the operation printer even if the custom
126   // form is defined.
127   llvm::cl::opt<bool> printGenericOpFormOpt{
128       "mlir-print-op-generic", llvm::cl::init(false),
129       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
130 
131   llvm::cl::opt<bool> printLocalScopeOpt{
132       "mlir-print-local-scope", llvm::cl::init(false),
133       llvm::cl::desc("Print assuming in local scope by default"),
134       llvm::cl::Hidden};
135 };
136 } // end anonymous namespace
137 
138 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
139 
140 /// Register a set of useful command-line options that can be used to configure
141 /// various flags within the AsmPrinter.
142 void mlir::registerAsmPrinterCLOptions() {
143   // Make sure that the options struct has been initialized.
144   *clOptions;
145 }
146 
147 /// Initialize the printing flags with default supplied by the cl::opts above.
148 OpPrintingFlags::OpPrintingFlags()
149     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
150       printGenericOpFormFlag(false), printLocalScope(false) {
151   // Initialize based upon command line options, if they are available.
152   if (!clOptions.isConstructed())
153     return;
154   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
155     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
156   printDebugInfoFlag = clOptions->printDebugInfoOpt;
157   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
158   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
159   printLocalScope = clOptions->printLocalScopeOpt;
160 }
161 
162 /// Enable the elision of large elements attributes, by printing a '...'
163 /// instead of the element data, when the number of elements is greater than
164 /// `largeElementLimit`. Note: The IR generated with this option is not
165 /// parsable.
166 OpPrintingFlags &
167 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
168   elementsAttrElementLimit = largeElementLimit;
169   return *this;
170 }
171 
172 /// Enable printing of debug information. If 'prettyForm' is set to true,
173 /// debug information is printed in a more readable 'pretty' form.
174 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
175   printDebugInfoFlag = true;
176   printDebugInfoPrettyFormFlag = prettyForm;
177   return *this;
178 }
179 
180 /// Always print operations in the generic form.
181 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
182   printGenericOpFormFlag = true;
183   return *this;
184 }
185 
186 /// Use local scope when printing the operation. This allows for using the
187 /// printer in a more localized and thread-safe setting, but may not necessarily
188 /// be identical of what the IR will look like when dumping the full module.
189 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
190   printLocalScope = true;
191   return *this;
192 }
193 
194 /// Return if the given ElementsAttr should be elided.
195 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
196   return elementsAttrElementLimit.hasValue() &&
197          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
198          !attr.isa<SplatElementsAttr>();
199 }
200 
201 /// Return the size limit for printing large ElementsAttr.
202 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
203   return elementsAttrElementLimit;
204 }
205 
206 /// Return if debug information should be printed.
207 bool OpPrintingFlags::shouldPrintDebugInfo() const {
208   return printDebugInfoFlag;
209 }
210 
211 /// Return if debug information should be printed in the pretty form.
212 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
213   return printDebugInfoPrettyFormFlag;
214 }
215 
216 /// Return if operations should be printed in the generic form.
217 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
218   return printGenericOpFormFlag;
219 }
220 
221 /// Return if the printer should use local scope when dumping the IR.
222 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
223 
224 /// Returns true if an ElementsAttr with the given number of elements should be
225 /// printed with hex.
226 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
227   // Check to see if a command line option was provided for the limit.
228   if (clOptions.isConstructed()) {
229     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
230       // -1 is used to disable hex printing.
231       if (clOptions->printElementsAttrWithHexIfLarger == -1)
232         return false;
233       return numElements > clOptions->printElementsAttrWithHexIfLarger;
234     }
235   }
236 
237   // Otherwise, default to printing with hex if the number of elements is >100.
238   return numElements > 100;
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // NewLineCounter
243 //===----------------------------------------------------------------------===//
244 
245 namespace {
246 /// This class is a simple formatter that emits a new line when inputted into a
247 /// stream, that enables counting the number of newlines emitted. This class
248 /// should be used whenever emitting newlines in the printer.
249 struct NewLineCounter {
250   unsigned curLine = 1;
251 };
252 } // end anonymous namespace
253 
254 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
255   ++newLine.curLine;
256   return os << '\n';
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // AliasInitializer
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 /// This class represents a specific instance of a symbol Alias.
265 class SymbolAlias {
266 public:
267   SymbolAlias(StringRef name, bool isDeferrable)
268       : name(name), suffixIndex(0), hasSuffixIndex(false),
269         isDeferrable(isDeferrable) {}
270   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
271       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
272         isDeferrable(isDeferrable) {}
273 
274   /// Print this alias to the given stream.
275   void print(raw_ostream &os) const {
276     os << name;
277     if (hasSuffixIndex)
278       os << suffixIndex;
279   }
280 
281   /// Returns true if this alias supports deferred resolution when parsing.
282   bool canBeDeferred() const { return isDeferrable; }
283 
284 private:
285   /// The main name of the alias.
286   StringRef name;
287   /// The optional suffix index of the alias, if multiple aliases had the same
288   /// name.
289   uint32_t suffixIndex : 30;
290   /// A flag indicating whether this alias has a suffix or not.
291   bool hasSuffixIndex : 1;
292   /// A flag indicating whether this alias may be deferred or not.
293   bool isDeferrable : 1;
294 };
295 
296 /// This class represents a utility that initializes the set of attribute and
297 /// type aliases, without the need to store the extra information within the
298 /// main AliasState class or pass it around via function arguments.
299 class AliasInitializer {
300 public:
301   AliasInitializer(
302       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
303       llvm::BumpPtrAllocator &aliasAllocator)
304       : interfaces(interfaces), aliasAllocator(aliasAllocator),
305         aliasOS(aliasBuffer) {}
306 
307   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
308                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
309                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
310 
311   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
312   /// set to true if the originator of this attribute can resolve the alias
313   /// after parsing has completed (e.g. in the case of operation locations).
314   void visit(Attribute attr, bool canBeDeferred = false);
315 
316   /// Visit the given type to see if it has an alias.
317   void visit(Type type);
318 
319 private:
320   /// Try to generate an alias for the provided symbol. If an alias is
321   /// generated, the provided alias mapping and reverse mapping are updated.
322   /// Returns success if an alias was generated, failure otherwise.
323   template <typename T>
324   LogicalResult
325   generateAlias(T symbol,
326                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
327 
328   /// The set of asm interfaces within the context.
329   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
330 
331   /// Mapping between an alias and the set of symbols mapped to it.
332   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
333   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
334 
335   /// An allocator used for alias names.
336   llvm::BumpPtrAllocator &aliasAllocator;
337 
338   /// The set of visited attributes.
339   DenseSet<Attribute> visitedAttributes;
340 
341   /// The set of attributes that have aliases *and* can be deferred.
342   DenseSet<Attribute> deferrableAttributes;
343 
344   /// The set of visited types.
345   DenseSet<Type> visitedTypes;
346 
347   /// Storage and stream used when generating an alias.
348   SmallString<32> aliasBuffer;
349   llvm::raw_svector_ostream aliasOS;
350 };
351 
352 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
353 /// and merely collects the attributes and types that *would* be printed in a
354 /// normal print invocation so that we can generate proper aliases. This allows
355 /// for us to generate aliases only for the attributes and types that would be
356 /// in the output, and trims down unnecessary output.
357 class DummyAliasOperationPrinter : private OpAsmPrinter {
358 public:
359   explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags,
360                                       AliasInitializer &initializer)
361       : printerFlags(flags), initializer(initializer) {}
362 
363   /// Print the given operation.
364   void print(Operation *op) {
365     // Visit the operation location.
366     if (printerFlags.shouldPrintDebugInfo())
367       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
368 
369     // If requested, always print the generic form.
370     if (!printerFlags.shouldPrintGenericOpForm()) {
371       // Check to see if this is a known operation.  If so, use the registered
372       // custom printer hook.
373       if (auto *opInfo = op->getAbstractOperation()) {
374         opInfo->printAssembly(op, *this);
375         return;
376       }
377     }
378 
379     // Otherwise print with the generic assembly form.
380     printGenericOp(op);
381   }
382 
383 private:
384   /// Print the given operation in the generic form.
385   void printGenericOp(Operation *op) override {
386     // Consider nested operations for aliases.
387     if (op->getNumRegions() != 0) {
388       for (Region &region : op->getRegions())
389         printRegion(region, /*printEntryBlockArgs=*/true,
390                     /*printBlockTerminators=*/true);
391     }
392 
393     // Visit all the types used in the operation.
394     for (Type type : op->getOperandTypes())
395       printType(type);
396     for (Type type : op->getResultTypes())
397       printType(type);
398 
399     // Consider the attributes of the operation for aliases.
400     for (const NamedAttribute &attr : op->getAttrs())
401       printAttribute(attr.second);
402   }
403 
404   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
405   /// block are not printed. If 'printBlockTerminator' is false, the terminator
406   /// operation of the block is not printed.
407   void print(Block *block, bool printBlockArgs = true,
408              bool printBlockTerminator = true) {
409     // Consider the types of the block arguments for aliases if 'printBlockArgs'
410     // is set to true.
411     if (printBlockArgs) {
412       for (BlockArgument arg : block->getArguments()) {
413         printType(arg.getType());
414 
415         // Visit the argument location.
416         if (printerFlags.shouldPrintDebugInfo())
417           // TODO: Allow deferring argument locations.
418           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
419       }
420     }
421 
422     // Consider the operations within this block, ignoring the terminator if
423     // requested.
424     bool hasTerminator =
425         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
426     auto range = llvm::make_range(
427         block->begin(),
428         std::prev(block->end(),
429                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
430     for (Operation &op : range)
431       print(&op);
432   }
433 
434   /// Print the given region.
435   void printRegion(Region &region, bool printEntryBlockArgs,
436                    bool printBlockTerminators,
437                    bool printEmptyBlock = false) override {
438     if (region.empty())
439       return;
440 
441     auto *entryBlock = &region.front();
442     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
443     for (Block &b : llvm::drop_begin(region, 1))
444       print(&b);
445   }
446 
447   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
448                            bool omitType) override {
449     printType(arg.getType());
450     // Visit the argument location.
451     if (printerFlags.shouldPrintDebugInfo())
452       // TODO: Allow deferring argument locations.
453       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
454   }
455 
456   /// Consider the given type to be printed for an alias.
457   void printType(Type type) override { initializer.visit(type); }
458 
459   /// Consider the given attribute to be printed for an alias.
460   void printAttribute(Attribute attr) override { initializer.visit(attr); }
461   void printAttributeWithoutType(Attribute attr) override {
462     printAttribute(attr);
463   }
464 
465   /// Print the given set of attributes with names not included within
466   /// 'elidedAttrs'.
467   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
468                              ArrayRef<StringRef> elidedAttrs = {}) override {
469     if (attrs.empty())
470       return;
471     if (elidedAttrs.empty()) {
472       for (const NamedAttribute &attr : attrs)
473         printAttribute(attr.second);
474       return;
475     }
476     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
477                                                   elidedAttrs.end());
478     for (const NamedAttribute &attr : attrs)
479       if (!elidedAttrsSet.contains(attr.first.strref()))
480         printAttribute(attr.second);
481   }
482   void printOptionalAttrDictWithKeyword(
483       ArrayRef<NamedAttribute> attrs,
484       ArrayRef<StringRef> elidedAttrs = {}) override {
485     printOptionalAttrDict(attrs, elidedAttrs);
486   }
487 
488   /// Return a null stream as the output stream, this will ignore any data fed
489   /// to it.
490   raw_ostream &getStream() const override { return os; }
491 
492   /// The following are hooks of `OpAsmPrinter` that are not necessary for
493   /// determining potential aliases.
494   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
495   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
496   void printNewline() override {}
497   void printOperand(Value) override {}
498   void printOperand(Value, raw_ostream &os) override {
499     // Users expect the output string to have at least the prefixed % to signal
500     // a value name. To maintain this invariant, emit a name even if it is
501     // guaranteed to go unused.
502     os << "%";
503   }
504   void printSymbolName(StringRef) override {}
505   void printSuccessor(Block *) override {}
506   void printSuccessorAndUseList(Block *, ValueRange) override {}
507   void shadowRegionArgs(Region &, ValueRange) override {}
508 
509   /// The printer flags to use when determining potential aliases.
510   const OpPrintingFlags &printerFlags;
511 
512   /// The initializer to use when identifying aliases.
513   AliasInitializer &initializer;
514 
515   /// A dummy output stream.
516   mutable llvm::raw_null_ostream os;
517 };
518 } // end anonymous namespace
519 
520 /// Sanitize the given name such that it can be used as a valid identifier. If
521 /// the string needs to be modified in any way, the provided buffer is used to
522 /// store the new copy,
523 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
524                                     StringRef allowedPunctChars = "$._-",
525                                     bool allowTrailingDigit = true) {
526   assert(!name.empty() && "Shouldn't have an empty name here");
527 
528   auto copyNameToBuffer = [&] {
529     for (char ch : name) {
530       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
531         buffer.push_back(ch);
532       else if (ch == ' ')
533         buffer.push_back('_');
534       else
535         buffer.append(llvm::utohexstr((unsigned char)ch));
536     }
537   };
538 
539   // Check to see if this name is valid. If it starts with a digit, then it
540   // could conflict with the autogenerated numeric ID's, so add an underscore
541   // prefix to avoid problems.
542   if (isdigit(name[0])) {
543     buffer.push_back('_');
544     copyNameToBuffer();
545     return buffer;
546   }
547 
548   // If the name ends with a trailing digit, add a '_' to avoid potential
549   // conflicts with autogenerated ID's.
550   if (!allowTrailingDigit && isdigit(name.back())) {
551     copyNameToBuffer();
552     buffer.push_back('_');
553     return buffer;
554   }
555 
556   // Check to see that the name consists of only valid identifier characters.
557   for (char ch : name) {
558     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
559       copyNameToBuffer();
560       return buffer;
561     }
562   }
563 
564   // If there are no invalid characters, return the original name.
565   return name;
566 }
567 
568 /// Given a collection of aliases and symbols, initialize a mapping from a
569 /// symbol to a given alias.
570 template <typename T>
571 static void
572 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
573                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
574                   DenseSet<T> *deferrableAliases = nullptr) {
575   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
576       aliasToSymbol.takeVector();
577   llvm::array_pod_sort(aliases.begin(), aliases.end(),
578                        [](const auto *lhs, const auto *rhs) {
579                          return lhs->first.compare(rhs->first);
580                        });
581 
582   for (auto &it : aliases) {
583     // If there is only one instance for this alias, use the name directly.
584     if (it.second.size() == 1) {
585       T symbol = it.second.front();
586       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
587       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
588       continue;
589     }
590     // Otherwise, add the index to the name.
591     for (int i = 0, e = it.second.size(); i < e; ++i) {
592       T symbol = it.second[i];
593       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
594       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
595     }
596   }
597 }
598 
599 void AliasInitializer::initialize(
600     Operation *op, const OpPrintingFlags &printerFlags,
601     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
602     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
603   // Use a dummy printer when walking the IR so that we can collect the
604   // attributes/types that will actually be used during printing when
605   // considering aliases.
606   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
607   aliasPrinter.print(op);
608 
609   // Initialize the aliases sorted by name.
610   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
611   initializeAliases(aliasToType, typeToAlias);
612 }
613 
614 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
615   if (!visitedAttributes.insert(attr).second) {
616     // If this attribute already has an alias and this instance can't be
617     // deferred, make sure that the alias isn't deferred.
618     if (!canBeDeferred)
619       deferrableAttributes.erase(attr);
620     return;
621   }
622 
623   // Try to generate an alias for this attribute.
624   if (succeeded(generateAlias(attr, aliasToAttr))) {
625     if (canBeDeferred)
626       deferrableAttributes.insert(attr);
627     return;
628   }
629 
630   // Check for any sub elements.
631   if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
632     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
633                                         [&](Type type) { visit(type); });
634   }
635 }
636 
637 void AliasInitializer::visit(Type type) {
638   if (!visitedTypes.insert(type).second)
639     return;
640 
641   // Try to generate an alias for this type.
642   if (succeeded(generateAlias(type, aliasToType)))
643     return;
644 
645   // Check for any sub elements.
646   if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
647     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
648                                         [&](Type type) { visit(type); });
649   }
650 }
651 
652 template <typename T>
653 LogicalResult AliasInitializer::generateAlias(
654     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
655   SmallString<16> tempBuffer;
656   for (const auto &interface : interfaces) {
657     if (failed(interface.getAlias(symbol, aliasOS)))
658       continue;
659     StringRef name = aliasOS.str();
660     assert(!name.empty() && "expected valid alias name");
661     name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
662                               /*allowTrailingDigit=*/false);
663     name = name.copy(aliasAllocator);
664 
665     aliasToSymbol[name].push_back(symbol);
666     aliasBuffer.clear();
667     return success();
668   }
669   return failure();
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // AliasState
674 //===----------------------------------------------------------------------===//
675 
676 namespace {
677 /// This class manages the state for type and attribute aliases.
678 class AliasState {
679 public:
680   // Initialize the internal aliases.
681   void
682   initialize(Operation *op, const OpPrintingFlags &printerFlags,
683              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
684 
685   /// Get an alias for the given attribute if it has one and print it in `os`.
686   /// Returns success if an alias was printed, failure otherwise.
687   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
688 
689   /// Get an alias for the given type if it has one and print it in `os`.
690   /// Returns success if an alias was printed, failure otherwise.
691   LogicalResult getAlias(Type ty, raw_ostream &os) const;
692 
693   /// Print all of the referenced aliases that can not be resolved in a deferred
694   /// manner.
695   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
696     printAliases(os, newLine, /*isDeferred=*/false);
697   }
698 
699   /// Print all of the referenced aliases that support deferred resolution.
700   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
701     printAliases(os, newLine, /*isDeferred=*/true);
702   }
703 
704 private:
705   /// Print all of the referenced aliases that support the provided resolution
706   /// behavior.
707   void printAliases(raw_ostream &os, NewLineCounter &newLine,
708                     bool isDeferred) const;
709 
710   /// Mapping between attribute and alias.
711   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
712   /// Mapping between type and alias.
713   llvm::MapVector<Type, SymbolAlias> typeToAlias;
714 
715   /// An allocator used for alias names.
716   llvm::BumpPtrAllocator aliasAllocator;
717 };
718 } // end anonymous namespace
719 
720 void AliasState::initialize(
721     Operation *op, const OpPrintingFlags &printerFlags,
722     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
723   AliasInitializer initializer(interfaces, aliasAllocator);
724   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
725 }
726 
727 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
728   auto it = attrToAlias.find(attr);
729   if (it == attrToAlias.end())
730     return failure();
731   it->second.print(os << '#');
732   return success();
733 }
734 
735 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
736   auto it = typeToAlias.find(ty);
737   if (it == typeToAlias.end())
738     return failure();
739 
740   it->second.print(os << '!');
741   return success();
742 }
743 
744 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
745                               bool isDeferred) const {
746   auto filterFn = [=](const auto &aliasIt) {
747     return aliasIt.second.canBeDeferred() == isDeferred;
748   };
749   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
750     it.second.print(os << '#');
751     os << " = " << it.first << newLine;
752   }
753   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
754     it.second.print(os << '!');
755     os << " = type " << it.first << newLine;
756   }
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // SSANameState
761 //===----------------------------------------------------------------------===//
762 
763 namespace {
764 /// This class manages the state of SSA value names.
765 class SSANameState {
766 public:
767   /// A sentinel value used for values with names set.
768   enum : unsigned { NameSentinel = ~0U };
769 
770   SSANameState(Operation *op,
771                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
772 
773   /// Print the SSA identifier for the given value to 'stream'. If
774   /// 'printResultNo' is true, it also presents the result number ('#' number)
775   /// of this value.
776   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
777 
778   /// Return the result indices for each of the result groups registered by this
779   /// operation, or empty if none exist.
780   ArrayRef<int> getOpResultGroups(Operation *op);
781 
782   /// Get the ID for the given block.
783   unsigned getBlockID(Block *block);
784 
785   /// Renumber the arguments for the specified region to the same names as the
786   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
787   /// details.
788   void shadowRegionArgs(Region &region, ValueRange namesToUse);
789 
790 private:
791   /// Number the SSA values within the given IR unit.
792   void numberValuesInRegion(
793       Region &region,
794       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
795   void numberValuesInBlock(
796       Block &block,
797       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
798   void numberValuesInOp(
799       Operation &op,
800       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
801 
802   /// Given a result of an operation 'result', find the result group head
803   /// 'lookupValue' and the result of 'result' within that group in
804   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
805   /// has more than 1 result.
806   void getResultIDAndNumber(OpResult result, Value &lookupValue,
807                             Optional<int> &lookupResultNo) const;
808 
809   /// Set a special value name for the given value.
810   void setValueName(Value value, StringRef name);
811 
812   /// Uniques the given value name within the printer. If the given name
813   /// conflicts, it is automatically renamed.
814   StringRef uniqueValueName(StringRef name);
815 
816   /// This is the value ID for each SSA value. If this returns NameSentinel,
817   /// then the valueID has an entry in valueNames.
818   DenseMap<Value, unsigned> valueIDs;
819   DenseMap<Value, StringRef> valueNames;
820 
821   /// This is a map of operations that contain multiple named result groups,
822   /// i.e. there may be multiple names for the results of the operation. The
823   /// value of this map are the result numbers that start a result group.
824   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
825 
826   /// This is the block ID for each block in the current.
827   DenseMap<Block *, unsigned> blockIDs;
828 
829   /// This keeps track of all of the non-numeric names that are in flight,
830   /// allowing us to check for duplicates.
831   /// Note: the value of the map is unused.
832   llvm::ScopedHashTable<StringRef, char> usedNames;
833   llvm::BumpPtrAllocator usedNameAllocator;
834 
835   /// This is the next value ID to assign in numbering.
836   unsigned nextValueID = 0;
837   /// This is the next ID to assign to a region entry block argument.
838   unsigned nextArgumentID = 0;
839   /// This is the next ID to assign when a name conflict is detected.
840   unsigned nextConflictID = 0;
841 };
842 } // end anonymous namespace
843 
844 SSANameState::SSANameState(
845     Operation *op,
846     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
847   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
848   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
849   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
850 
851   // The naming context includes `nextValueID`, `nextArgumentID`,
852   // `nextConflictID` and `usedNames` scoped HashTable. This information is
853   // carried from the parent region.
854   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
855   using NamingContext =
856       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
857 
858   // Allocator for UsedNamesScopeTy
859   llvm::BumpPtrAllocator allocator;
860 
861   // Add a scope for the top level operation.
862   auto *topLevelNamesScope =
863       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
864 
865   SmallVector<NamingContext, 8> nameContext;
866   for (Region &region : op->getRegions())
867     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
868                                           nextConflictID, topLevelNamesScope));
869 
870   numberValuesInOp(*op, interfaces);
871 
872   while (!nameContext.empty()) {
873     Region *region;
874     UsedNamesScopeTy *parentScope;
875     std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
876         nameContext.pop_back_val();
877 
878     // When we switch from one subtree to another, pop the scopes(needless)
879     // until the parent scope.
880     while (usedNames.getCurScope() != parentScope) {
881       usedNames.getCurScope()->~UsedNamesScopeTy();
882       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
883              "top level parentScope must be a nullptr");
884     }
885 
886     // Add a scope for the current region.
887     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
888         UsedNamesScopeTy(usedNames);
889 
890     numberValuesInRegion(*region, interfaces);
891 
892     for (Operation &op : region->getOps())
893       for (Region &region : op.getRegions())
894         nameContext.push_back(std::make_tuple(&region, nextValueID,
895                                               nextArgumentID, nextConflictID,
896                                               curNamesScope));
897   }
898 
899   // Manually remove all the scopes.
900   while (usedNames.getCurScope() != nullptr)
901     usedNames.getCurScope()->~UsedNamesScopeTy();
902 }
903 
904 void SSANameState::printValueID(Value value, bool printResultNo,
905                                 raw_ostream &stream) const {
906   if (!value) {
907     stream << "<<NULL>>";
908     return;
909   }
910 
911   Optional<int> resultNo;
912   auto lookupValue = value;
913 
914   // If this is an operation result, collect the head lookup value of the result
915   // group and the result number of 'result' within that group.
916   if (OpResult result = value.dyn_cast<OpResult>())
917     getResultIDAndNumber(result, lookupValue, resultNo);
918 
919   auto it = valueIDs.find(lookupValue);
920   if (it == valueIDs.end()) {
921     stream << "<<UNKNOWN SSA VALUE>>";
922     return;
923   }
924 
925   stream << '%';
926   if (it->second != NameSentinel) {
927     stream << it->second;
928   } else {
929     auto nameIt = valueNames.find(lookupValue);
930     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
931     stream << nameIt->second;
932   }
933 
934   if (resultNo.hasValue() && printResultNo)
935     stream << '#' << resultNo;
936 }
937 
938 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
939   auto it = opResultGroups.find(op);
940   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
941 }
942 
943 unsigned SSANameState::getBlockID(Block *block) {
944   auto it = blockIDs.find(block);
945   return it != blockIDs.end() ? it->second : NameSentinel;
946 }
947 
948 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
949   assert(!region.empty() && "cannot shadow arguments of an empty region");
950   assert(region.getNumArguments() == namesToUse.size() &&
951          "incorrect number of names passed in");
952   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
953          "only KnownIsolatedFromAbove ops can shadow names");
954 
955   SmallVector<char, 16> nameStr;
956   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
957     auto nameToUse = namesToUse[i];
958     if (nameToUse == nullptr)
959       continue;
960     auto nameToReplace = region.getArgument(i);
961 
962     nameStr.clear();
963     llvm::raw_svector_ostream nameStream(nameStr);
964     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
965 
966     // Entry block arguments should already have a pretty "arg" name.
967     assert(valueIDs[nameToReplace] == NameSentinel);
968 
969     // Use the name without the leading %.
970     auto name = StringRef(nameStream.str()).drop_front();
971 
972     // Overwrite the name.
973     valueNames[nameToReplace] = name.copy(usedNameAllocator);
974   }
975 }
976 
977 void SSANameState::numberValuesInRegion(
978     Region &region,
979     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
980   // Number the values within this region in a breadth-first order.
981   unsigned nextBlockID = 0;
982   for (auto &block : region) {
983     // Each block gets a unique ID, and all of the operations within it get
984     // numbered as well.
985     blockIDs[&block] = nextBlockID++;
986     numberValuesInBlock(block, interfaces);
987   }
988 }
989 
990 void SSANameState::numberValuesInBlock(
991     Block &block,
992     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
993   auto setArgNameFn = [&](Value arg, StringRef name) {
994     assert(!valueIDs.count(arg) && "arg numbered multiple times");
995     assert(arg.cast<BlockArgument>().getOwner() == &block &&
996            "arg not defined in 'block'");
997     setValueName(arg, name);
998   };
999 
1000   bool isEntryBlock = block.isEntryBlock();
1001   if (isEntryBlock) {
1002     if (auto *op = block.getParentOp()) {
1003       if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
1004         asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
1005     }
1006   }
1007 
1008   // Number the block arguments. We give entry block arguments a special name
1009   // 'arg'.
1010   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1011   llvm::raw_svector_ostream specialName(specialNameBuffer);
1012   for (auto arg : block.getArguments()) {
1013     if (valueIDs.count(arg))
1014       continue;
1015     if (isEntryBlock) {
1016       specialNameBuffer.resize(strlen("arg"));
1017       specialName << nextArgumentID++;
1018     }
1019     setValueName(arg, specialName.str());
1020   }
1021 
1022   // Number the operations in this block.
1023   for (auto &op : block)
1024     numberValuesInOp(op, interfaces);
1025 }
1026 
1027 void SSANameState::numberValuesInOp(
1028     Operation &op,
1029     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
1030   unsigned numResults = op.getNumResults();
1031   if (numResults == 0)
1032     return;
1033   Value resultBegin = op.getResult(0);
1034 
1035   // Function used to set the special result names for the operation.
1036   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1037   auto setResultNameFn = [&](Value result, StringRef name) {
1038     assert(!valueIDs.count(result) && "result numbered multiple times");
1039     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1040     setValueName(result, name);
1041 
1042     // Record the result number for groups not anchored at 0.
1043     if (int resultNo = result.cast<OpResult>().getResultNumber())
1044       resultGroups.push_back(resultNo);
1045   };
1046   if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
1047     asmInterface.getAsmResultNames(setResultNameFn);
1048   else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
1049     asmInterface->getAsmResultNames(&op, setResultNameFn);
1050 
1051   // If the first result wasn't numbered, give it a default number.
1052   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1053     ++nextValueID;
1054 
1055   // If this operation has multiple result groups, mark it.
1056   if (resultGroups.size() != 1) {
1057     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1058     opResultGroups.try_emplace(&op, std::move(resultGroups));
1059   }
1060 }
1061 
1062 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1063                                         Optional<int> &lookupResultNo) const {
1064   Operation *owner = result.getOwner();
1065   if (owner->getNumResults() == 1)
1066     return;
1067   int resultNo = result.getResultNumber();
1068 
1069   // If this operation has multiple result groups, we will need to find the
1070   // one corresponding to this result.
1071   auto resultGroupIt = opResultGroups.find(owner);
1072   if (resultGroupIt == opResultGroups.end()) {
1073     // If not, just use the first result.
1074     lookupResultNo = resultNo;
1075     lookupValue = owner->getResult(0);
1076     return;
1077   }
1078 
1079   // Find the correct index using a binary search, as the groups are ordered.
1080   ArrayRef<int> resultGroups = resultGroupIt->second;
1081   auto it = llvm::upper_bound(resultGroups, resultNo);
1082   int groupResultNo = 0, groupSize = 0;
1083 
1084   // If there are no smaller elements, the last result group is the lookup.
1085   if (it == resultGroups.end()) {
1086     groupResultNo = resultGroups.back();
1087     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1088   } else {
1089     // Otherwise, the previous element is the lookup.
1090     groupResultNo = *std::prev(it);
1091     groupSize = *it - groupResultNo;
1092   }
1093 
1094   // We only record the result number for a group of size greater than 1.
1095   if (groupSize != 1)
1096     lookupResultNo = resultNo - groupResultNo;
1097   lookupValue = owner->getResult(groupResultNo);
1098 }
1099 
1100 void SSANameState::setValueName(Value value, StringRef name) {
1101   // If the name is empty, the value uses the default numbering.
1102   if (name.empty()) {
1103     valueIDs[value] = nextValueID++;
1104     return;
1105   }
1106 
1107   valueIDs[value] = NameSentinel;
1108   valueNames[value] = uniqueValueName(name);
1109 }
1110 
1111 StringRef SSANameState::uniqueValueName(StringRef name) {
1112   SmallString<16> tmpBuffer;
1113   name = sanitizeIdentifier(name, tmpBuffer);
1114 
1115   // Check to see if this name is already unique.
1116   if (!usedNames.count(name)) {
1117     name = name.copy(usedNameAllocator);
1118   } else {
1119     // Otherwise, we had a conflict - probe until we find a unique name. This
1120     // is guaranteed to terminate (and usually in a single iteration) because it
1121     // generates new names by incrementing nextConflictID.
1122     SmallString<64> probeName(name);
1123     probeName.push_back('_');
1124     while (true) {
1125       probeName += llvm::utostr(nextConflictID++);
1126       if (!usedNames.count(probeName)) {
1127         name = StringRef(probeName).copy(usedNameAllocator);
1128         break;
1129       }
1130       probeName.resize(name.size() + 1);
1131     }
1132   }
1133 
1134   usedNames.insert(name, char());
1135   return name;
1136 }
1137 
1138 //===----------------------------------------------------------------------===//
1139 // AsmState
1140 //===----------------------------------------------------------------------===//
1141 
1142 namespace mlir {
1143 namespace detail {
1144 class AsmStateImpl {
1145 public:
1146   explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap)
1147       : interfaces(op->getContext()), nameState(op, interfaces),
1148         locationMap(locationMap) {}
1149 
1150   /// Initialize the alias state to enable the printing of aliases.
1151   void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) {
1152     aliasState.initialize(op, printerFlags, interfaces);
1153   }
1154 
1155   /// Get an instance of the OpAsmDialectInterface for the given dialect, or
1156   /// null if one wasn't registered.
1157   const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
1158     return interfaces.getInterfaceFor(dialect);
1159   }
1160 
1161   /// Get the state used for aliases.
1162   AliasState &getAliasState() { return aliasState; }
1163 
1164   /// Get the state used for SSA names.
1165   SSANameState &getSSANameState() { return nameState; }
1166 
1167   /// Register the location, line and column, within the buffer that the given
1168   /// operation was printed at.
1169   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1170     if (locationMap)
1171       (*locationMap)[op] = std::make_pair(line, col);
1172   }
1173 
1174 private:
1175   /// Collection of OpAsm interfaces implemented in the context.
1176   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1177 
1178   /// The state used for attribute and type aliases.
1179   AliasState aliasState;
1180 
1181   /// The state used for SSA value names.
1182   SSANameState nameState;
1183 
1184   /// An optional location map to be populated.
1185   AsmState::LocationMap *locationMap;
1186 };
1187 } // end namespace detail
1188 } // end namespace mlir
1189 
1190 AsmState::AsmState(Operation *op, LocationMap *locationMap)
1191     : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {}
1192 AsmState::~AsmState() {}
1193 
1194 //===----------------------------------------------------------------------===//
1195 // ModulePrinter
1196 //===----------------------------------------------------------------------===//
1197 
1198 namespace {
1199 class ModulePrinter {
1200 public:
1201   ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1202                 AsmStateImpl *state = nullptr)
1203       : os(os), printerFlags(flags), state(state) {}
1204   explicit ModulePrinter(ModulePrinter &printer)
1205       : os(printer.os), printerFlags(printer.printerFlags),
1206         state(printer.state) {}
1207 
1208   /// Returns the output stream of the printer.
1209   raw_ostream &getStream() { return os; }
1210 
1211   template <typename Container, typename UnaryFunctor>
1212   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
1213     llvm::interleaveComma(c, os, each_fn);
1214   }
1215 
1216   /// This enum describes the different kinds of elision for the type of an
1217   /// attribute when printing it.
1218   enum class AttrTypeElision {
1219     /// The type must not be elided,
1220     Never,
1221     /// The type may be elided when it matches the default used in the parser
1222     /// (for example i64 is the default for integer attributes).
1223     May,
1224     /// The type must be elided.
1225     Must
1226   };
1227 
1228   /// Print the given attribute.
1229   void printAttribute(Attribute attr,
1230                       AttrTypeElision typeElision = AttrTypeElision::Never);
1231 
1232   void printType(Type type);
1233 
1234   /// Print the given location to the stream. If `allowAlias` is true, this
1235   /// allows for the internal location to use an attribute alias.
1236   void printLocation(LocationAttr loc, bool allowAlias = false);
1237 
1238   void printAffineMap(AffineMap map);
1239   void
1240   printAffineExpr(AffineExpr expr,
1241                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1242   void printAffineConstraint(AffineExpr expr, bool isEq);
1243   void printIntegerSet(IntegerSet set);
1244 
1245 protected:
1246   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1247                              ArrayRef<StringRef> elidedAttrs = {},
1248                              bool withKeyword = false);
1249   void printNamedAttribute(NamedAttribute attr);
1250   void printTrailingLocation(Location loc, bool allowAlias = true);
1251   void printLocationInternal(LocationAttr loc, bool pretty = false);
1252 
1253   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1254   /// used instead of individual elements when the elements attr is large.
1255   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1256 
1257   /// Print a dense string elements attribute.
1258   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1259 
1260   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1261   /// used instead of individual elements when the elements attr is large.
1262   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1263                                      bool allowHex);
1264 
1265   void printDialectAttribute(Attribute attr);
1266   void printDialectType(Type type);
1267 
1268   /// This enum is used to represent the binding strength of the enclosing
1269   /// context that an AffineExprStorage is being printed in, so we can
1270   /// intelligently produce parens.
1271   enum class BindingStrength {
1272     Weak,   // + and -
1273     Strong, // All other binary operators.
1274   };
1275   void printAffineExprInternal(
1276       AffineExpr expr, BindingStrength enclosingTightness,
1277       function_ref<void(unsigned, bool)> printValueName = nullptr);
1278 
1279   /// The output stream for the printer.
1280   raw_ostream &os;
1281 
1282   /// A set of flags to control the printer's behavior.
1283   OpPrintingFlags printerFlags;
1284 
1285   /// An optional printer state for the module.
1286   AsmStateImpl *state;
1287 
1288   /// A tracker for the number of new lines emitted during printing.
1289   NewLineCounter newLine;
1290 };
1291 } // end anonymous namespace
1292 
1293 void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) {
1294   // Check to see if we are printing debug information.
1295   if (!printerFlags.shouldPrintDebugInfo())
1296     return;
1297 
1298   os << " ";
1299   printLocation(loc, /*allowAlias=*/allowAlias);
1300 }
1301 
1302 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
1303   TypeSwitch<LocationAttr>(loc)
1304       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1305         printLocationInternal(loc.getFallbackLocation(), pretty);
1306       })
1307       .Case<UnknownLoc>([&](UnknownLoc loc) {
1308         if (pretty)
1309           os << "[unknown]";
1310         else
1311           os << "unknown";
1312       })
1313       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1314         if (pretty) {
1315           os << loc.getFilename();
1316         } else {
1317           os << "\"";
1318           printEscapedString(loc.getFilename(), os);
1319           os << "\"";
1320         }
1321         os << ':' << loc.getLine() << ':' << loc.getColumn();
1322       })
1323       .Case<NameLoc>([&](NameLoc loc) {
1324         os << '\"';
1325         printEscapedString(loc.getName(), os);
1326         os << '\"';
1327 
1328         // Print the child if it isn't unknown.
1329         auto childLoc = loc.getChildLoc();
1330         if (!childLoc.isa<UnknownLoc>()) {
1331           os << '(';
1332           printLocationInternal(childLoc, pretty);
1333           os << ')';
1334         }
1335       })
1336       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1337         Location caller = loc.getCaller();
1338         Location callee = loc.getCallee();
1339         if (!pretty)
1340           os << "callsite(";
1341         printLocationInternal(callee, pretty);
1342         if (pretty) {
1343           if (callee.isa<NameLoc>()) {
1344             if (caller.isa<FileLineColLoc>()) {
1345               os << " at ";
1346             } else {
1347               os << newLine << " at ";
1348             }
1349           } else {
1350             os << newLine << " at ";
1351           }
1352         } else {
1353           os << " at ";
1354         }
1355         printLocationInternal(caller, pretty);
1356         if (!pretty)
1357           os << ")";
1358       })
1359       .Case<FusedLoc>([&](FusedLoc loc) {
1360         if (!pretty)
1361           os << "fused";
1362         if (Attribute metadata = loc.getMetadata())
1363           os << '<' << metadata << '>';
1364         os << '[';
1365         interleave(
1366             loc.getLocations(),
1367             [&](Location loc) { printLocationInternal(loc, pretty); },
1368             [&]() { os << ", "; });
1369         os << ']';
1370       });
1371 }
1372 
1373 /// Print a floating point value in a way that the parser will be able to
1374 /// round-trip losslessly.
1375 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1376   // We would like to output the FP constant value in exponential notation,
1377   // but we cannot do this if doing so will lose precision.  Check here to
1378   // make sure that we only output it in exponential format if we can parse
1379   // the value back and get the same value.
1380   bool isInf = apValue.isInfinity();
1381   bool isNaN = apValue.isNaN();
1382   if (!isInf && !isNaN) {
1383     SmallString<128> strValue;
1384     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1385                      /*TruncateZero=*/false);
1386 
1387     // Check to make sure that the stringized number is not some string like
1388     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1389     // that the string matches the "[-+]?[0-9]" regex.
1390     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1391             ((strValue[0] == '-' || strValue[0] == '+') &&
1392              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1393            "[-+]?[0-9] regex does not match!");
1394 
1395     // Parse back the stringized version and check that the value is equal
1396     // (i.e., there is no precision loss).
1397     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1398       os << strValue;
1399       return;
1400     }
1401 
1402     // If it is not, use the default format of APFloat instead of the
1403     // exponential notation.
1404     strValue.clear();
1405     apValue.toString(strValue);
1406 
1407     // Make sure that we can parse the default form as a float.
1408     if (StringRef(strValue).contains('.')) {
1409       os << strValue;
1410       return;
1411     }
1412   }
1413 
1414   // Print special values in hexadecimal format. The sign bit should be included
1415   // in the literal.
1416   SmallVector<char, 16> str;
1417   APInt apInt = apValue.bitcastToAPInt();
1418   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1419                  /*formatAsCLiteral=*/true);
1420   os << str;
1421 }
1422 
1423 void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) {
1424   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1425     return printLocationInternal(loc, /*pretty=*/true);
1426 
1427   os << "loc(";
1428   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1429     printLocationInternal(loc);
1430   os << ')';
1431 }
1432 
1433 /// Returns true if the given dialect symbol data is simple enough to print in
1434 /// the pretty form, i.e. without the enclosing "".
1435 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1436   // The name must start with an identifier.
1437   if (symName.empty() || !isalpha(symName.front()))
1438     return false;
1439 
1440   // Ignore all the characters that are valid in an identifier in the symbol
1441   // name.
1442   symName = symName.drop_while(
1443       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1444   if (symName.empty())
1445     return true;
1446 
1447   // If we got to an unexpected character, then it must be a <>.  Check those
1448   // recursively.
1449   if (symName.front() != '<' || symName.back() != '>')
1450     return false;
1451 
1452   SmallVector<char, 8> nestedPunctuation;
1453   do {
1454     // If we ran out of characters, then we had a punctuation mismatch.
1455     if (symName.empty())
1456       return false;
1457 
1458     auto c = symName.front();
1459     symName = symName.drop_front();
1460 
1461     switch (c) {
1462     // We never allow null characters. This is an EOF indicator for the lexer
1463     // which we could handle, but isn't important for any known dialect.
1464     case '\0':
1465       return false;
1466     case '<':
1467     case '[':
1468     case '(':
1469     case '{':
1470       nestedPunctuation.push_back(c);
1471       continue;
1472     case '-':
1473       // Treat `->` as a special token.
1474       if (!symName.empty() && symName.front() == '>') {
1475         symName = symName.drop_front();
1476         continue;
1477       }
1478       break;
1479     // Reject types with mismatched brackets.
1480     case '>':
1481       if (nestedPunctuation.pop_back_val() != '<')
1482         return false;
1483       break;
1484     case ']':
1485       if (nestedPunctuation.pop_back_val() != '[')
1486         return false;
1487       break;
1488     case ')':
1489       if (nestedPunctuation.pop_back_val() != '(')
1490         return false;
1491       break;
1492     case '}':
1493       if (nestedPunctuation.pop_back_val() != '{')
1494         return false;
1495       break;
1496     default:
1497       continue;
1498     }
1499 
1500     // We're done when the punctuation is fully matched.
1501   } while (!nestedPunctuation.empty());
1502 
1503   // If there were extra characters, then we failed.
1504   return symName.empty();
1505 }
1506 
1507 /// Print the given dialect symbol to the stream.
1508 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1509                                StringRef dialectName, StringRef symString) {
1510   os << symPrefix << dialectName;
1511 
1512   // If this symbol name is simple enough, print it directly in pretty form,
1513   // otherwise, we print it as an escaped string.
1514   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1515     os << '.' << symString;
1516     return;
1517   }
1518 
1519   // TODO: escape the symbol name, it could contain " characters.
1520   os << "<\"" << symString << "\">";
1521 }
1522 
1523 /// Returns true if the given string can be represented as a bare identifier.
1524 static bool isBareIdentifier(StringRef name) {
1525   assert(!name.empty() && "invalid name");
1526 
1527   // By making this unsigned, the value passed in to isalnum will always be
1528   // in the range 0-255. This is important when building with MSVC because
1529   // its implementation will assert. This situation can arise when dealing
1530   // with UTF-8 multibyte characters.
1531   unsigned char firstChar = static_cast<unsigned char>(name[0]);
1532   if (!isalpha(firstChar) && firstChar != '_')
1533     return false;
1534   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1535     return isalnum(c) || c == '_' || c == '$' || c == '.';
1536   });
1537 }
1538 
1539 /// Print the given string as a symbol reference. A symbol reference is
1540 /// represented as a string prefixed with '@'. The reference is surrounded with
1541 /// ""'s and escaped if it has any special or non-printable characters in it.
1542 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1543   assert(!symbolRef.empty() && "expected valid symbol reference");
1544 
1545   // If the symbol can be represented as a bare identifier, write it directly.
1546   if (isBareIdentifier(symbolRef)) {
1547     os << '@' << symbolRef;
1548     return;
1549   }
1550 
1551   // Otherwise, output the reference wrapped in quotes with proper escaping.
1552   os << "@\"";
1553   printEscapedString(symbolRef, os);
1554   os << '"';
1555 }
1556 
1557 // Print out a valid ElementsAttr that is succinct and can represent any
1558 // potential shape/type, for use when eliding a large ElementsAttr.
1559 //
1560 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1561 // hopefully alert readers to the fact that this has been elided.
1562 //
1563 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1564 // accept the string "elided". The first string must be a registered dialect
1565 // name and the latter must be a hex constant.
1566 static void printElidedElementsAttr(raw_ostream &os) {
1567   os << R"(opaque<"_", "0xDEADBEEF">)";
1568 }
1569 
1570 void ModulePrinter::printAttribute(Attribute attr,
1571                                    AttrTypeElision typeElision) {
1572   if (!attr) {
1573     os << "<<NULL ATTRIBUTE>>";
1574     return;
1575   }
1576 
1577   // Try to print an alias for this attribute.
1578   if (state && succeeded(state->getAliasState().getAlias(attr, os)))
1579     return;
1580 
1581   auto attrType = attr.getType();
1582   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1583     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1584                        opaqueAttr.getAttrData());
1585   } else if (attr.isa<UnitAttr>()) {
1586     os << "unit";
1587     return;
1588   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1589     os << '{';
1590     interleaveComma(dictAttr.getValue(),
1591                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1592     os << '}';
1593 
1594   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1595     if (attrType.isSignlessInteger(1)) {
1596       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1597 
1598       // Boolean integer attributes always elides the type.
1599       return;
1600     }
1601 
1602     // Only print attributes as unsigned if they are explicitly unsigned or are
1603     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1604     // values print as signed.
1605     bool isUnsigned =
1606         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1607     intAttr.getValue().print(os, !isUnsigned);
1608 
1609     // IntegerAttr elides the type if I64.
1610     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1611       return;
1612 
1613   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1614     printFloatValue(floatAttr.getValue(), os);
1615 
1616     // FloatAttr elides the type if F64.
1617     if (typeElision == AttrTypeElision::May && attrType.isF64())
1618       return;
1619 
1620   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1621     os << '"';
1622     printEscapedString(strAttr.getValue(), os);
1623     os << '"';
1624 
1625   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1626     os << '[';
1627     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1628       printAttribute(attr, AttrTypeElision::May);
1629     });
1630     os << ']';
1631 
1632   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1633     os << "affine_map<";
1634     affineMapAttr.getValue().print(os);
1635     os << '>';
1636 
1637     // AffineMap always elides the type.
1638     return;
1639 
1640   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1641     os << "affine_set<";
1642     integerSetAttr.getValue().print(os);
1643     os << '>';
1644 
1645     // IntegerSet always elides the type.
1646     return;
1647 
1648   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1649     printType(typeAttr.getValue());
1650 
1651   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1652     printSymbolReference(refAttr.getRootReference(), os);
1653     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1654       os << "::";
1655       printSymbolReference(nestedRef.getValue(), os);
1656     }
1657 
1658   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1659     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1660       printElidedElementsAttr(os);
1661     } else {
1662       os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x"
1663          << llvm::toHex(opaqueAttr.getValue()) << "\">";
1664     }
1665 
1666   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1667     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1668       printElidedElementsAttr(os);
1669     } else {
1670       os << "dense<";
1671       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1672       os << '>';
1673     }
1674 
1675   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1676     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1677       printElidedElementsAttr(os);
1678     } else {
1679       os << "dense<";
1680       printDenseStringElementsAttr(strEltAttr);
1681       os << '>';
1682     }
1683 
1684   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1685     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1686         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1687       printElidedElementsAttr(os);
1688     } else {
1689       os << "sparse<";
1690       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1691       if (indices.getNumElements() != 0) {
1692         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1693         os << ", ";
1694         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1695       }
1696       os << '>';
1697     }
1698 
1699   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1700     printLocation(locAttr);
1701 
1702   } else {
1703     return printDialectAttribute(attr);
1704   }
1705 
1706   // Don't print the type if we must elide it, or if it is a None type.
1707   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1708     os << " : ";
1709     printType(attrType);
1710   }
1711 }
1712 
1713 /// Print the integer element of a DenseElementsAttr.
1714 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1715                                  bool isSigned) {
1716   if (value.getBitWidth() == 1)
1717     os << (value.getBoolValue() ? "true" : "false");
1718   else
1719     value.print(os, isSigned);
1720 }
1721 
1722 static void
1723 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1724                            function_ref<void(unsigned)> printEltFn) {
1725   // Special case for 0-d and splat tensors.
1726   if (isSplat)
1727     return printEltFn(0);
1728 
1729   // Special case for degenerate tensors.
1730   auto numElements = type.getNumElements();
1731   if (numElements == 0)
1732     return;
1733 
1734   // We use a mixed-radix counter to iterate through the shape. When we bump a
1735   // non-least-significant digit, we emit a close bracket. When we next emit an
1736   // element we re-open all closed brackets.
1737 
1738   // The mixed-radix counter, with radices in 'shape'.
1739   int64_t rank = type.getRank();
1740   SmallVector<unsigned, 4> counter(rank, 0);
1741   // The number of brackets that have been opened and not closed.
1742   unsigned openBrackets = 0;
1743 
1744   auto shape = type.getShape();
1745   auto bumpCounter = [&] {
1746     // Bump the least significant digit.
1747     ++counter[rank - 1];
1748     // Iterate backwards bubbling back the increment.
1749     for (unsigned i = rank - 1; i > 0; --i)
1750       if (counter[i] >= shape[i]) {
1751         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1752         counter[i] = 0;
1753         ++counter[i - 1];
1754         --openBrackets;
1755         os << ']';
1756       }
1757   };
1758 
1759   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1760     if (idx != 0)
1761       os << ", ";
1762     while (openBrackets++ < rank)
1763       os << '[';
1764     openBrackets = rank;
1765     printEltFn(idx);
1766     bumpCounter();
1767   }
1768   while (openBrackets-- > 0)
1769     os << ']';
1770 }
1771 
1772 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
1773                                            bool allowHex) {
1774   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1775     return printDenseStringElementsAttr(stringAttr);
1776 
1777   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1778                                 allowHex);
1779 }
1780 
1781 void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1782                                                   bool allowHex) {
1783   auto type = attr.getType();
1784   auto elementType = type.getElementType();
1785 
1786   // Check to see if we should format this attribute as a hex string.
1787   auto numElements = type.getNumElements();
1788   if (!attr.isSplat() && allowHex &&
1789       shouldPrintElementsAttrWithHex(numElements)) {
1790     ArrayRef<char> rawData = attr.getRawData();
1791     if (llvm::support::endian::system_endianness() ==
1792         llvm::support::endianness::big) {
1793       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1794       // machines. It is converted here to print in LE format.
1795       SmallVector<char, 64> outDataVec(rawData.size());
1796       MutableArrayRef<char> convRawData(outDataVec);
1797       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1798           rawData, convRawData, type);
1799       os << '"' << "0x"
1800          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1801          << "\"";
1802     } else {
1803       os << '"' << "0x"
1804          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1805     }
1806 
1807     return;
1808   }
1809 
1810   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1811     Type complexElementType = complexTy.getElementType();
1812     // Note: The if and else below had a common lambda function which invoked
1813     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1814     // and hence was replaced.
1815     if (complexElementType.isa<IntegerType>()) {
1816       bool isSigned = !complexElementType.isUnsignedInteger();
1817       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1818         auto complexValue = *(attr.getComplexIntValues().begin() + index);
1819         os << "(";
1820         printDenseIntElement(complexValue.real(), os, isSigned);
1821         os << ",";
1822         printDenseIntElement(complexValue.imag(), os, isSigned);
1823         os << ")";
1824       });
1825     } else {
1826       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1827         auto complexValue = *(attr.getComplexFloatValues().begin() + index);
1828         os << "(";
1829         printFloatValue(complexValue.real(), os);
1830         os << ",";
1831         printFloatValue(complexValue.imag(), os);
1832         os << ")";
1833       });
1834     }
1835   } else if (elementType.isIntOrIndex()) {
1836     bool isSigned = !elementType.isUnsignedInteger();
1837     auto intValues = attr.getIntValues();
1838     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1839       printDenseIntElement(*(intValues.begin() + index), os, isSigned);
1840     });
1841   } else {
1842     assert(elementType.isa<FloatType>() && "unexpected element type");
1843     auto floatValues = attr.getFloatValues();
1844     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1845       printFloatValue(*(floatValues.begin() + index), os);
1846     });
1847   }
1848 }
1849 
1850 void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
1851   ArrayRef<StringRef> data = attr.getRawStringData();
1852   auto printFn = [&](unsigned index) {
1853     os << "\"";
1854     printEscapedString(data[index], os);
1855     os << "\"";
1856   };
1857   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1858 }
1859 
1860 void ModulePrinter::printType(Type type) {
1861   if (!type) {
1862     os << "<<NULL TYPE>>";
1863     return;
1864   }
1865 
1866   // Try to print an alias for this type.
1867   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1868     return;
1869 
1870   TypeSwitch<Type>(type)
1871       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1872         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1873                            opaqueTy.getTypeData());
1874       })
1875       .Case<IndexType>([&](Type) { os << "index"; })
1876       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1877       .Case<Float16Type>([&](Type) { os << "f16"; })
1878       .Case<Float32Type>([&](Type) { os << "f32"; })
1879       .Case<Float64Type>([&](Type) { os << "f64"; })
1880       .Case<Float80Type>([&](Type) { os << "f80"; })
1881       .Case<Float128Type>([&](Type) { os << "f128"; })
1882       .Case<IntegerType>([&](IntegerType integerTy) {
1883         if (integerTy.isSigned())
1884           os << 's';
1885         else if (integerTy.isUnsigned())
1886           os << 'u';
1887         os << 'i' << integerTy.getWidth();
1888       })
1889       .Case<FunctionType>([&](FunctionType funcTy) {
1890         os << '(';
1891         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1892         os << ") -> ";
1893         ArrayRef<Type> results = funcTy.getResults();
1894         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1895           os << results[0];
1896         } else {
1897           os << '(';
1898           interleaveComma(results, [&](Type ty) { printType(ty); });
1899           os << ')';
1900         }
1901       })
1902       .Case<VectorType>([&](VectorType vectorTy) {
1903         os << "vector<";
1904         for (int64_t dim : vectorTy.getShape())
1905           os << dim << 'x';
1906         os << vectorTy.getElementType() << '>';
1907       })
1908       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
1909         os << "tensor<";
1910         for (int64_t dim : tensorTy.getShape()) {
1911           if (ShapedType::isDynamic(dim))
1912             os << '?';
1913           else
1914             os << dim;
1915           os << 'x';
1916         }
1917         os << tensorTy.getElementType();
1918         // Only print the encoding attribute value if set.
1919         if (tensorTy.getEncoding()) {
1920           os << ", ";
1921           printAttribute(tensorTy.getEncoding());
1922         }
1923         os << '>';
1924       })
1925       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
1926         os << "tensor<*x";
1927         printType(tensorTy.getElementType());
1928         os << '>';
1929       })
1930       .Case<MemRefType>([&](MemRefType memrefTy) {
1931         os << "memref<";
1932         for (int64_t dim : memrefTy.getShape()) {
1933           if (ShapedType::isDynamic(dim))
1934             os << '?';
1935           else
1936             os << dim;
1937           os << 'x';
1938         }
1939         printType(memrefTy.getElementType());
1940         for (auto map : memrefTy.getAffineMaps()) {
1941           os << ", ";
1942           printAttribute(AffineMapAttr::get(map));
1943         }
1944         // Only print the memory space if it is the non-default one.
1945         if (memrefTy.getMemorySpace()) {
1946           os << ", ";
1947           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
1948         }
1949         os << '>';
1950       })
1951       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
1952         os << "memref<*x";
1953         printType(memrefTy.getElementType());
1954         // Only print the memory space if it is the non-default one.
1955         if (memrefTy.getMemorySpace()) {
1956           os << ", ";
1957           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
1958         }
1959         os << '>';
1960       })
1961       .Case<ComplexType>([&](ComplexType complexTy) {
1962         os << "complex<";
1963         printType(complexTy.getElementType());
1964         os << '>';
1965       })
1966       .Case<TupleType>([&](TupleType tupleTy) {
1967         os << "tuple<";
1968         interleaveComma(tupleTy.getTypes(),
1969                         [&](Type type) { printType(type); });
1970         os << '>';
1971       })
1972       .Case<NoneType>([&](Type) { os << "none"; })
1973       .Default([&](Type type) { return printDialectType(type); });
1974 }
1975 
1976 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1977                                           ArrayRef<StringRef> elidedAttrs,
1978                                           bool withKeyword) {
1979   // If there are no attributes, then there is nothing to be done.
1980   if (attrs.empty())
1981     return;
1982 
1983   // Functor used to print a filtered attribute list.
1984   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
1985     // Print the 'attributes' keyword if necessary.
1986     if (withKeyword)
1987       os << " attributes";
1988 
1989     // Otherwise, print them all out in braces.
1990     os << " {";
1991     interleaveComma(filteredAttrs,
1992                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1993     os << '}';
1994   };
1995 
1996   // If no attributes are elided, we can directly print with no filtering.
1997   if (elidedAttrs.empty())
1998     return printFilteredAttributesFn(attrs);
1999 
2000   // Otherwise, filter out any attributes that shouldn't be included.
2001   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2002                                                 elidedAttrs.end());
2003   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2004     return !elidedAttrsSet.contains(attr.first.strref());
2005   });
2006   if (!filteredAttrs.empty())
2007     printFilteredAttributesFn(filteredAttrs);
2008 }
2009 
2010 void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
2011   if (isBareIdentifier(attr.first)) {
2012     os << attr.first;
2013   } else {
2014     os << '"';
2015     printEscapedString(attr.first.strref(), os);
2016     os << '"';
2017   }
2018 
2019   // Pretty printing elides the attribute value for unit attributes.
2020   if (attr.second.isa<UnitAttr>())
2021     return;
2022 
2023   os << " = ";
2024   printAttribute(attr.second);
2025 }
2026 
2027 //===----------------------------------------------------------------------===//
2028 // CustomDialectAsmPrinter
2029 //===----------------------------------------------------------------------===//
2030 
2031 namespace {
2032 /// This class provides the main specialization of the DialectAsmPrinter that is
2033 /// used to provide support for print attributes and types. This hooks allows
2034 /// for dialects to hook into the main ModulePrinter.
2035 struct CustomDialectAsmPrinter : public DialectAsmPrinter {
2036 public:
2037   CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
2038   ~CustomDialectAsmPrinter() override {}
2039 
2040   raw_ostream &getStream() const override { return printer.getStream(); }
2041 
2042   /// Print the given attribute to the stream.
2043   void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
2044 
2045   /// Print the given floating point value in a stablized form.
2046   void printFloat(const APFloat &value) override {
2047     printFloatValue(value, getStream());
2048   }
2049 
2050   /// Print the given type to the stream.
2051   void printType(Type type) override { printer.printType(type); }
2052 
2053   /// The main module printer.
2054   ModulePrinter &printer;
2055 };
2056 } // end anonymous namespace
2057 
2058 void ModulePrinter::printDialectAttribute(Attribute attr) {
2059   auto &dialect = attr.getDialect();
2060 
2061   // Ask the dialect to serialize the attribute to a string.
2062   std::string attrName;
2063   {
2064     llvm::raw_string_ostream attrNameStr(attrName);
2065     ModulePrinter subPrinter(attrNameStr, printerFlags, state);
2066     CustomDialectAsmPrinter printer(subPrinter);
2067     dialect.printAttribute(attr, printer);
2068   }
2069   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2070 }
2071 
2072 void ModulePrinter::printDialectType(Type type) {
2073   auto &dialect = type.getDialect();
2074 
2075   // Ask the dialect to serialize the type to a string.
2076   std::string typeName;
2077   {
2078     llvm::raw_string_ostream typeNameStr(typeName);
2079     ModulePrinter subPrinter(typeNameStr, printerFlags, state);
2080     CustomDialectAsmPrinter printer(subPrinter);
2081     dialect.printType(type, printer);
2082   }
2083   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2084 }
2085 
2086 //===----------------------------------------------------------------------===//
2087 // Affine expressions and maps
2088 //===----------------------------------------------------------------------===//
2089 
2090 void ModulePrinter::printAffineExpr(
2091     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2092   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2093 }
2094 
2095 void ModulePrinter::printAffineExprInternal(
2096     AffineExpr expr, BindingStrength enclosingTightness,
2097     function_ref<void(unsigned, bool)> printValueName) {
2098   const char *binopSpelling = nullptr;
2099   switch (expr.getKind()) {
2100   case AffineExprKind::SymbolId: {
2101     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2102     if (printValueName)
2103       printValueName(pos, /*isSymbol=*/true);
2104     else
2105       os << 's' << pos;
2106     return;
2107   }
2108   case AffineExprKind::DimId: {
2109     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2110     if (printValueName)
2111       printValueName(pos, /*isSymbol=*/false);
2112     else
2113       os << 'd' << pos;
2114     return;
2115   }
2116   case AffineExprKind::Constant:
2117     os << expr.cast<AffineConstantExpr>().getValue();
2118     return;
2119   case AffineExprKind::Add:
2120     binopSpelling = " + ";
2121     break;
2122   case AffineExprKind::Mul:
2123     binopSpelling = " * ";
2124     break;
2125   case AffineExprKind::FloorDiv:
2126     binopSpelling = " floordiv ";
2127     break;
2128   case AffineExprKind::CeilDiv:
2129     binopSpelling = " ceildiv ";
2130     break;
2131   case AffineExprKind::Mod:
2132     binopSpelling = " mod ";
2133     break;
2134   }
2135 
2136   auto binOp = expr.cast<AffineBinaryOpExpr>();
2137   AffineExpr lhsExpr = binOp.getLHS();
2138   AffineExpr rhsExpr = binOp.getRHS();
2139 
2140   // Handle tightly binding binary operators.
2141   if (binOp.getKind() != AffineExprKind::Add) {
2142     if (enclosingTightness == BindingStrength::Strong)
2143       os << '(';
2144 
2145     // Pretty print multiplication with -1.
2146     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2147     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2148         rhsConst.getValue() == -1) {
2149       os << "-";
2150       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2151       if (enclosingTightness == BindingStrength::Strong)
2152         os << ')';
2153       return;
2154     }
2155 
2156     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2157 
2158     os << binopSpelling;
2159     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2160 
2161     if (enclosingTightness == BindingStrength::Strong)
2162       os << ')';
2163     return;
2164   }
2165 
2166   // Print out special "pretty" forms for add.
2167   if (enclosingTightness == BindingStrength::Strong)
2168     os << '(';
2169 
2170   // Pretty print addition to a product that has a negative operand as a
2171   // subtraction.
2172   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2173     if (rhs.getKind() == AffineExprKind::Mul) {
2174       AffineExpr rrhsExpr = rhs.getRHS();
2175       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2176         if (rrhs.getValue() == -1) {
2177           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2178                                   printValueName);
2179           os << " - ";
2180           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2181             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2182                                     printValueName);
2183           } else {
2184             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2185                                     printValueName);
2186           }
2187 
2188           if (enclosingTightness == BindingStrength::Strong)
2189             os << ')';
2190           return;
2191         }
2192 
2193         if (rrhs.getValue() < -1) {
2194           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2195                                   printValueName);
2196           os << " - ";
2197           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2198                                   printValueName);
2199           os << " * " << -rrhs.getValue();
2200           if (enclosingTightness == BindingStrength::Strong)
2201             os << ')';
2202           return;
2203         }
2204       }
2205     }
2206   }
2207 
2208   // Pretty print addition to a negative number as a subtraction.
2209   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2210     if (rhsConst.getValue() < 0) {
2211       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2212       os << " - " << -rhsConst.getValue();
2213       if (enclosingTightness == BindingStrength::Strong)
2214         os << ')';
2215       return;
2216     }
2217   }
2218 
2219   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2220 
2221   os << " + ";
2222   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2223 
2224   if (enclosingTightness == BindingStrength::Strong)
2225     os << ')';
2226 }
2227 
2228 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
2229   printAffineExprInternal(expr, BindingStrength::Weak);
2230   isEq ? os << " == 0" : os << " >= 0";
2231 }
2232 
2233 void ModulePrinter::printAffineMap(AffineMap map) {
2234   // Dimension identifiers.
2235   os << '(';
2236   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2237     os << 'd' << i << ", ";
2238   if (map.getNumDims() >= 1)
2239     os << 'd' << map.getNumDims() - 1;
2240   os << ')';
2241 
2242   // Symbolic identifiers.
2243   if (map.getNumSymbols() != 0) {
2244     os << '[';
2245     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2246       os << 's' << i << ", ";
2247     if (map.getNumSymbols() >= 1)
2248       os << 's' << map.getNumSymbols() - 1;
2249     os << ']';
2250   }
2251 
2252   // Result affine expressions.
2253   os << " -> (";
2254   interleaveComma(map.getResults(),
2255                   [&](AffineExpr expr) { printAffineExpr(expr); });
2256   os << ')';
2257 }
2258 
2259 void ModulePrinter::printIntegerSet(IntegerSet set) {
2260   // Dimension identifiers.
2261   os << '(';
2262   for (unsigned i = 1; i < set.getNumDims(); ++i)
2263     os << 'd' << i - 1 << ", ";
2264   if (set.getNumDims() >= 1)
2265     os << 'd' << set.getNumDims() - 1;
2266   os << ')';
2267 
2268   // Symbolic identifiers.
2269   if (set.getNumSymbols() != 0) {
2270     os << '[';
2271     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2272       os << 's' << i << ", ";
2273     if (set.getNumSymbols() >= 1)
2274       os << 's' << set.getNumSymbols() - 1;
2275     os << ']';
2276   }
2277 
2278   // Print constraints.
2279   os << " : (";
2280   int numConstraints = set.getNumConstraints();
2281   for (int i = 1; i < numConstraints; ++i) {
2282     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2283     os << ", ";
2284   }
2285   if (numConstraints >= 1)
2286     printAffineConstraint(set.getConstraint(numConstraints - 1),
2287                           set.isEq(numConstraints - 1));
2288   os << ')';
2289 }
2290 
2291 //===----------------------------------------------------------------------===//
2292 // OperationPrinter
2293 //===----------------------------------------------------------------------===//
2294 
2295 namespace {
2296 /// This class contains the logic for printing operations, regions, and blocks.
2297 class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
2298 public:
2299   explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2300                             AsmStateImpl &state)
2301       : ModulePrinter(os, flags, &state) {}
2302 
2303   /// Print the given top-level operation.
2304   void printTopLevelOperation(Operation *op);
2305 
2306   /// Print the given operation with its indent and location.
2307   void print(Operation *op);
2308   /// Print the bare location, not including indentation/location/etc.
2309   void printOperation(Operation *op);
2310   /// Print the given operation in the generic form.
2311   void printGenericOp(Operation *op) override;
2312 
2313   /// Print the name of the given block.
2314   void printBlockName(Block *block);
2315 
2316   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2317   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2318   /// operation of the block is not printed.
2319   void print(Block *block, bool printBlockArgs = true,
2320              bool printBlockTerminator = true);
2321 
2322   /// Print the ID of the given value, optionally with its result number.
2323   void printValueID(Value value, bool printResultNo = true,
2324                     raw_ostream *streamOverride = nullptr) const;
2325 
2326   //===--------------------------------------------------------------------===//
2327   // OpAsmPrinter methods
2328   //===--------------------------------------------------------------------===//
2329 
2330   /// Return the current stream of the printer.
2331   raw_ostream &getStream() const override { return os; }
2332 
2333   /// Print a newline and indent the printer to the start of the current
2334   /// operation.
2335   void printNewline() override {
2336     os << newLine;
2337     os.indent(currentIndent);
2338   }
2339 
2340   /// Print the given type.
2341   void printType(Type type) override { ModulePrinter::printType(type); }
2342 
2343   /// Print the given attribute.
2344   void printAttribute(Attribute attr) override {
2345     ModulePrinter::printAttribute(attr);
2346   }
2347 
2348   /// Print the given attribute without its type. The corresponding parser must
2349   /// provide a valid type for the attribute.
2350   void printAttributeWithoutType(Attribute attr) override {
2351     ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
2352   }
2353 
2354   /// Print a block argument in the usual format of:
2355   ///   %ssaName : type {attr1=42} loc("here")
2356   /// where location printing is controlled by the standard internal option.
2357   /// You may pass omitType=true to not print a type, and pass an empty
2358   /// attribute list if you don't care for attributes.
2359   void printRegionArgument(BlockArgument arg,
2360                            ArrayRef<NamedAttribute> argAttrs = {},
2361                            bool omitType = false) override;
2362 
2363   /// Print the ID for the given value.
2364   void printOperand(Value value) override { printValueID(value); }
2365   void printOperand(Value value, raw_ostream &os) override {
2366     printValueID(value, /*printResultNo=*/true, &os);
2367   }
2368 
2369   /// Print an optional attribute dictionary with a given set of elided values.
2370   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2371                              ArrayRef<StringRef> elidedAttrs = {}) override {
2372     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
2373   }
2374   void printOptionalAttrDictWithKeyword(
2375       ArrayRef<NamedAttribute> attrs,
2376       ArrayRef<StringRef> elidedAttrs = {}) override {
2377     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
2378                                          /*withKeyword=*/true);
2379   }
2380 
2381   /// Print the given successor.
2382   void printSuccessor(Block *successor) override;
2383 
2384   /// Print an operation successor with the operands used for the block
2385   /// arguments.
2386   void printSuccessorAndUseList(Block *successor,
2387                                 ValueRange succOperands) override;
2388 
2389   /// Print the given region.
2390   void printRegion(Region &region, bool printEntryBlockArgs,
2391                    bool printBlockTerminators, bool printEmptyBlock) override;
2392 
2393   /// Renumber the arguments for the specified region to the same names as the
2394   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2395   /// operations. If any entry in namesToUse is null, the corresponding
2396   /// argument name is left alone.
2397   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2398     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2399   }
2400 
2401   /// Print the given affine map with the symbol and dimension operands printed
2402   /// inline with the map.
2403   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2404                               ValueRange operands) override;
2405 
2406   /// Print the given affine expression with the symbol and dimension operands
2407   /// printed inline with the expression.
2408   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2409                                ValueRange symOperands) override;
2410 
2411   /// Print the given string as a symbol reference.
2412   void printSymbolName(StringRef symbolRef) override {
2413     ::printSymbolReference(symbolRef, os);
2414   }
2415 
2416 private:
2417   /// The number of spaces used for indenting nested operations.
2418   const static unsigned indentWidth = 2;
2419 
2420   // This is the current indentation level for nested structures.
2421   unsigned currentIndent = 0;
2422 };
2423 } // end anonymous namespace
2424 
2425 void OperationPrinter::printTopLevelOperation(Operation *op) {
2426   // Output the aliases at the top level that can't be deferred.
2427   state->getAliasState().printNonDeferredAliases(os, newLine);
2428 
2429   // Print the module.
2430   print(op);
2431   os << newLine;
2432 
2433   // Output the aliases at the top level that can be deferred.
2434   state->getAliasState().printDeferredAliases(os, newLine);
2435 }
2436 
2437 /// Print a block argument in the usual format of:
2438 ///   %ssaName : type {attr1=42} loc("here")
2439 /// where location printing is controlled by the standard internal option.
2440 /// You may pass omitType=true to not print a type, and pass an empty
2441 /// attribute list if you don't care for attributes.
2442 void OperationPrinter::printRegionArgument(BlockArgument arg,
2443                                            ArrayRef<NamedAttribute> argAttrs,
2444                                            bool omitType) {
2445   printOperand(arg);
2446   if (!omitType) {
2447     os << ": ";
2448     printType(arg.getType());
2449   }
2450   printOptionalAttrDict(argAttrs);
2451   // TODO: We should allow location aliases on block arguments.
2452   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2453 }
2454 
2455 void OperationPrinter::print(Operation *op) {
2456   // Track the location of this operation.
2457   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2458 
2459   os.indent(currentIndent);
2460   printOperation(op);
2461   printTrailingLocation(op->getLoc());
2462 }
2463 
2464 void OperationPrinter::printOperation(Operation *op) {
2465   if (size_t numResults = op->getNumResults()) {
2466     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2467       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2468       if (resultCount > 1)
2469         os << ':' << resultCount;
2470     };
2471 
2472     // Check to see if this operation has multiple result groups.
2473     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2474     if (!resultGroups.empty()) {
2475       // Interleave the groups excluding the last one, this one will be handled
2476       // separately.
2477       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2478         printResultGroup(resultGroups[i],
2479                          resultGroups[i + 1] - resultGroups[i]);
2480       });
2481       os << ", ";
2482       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2483 
2484     } else {
2485       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2486     }
2487 
2488     os << " = ";
2489   }
2490 
2491   // If requested, always print the generic form.
2492   if (!printerFlags.shouldPrintGenericOpForm()) {
2493     // Check to see if this is a known operation.  If so, use the registered
2494     // custom printer hook.
2495     if (auto *opInfo = op->getAbstractOperation()) {
2496       opInfo->printAssembly(op, *this);
2497       return;
2498     }
2499     // Otherwise try to dispatch to the dialect, if available.
2500     if (Dialect *dialect = op->getDialect()) {
2501       if (succeeded(dialect->printOperation(op, *this)))
2502         return;
2503     }
2504   }
2505 
2506   // Otherwise print with the generic assembly form.
2507   printGenericOp(op);
2508 }
2509 
2510 void OperationPrinter::printGenericOp(Operation *op) {
2511   os << '"';
2512   printEscapedString(op->getName().getStringRef(), os);
2513   os << "\"(";
2514   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2515   os << ')';
2516 
2517   // For terminators, print the list of successors and their operands.
2518   if (op->getNumSuccessors() != 0) {
2519     os << '[';
2520     interleaveComma(op->getSuccessors(),
2521                     [&](Block *successor) { printBlockName(successor); });
2522     os << ']';
2523   }
2524 
2525   // Print regions.
2526   if (op->getNumRegions() != 0) {
2527     os << " (";
2528     interleaveComma(op->getRegions(), [&](Region &region) {
2529       printRegion(region, /*printEntryBlockArgs=*/true,
2530                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2531     });
2532     os << ')';
2533   }
2534 
2535   auto attrs = op->getAttrs();
2536   printOptionalAttrDict(attrs);
2537 
2538   // Print the type signature of the operation.
2539   os << " : ";
2540   printFunctionalType(op);
2541 }
2542 
2543 void OperationPrinter::printBlockName(Block *block) {
2544   auto id = state->getSSANameState().getBlockID(block);
2545   if (id != SSANameState::NameSentinel)
2546     os << "^bb" << id;
2547   else
2548     os << "^INVALIDBLOCK";
2549 }
2550 
2551 void OperationPrinter::print(Block *block, bool printBlockArgs,
2552                              bool printBlockTerminator) {
2553   // Print the block label and argument list if requested.
2554   if (printBlockArgs) {
2555     os.indent(currentIndent);
2556     printBlockName(block);
2557 
2558     // Print the argument list if non-empty.
2559     if (!block->args_empty()) {
2560       os << '(';
2561       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2562         printValueID(arg);
2563         os << ": ";
2564         printType(arg.getType());
2565         // TODO: We should allow location aliases on block arguments.
2566         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2567       });
2568       os << ')';
2569     }
2570     os << ':';
2571 
2572     // Print out some context information about the predecessors of this block.
2573     if (!block->getParent()) {
2574       os << "  // block is not in a region!";
2575     } else if (block->hasNoPredecessors()) {
2576       os << "  // no predecessors";
2577     } else if (auto *pred = block->getSinglePredecessor()) {
2578       os << "  // pred: ";
2579       printBlockName(pred);
2580     } else {
2581       // We want to print the predecessors in increasing numeric order, not in
2582       // whatever order the use-list is in, so gather and sort them.
2583       SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2584       for (auto *pred : block->getPredecessors())
2585         predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2586       llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2587 
2588       os << "  // " << predIDs.size() << " preds: ";
2589 
2590       interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2591         printBlockName(pred.second);
2592       });
2593     }
2594     os << newLine;
2595   }
2596 
2597   currentIndent += indentWidth;
2598   bool hasTerminator =
2599       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
2600   auto range = llvm::make_range(
2601       block->begin(),
2602       std::prev(block->end(),
2603                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
2604   for (auto &op : range) {
2605     print(&op);
2606     os << newLine;
2607   }
2608   currentIndent -= indentWidth;
2609 }
2610 
2611 void OperationPrinter::printValueID(Value value, bool printResultNo,
2612                                     raw_ostream *streamOverride) const {
2613   state->getSSANameState().printValueID(value, printResultNo,
2614                                         streamOverride ? *streamOverride : os);
2615 }
2616 
2617 void OperationPrinter::printSuccessor(Block *successor) {
2618   printBlockName(successor);
2619 }
2620 
2621 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2622                                                 ValueRange succOperands) {
2623   printBlockName(successor);
2624   if (succOperands.empty())
2625     return;
2626 
2627   os << '(';
2628   interleaveComma(succOperands,
2629                   [this](Value operand) { printValueID(operand); });
2630   os << " : ";
2631   interleaveComma(succOperands,
2632                   [this](Value operand) { printType(operand.getType()); });
2633   os << ')';
2634 }
2635 
2636 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2637                                    bool printBlockTerminators,
2638                                    bool printEmptyBlock) {
2639   os << " {" << newLine;
2640   if (!region.empty()) {
2641     auto *entryBlock = &region.front();
2642     // Force printing the block header if printEmptyBlock is set and the block
2643     // is empty or if printEntryBlockArgs is set and there are arguments to
2644     // print.
2645     bool shouldAlwaysPrintBlockHeader =
2646         (printEmptyBlock && entryBlock->empty()) ||
2647         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
2648     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
2649     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2650       print(&b);
2651   }
2652   os.indent(currentIndent) << "}";
2653 }
2654 
2655 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2656                                               ValueRange operands) {
2657   AffineMap map = mapAttr.getValue();
2658   unsigned numDims = map.getNumDims();
2659   auto printValueName = [&](unsigned pos, bool isSymbol) {
2660     unsigned index = isSymbol ? numDims + pos : pos;
2661     assert(index < operands.size());
2662     if (isSymbol)
2663       os << "symbol(";
2664     printValueID(operands[index]);
2665     if (isSymbol)
2666       os << ')';
2667   };
2668 
2669   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2670     printAffineExpr(expr, printValueName);
2671   });
2672 }
2673 
2674 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
2675                                                ValueRange dimOperands,
2676                                                ValueRange symOperands) {
2677   auto printValueName = [&](unsigned pos, bool isSymbol) {
2678     if (!isSymbol)
2679       return printValueID(dimOperands[pos]);
2680     os << "symbol(";
2681     printValueID(symOperands[pos]);
2682     os << ')';
2683   };
2684   printAffineExpr(expr, printValueName);
2685 }
2686 
2687 //===----------------------------------------------------------------------===//
2688 // print and dump methods
2689 //===----------------------------------------------------------------------===//
2690 
2691 void Attribute::print(raw_ostream &os) const {
2692   ModulePrinter(os).printAttribute(*this);
2693 }
2694 
2695 void Attribute::dump() const {
2696   print(llvm::errs());
2697   llvm::errs() << "\n";
2698 }
2699 
2700 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
2701 
2702 void Type::dump() { print(llvm::errs()); }
2703 
2704 void AffineMap::dump() const {
2705   print(llvm::errs());
2706   llvm::errs() << "\n";
2707 }
2708 
2709 void IntegerSet::dump() const {
2710   print(llvm::errs());
2711   llvm::errs() << "\n";
2712 }
2713 
2714 void AffineExpr::print(raw_ostream &os) const {
2715   if (!expr) {
2716     os << "<<NULL AFFINE EXPR>>";
2717     return;
2718   }
2719   ModulePrinter(os).printAffineExpr(*this);
2720 }
2721 
2722 void AffineExpr::dump() const {
2723   print(llvm::errs());
2724   llvm::errs() << "\n";
2725 }
2726 
2727 void AffineMap::print(raw_ostream &os) const {
2728   if (!map) {
2729     os << "<<NULL AFFINE MAP>>";
2730     return;
2731   }
2732   ModulePrinter(os).printAffineMap(*this);
2733 }
2734 
2735 void IntegerSet::print(raw_ostream &os) const {
2736   ModulePrinter(os).printIntegerSet(*this);
2737 }
2738 
2739 void Value::print(raw_ostream &os) {
2740   if (auto *op = getDefiningOp())
2741     return op->print(os);
2742   // TODO: Improve BlockArgument print'ing.
2743   BlockArgument arg = this->cast<BlockArgument>();
2744   os << "<block argument> of type '" << arg.getType()
2745      << "' at index: " << arg.getArgNumber() << '\n';
2746 }
2747 void Value::print(raw_ostream &os, AsmState &state) {
2748   if (auto *op = getDefiningOp())
2749     return op->print(os, state);
2750 
2751   // TODO: Improve BlockArgument print'ing.
2752   BlockArgument arg = this->cast<BlockArgument>();
2753   os << "<block argument> of type '" << arg.getType()
2754      << "' at index: " << arg.getArgNumber() << '\n';
2755 }
2756 
2757 void Value::dump() {
2758   print(llvm::errs());
2759   llvm::errs() << "\n";
2760 }
2761 
2762 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2763   // TODO: This doesn't necessarily capture all potential cases.
2764   // Currently, region arguments can be shadowed when printing the main
2765   // operation. If the IR hasn't been printed, this will produce the old SSA
2766   // name and not the shadowed name.
2767   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2768                                                  os);
2769 }
2770 
2771 void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
2772   // If this is a top level operation, we also print aliases.
2773   if (!getParent() && !flags.shouldUseLocalScope()) {
2774     AsmState state(this);
2775     state.getImpl().initializeAliases(this, flags);
2776     print(os, state, flags);
2777     return;
2778   }
2779 
2780   // Find the operation to number from based upon the provided flags.
2781   Operation *op = this;
2782   bool shouldUseLocalScope = flags.shouldUseLocalScope();
2783   do {
2784     // If we are printing local scope, stop at the first operation that is
2785     // isolated from above.
2786     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
2787       break;
2788 
2789     // Otherwise, traverse up to the next parent.
2790     Operation *parentOp = op->getParentOp();
2791     if (!parentOp)
2792       break;
2793     op = parentOp;
2794   } while (true);
2795 
2796   AsmState state(op);
2797   print(os, state, flags);
2798 }
2799 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2800   OperationPrinter printer(os, flags, state.getImpl());
2801   if (!getParent() && !flags.shouldUseLocalScope())
2802     printer.printTopLevelOperation(this);
2803   else
2804     printer.print(this);
2805 }
2806 
2807 void Operation::dump() {
2808   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2809   llvm::errs() << "\n";
2810 }
2811 
2812 void Block::print(raw_ostream &os) {
2813   Operation *parentOp = getParentOp();
2814   if (!parentOp) {
2815     os << "<<UNLINKED BLOCK>>\n";
2816     return;
2817   }
2818   // Get the top-level op.
2819   while (auto *nextOp = parentOp->getParentOp())
2820     parentOp = nextOp;
2821 
2822   AsmState state(parentOp);
2823   print(os, state);
2824 }
2825 void Block::print(raw_ostream &os, AsmState &state) {
2826   OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2827 }
2828 
2829 void Block::dump() { print(llvm::errs()); }
2830 
2831 /// Print out the name of the block without printing its body.
2832 void Block::printAsOperand(raw_ostream &os, bool printType) {
2833   Operation *parentOp = getParentOp();
2834   if (!parentOp) {
2835     os << "<<UNLINKED BLOCK>>\n";
2836     return;
2837   }
2838   AsmState state(parentOp);
2839   printAsOperand(os, state);
2840 }
2841 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2842   OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2843   printer.printBlockName(this);
2844 }
2845