1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/SubElementInterfaces.h"
28 #include "mlir/IR/Verifier.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/MapVector.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/ScopedHashTable.h"
35 #include "llvm/ADT/SetVector.h"
36 #include "llvm/ADT/SmallString.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringSet.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Endian.h"
42 #include "llvm/Support/Regex.h"
43 #include "llvm/Support/SaveAndRestore.h"
44 #include "llvm/Support/Threading.h"
45 
46 #include <tuple>
47 
48 using namespace mlir;
49 using namespace mlir::detail;
50 
51 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
52 
53 void OperationName::dump() const { print(llvm::errs()); }
54 
55 //===--------------------------------------------------------------------===//
56 // AsmParser
57 //===--------------------------------------------------------------------===//
58 
59 AsmParser::~AsmParser() = default;
60 DialectAsmParser::~DialectAsmParser() = default;
61 OpAsmParser::~OpAsmParser() = default;
62 
63 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
64 
65 //===----------------------------------------------------------------------===//
66 // DialectAsmPrinter
67 //===----------------------------------------------------------------------===//
68 
69 DialectAsmPrinter::~DialectAsmPrinter() = default;
70 
71 //===----------------------------------------------------------------------===//
72 // OpAsmPrinter
73 //===----------------------------------------------------------------------===//
74 
75 OpAsmPrinter::~OpAsmPrinter() = default;
76 
77 void OpAsmPrinter::printFunctionalType(Operation *op) {
78   auto &os = getStream();
79   os << '(';
80   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
81     // Print the types of null values as <<NULL TYPE>>.
82     *this << (operand ? operand.getType() : Type());
83   });
84   os << ") -> ";
85 
86   // Print the result list.  We don't parenthesize single result types unless
87   // it is a function (avoiding a grammar ambiguity).
88   bool wrapped = op->getNumResults() != 1;
89   if (!wrapped && op->getResult(0).getType() &&
90       op->getResult(0).getType().isa<FunctionType>())
91     wrapped = true;
92 
93   if (wrapped)
94     os << '(';
95 
96   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
97     // Print the types of null values as <<NULL TYPE>>.
98     *this << (result ? result.getType() : Type());
99   });
100 
101   if (wrapped)
102     os << ')';
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Operation OpAsm interface.
107 //===----------------------------------------------------------------------===//
108 
109 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
110 #include "mlir/IR/OpAsmInterface.cpp.inc"
111 
112 //===----------------------------------------------------------------------===//
113 // OpPrintingFlags
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 /// This struct contains command line options that can be used to initialize
118 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
119 /// for global command line options.
120 struct AsmPrinterOptions {
121   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
122       "mlir-print-elementsattrs-with-hex-if-larger",
123       llvm::cl::desc(
124           "Print DenseElementsAttrs with a hex string that have "
125           "more elements than the given upper limit (use -1 to disable)")};
126 
127   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
128       "mlir-elide-elementsattrs-if-larger",
129       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
130                      "more elements than the given upper limit")};
131 
132   llvm::cl::opt<bool> printDebugInfoOpt{
133       "mlir-print-debuginfo", llvm::cl::init(false),
134       llvm::cl::desc("Print debug info in MLIR output")};
135 
136   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
137       "mlir-pretty-debuginfo", llvm::cl::init(false),
138       llvm::cl::desc("Print pretty debug info in MLIR output")};
139 
140   // Use the generic op output form in the operation printer even if the custom
141   // form is defined.
142   llvm::cl::opt<bool> printGenericOpFormOpt{
143       "mlir-print-op-generic", llvm::cl::init(false),
144       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
145 
146   llvm::cl::opt<bool> assumeVerifiedOpt{
147       "mlir-print-assume-verified", llvm::cl::init(false),
148       llvm::cl::desc("Skip op verification when using custom printers"),
149       llvm::cl::Hidden};
150 
151   llvm::cl::opt<bool> printLocalScopeOpt{
152       "mlir-print-local-scope", llvm::cl::init(false),
153       llvm::cl::desc("Print with local scope and inline information (eliding "
154                      "aliases for attributes, types, and locations")};
155 };
156 } // namespace
157 
158 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
159 
160 /// Register a set of useful command-line options that can be used to configure
161 /// various flags within the AsmPrinter.
162 void mlir::registerAsmPrinterCLOptions() {
163   // Make sure that the options struct has been initialized.
164   *clOptions;
165 }
166 
167 /// Initialize the printing flags with default supplied by the cl::opts above.
168 OpPrintingFlags::OpPrintingFlags()
169     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
170       printGenericOpFormFlag(false), assumeVerifiedFlag(false),
171       printLocalScope(false) {
172   // Initialize based upon command line options, if they are available.
173   if (!clOptions.isConstructed())
174     return;
175   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
176     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
177   printDebugInfoFlag = clOptions->printDebugInfoOpt;
178   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
179   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
180   assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
181   printLocalScope = clOptions->printLocalScopeOpt;
182 }
183 
184 /// Enable the elision of large elements attributes, by printing a '...'
185 /// instead of the element data, when the number of elements is greater than
186 /// `largeElementLimit`. Note: The IR generated with this option is not
187 /// parsable.
188 OpPrintingFlags &
189 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
190   elementsAttrElementLimit = largeElementLimit;
191   return *this;
192 }
193 
194 /// Enable printing of debug information. If 'prettyForm' is set to true,
195 /// debug information is printed in a more readable 'pretty' form.
196 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
197   printDebugInfoFlag = true;
198   printDebugInfoPrettyFormFlag = prettyForm;
199   return *this;
200 }
201 
202 /// Always print operations in the generic form.
203 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
204   printGenericOpFormFlag = true;
205   return *this;
206 }
207 
208 /// Do not verify the operation when using custom operation printers.
209 OpPrintingFlags &OpPrintingFlags::assumeVerified() {
210   assumeVerifiedFlag = true;
211   return *this;
212 }
213 
214 /// Use local scope when printing the operation. This allows for using the
215 /// printer in a more localized and thread-safe setting, but may not necessarily
216 /// be identical of what the IR will look like when dumping the full module.
217 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
218   printLocalScope = true;
219   return *this;
220 }
221 
222 /// Return if the given ElementsAttr should be elided.
223 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
224   return elementsAttrElementLimit.hasValue() &&
225          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
226          !attr.isa<SplatElementsAttr>();
227 }
228 
229 /// Return the size limit for printing large ElementsAttr.
230 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
231   return elementsAttrElementLimit;
232 }
233 
234 /// Return if debug information should be printed.
235 bool OpPrintingFlags::shouldPrintDebugInfo() const {
236   return printDebugInfoFlag;
237 }
238 
239 /// Return if debug information should be printed in the pretty form.
240 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
241   return printDebugInfoPrettyFormFlag;
242 }
243 
244 /// Return if operations should be printed in the generic form.
245 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
246   return printGenericOpFormFlag;
247 }
248 
249 /// Return if operation verification should be skipped.
250 bool OpPrintingFlags::shouldAssumeVerified() const {
251   return assumeVerifiedFlag;
252 }
253 
254 /// Return if the printer should use local scope when dumping the IR.
255 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
256 
257 /// Returns true if an ElementsAttr with the given number of elements should be
258 /// printed with hex.
259 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
260   // Check to see if a command line option was provided for the limit.
261   if (clOptions.isConstructed()) {
262     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
263       // -1 is used to disable hex printing.
264       if (clOptions->printElementsAttrWithHexIfLarger == -1)
265         return false;
266       return numElements > clOptions->printElementsAttrWithHexIfLarger;
267     }
268   }
269 
270   // Otherwise, default to printing with hex if the number of elements is >100.
271   return numElements > 100;
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // NewLineCounter
276 //===----------------------------------------------------------------------===//
277 
278 namespace {
279 /// This class is a simple formatter that emits a new line when inputted into a
280 /// stream, that enables counting the number of newlines emitted. This class
281 /// should be used whenever emitting newlines in the printer.
282 struct NewLineCounter {
283   unsigned curLine = 1;
284 };
285 
286 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
287   ++newLine.curLine;
288   return os << '\n';
289 }
290 } // namespace
291 
292 //===----------------------------------------------------------------------===//
293 // AliasInitializer
294 //===----------------------------------------------------------------------===//
295 
296 namespace {
297 /// This class represents a specific instance of a symbol Alias.
298 class SymbolAlias {
299 public:
300   SymbolAlias(StringRef name, bool isDeferrable)
301       : name(name), suffixIndex(0), hasSuffixIndex(false),
302         isDeferrable(isDeferrable) {}
303   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
304       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
305         isDeferrable(isDeferrable) {}
306 
307   /// Print this alias to the given stream.
308   void print(raw_ostream &os) const {
309     os << name;
310     if (hasSuffixIndex)
311       os << suffixIndex;
312   }
313 
314   /// Returns true if this alias supports deferred resolution when parsing.
315   bool canBeDeferred() const { return isDeferrable; }
316 
317 private:
318   /// The main name of the alias.
319   StringRef name;
320   /// The optional suffix index of the alias, if multiple aliases had the same
321   /// name.
322   uint32_t suffixIndex : 30;
323   /// A flag indicating whether this alias has a suffix or not.
324   bool hasSuffixIndex : 1;
325   /// A flag indicating whether this alias may be deferred or not.
326   bool isDeferrable : 1;
327 };
328 
329 /// This class represents a utility that initializes the set of attribute and
330 /// type aliases, without the need to store the extra information within the
331 /// main AliasState class or pass it around via function arguments.
332 class AliasInitializer {
333 public:
334   AliasInitializer(
335       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
336       llvm::BumpPtrAllocator &aliasAllocator)
337       : interfaces(interfaces), aliasAllocator(aliasAllocator),
338         aliasOS(aliasBuffer) {}
339 
340   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
341                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
342                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
343 
344   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
345   /// set to true if the originator of this attribute can resolve the alias
346   /// after parsing has completed (e.g. in the case of operation locations).
347   void visit(Attribute attr, bool canBeDeferred = false);
348 
349   /// Visit the given type to see if it has an alias.
350   void visit(Type type);
351 
352 private:
353   /// Try to generate an alias for the provided symbol. If an alias is
354   /// generated, the provided alias mapping and reverse mapping are updated.
355   /// Returns success if an alias was generated, failure otherwise.
356   template <typename T>
357   LogicalResult
358   generateAlias(T symbol,
359                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
360 
361   /// The set of asm interfaces within the context.
362   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
363 
364   /// Mapping between an alias and the set of symbols mapped to it.
365   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
366   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
367 
368   /// An allocator used for alias names.
369   llvm::BumpPtrAllocator &aliasAllocator;
370 
371   /// The set of visited attributes.
372   DenseSet<Attribute> visitedAttributes;
373 
374   /// The set of attributes that have aliases *and* can be deferred.
375   DenseSet<Attribute> deferrableAttributes;
376 
377   /// The set of visited types.
378   DenseSet<Type> visitedTypes;
379 
380   /// Storage and stream used when generating an alias.
381   SmallString<32> aliasBuffer;
382   llvm::raw_svector_ostream aliasOS;
383 };
384 
385 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
386 /// and merely collects the attributes and types that *would* be printed in a
387 /// normal print invocation so that we can generate proper aliases. This allows
388 /// for us to generate aliases only for the attributes and types that would be
389 /// in the output, and trims down unnecessary output.
390 class DummyAliasOperationPrinter : private OpAsmPrinter {
391 public:
392   explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
393                                       AliasInitializer &initializer)
394       : printerFlags(printerFlags), initializer(initializer) {}
395 
396   /// Print the given operation.
397   void print(Operation *op) {
398     // Visit the operation location.
399     if (printerFlags.shouldPrintDebugInfo())
400       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
401 
402     // If requested, always print the generic form.
403     if (!printerFlags.shouldPrintGenericOpForm()) {
404       // Check to see if this is a known operation.  If so, use the registered
405       // custom printer hook.
406       if (auto opInfo = op->getRegisteredInfo()) {
407         opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
408         return;
409       }
410     }
411 
412     // Otherwise print with the generic assembly form.
413     printGenericOp(op);
414   }
415 
416 private:
417   /// Print the given operation in the generic form.
418   void printGenericOp(Operation *op, bool printOpName = true) override {
419     // Consider nested operations for aliases.
420     if (op->getNumRegions() != 0) {
421       for (Region &region : op->getRegions())
422         printRegion(region, /*printEntryBlockArgs=*/true,
423                     /*printBlockTerminators=*/true);
424     }
425 
426     // Visit all the types used in the operation.
427     for (Type type : op->getOperandTypes())
428       printType(type);
429     for (Type type : op->getResultTypes())
430       printType(type);
431 
432     // Consider the attributes of the operation for aliases.
433     for (const NamedAttribute &attr : op->getAttrs())
434       printAttribute(attr.getValue());
435   }
436 
437   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
438   /// block are not printed. If 'printBlockTerminator' is false, the terminator
439   /// operation of the block is not printed.
440   void print(Block *block, bool printBlockArgs = true,
441              bool printBlockTerminator = true) {
442     // Consider the types of the block arguments for aliases if 'printBlockArgs'
443     // is set to true.
444     if (printBlockArgs) {
445       for (BlockArgument arg : block->getArguments()) {
446         printType(arg.getType());
447 
448         // Visit the argument location.
449         if (printerFlags.shouldPrintDebugInfo())
450           // TODO: Allow deferring argument locations.
451           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
452       }
453     }
454 
455     // Consider the operations within this block, ignoring the terminator if
456     // requested.
457     bool hasTerminator =
458         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
459     auto range = llvm::make_range(
460         block->begin(),
461         std::prev(block->end(),
462                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
463     for (Operation &op : range)
464       print(&op);
465   }
466 
467   /// Print the given region.
468   void printRegion(Region &region, bool printEntryBlockArgs,
469                    bool printBlockTerminators,
470                    bool printEmptyBlock = false) override {
471     if (region.empty())
472       return;
473 
474     auto *entryBlock = &region.front();
475     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
476     for (Block &b : llvm::drop_begin(region, 1))
477       print(&b);
478   }
479 
480   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
481                            bool omitType) override {
482     printType(arg.getType());
483     // Visit the argument location.
484     if (printerFlags.shouldPrintDebugInfo())
485       // TODO: Allow deferring argument locations.
486       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
487   }
488 
489   /// Consider the given type to be printed for an alias.
490   void printType(Type type) override { initializer.visit(type); }
491 
492   /// Consider the given attribute to be printed for an alias.
493   void printAttribute(Attribute attr) override { initializer.visit(attr); }
494   void printAttributeWithoutType(Attribute attr) override {
495     printAttribute(attr);
496   }
497   LogicalResult printAlias(Attribute attr) override {
498     initializer.visit(attr);
499     return success();
500   }
501   LogicalResult printAlias(Type type) override {
502     initializer.visit(type);
503     return success();
504   }
505 
506   /// Print the given set of attributes with names not included within
507   /// 'elidedAttrs'.
508   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
509                              ArrayRef<StringRef> elidedAttrs = {}) override {
510     if (attrs.empty())
511       return;
512     if (elidedAttrs.empty()) {
513       for (const NamedAttribute &attr : attrs)
514         printAttribute(attr.getValue());
515       return;
516     }
517     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
518                                                   elidedAttrs.end());
519     for (const NamedAttribute &attr : attrs)
520       if (!elidedAttrsSet.contains(attr.getName().strref()))
521         printAttribute(attr.getValue());
522   }
523   void printOptionalAttrDictWithKeyword(
524       ArrayRef<NamedAttribute> attrs,
525       ArrayRef<StringRef> elidedAttrs = {}) override {
526     printOptionalAttrDict(attrs, elidedAttrs);
527   }
528 
529   /// Return a null stream as the output stream, this will ignore any data fed
530   /// to it.
531   raw_ostream &getStream() const override { return os; }
532 
533   /// The following are hooks of `OpAsmPrinter` that are not necessary for
534   /// determining potential aliases.
535   void printFloat(const APFloat &value) override {}
536   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
537   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
538   void printNewline() override {}
539   void printOperand(Value) override {}
540   void printOperand(Value, raw_ostream &os) override {
541     // Users expect the output string to have at least the prefixed % to signal
542     // a value name. To maintain this invariant, emit a name even if it is
543     // guaranteed to go unused.
544     os << "%";
545   }
546   void printKeywordOrString(StringRef) override {}
547   void printSymbolName(StringRef) override {}
548   void printSuccessor(Block *) override {}
549   void printSuccessorAndUseList(Block *, ValueRange) override {}
550   void shadowRegionArgs(Region &, ValueRange) override {}
551 
552   /// The printer flags to use when determining potential aliases.
553   const OpPrintingFlags &printerFlags;
554 
555   /// The initializer to use when identifying aliases.
556   AliasInitializer &initializer;
557 
558   /// A dummy output stream.
559   mutable llvm::raw_null_ostream os;
560 };
561 } // namespace
562 
563 /// Sanitize the given name such that it can be used as a valid identifier. If
564 /// the string needs to be modified in any way, the provided buffer is used to
565 /// store the new copy,
566 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
567                                     StringRef allowedPunctChars = "$._-",
568                                     bool allowTrailingDigit = true) {
569   assert(!name.empty() && "Shouldn't have an empty name here");
570 
571   auto copyNameToBuffer = [&] {
572     for (char ch : name) {
573       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
574         buffer.push_back(ch);
575       else if (ch == ' ')
576         buffer.push_back('_');
577       else
578         buffer.append(llvm::utohexstr((unsigned char)ch));
579     }
580   };
581 
582   // Check to see if this name is valid. If it starts with a digit, then it
583   // could conflict with the autogenerated numeric ID's, so add an underscore
584   // prefix to avoid problems.
585   if (isdigit(name[0])) {
586     buffer.push_back('_');
587     copyNameToBuffer();
588     return buffer;
589   }
590 
591   // If the name ends with a trailing digit, add a '_' to avoid potential
592   // conflicts with autogenerated ID's.
593   if (!allowTrailingDigit && isdigit(name.back())) {
594     copyNameToBuffer();
595     buffer.push_back('_');
596     return buffer;
597   }
598 
599   // Check to see that the name consists of only valid identifier characters.
600   for (char ch : name) {
601     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
602       copyNameToBuffer();
603       return buffer;
604     }
605   }
606 
607   // If there are no invalid characters, return the original name.
608   return name;
609 }
610 
611 /// Given a collection of aliases and symbols, initialize a mapping from a
612 /// symbol to a given alias.
613 template <typename T>
614 static void
615 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
616                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
617                   DenseSet<T> *deferrableAliases = nullptr) {
618   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
619       aliasToSymbol.takeVector();
620   llvm::array_pod_sort(aliases.begin(), aliases.end(),
621                        [](const auto *lhs, const auto *rhs) {
622                          return lhs->first.compare(rhs->first);
623                        });
624 
625   for (auto &it : aliases) {
626     // If there is only one instance for this alias, use the name directly.
627     if (it.second.size() == 1) {
628       T symbol = it.second.front();
629       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
630       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
631       continue;
632     }
633     // Otherwise, add the index to the name.
634     for (int i = 0, e = it.second.size(); i < e; ++i) {
635       T symbol = it.second[i];
636       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
637       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
638     }
639   }
640 }
641 
642 void AliasInitializer::initialize(
643     Operation *op, const OpPrintingFlags &printerFlags,
644     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
645     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
646   // Use a dummy printer when walking the IR so that we can collect the
647   // attributes/types that will actually be used during printing when
648   // considering aliases.
649   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
650   aliasPrinter.print(op);
651 
652   // Initialize the aliases sorted by name.
653   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
654   initializeAliases(aliasToType, typeToAlias);
655 }
656 
657 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
658   if (!visitedAttributes.insert(attr).second) {
659     // If this attribute already has an alias and this instance can't be
660     // deferred, make sure that the alias isn't deferred.
661     if (!canBeDeferred)
662       deferrableAttributes.erase(attr);
663     return;
664   }
665 
666   // Try to generate an alias for this attribute.
667   if (succeeded(generateAlias(attr, aliasToAttr))) {
668     if (canBeDeferred)
669       deferrableAttributes.insert(attr);
670     return;
671   }
672 
673   // Check for any sub elements.
674   if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
675     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
676                                         [&](Type type) { visit(type); });
677   }
678 }
679 
680 void AliasInitializer::visit(Type type) {
681   if (!visitedTypes.insert(type).second)
682     return;
683 
684   // Try to generate an alias for this type.
685   if (succeeded(generateAlias(type, aliasToType)))
686     return;
687 
688   // Check for any sub elements.
689   if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
690     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
691                                         [&](Type type) { visit(type); });
692   }
693 }
694 
695 template <typename T>
696 LogicalResult AliasInitializer::generateAlias(
697     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
698   SmallString<32> nameBuffer;
699   for (const auto &interface : interfaces) {
700     OpAsmDialectInterface::AliasResult result =
701         interface.getAlias(symbol, aliasOS);
702     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
703       continue;
704     nameBuffer = std::move(aliasBuffer);
705     assert(!nameBuffer.empty() && "expected valid alias name");
706     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
707       break;
708   }
709 
710   if (nameBuffer.empty())
711     return failure();
712 
713   SmallString<16> tempBuffer;
714   StringRef name =
715       sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
716                          /*allowTrailingDigit=*/false);
717   name = name.copy(aliasAllocator);
718   aliasToSymbol[name].push_back(symbol);
719   return success();
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // AliasState
724 //===----------------------------------------------------------------------===//
725 
726 namespace {
727 /// This class manages the state for type and attribute aliases.
728 class AliasState {
729 public:
730   // Initialize the internal aliases.
731   void
732   initialize(Operation *op, const OpPrintingFlags &printerFlags,
733              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
734 
735   /// Get an alias for the given attribute if it has one and print it in `os`.
736   /// Returns success if an alias was printed, failure otherwise.
737   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
738 
739   /// Get an alias for the given type if it has one and print it in `os`.
740   /// Returns success if an alias was printed, failure otherwise.
741   LogicalResult getAlias(Type ty, raw_ostream &os) const;
742 
743   /// Print all of the referenced aliases that can not be resolved in a deferred
744   /// manner.
745   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
746     printAliases(os, newLine, /*isDeferred=*/false);
747   }
748 
749   /// Print all of the referenced aliases that support deferred resolution.
750   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
751     printAliases(os, newLine, /*isDeferred=*/true);
752   }
753 
754 private:
755   /// Print all of the referenced aliases that support the provided resolution
756   /// behavior.
757   void printAliases(raw_ostream &os, NewLineCounter &newLine,
758                     bool isDeferred) const;
759 
760   /// Mapping between attribute and alias.
761   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
762   /// Mapping between type and alias.
763   llvm::MapVector<Type, SymbolAlias> typeToAlias;
764 
765   /// An allocator used for alias names.
766   llvm::BumpPtrAllocator aliasAllocator;
767 };
768 } // namespace
769 
770 void AliasState::initialize(
771     Operation *op, const OpPrintingFlags &printerFlags,
772     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
773   AliasInitializer initializer(interfaces, aliasAllocator);
774   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
775 }
776 
777 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
778   auto it = attrToAlias.find(attr);
779   if (it == attrToAlias.end())
780     return failure();
781   it->second.print(os << '#');
782   return success();
783 }
784 
785 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
786   auto it = typeToAlias.find(ty);
787   if (it == typeToAlias.end())
788     return failure();
789 
790   it->second.print(os << '!');
791   return success();
792 }
793 
794 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
795                               bool isDeferred) const {
796   auto filterFn = [=](const auto &aliasIt) {
797     return aliasIt.second.canBeDeferred() == isDeferred;
798   };
799   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
800     it.second.print(os << '#');
801     os << " = " << it.first << newLine;
802   }
803   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
804     it.second.print(os << '!');
805     os << " = type " << it.first << newLine;
806   }
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // SSANameState
811 //===----------------------------------------------------------------------===//
812 
813 namespace {
814 /// Info about block printing: a number which is its position in the visitation
815 /// order, and a name that is used to print reference to it, e.g. ^bb42.
816 struct BlockInfo {
817   int ordering;
818   StringRef name;
819 };
820 
821 /// This class manages the state of SSA value names.
822 class SSANameState {
823 public:
824   /// A sentinel value used for values with names set.
825   enum : unsigned { NameSentinel = ~0U };
826 
827   SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
828 
829   /// Print the SSA identifier for the given value to 'stream'. If
830   /// 'printResultNo' is true, it also presents the result number ('#' number)
831   /// of this value.
832   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
833 
834   /// Return the result indices for each of the result groups registered by this
835   /// operation, or empty if none exist.
836   ArrayRef<int> getOpResultGroups(Operation *op);
837 
838   /// Get the info for the given block.
839   BlockInfo getBlockInfo(Block *block);
840 
841   /// Renumber the arguments for the specified region to the same names as the
842   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
843   /// details.
844   void shadowRegionArgs(Region &region, ValueRange namesToUse);
845 
846 private:
847   /// Number the SSA values within the given IR unit.
848   void numberValuesInRegion(Region &region);
849   void numberValuesInBlock(Block &block);
850   void numberValuesInOp(Operation &op);
851 
852   /// Given a result of an operation 'result', find the result group head
853   /// 'lookupValue' and the result of 'result' within that group in
854   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
855   /// has more than 1 result.
856   void getResultIDAndNumber(OpResult result, Value &lookupValue,
857                             Optional<int> &lookupResultNo) const;
858 
859   /// Set a special value name for the given value.
860   void setValueName(Value value, StringRef name);
861 
862   /// Uniques the given value name within the printer. If the given name
863   /// conflicts, it is automatically renamed.
864   StringRef uniqueValueName(StringRef name);
865 
866   /// This is the value ID for each SSA value. If this returns NameSentinel,
867   /// then the valueID has an entry in valueNames.
868   DenseMap<Value, unsigned> valueIDs;
869   DenseMap<Value, StringRef> valueNames;
870 
871   /// This is a map of operations that contain multiple named result groups,
872   /// i.e. there may be multiple names for the results of the operation. The
873   /// value of this map are the result numbers that start a result group.
874   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
875 
876   /// This maps blocks to there visitation number in the current region as well
877   /// as the string representing their name.
878   DenseMap<Block *, BlockInfo> blockNames;
879 
880   /// This keeps track of all of the non-numeric names that are in flight,
881   /// allowing us to check for duplicates.
882   /// Note: the value of the map is unused.
883   llvm::ScopedHashTable<StringRef, char> usedNames;
884   llvm::BumpPtrAllocator usedNameAllocator;
885 
886   /// This is the next value ID to assign in numbering.
887   unsigned nextValueID = 0;
888   /// This is the next ID to assign to a region entry block argument.
889   unsigned nextArgumentID = 0;
890   /// This is the next ID to assign when a name conflict is detected.
891   unsigned nextConflictID = 0;
892 
893   /// These are the printing flags.  They control, eg., whether to print in
894   /// generic form.
895   OpPrintingFlags printerFlags;
896 };
897 } // namespace
898 
899 SSANameState::SSANameState(
900     Operation *op, const OpPrintingFlags &printerFlags)
901     : printerFlags(printerFlags) {
902   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
903   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
904   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
905 
906   // The naming context includes `nextValueID`, `nextArgumentID`,
907   // `nextConflictID` and `usedNames` scoped HashTable. This information is
908   // carried from the parent region.
909   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
910   using NamingContext =
911       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
912 
913   // Allocator for UsedNamesScopeTy
914   llvm::BumpPtrAllocator allocator;
915 
916   // Add a scope for the top level operation.
917   auto *topLevelNamesScope =
918       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
919 
920   SmallVector<NamingContext, 8> nameContext;
921   for (Region &region : op->getRegions())
922     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
923                                           nextConflictID, topLevelNamesScope));
924 
925   numberValuesInOp(*op);
926 
927   while (!nameContext.empty()) {
928     Region *region;
929     UsedNamesScopeTy *parentScope;
930     std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
931         nameContext.pop_back_val();
932 
933     // When we switch from one subtree to another, pop the scopes(needless)
934     // until the parent scope.
935     while (usedNames.getCurScope() != parentScope) {
936       usedNames.getCurScope()->~UsedNamesScopeTy();
937       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
938              "top level parentScope must be a nullptr");
939     }
940 
941     // Add a scope for the current region.
942     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
943         UsedNamesScopeTy(usedNames);
944 
945     numberValuesInRegion(*region);
946 
947     for (Operation &op : region->getOps())
948       for (Region &region : op.getRegions())
949         nameContext.push_back(std::make_tuple(&region, nextValueID,
950                                               nextArgumentID, nextConflictID,
951                                               curNamesScope));
952   }
953 
954   // Manually remove all the scopes.
955   while (usedNames.getCurScope() != nullptr)
956     usedNames.getCurScope()->~UsedNamesScopeTy();
957 }
958 
959 void SSANameState::printValueID(Value value, bool printResultNo,
960                                 raw_ostream &stream) const {
961   if (!value) {
962     stream << "<<NULL VALUE>>";
963     return;
964   }
965 
966   Optional<int> resultNo;
967   auto lookupValue = value;
968 
969   // If this is an operation result, collect the head lookup value of the result
970   // group and the result number of 'result' within that group.
971   if (OpResult result = value.dyn_cast<OpResult>())
972     getResultIDAndNumber(result, lookupValue, resultNo);
973 
974   auto it = valueIDs.find(lookupValue);
975   if (it == valueIDs.end()) {
976     stream << "<<UNKNOWN SSA VALUE>>";
977     return;
978   }
979 
980   stream << '%';
981   if (it->second != NameSentinel) {
982     stream << it->second;
983   } else {
984     auto nameIt = valueNames.find(lookupValue);
985     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
986     stream << nameIt->second;
987   }
988 
989   if (resultNo.hasValue() && printResultNo)
990     stream << '#' << resultNo;
991 }
992 
993 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
994   auto it = opResultGroups.find(op);
995   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
996 }
997 
998 BlockInfo SSANameState::getBlockInfo(Block *block) {
999   auto it = blockNames.find(block);
1000   BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
1001   return it != blockNames.end() ? it->second : invalidBlock;
1002 }
1003 
1004 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
1005   assert(!region.empty() && "cannot shadow arguments of an empty region");
1006   assert(region.getNumArguments() == namesToUse.size() &&
1007          "incorrect number of names passed in");
1008   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1009          "only KnownIsolatedFromAbove ops can shadow names");
1010 
1011   SmallVector<char, 16> nameStr;
1012   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1013     auto nameToUse = namesToUse[i];
1014     if (nameToUse == nullptr)
1015       continue;
1016     auto nameToReplace = region.getArgument(i);
1017 
1018     nameStr.clear();
1019     llvm::raw_svector_ostream nameStream(nameStr);
1020     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
1021 
1022     // Entry block arguments should already have a pretty "arg" name.
1023     assert(valueIDs[nameToReplace] == NameSentinel);
1024 
1025     // Use the name without the leading %.
1026     auto name = StringRef(nameStream.str()).drop_front();
1027 
1028     // Overwrite the name.
1029     valueNames[nameToReplace] = name.copy(usedNameAllocator);
1030   }
1031 }
1032 
1033 void SSANameState::numberValuesInRegion(Region &region) {
1034   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1035     assert(!valueIDs.count(arg) && "arg numbered multiple times");
1036     assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
1037            "arg not defined in current region");
1038     setValueName(arg, name);
1039   };
1040 
1041   if (!printerFlags.shouldPrintGenericOpForm()) {
1042     if (Operation *op = region.getParentOp()) {
1043       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1044         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1045     }
1046   }
1047 
1048   // Number the values within this region in a breadth-first order.
1049   unsigned nextBlockID = 0;
1050   for (auto &block : region) {
1051     // Each block gets a unique ID, and all of the operations within it get
1052     // numbered as well.
1053     auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1054     if (blockInfoIt.second) {
1055       // This block hasn't been named through `getAsmBlockArgumentNames`, use
1056       // default `^bbNNN` format.
1057       std::string name;
1058       llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1059       blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1060     }
1061     blockInfoIt.first->second.ordering = nextBlockID++;
1062 
1063     numberValuesInBlock(block);
1064   }
1065 }
1066 
1067 void SSANameState::numberValuesInBlock(Block &block) {
1068   // Number the block arguments. We give entry block arguments a special name
1069   // 'arg'.
1070   bool isEntryBlock = block.isEntryBlock();
1071   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1072   llvm::raw_svector_ostream specialName(specialNameBuffer);
1073   for (auto arg : block.getArguments()) {
1074     if (valueIDs.count(arg))
1075       continue;
1076     if (isEntryBlock) {
1077       specialNameBuffer.resize(strlen("arg"));
1078       specialName << nextArgumentID++;
1079     }
1080     setValueName(arg, specialName.str());
1081   }
1082 
1083   // Number the operations in this block.
1084   for (auto &op : block)
1085     numberValuesInOp(op);
1086 }
1087 
1088 void SSANameState::numberValuesInOp(Operation &op) {
1089   // Function used to set the special result names for the operation.
1090   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1091   auto setResultNameFn = [&](Value result, StringRef name) {
1092     assert(!valueIDs.count(result) && "result numbered multiple times");
1093     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1094     setValueName(result, name);
1095 
1096     // Record the result number for groups not anchored at 0.
1097     if (int resultNo = result.cast<OpResult>().getResultNumber())
1098       resultGroups.push_back(resultNo);
1099   };
1100   // Operations can customize the printing of block names in OpAsmOpInterface.
1101   auto setBlockNameFn = [&](Block *block, StringRef name) {
1102     assert(block->getParentOp() == &op &&
1103            "getAsmBlockArgumentNames callback invoked on a block not directly "
1104            "nested under the current operation");
1105     assert(!blockNames.count(block) && "block numbered multiple times");
1106     SmallString<16> tmpBuffer{"^"};
1107     name = sanitizeIdentifier(name, tmpBuffer);
1108     if (name.data() != tmpBuffer.data()) {
1109       tmpBuffer.append(name);
1110       name = tmpBuffer.str();
1111     }
1112     name = name.copy(usedNameAllocator);
1113     blockNames[block] = {-1, name};
1114   };
1115 
1116   if (!printerFlags.shouldPrintGenericOpForm()) {
1117     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1118       asmInterface.getAsmBlockNames(setBlockNameFn);
1119       asmInterface.getAsmResultNames(setResultNameFn);
1120     }
1121   }
1122 
1123   unsigned numResults = op.getNumResults();
1124   if (numResults == 0)
1125     return;
1126   Value resultBegin = op.getResult(0);
1127 
1128   // If the first result wasn't numbered, give it a default number.
1129   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1130     ++nextValueID;
1131 
1132   // If this operation has multiple result groups, mark it.
1133   if (resultGroups.size() != 1) {
1134     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1135     opResultGroups.try_emplace(&op, std::move(resultGroups));
1136   }
1137 }
1138 
1139 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1140                                         Optional<int> &lookupResultNo) const {
1141   Operation *owner = result.getOwner();
1142   if (owner->getNumResults() == 1)
1143     return;
1144   int resultNo = result.getResultNumber();
1145 
1146   // If this operation has multiple result groups, we will need to find the
1147   // one corresponding to this result.
1148   auto resultGroupIt = opResultGroups.find(owner);
1149   if (resultGroupIt == opResultGroups.end()) {
1150     // If not, just use the first result.
1151     lookupResultNo = resultNo;
1152     lookupValue = owner->getResult(0);
1153     return;
1154   }
1155 
1156   // Find the correct index using a binary search, as the groups are ordered.
1157   ArrayRef<int> resultGroups = resultGroupIt->second;
1158   const auto *it = llvm::upper_bound(resultGroups, resultNo);
1159   int groupResultNo = 0, groupSize = 0;
1160 
1161   // If there are no smaller elements, the last result group is the lookup.
1162   if (it == resultGroups.end()) {
1163     groupResultNo = resultGroups.back();
1164     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1165   } else {
1166     // Otherwise, the previous element is the lookup.
1167     groupResultNo = *std::prev(it);
1168     groupSize = *it - groupResultNo;
1169   }
1170 
1171   // We only record the result number for a group of size greater than 1.
1172   if (groupSize != 1)
1173     lookupResultNo = resultNo - groupResultNo;
1174   lookupValue = owner->getResult(groupResultNo);
1175 }
1176 
1177 void SSANameState::setValueName(Value value, StringRef name) {
1178   // If the name is empty, the value uses the default numbering.
1179   if (name.empty()) {
1180     valueIDs[value] = nextValueID++;
1181     return;
1182   }
1183 
1184   valueIDs[value] = NameSentinel;
1185   valueNames[value] = uniqueValueName(name);
1186 }
1187 
1188 StringRef SSANameState::uniqueValueName(StringRef name) {
1189   SmallString<16> tmpBuffer;
1190   name = sanitizeIdentifier(name, tmpBuffer);
1191 
1192   // Check to see if this name is already unique.
1193   if (!usedNames.count(name)) {
1194     name = name.copy(usedNameAllocator);
1195   } else {
1196     // Otherwise, we had a conflict - probe until we find a unique name. This
1197     // is guaranteed to terminate (and usually in a single iteration) because it
1198     // generates new names by incrementing nextConflictID.
1199     SmallString<64> probeName(name);
1200     probeName.push_back('_');
1201     while (true) {
1202       probeName += llvm::utostr(nextConflictID++);
1203       if (!usedNames.count(probeName)) {
1204         name = probeName.str().copy(usedNameAllocator);
1205         break;
1206       }
1207       probeName.resize(name.size() + 1);
1208     }
1209   }
1210 
1211   usedNames.insert(name, char());
1212   return name;
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // AsmState
1217 //===----------------------------------------------------------------------===//
1218 
1219 namespace mlir {
1220 namespace detail {
1221 class AsmStateImpl {
1222 public:
1223   explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1224                         AsmState::LocationMap *locationMap)
1225       : interfaces(op->getContext()), nameState(op, printerFlags),
1226         printerFlags(printerFlags), locationMap(locationMap) {}
1227 
1228   /// Initialize the alias state to enable the printing of aliases.
1229   void initializeAliases(Operation *op) {
1230     aliasState.initialize(op, printerFlags, interfaces);
1231   }
1232 
1233   /// Get the state used for aliases.
1234   AliasState &getAliasState() { return aliasState; }
1235 
1236   /// Get the state used for SSA names.
1237   SSANameState &getSSANameState() { return nameState; }
1238 
1239   /// Get the printer flags.
1240   const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1241 
1242   /// Register the location, line and column, within the buffer that the given
1243   /// operation was printed at.
1244   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1245     if (locationMap)
1246       (*locationMap)[op] = std::make_pair(line, col);
1247   }
1248 
1249 private:
1250   /// Collection of OpAsm interfaces implemented in the context.
1251   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1252 
1253   /// The state used for attribute and type aliases.
1254   AliasState aliasState;
1255 
1256   /// The state used for SSA value names.
1257   SSANameState nameState;
1258 
1259   /// Flags that control op output.
1260   OpPrintingFlags printerFlags;
1261 
1262   /// An optional location map to be populated.
1263   AsmState::LocationMap *locationMap;
1264 };
1265 } // namespace detail
1266 } // namespace mlir
1267 
1268 /// Verifies the operation and switches to generic op printing if verification
1269 /// fails. We need to do this because custom print functions may fail for
1270 /// invalid ops.
1271 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
1272                                               OpPrintingFlags printerFlags) {
1273   if (printerFlags.shouldPrintGenericOpForm() ||
1274       printerFlags.shouldAssumeVerified())
1275     return printerFlags;
1276 
1277   // Ignore errors emitted by the verifier. We check the thread id to avoid
1278   // consuming other threads' errors.
1279   auto parentThreadId = llvm::get_threadid();
1280   ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &) {
1281     return success(parentThreadId == llvm::get_threadid());
1282   });
1283   if (failed(verify(op)))
1284     printerFlags.printGenericOpForm();
1285 
1286   return printerFlags;
1287 }
1288 
1289 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1290                    LocationMap *locationMap)
1291     : impl(std::make_unique<AsmStateImpl>(
1292           op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
1293 AsmState::~AsmState() = default;
1294 
1295 const OpPrintingFlags &AsmState::getPrinterFlags() const {
1296   return impl->getPrinterFlags();
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // AsmPrinter::Impl
1301 //===----------------------------------------------------------------------===//
1302 
1303 namespace mlir {
1304 class AsmPrinter::Impl {
1305 public:
1306   Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1307        AsmStateImpl *state = nullptr)
1308       : os(os), printerFlags(flags), state(state) {}
1309   explicit Impl(Impl &other)
1310       : Impl(other.os, other.printerFlags, other.state) {}
1311 
1312   /// Returns the output stream of the printer.
1313   raw_ostream &getStream() { return os; }
1314 
1315   template <typename Container, typename UnaryFunctor>
1316   inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
1317     llvm::interleaveComma(c, os, eachFn);
1318   }
1319 
1320   /// This enum describes the different kinds of elision for the type of an
1321   /// attribute when printing it.
1322   enum class AttrTypeElision {
1323     /// The type must not be elided,
1324     Never,
1325     /// The type may be elided when it matches the default used in the parser
1326     /// (for example i64 is the default for integer attributes).
1327     May,
1328     /// The type must be elided.
1329     Must
1330   };
1331 
1332   /// Print the given attribute.
1333   void printAttribute(Attribute attr,
1334                       AttrTypeElision typeElision = AttrTypeElision::Never);
1335 
1336   /// Print the alias for the given attribute, return failure if no alias could
1337   /// be printed.
1338   LogicalResult printAlias(Attribute attr);
1339 
1340   void printType(Type type);
1341 
1342   /// Print the alias for the given type, return failure if no alias could
1343   /// be printed.
1344   LogicalResult printAlias(Type type);
1345 
1346   /// Print the given location to the stream. If `allowAlias` is true, this
1347   /// allows for the internal location to use an attribute alias.
1348   void printLocation(LocationAttr loc, bool allowAlias = false);
1349 
1350   void printAffineMap(AffineMap map);
1351   void
1352   printAffineExpr(AffineExpr expr,
1353                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1354   void printAffineConstraint(AffineExpr expr, bool isEq);
1355   void printIntegerSet(IntegerSet set);
1356 
1357 protected:
1358   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1359                              ArrayRef<StringRef> elidedAttrs = {},
1360                              bool withKeyword = false);
1361   void printNamedAttribute(NamedAttribute attr);
1362   void printTrailingLocation(Location loc, bool allowAlias = true);
1363   void printLocationInternal(LocationAttr loc, bool pretty = false);
1364 
1365   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1366   /// used instead of individual elements when the elements attr is large.
1367   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1368 
1369   /// Print a dense string elements attribute.
1370   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1371 
1372   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1373   /// used instead of individual elements when the elements attr is large.
1374   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1375                                      bool allowHex);
1376 
1377   void printDialectAttribute(Attribute attr);
1378   void printDialectType(Type type);
1379 
1380   /// This enum is used to represent the binding strength of the enclosing
1381   /// context that an AffineExprStorage is being printed in, so we can
1382   /// intelligently produce parens.
1383   enum class BindingStrength {
1384     Weak,   // + and -
1385     Strong, // All other binary operators.
1386   };
1387   void printAffineExprInternal(
1388       AffineExpr expr, BindingStrength enclosingTightness,
1389       function_ref<void(unsigned, bool)> printValueName = nullptr);
1390 
1391   /// The output stream for the printer.
1392   raw_ostream &os;
1393 
1394   /// A set of flags to control the printer's behavior.
1395   OpPrintingFlags printerFlags;
1396 
1397   /// An optional printer state for the module.
1398   AsmStateImpl *state;
1399 
1400   /// A tracker for the number of new lines emitted during printing.
1401   NewLineCounter newLine;
1402 };
1403 } // namespace mlir
1404 
1405 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1406   // Check to see if we are printing debug information.
1407   if (!printerFlags.shouldPrintDebugInfo())
1408     return;
1409 
1410   os << " ";
1411   printLocation(loc, /*allowAlias=*/allowAlias);
1412 }
1413 
1414 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
1415   TypeSwitch<LocationAttr>(loc)
1416       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1417         printLocationInternal(loc.getFallbackLocation(), pretty);
1418       })
1419       .Case<UnknownLoc>([&](UnknownLoc loc) {
1420         if (pretty)
1421           os << "[unknown]";
1422         else
1423           os << "unknown";
1424       })
1425       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1426         if (pretty) {
1427           os << loc.getFilename().getValue();
1428         } else {
1429           os << "\"";
1430           printEscapedString(loc.getFilename(), os);
1431           os << "\"";
1432         }
1433         os << ':' << loc.getLine() << ':' << loc.getColumn();
1434       })
1435       .Case<NameLoc>([&](NameLoc loc) {
1436         os << '\"';
1437         printEscapedString(loc.getName(), os);
1438         os << '\"';
1439 
1440         // Print the child if it isn't unknown.
1441         auto childLoc = loc.getChildLoc();
1442         if (!childLoc.isa<UnknownLoc>()) {
1443           os << '(';
1444           printLocationInternal(childLoc, pretty);
1445           os << ')';
1446         }
1447       })
1448       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1449         Location caller = loc.getCaller();
1450         Location callee = loc.getCallee();
1451         if (!pretty)
1452           os << "callsite(";
1453         printLocationInternal(callee, pretty);
1454         if (pretty) {
1455           if (callee.isa<NameLoc>()) {
1456             if (caller.isa<FileLineColLoc>()) {
1457               os << " at ";
1458             } else {
1459               os << newLine << " at ";
1460             }
1461           } else {
1462             os << newLine << " at ";
1463           }
1464         } else {
1465           os << " at ";
1466         }
1467         printLocationInternal(caller, pretty);
1468         if (!pretty)
1469           os << ")";
1470       })
1471       .Case<FusedLoc>([&](FusedLoc loc) {
1472         if (!pretty)
1473           os << "fused";
1474         if (Attribute metadata = loc.getMetadata())
1475           os << '<' << metadata << '>';
1476         os << '[';
1477         interleave(
1478             loc.getLocations(),
1479             [&](Location loc) { printLocationInternal(loc, pretty); },
1480             [&]() { os << ", "; });
1481         os << ']';
1482       });
1483 }
1484 
1485 /// Print a floating point value in a way that the parser will be able to
1486 /// round-trip losslessly.
1487 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1488   // We would like to output the FP constant value in exponential notation,
1489   // but we cannot do this if doing so will lose precision.  Check here to
1490   // make sure that we only output it in exponential format if we can parse
1491   // the value back and get the same value.
1492   bool isInf = apValue.isInfinity();
1493   bool isNaN = apValue.isNaN();
1494   if (!isInf && !isNaN) {
1495     SmallString<128> strValue;
1496     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1497                      /*TruncateZero=*/false);
1498 
1499     // Check to make sure that the stringized number is not some string like
1500     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1501     // that the string matches the "[-+]?[0-9]" regex.
1502     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1503             ((strValue[0] == '-' || strValue[0] == '+') &&
1504              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1505            "[-+]?[0-9] regex does not match!");
1506 
1507     // Parse back the stringized version and check that the value is equal
1508     // (i.e., there is no precision loss).
1509     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1510       os << strValue;
1511       return;
1512     }
1513 
1514     // If it is not, use the default format of APFloat instead of the
1515     // exponential notation.
1516     strValue.clear();
1517     apValue.toString(strValue);
1518 
1519     // Make sure that we can parse the default form as a float.
1520     if (strValue.str().contains('.')) {
1521       os << strValue;
1522       return;
1523     }
1524   }
1525 
1526   // Print special values in hexadecimal format. The sign bit should be included
1527   // in the literal.
1528   SmallVector<char, 16> str;
1529   APInt apInt = apValue.bitcastToAPInt();
1530   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1531                  /*formatAsCLiteral=*/true);
1532   os << str;
1533 }
1534 
1535 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
1536   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1537     return printLocationInternal(loc, /*pretty=*/true);
1538 
1539   os << "loc(";
1540   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1541     printLocationInternal(loc);
1542   os << ')';
1543 }
1544 
1545 /// Returns true if the given dialect symbol data is simple enough to print in
1546 /// the pretty form, i.e. without the enclosing "".
1547 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1548   // The name must start with an identifier.
1549   if (symName.empty() || !isalpha(symName.front()))
1550     return false;
1551 
1552   // Ignore all the characters that are valid in an identifier in the symbol
1553   // name.
1554   symName = symName.drop_while(
1555       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1556   if (symName.empty())
1557     return true;
1558 
1559   // If we got to an unexpected character, then it must be a <>.  Check those
1560   // recursively.
1561   if (symName.front() != '<' || symName.back() != '>')
1562     return false;
1563 
1564   SmallVector<char, 8> nestedPunctuation;
1565   do {
1566     // If we ran out of characters, then we had a punctuation mismatch.
1567     if (symName.empty())
1568       return false;
1569 
1570     auto c = symName.front();
1571     symName = symName.drop_front();
1572 
1573     switch (c) {
1574     // We never allow null characters. This is an EOF indicator for the lexer
1575     // which we could handle, but isn't important for any known dialect.
1576     case '\0':
1577       return false;
1578     case '<':
1579     case '[':
1580     case '(':
1581     case '{':
1582       nestedPunctuation.push_back(c);
1583       continue;
1584     case '-':
1585       // Treat `->` as a special token.
1586       if (!symName.empty() && symName.front() == '>') {
1587         symName = symName.drop_front();
1588         continue;
1589       }
1590       break;
1591     // Reject types with mismatched brackets.
1592     case '>':
1593       if (nestedPunctuation.pop_back_val() != '<')
1594         return false;
1595       break;
1596     case ']':
1597       if (nestedPunctuation.pop_back_val() != '[')
1598         return false;
1599       break;
1600     case ')':
1601       if (nestedPunctuation.pop_back_val() != '(')
1602         return false;
1603       break;
1604     case '}':
1605       if (nestedPunctuation.pop_back_val() != '{')
1606         return false;
1607       break;
1608     default:
1609       continue;
1610     }
1611 
1612     // We're done when the punctuation is fully matched.
1613   } while (!nestedPunctuation.empty());
1614 
1615   // If there were extra characters, then we failed.
1616   return symName.empty();
1617 }
1618 
1619 /// Print the given dialect symbol to the stream.
1620 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1621                                StringRef dialectName, StringRef symString) {
1622   os << symPrefix << dialectName;
1623 
1624   // If this symbol name is simple enough, print it directly in pretty form,
1625   // otherwise, we print it as an escaped string.
1626   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1627     os << '.' << symString;
1628     return;
1629   }
1630 
1631   os << "<\"";
1632   llvm::printEscapedString(symString, os);
1633   os << "\">";
1634 }
1635 
1636 /// Returns true if the given string can be represented as a bare identifier.
1637 static bool isBareIdentifier(StringRef name) {
1638   // By making this unsigned, the value passed in to isalnum will always be
1639   // in the range 0-255. This is important when building with MSVC because
1640   // its implementation will assert. This situation can arise when dealing
1641   // with UTF-8 multibyte characters.
1642   if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
1643     return false;
1644   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1645     return isalnum(c) || c == '_' || c == '$' || c == '.';
1646   });
1647 }
1648 
1649 /// Print the given string as a keyword, or a quoted and escaped string if it
1650 /// has any special or non-printable characters in it.
1651 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
1652   // If it can be represented as a bare identifier, write it directly.
1653   if (isBareIdentifier(keyword)) {
1654     os << keyword;
1655     return;
1656   }
1657 
1658   // Otherwise, output the keyword wrapped in quotes with proper escaping.
1659   os << "\"";
1660   printEscapedString(keyword, os);
1661   os << '"';
1662 }
1663 
1664 /// Print the given string as a symbol reference. A symbol reference is
1665 /// represented as a string prefixed with '@'. The reference is surrounded with
1666 /// ""'s and escaped if it has any special or non-printable characters in it.
1667 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1668   assert(!symbolRef.empty() && "expected valid symbol reference");
1669   os << '@';
1670   printKeywordOrString(symbolRef, os);
1671 }
1672 
1673 // Print out a valid ElementsAttr that is succinct and can represent any
1674 // potential shape/type, for use when eliding a large ElementsAttr.
1675 //
1676 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1677 // hopefully alert readers to the fact that this has been elided.
1678 //
1679 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1680 // accept the string "elided". The first string must be a registered dialect
1681 // name and the latter must be a hex constant.
1682 static void printElidedElementsAttr(raw_ostream &os) {
1683   os << R"(opaque<"elided_large_const", "0xDEADBEEF">)";
1684 }
1685 
1686 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
1687   return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
1688 }
1689 
1690 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
1691   return success(state && succeeded(state->getAliasState().getAlias(type, os)));
1692 }
1693 
1694 void AsmPrinter::Impl::printAttribute(Attribute attr,
1695                                       AttrTypeElision typeElision) {
1696   if (!attr) {
1697     os << "<<NULL ATTRIBUTE>>";
1698     return;
1699   }
1700 
1701   // Try to print an alias for this attribute.
1702   if (succeeded(printAlias(attr)))
1703     return;
1704 
1705   if (!isa<BuiltinDialect>(attr.getDialect()))
1706     return printDialectAttribute(attr);
1707 
1708   auto attrType = attr.getType();
1709   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1710     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1711                        opaqueAttr.getAttrData());
1712   } else if (attr.isa<UnitAttr>()) {
1713     os << "unit";
1714     return;
1715   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1716     os << '{';
1717     interleaveComma(dictAttr.getValue(),
1718                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1719     os << '}';
1720 
1721   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1722     if (attrType.isSignlessInteger(1)) {
1723       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1724 
1725       // Boolean integer attributes always elides the type.
1726       return;
1727     }
1728 
1729     // Only print attributes as unsigned if they are explicitly unsigned or are
1730     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1731     // values print as signed.
1732     bool isUnsigned =
1733         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1734     intAttr.getValue().print(os, !isUnsigned);
1735 
1736     // IntegerAttr elides the type if I64.
1737     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1738       return;
1739 
1740   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1741     printFloatValue(floatAttr.getValue(), os);
1742 
1743     // FloatAttr elides the type if F64.
1744     if (typeElision == AttrTypeElision::May && attrType.isF64())
1745       return;
1746 
1747   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1748     os << '"';
1749     printEscapedString(strAttr.getValue(), os);
1750     os << '"';
1751 
1752   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1753     os << '[';
1754     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1755       printAttribute(attr, AttrTypeElision::May);
1756     });
1757     os << ']';
1758 
1759   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1760     os << "affine_map<";
1761     affineMapAttr.getValue().print(os);
1762     os << '>';
1763 
1764     // AffineMap always elides the type.
1765     return;
1766 
1767   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1768     os << "affine_set<";
1769     integerSetAttr.getValue().print(os);
1770     os << '>';
1771 
1772     // IntegerSet always elides the type.
1773     return;
1774 
1775   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1776     printType(typeAttr.getValue());
1777 
1778   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1779     printSymbolReference(refAttr.getRootReference().getValue(), os);
1780     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1781       os << "::";
1782       printSymbolReference(nestedRef.getValue(), os);
1783     }
1784 
1785   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1786     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1787       printElidedElementsAttr(os);
1788     } else {
1789       os << "opaque<" << opaqueAttr.getDialect() << ", \"0x"
1790          << llvm::toHex(opaqueAttr.getValue()) << "\">";
1791     }
1792 
1793   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1794     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1795       printElidedElementsAttr(os);
1796     } else {
1797       os << "dense<";
1798       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1799       os << '>';
1800     }
1801 
1802   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1803     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1804       printElidedElementsAttr(os);
1805     } else {
1806       os << "dense<";
1807       printDenseStringElementsAttr(strEltAttr);
1808       os << '>';
1809     }
1810 
1811   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1812     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1813         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1814       printElidedElementsAttr(os);
1815     } else {
1816       os << "sparse<";
1817       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1818       if (indices.getNumElements() != 0) {
1819         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1820         os << ", ";
1821         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1822       }
1823       os << '>';
1824     }
1825 
1826   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1827     printLocation(locAttr);
1828   }
1829   // Don't print the type if we must elide it, or if it is a None type.
1830   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1831     os << " : ";
1832     printType(attrType);
1833   }
1834 }
1835 
1836 /// Print the integer element of a DenseElementsAttr.
1837 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1838                                  bool isSigned) {
1839   if (value.getBitWidth() == 1)
1840     os << (value.getBoolValue() ? "true" : "false");
1841   else
1842     value.print(os, isSigned);
1843 }
1844 
1845 static void
1846 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1847                            function_ref<void(unsigned)> printEltFn) {
1848   // Special case for 0-d and splat tensors.
1849   if (isSplat)
1850     return printEltFn(0);
1851 
1852   // Special case for degenerate tensors.
1853   auto numElements = type.getNumElements();
1854   if (numElements == 0)
1855     return;
1856 
1857   // We use a mixed-radix counter to iterate through the shape. When we bump a
1858   // non-least-significant digit, we emit a close bracket. When we next emit an
1859   // element we re-open all closed brackets.
1860 
1861   // The mixed-radix counter, with radices in 'shape'.
1862   int64_t rank = type.getRank();
1863   SmallVector<unsigned, 4> counter(rank, 0);
1864   // The number of brackets that have been opened and not closed.
1865   unsigned openBrackets = 0;
1866 
1867   auto shape = type.getShape();
1868   auto bumpCounter = [&] {
1869     // Bump the least significant digit.
1870     ++counter[rank - 1];
1871     // Iterate backwards bubbling back the increment.
1872     for (unsigned i = rank - 1; i > 0; --i)
1873       if (counter[i] >= shape[i]) {
1874         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1875         counter[i] = 0;
1876         ++counter[i - 1];
1877         --openBrackets;
1878         os << ']';
1879       }
1880   };
1881 
1882   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1883     if (idx != 0)
1884       os << ", ";
1885     while (openBrackets++ < rank)
1886       os << '[';
1887     openBrackets = rank;
1888     printEltFn(idx);
1889     bumpCounter();
1890   }
1891   while (openBrackets-- > 0)
1892     os << ']';
1893 }
1894 
1895 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
1896                                               bool allowHex) {
1897   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1898     return printDenseStringElementsAttr(stringAttr);
1899 
1900   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1901                                 allowHex);
1902 }
1903 
1904 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
1905     DenseIntOrFPElementsAttr attr, bool allowHex) {
1906   auto type = attr.getType();
1907   auto elementType = type.getElementType();
1908 
1909   // Check to see if we should format this attribute as a hex string.
1910   auto numElements = type.getNumElements();
1911   if (!attr.isSplat() && allowHex &&
1912       shouldPrintElementsAttrWithHex(numElements)) {
1913     ArrayRef<char> rawData = attr.getRawData();
1914     if (llvm::support::endian::system_endianness() ==
1915         llvm::support::endianness::big) {
1916       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1917       // machines. It is converted here to print in LE format.
1918       SmallVector<char, 64> outDataVec(rawData.size());
1919       MutableArrayRef<char> convRawData(outDataVec);
1920       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1921           rawData, convRawData, type);
1922       os << '"' << "0x"
1923          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1924          << "\"";
1925     } else {
1926       os << '"' << "0x"
1927          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1928     }
1929 
1930     return;
1931   }
1932 
1933   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1934     Type complexElementType = complexTy.getElementType();
1935     // Note: The if and else below had a common lambda function which invoked
1936     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1937     // and hence was replaced.
1938     if (complexElementType.isa<IntegerType>()) {
1939       bool isSigned = !complexElementType.isUnsignedInteger();
1940       auto valueIt = attr.value_begin<std::complex<APInt>>();
1941       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1942         auto complexValue = *(valueIt + index);
1943         os << "(";
1944         printDenseIntElement(complexValue.real(), os, isSigned);
1945         os << ",";
1946         printDenseIntElement(complexValue.imag(), os, isSigned);
1947         os << ")";
1948       });
1949     } else {
1950       auto valueIt = attr.value_begin<std::complex<APFloat>>();
1951       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1952         auto complexValue = *(valueIt + index);
1953         os << "(";
1954         printFloatValue(complexValue.real(), os);
1955         os << ",";
1956         printFloatValue(complexValue.imag(), os);
1957         os << ")";
1958       });
1959     }
1960   } else if (elementType.isIntOrIndex()) {
1961     bool isSigned = !elementType.isUnsignedInteger();
1962     auto valueIt = attr.value_begin<APInt>();
1963     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1964       printDenseIntElement(*(valueIt + index), os, isSigned);
1965     });
1966   } else {
1967     assert(elementType.isa<FloatType>() && "unexpected element type");
1968     auto valueIt = attr.value_begin<APFloat>();
1969     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1970       printFloatValue(*(valueIt + index), os);
1971     });
1972   }
1973 }
1974 
1975 void AsmPrinter::Impl::printDenseStringElementsAttr(
1976     DenseStringElementsAttr attr) {
1977   ArrayRef<StringRef> data = attr.getRawStringData();
1978   auto printFn = [&](unsigned index) {
1979     os << "\"";
1980     printEscapedString(data[index], os);
1981     os << "\"";
1982   };
1983   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1984 }
1985 
1986 void AsmPrinter::Impl::printType(Type type) {
1987   if (!type) {
1988     os << "<<NULL TYPE>>";
1989     return;
1990   }
1991 
1992   // Try to print an alias for this type.
1993   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1994     return;
1995 
1996   TypeSwitch<Type>(type)
1997       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1998         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1999                            opaqueTy.getTypeData());
2000       })
2001       .Case<IndexType>([&](Type) { os << "index"; })
2002       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2003       .Case<Float16Type>([&](Type) { os << "f16"; })
2004       .Case<Float32Type>([&](Type) { os << "f32"; })
2005       .Case<Float64Type>([&](Type) { os << "f64"; })
2006       .Case<Float80Type>([&](Type) { os << "f80"; })
2007       .Case<Float128Type>([&](Type) { os << "f128"; })
2008       .Case<IntegerType>([&](IntegerType integerTy) {
2009         if (integerTy.isSigned())
2010           os << 's';
2011         else if (integerTy.isUnsigned())
2012           os << 'u';
2013         os << 'i' << integerTy.getWidth();
2014       })
2015       .Case<FunctionType>([&](FunctionType funcTy) {
2016         os << '(';
2017         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2018         os << ") -> ";
2019         ArrayRef<Type> results = funcTy.getResults();
2020         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
2021           printType(results[0]);
2022         } else {
2023           os << '(';
2024           interleaveComma(results, [&](Type ty) { printType(ty); });
2025           os << ')';
2026         }
2027       })
2028       .Case<VectorType>([&](VectorType vectorTy) {
2029         os << "vector<";
2030         auto vShape = vectorTy.getShape();
2031         unsigned lastDim = vShape.size();
2032         unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
2033         unsigned dimIdx = 0;
2034         for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
2035           os << vShape[dimIdx] << 'x';
2036         if (vectorTy.isScalable()) {
2037           os << '[';
2038           unsigned secondToLastDim = lastDim - 1;
2039           for (; dimIdx < secondToLastDim; dimIdx++)
2040             os << vShape[dimIdx] << 'x';
2041           os << vShape[dimIdx] << "]x";
2042         }
2043         printType(vectorTy.getElementType());
2044         os << '>';
2045       })
2046       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2047         os << "tensor<";
2048         for (int64_t dim : tensorTy.getShape()) {
2049           if (ShapedType::isDynamic(dim))
2050             os << '?';
2051           else
2052             os << dim;
2053           os << 'x';
2054         }
2055         printType(tensorTy.getElementType());
2056         // Only print the encoding attribute value if set.
2057         if (tensorTy.getEncoding()) {
2058           os << ", ";
2059           printAttribute(tensorTy.getEncoding());
2060         }
2061         os << '>';
2062       })
2063       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2064         os << "tensor<*x";
2065         printType(tensorTy.getElementType());
2066         os << '>';
2067       })
2068       .Case<MemRefType>([&](MemRefType memrefTy) {
2069         os << "memref<";
2070         for (int64_t dim : memrefTy.getShape()) {
2071           if (ShapedType::isDynamic(dim))
2072             os << '?';
2073           else
2074             os << dim;
2075           os << 'x';
2076         }
2077         printType(memrefTy.getElementType());
2078         if (!memrefTy.getLayout().isIdentity()) {
2079           os << ", ";
2080           printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2081         }
2082         // Only print the memory space if it is the non-default one.
2083         if (memrefTy.getMemorySpace()) {
2084           os << ", ";
2085           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2086         }
2087         os << '>';
2088       })
2089       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2090         os << "memref<*x";
2091         printType(memrefTy.getElementType());
2092         // Only print the memory space if it is the non-default one.
2093         if (memrefTy.getMemorySpace()) {
2094           os << ", ";
2095           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2096         }
2097         os << '>';
2098       })
2099       .Case<ComplexType>([&](ComplexType complexTy) {
2100         os << "complex<";
2101         printType(complexTy.getElementType());
2102         os << '>';
2103       })
2104       .Case<TupleType>([&](TupleType tupleTy) {
2105         os << "tuple<";
2106         interleaveComma(tupleTy.getTypes(),
2107                         [&](Type type) { printType(type); });
2108         os << '>';
2109       })
2110       .Case<NoneType>([&](Type) { os << "none"; })
2111       .Default([&](Type type) { return printDialectType(type); });
2112 }
2113 
2114 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2115                                              ArrayRef<StringRef> elidedAttrs,
2116                                              bool withKeyword) {
2117   // If there are no attributes, then there is nothing to be done.
2118   if (attrs.empty())
2119     return;
2120 
2121   // Functor used to print a filtered attribute list.
2122   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2123     // Print the 'attributes' keyword if necessary.
2124     if (withKeyword)
2125       os << " attributes";
2126 
2127     // Otherwise, print them all out in braces.
2128     os << " {";
2129     interleaveComma(filteredAttrs,
2130                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
2131     os << '}';
2132   };
2133 
2134   // If no attributes are elided, we can directly print with no filtering.
2135   if (elidedAttrs.empty())
2136     return printFilteredAttributesFn(attrs);
2137 
2138   // Otherwise, filter out any attributes that shouldn't be included.
2139   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2140                                                 elidedAttrs.end());
2141   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2142     return !elidedAttrsSet.contains(attr.getName().strref());
2143   });
2144   if (!filteredAttrs.empty())
2145     printFilteredAttributesFn(filteredAttrs);
2146 }
2147 
2148 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2149   // Print the name without quotes if possible.
2150   ::printKeywordOrString(attr.getName().strref(), os);
2151 
2152   // Pretty printing elides the attribute value for unit attributes.
2153   if (attr.getValue().isa<UnitAttr>())
2154     return;
2155 
2156   os << " = ";
2157   printAttribute(attr.getValue());
2158 }
2159 
2160 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2161   auto &dialect = attr.getDialect();
2162 
2163   // Ask the dialect to serialize the attribute to a string.
2164   std::string attrName;
2165   {
2166     llvm::raw_string_ostream attrNameStr(attrName);
2167     Impl subPrinter(attrNameStr, printerFlags, state);
2168     DialectAsmPrinter printer(subPrinter);
2169     dialect.printAttribute(attr, printer);
2170   }
2171   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2172 }
2173 
2174 void AsmPrinter::Impl::printDialectType(Type type) {
2175   auto &dialect = type.getDialect();
2176 
2177   // Ask the dialect to serialize the type to a string.
2178   std::string typeName;
2179   {
2180     llvm::raw_string_ostream typeNameStr(typeName);
2181     Impl subPrinter(typeNameStr, printerFlags, state);
2182     DialectAsmPrinter printer(subPrinter);
2183     dialect.printType(type, printer);
2184   }
2185   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2186 }
2187 
2188 //===--------------------------------------------------------------------===//
2189 // AsmPrinter
2190 //===--------------------------------------------------------------------===//
2191 
2192 AsmPrinter::~AsmPrinter() = default;
2193 
2194 raw_ostream &AsmPrinter::getStream() const {
2195   assert(impl && "expected AsmPrinter::getStream to be overriden");
2196   return impl->getStream();
2197 }
2198 
2199 /// Print the given floating point value in a stablized form.
2200 void AsmPrinter::printFloat(const APFloat &value) {
2201   assert(impl && "expected AsmPrinter::printFloat to be overriden");
2202   printFloatValue(value, impl->getStream());
2203 }
2204 
2205 void AsmPrinter::printType(Type type) {
2206   assert(impl && "expected AsmPrinter::printType to be overriden");
2207   impl->printType(type);
2208 }
2209 
2210 void AsmPrinter::printAttribute(Attribute attr) {
2211   assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2212   impl->printAttribute(attr);
2213 }
2214 
2215 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2216   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2217   return impl->printAlias(attr);
2218 }
2219 
2220 LogicalResult AsmPrinter::printAlias(Type type) {
2221   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2222   return impl->printAlias(type);
2223 }
2224 
2225 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2226   assert(impl &&
2227          "expected AsmPrinter::printAttributeWithoutType to be overriden");
2228   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2229 }
2230 
2231 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2232   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2233   ::printKeywordOrString(keyword, impl->getStream());
2234 }
2235 
2236 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2237   assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2238   ::printSymbolReference(symbolRef, impl->getStream());
2239 }
2240 
2241 //===----------------------------------------------------------------------===//
2242 // Affine expressions and maps
2243 //===----------------------------------------------------------------------===//
2244 
2245 void AsmPrinter::Impl::printAffineExpr(
2246     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2247   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2248 }
2249 
2250 void AsmPrinter::Impl::printAffineExprInternal(
2251     AffineExpr expr, BindingStrength enclosingTightness,
2252     function_ref<void(unsigned, bool)> printValueName) {
2253   const char *binopSpelling = nullptr;
2254   switch (expr.getKind()) {
2255   case AffineExprKind::SymbolId: {
2256     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2257     if (printValueName)
2258       printValueName(pos, /*isSymbol=*/true);
2259     else
2260       os << 's' << pos;
2261     return;
2262   }
2263   case AffineExprKind::DimId: {
2264     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2265     if (printValueName)
2266       printValueName(pos, /*isSymbol=*/false);
2267     else
2268       os << 'd' << pos;
2269     return;
2270   }
2271   case AffineExprKind::Constant:
2272     os << expr.cast<AffineConstantExpr>().getValue();
2273     return;
2274   case AffineExprKind::Add:
2275     binopSpelling = " + ";
2276     break;
2277   case AffineExprKind::Mul:
2278     binopSpelling = " * ";
2279     break;
2280   case AffineExprKind::FloorDiv:
2281     binopSpelling = " floordiv ";
2282     break;
2283   case AffineExprKind::CeilDiv:
2284     binopSpelling = " ceildiv ";
2285     break;
2286   case AffineExprKind::Mod:
2287     binopSpelling = " mod ";
2288     break;
2289   }
2290 
2291   auto binOp = expr.cast<AffineBinaryOpExpr>();
2292   AffineExpr lhsExpr = binOp.getLHS();
2293   AffineExpr rhsExpr = binOp.getRHS();
2294 
2295   // Handle tightly binding binary operators.
2296   if (binOp.getKind() != AffineExprKind::Add) {
2297     if (enclosingTightness == BindingStrength::Strong)
2298       os << '(';
2299 
2300     // Pretty print multiplication with -1.
2301     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2302     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2303         rhsConst.getValue() == -1) {
2304       os << "-";
2305       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2306       if (enclosingTightness == BindingStrength::Strong)
2307         os << ')';
2308       return;
2309     }
2310 
2311     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2312 
2313     os << binopSpelling;
2314     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2315 
2316     if (enclosingTightness == BindingStrength::Strong)
2317       os << ')';
2318     return;
2319   }
2320 
2321   // Print out special "pretty" forms for add.
2322   if (enclosingTightness == BindingStrength::Strong)
2323     os << '(';
2324 
2325   // Pretty print addition to a product that has a negative operand as a
2326   // subtraction.
2327   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2328     if (rhs.getKind() == AffineExprKind::Mul) {
2329       AffineExpr rrhsExpr = rhs.getRHS();
2330       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2331         if (rrhs.getValue() == -1) {
2332           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2333                                   printValueName);
2334           os << " - ";
2335           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2336             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2337                                     printValueName);
2338           } else {
2339             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2340                                     printValueName);
2341           }
2342 
2343           if (enclosingTightness == BindingStrength::Strong)
2344             os << ')';
2345           return;
2346         }
2347 
2348         if (rrhs.getValue() < -1) {
2349           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2350                                   printValueName);
2351           os << " - ";
2352           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2353                                   printValueName);
2354           os << " * " << -rrhs.getValue();
2355           if (enclosingTightness == BindingStrength::Strong)
2356             os << ')';
2357           return;
2358         }
2359       }
2360     }
2361   }
2362 
2363   // Pretty print addition to a negative number as a subtraction.
2364   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2365     if (rhsConst.getValue() < 0) {
2366       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2367       os << " - " << -rhsConst.getValue();
2368       if (enclosingTightness == BindingStrength::Strong)
2369         os << ')';
2370       return;
2371     }
2372   }
2373 
2374   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2375 
2376   os << " + ";
2377   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2378 
2379   if (enclosingTightness == BindingStrength::Strong)
2380     os << ')';
2381 }
2382 
2383 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2384   printAffineExprInternal(expr, BindingStrength::Weak);
2385   isEq ? os << " == 0" : os << " >= 0";
2386 }
2387 
2388 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2389   // Dimension identifiers.
2390   os << '(';
2391   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2392     os << 'd' << i << ", ";
2393   if (map.getNumDims() >= 1)
2394     os << 'd' << map.getNumDims() - 1;
2395   os << ')';
2396 
2397   // Symbolic identifiers.
2398   if (map.getNumSymbols() != 0) {
2399     os << '[';
2400     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2401       os << 's' << i << ", ";
2402     if (map.getNumSymbols() >= 1)
2403       os << 's' << map.getNumSymbols() - 1;
2404     os << ']';
2405   }
2406 
2407   // Result affine expressions.
2408   os << " -> (";
2409   interleaveComma(map.getResults(),
2410                   [&](AffineExpr expr) { printAffineExpr(expr); });
2411   os << ')';
2412 }
2413 
2414 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
2415   // Dimension identifiers.
2416   os << '(';
2417   for (unsigned i = 1; i < set.getNumDims(); ++i)
2418     os << 'd' << i - 1 << ", ";
2419   if (set.getNumDims() >= 1)
2420     os << 'd' << set.getNumDims() - 1;
2421   os << ')';
2422 
2423   // Symbolic identifiers.
2424   if (set.getNumSymbols() != 0) {
2425     os << '[';
2426     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2427       os << 's' << i << ", ";
2428     if (set.getNumSymbols() >= 1)
2429       os << 's' << set.getNumSymbols() - 1;
2430     os << ']';
2431   }
2432 
2433   // Print constraints.
2434   os << " : (";
2435   int numConstraints = set.getNumConstraints();
2436   for (int i = 1; i < numConstraints; ++i) {
2437     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2438     os << ", ";
2439   }
2440   if (numConstraints >= 1)
2441     printAffineConstraint(set.getConstraint(numConstraints - 1),
2442                           set.isEq(numConstraints - 1));
2443   os << ')';
2444 }
2445 
2446 //===----------------------------------------------------------------------===//
2447 // OperationPrinter
2448 //===----------------------------------------------------------------------===//
2449 
2450 namespace {
2451 /// This class contains the logic for printing operations, regions, and blocks.
2452 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
2453 public:
2454   using Impl = AsmPrinter::Impl;
2455   using Impl::printType;
2456 
2457   explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
2458       : Impl(os, state.getPrinterFlags(), &state),
2459         OpAsmPrinter(static_cast<Impl &>(*this)) {}
2460 
2461   /// Print the given top-level operation.
2462   void printTopLevelOperation(Operation *op);
2463 
2464   /// Print the given operation with its indent and location.
2465   void print(Operation *op);
2466   /// Print the bare location, not including indentation/location/etc.
2467   void printOperation(Operation *op);
2468   /// Print the given operation in the generic form.
2469   void printGenericOp(Operation *op, bool printOpName) override;
2470 
2471   /// Print the name of the given block.
2472   void printBlockName(Block *block);
2473 
2474   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2475   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2476   /// operation of the block is not printed.
2477   void print(Block *block, bool printBlockArgs = true,
2478              bool printBlockTerminator = true);
2479 
2480   /// Print the ID of the given value, optionally with its result number.
2481   void printValueID(Value value, bool printResultNo = true,
2482                     raw_ostream *streamOverride = nullptr) const;
2483 
2484   //===--------------------------------------------------------------------===//
2485   // OpAsmPrinter methods
2486   //===--------------------------------------------------------------------===//
2487 
2488   /// Print a newline and indent the printer to the start of the current
2489   /// operation.
2490   void printNewline() override {
2491     os << newLine;
2492     os.indent(currentIndent);
2493   }
2494 
2495   /// Print a block argument in the usual format of:
2496   ///   %ssaName : type {attr1=42} loc("here")
2497   /// where location printing is controlled by the standard internal option.
2498   /// You may pass omitType=true to not print a type, and pass an empty
2499   /// attribute list if you don't care for attributes.
2500   void printRegionArgument(BlockArgument arg,
2501                            ArrayRef<NamedAttribute> argAttrs = {},
2502                            bool omitType = false) override;
2503 
2504   /// Print the ID for the given value.
2505   void printOperand(Value value) override { printValueID(value); }
2506   void printOperand(Value value, raw_ostream &os) override {
2507     printValueID(value, /*printResultNo=*/true, &os);
2508   }
2509 
2510   /// Print an optional attribute dictionary with a given set of elided values.
2511   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2512                              ArrayRef<StringRef> elidedAttrs = {}) override {
2513     Impl::printOptionalAttrDict(attrs, elidedAttrs);
2514   }
2515   void printOptionalAttrDictWithKeyword(
2516       ArrayRef<NamedAttribute> attrs,
2517       ArrayRef<StringRef> elidedAttrs = {}) override {
2518     Impl::printOptionalAttrDict(attrs, elidedAttrs,
2519                                 /*withKeyword=*/true);
2520   }
2521 
2522   /// Print the given successor.
2523   void printSuccessor(Block *successor) override;
2524 
2525   /// Print an operation successor with the operands used for the block
2526   /// arguments.
2527   void printSuccessorAndUseList(Block *successor,
2528                                 ValueRange succOperands) override;
2529 
2530   /// Print the given region.
2531   void printRegion(Region &region, bool printEntryBlockArgs,
2532                    bool printBlockTerminators, bool printEmptyBlock) override;
2533 
2534   /// Renumber the arguments for the specified region to the same names as the
2535   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2536   /// operations. If any entry in namesToUse is null, the corresponding
2537   /// argument name is left alone.
2538   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2539     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2540   }
2541 
2542   /// Print the given affine map with the symbol and dimension operands printed
2543   /// inline with the map.
2544   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2545                               ValueRange operands) override;
2546 
2547   /// Print the given affine expression with the symbol and dimension operands
2548   /// printed inline with the expression.
2549   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2550                                ValueRange symOperands) override;
2551 
2552 private:
2553   // Contains the stack of default dialects to use when printing regions.
2554   // A new dialect is pushed to the stack before parsing regions nested under an
2555   // operation implementing `OpAsmOpInterface`, and popped when done. At the
2556   // top-level we start with "builtin" as the default, so that the top-level
2557   // `module` operation prints as-is.
2558   SmallVector<StringRef> defaultDialectStack{"builtin"};
2559 
2560   /// The number of spaces used for indenting nested operations.
2561   const static unsigned indentWidth = 2;
2562 
2563   // This is the current indentation level for nested structures.
2564   unsigned currentIndent = 0;
2565 };
2566 } // namespace
2567 
2568 void OperationPrinter::printTopLevelOperation(Operation *op) {
2569   // Output the aliases at the top level that can't be deferred.
2570   state->getAliasState().printNonDeferredAliases(os, newLine);
2571 
2572   // Print the module.
2573   print(op);
2574   os << newLine;
2575 
2576   // Output the aliases at the top level that can be deferred.
2577   state->getAliasState().printDeferredAliases(os, newLine);
2578 }
2579 
2580 /// Print a block argument in the usual format of:
2581 ///   %ssaName : type {attr1=42} loc("here")
2582 /// where location printing is controlled by the standard internal option.
2583 /// You may pass omitType=true to not print a type, and pass an empty
2584 /// attribute list if you don't care for attributes.
2585 void OperationPrinter::printRegionArgument(BlockArgument arg,
2586                                            ArrayRef<NamedAttribute> argAttrs,
2587                                            bool omitType) {
2588   printOperand(arg);
2589   if (!omitType) {
2590     os << ": ";
2591     printType(arg.getType());
2592   }
2593   printOptionalAttrDict(argAttrs);
2594   // TODO: We should allow location aliases on block arguments.
2595   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2596 }
2597 
2598 void OperationPrinter::print(Operation *op) {
2599   // Track the location of this operation.
2600   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2601 
2602   os.indent(currentIndent);
2603   printOperation(op);
2604   printTrailingLocation(op->getLoc());
2605 }
2606 
2607 void OperationPrinter::printOperation(Operation *op) {
2608   if (size_t numResults = op->getNumResults()) {
2609     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2610       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2611       if (resultCount > 1)
2612         os << ':' << resultCount;
2613     };
2614 
2615     // Check to see if this operation has multiple result groups.
2616     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2617     if (!resultGroups.empty()) {
2618       // Interleave the groups excluding the last one, this one will be handled
2619       // separately.
2620       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2621         printResultGroup(resultGroups[i],
2622                          resultGroups[i + 1] - resultGroups[i]);
2623       });
2624       os << ", ";
2625       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2626 
2627     } else {
2628       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2629     }
2630 
2631     os << " = ";
2632   }
2633 
2634   // If requested, always print the generic form.
2635   if (!printerFlags.shouldPrintGenericOpForm()) {
2636     // Check to see if this is a known operation. If so, use the registered
2637     // custom printer hook.
2638     if (auto opInfo = op->getRegisteredInfo()) {
2639       opInfo->printAssembly(op, *this, defaultDialectStack.back());
2640       return;
2641     }
2642     // Otherwise try to dispatch to the dialect, if available.
2643     if (Dialect *dialect = op->getDialect()) {
2644       if (auto opPrinter = dialect->getOperationPrinter(op)) {
2645         // Print the op name first.
2646         StringRef name = op->getName().getStringRef();
2647         name.consume_front((defaultDialectStack.back() + ".").str());
2648         printEscapedString(name, os);
2649         // Print the rest of the op now.
2650         opPrinter(op, *this);
2651         return;
2652       }
2653     }
2654   }
2655 
2656   // Otherwise print with the generic assembly form.
2657   printGenericOp(op, /*printOpName=*/true);
2658 }
2659 
2660 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
2661   if (printOpName) {
2662     os << '"';
2663     printEscapedString(op->getName().getStringRef(), os);
2664     os << '"';
2665   }
2666   os << '(';
2667   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2668   os << ')';
2669 
2670   // For terminators, print the list of successors and their operands.
2671   if (op->getNumSuccessors() != 0) {
2672     os << '[';
2673     interleaveComma(op->getSuccessors(),
2674                     [&](Block *successor) { printBlockName(successor); });
2675     os << ']';
2676   }
2677 
2678   // Print regions.
2679   if (op->getNumRegions() != 0) {
2680     os << " (";
2681     interleaveComma(op->getRegions(), [&](Region &region) {
2682       printRegion(region, /*printEntryBlockArgs=*/true,
2683                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2684     });
2685     os << ')';
2686   }
2687 
2688   auto attrs = op->getAttrs();
2689   printOptionalAttrDict(attrs);
2690 
2691   // Print the type signature of the operation.
2692   os << " : ";
2693   printFunctionalType(op);
2694 }
2695 
2696 void OperationPrinter::printBlockName(Block *block) {
2697   os << state->getSSANameState().getBlockInfo(block).name;
2698 }
2699 
2700 void OperationPrinter::print(Block *block, bool printBlockArgs,
2701                              bool printBlockTerminator) {
2702   // Print the block label and argument list if requested.
2703   if (printBlockArgs) {
2704     os.indent(currentIndent);
2705     printBlockName(block);
2706 
2707     // Print the argument list if non-empty.
2708     if (!block->args_empty()) {
2709       os << '(';
2710       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2711         printValueID(arg);
2712         os << ": ";
2713         printType(arg.getType());
2714         // TODO: We should allow location aliases on block arguments.
2715         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2716       });
2717       os << ')';
2718     }
2719     os << ':';
2720 
2721     // Print out some context information about the predecessors of this block.
2722     if (!block->getParent()) {
2723       os << "  // block is not in a region!";
2724     } else if (block->hasNoPredecessors()) {
2725       if (!block->isEntryBlock())
2726         os << "  // no predecessors";
2727     } else if (auto *pred = block->getSinglePredecessor()) {
2728       os << "  // pred: ";
2729       printBlockName(pred);
2730     } else {
2731       // We want to print the predecessors in a stable order, not in
2732       // whatever order the use-list is in, so gather and sort them.
2733       SmallVector<BlockInfo, 4> predIDs;
2734       for (auto *pred : block->getPredecessors())
2735         predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
2736       llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
2737         return lhs.ordering < rhs.ordering;
2738       });
2739 
2740       os << "  // " << predIDs.size() << " preds: ";
2741 
2742       interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
2743     }
2744     os << newLine;
2745   }
2746 
2747   currentIndent += indentWidth;
2748   bool hasTerminator =
2749       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
2750   auto range = llvm::make_range(
2751       block->begin(),
2752       std::prev(block->end(),
2753                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
2754   for (auto &op : range) {
2755     print(&op);
2756     os << newLine;
2757   }
2758   currentIndent -= indentWidth;
2759 }
2760 
2761 void OperationPrinter::printValueID(Value value, bool printResultNo,
2762                                     raw_ostream *streamOverride) const {
2763   state->getSSANameState().printValueID(value, printResultNo,
2764                                         streamOverride ? *streamOverride : os);
2765 }
2766 
2767 void OperationPrinter::printSuccessor(Block *successor) {
2768   printBlockName(successor);
2769 }
2770 
2771 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2772                                                 ValueRange succOperands) {
2773   printBlockName(successor);
2774   if (succOperands.empty())
2775     return;
2776 
2777   os << '(';
2778   interleaveComma(succOperands,
2779                   [this](Value operand) { printValueID(operand); });
2780   os << " : ";
2781   interleaveComma(succOperands,
2782                   [this](Value operand) { printType(operand.getType()); });
2783   os << ')';
2784 }
2785 
2786 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2787                                    bool printBlockTerminators,
2788                                    bool printEmptyBlock) {
2789   os << "{" << newLine;
2790   if (!region.empty()) {
2791     auto restoreDefaultDialect =
2792         llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
2793     if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
2794       defaultDialectStack.push_back(iface.getDefaultDialect());
2795     else
2796       defaultDialectStack.push_back("");
2797 
2798     auto *entryBlock = &region.front();
2799     // Force printing the block header if printEmptyBlock is set and the block
2800     // is empty or if printEntryBlockArgs is set and there are arguments to
2801     // print.
2802     bool shouldAlwaysPrintBlockHeader =
2803         (printEmptyBlock && entryBlock->empty()) ||
2804         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
2805     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
2806     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2807       print(&b);
2808   }
2809   os.indent(currentIndent) << "}";
2810 }
2811 
2812 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2813                                               ValueRange operands) {
2814   AffineMap map = mapAttr.getValue();
2815   unsigned numDims = map.getNumDims();
2816   auto printValueName = [&](unsigned pos, bool isSymbol) {
2817     unsigned index = isSymbol ? numDims + pos : pos;
2818     assert(index < operands.size());
2819     if (isSymbol)
2820       os << "symbol(";
2821     printValueID(operands[index]);
2822     if (isSymbol)
2823       os << ')';
2824   };
2825 
2826   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2827     printAffineExpr(expr, printValueName);
2828   });
2829 }
2830 
2831 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
2832                                                ValueRange dimOperands,
2833                                                ValueRange symOperands) {
2834   auto printValueName = [&](unsigned pos, bool isSymbol) {
2835     if (!isSymbol)
2836       return printValueID(dimOperands[pos]);
2837     os << "symbol(";
2838     printValueID(symOperands[pos]);
2839     os << ')';
2840   };
2841   printAffineExpr(expr, printValueName);
2842 }
2843 
2844 //===----------------------------------------------------------------------===//
2845 // print and dump methods
2846 //===----------------------------------------------------------------------===//
2847 
2848 void Attribute::print(raw_ostream &os) const {
2849   AsmPrinter::Impl(os).printAttribute(*this);
2850 }
2851 
2852 void Attribute::dump() const {
2853   print(llvm::errs());
2854   llvm::errs() << "\n";
2855 }
2856 
2857 void Type::print(raw_ostream &os) const {
2858   AsmPrinter::Impl(os).printType(*this);
2859 }
2860 
2861 void Type::dump() const { print(llvm::errs()); }
2862 
2863 void AffineMap::dump() const {
2864   print(llvm::errs());
2865   llvm::errs() << "\n";
2866 }
2867 
2868 void IntegerSet::dump() const {
2869   print(llvm::errs());
2870   llvm::errs() << "\n";
2871 }
2872 
2873 void AffineExpr::print(raw_ostream &os) const {
2874   if (!expr) {
2875     os << "<<NULL AFFINE EXPR>>";
2876     return;
2877   }
2878   AsmPrinter::Impl(os).printAffineExpr(*this);
2879 }
2880 
2881 void AffineExpr::dump() const {
2882   print(llvm::errs());
2883   llvm::errs() << "\n";
2884 }
2885 
2886 void AffineMap::print(raw_ostream &os) const {
2887   if (!map) {
2888     os << "<<NULL AFFINE MAP>>";
2889     return;
2890   }
2891   AsmPrinter::Impl(os).printAffineMap(*this);
2892 }
2893 
2894 void IntegerSet::print(raw_ostream &os) const {
2895   AsmPrinter::Impl(os).printIntegerSet(*this);
2896 }
2897 
2898 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
2899 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
2900   if (!impl) {
2901     os << "<<NULL VALUE>>";
2902     return;
2903   }
2904 
2905   if (auto *op = getDefiningOp())
2906     return op->print(os, flags);
2907   // TODO: Improve BlockArgument print'ing.
2908   BlockArgument arg = this->cast<BlockArgument>();
2909   os << "<block argument> of type '" << arg.getType()
2910      << "' at index: " << arg.getArgNumber();
2911 }
2912 void Value::print(raw_ostream &os, AsmState &state) {
2913   if (!impl) {
2914     os << "<<NULL VALUE>>";
2915     return;
2916   }
2917 
2918   if (auto *op = getDefiningOp())
2919     return op->print(os, state);
2920 
2921   // TODO: Improve BlockArgument print'ing.
2922   BlockArgument arg = this->cast<BlockArgument>();
2923   os << "<block argument> of type '" << arg.getType()
2924      << "' at index: " << arg.getArgNumber();
2925 }
2926 
2927 void Value::dump() {
2928   print(llvm::errs());
2929   llvm::errs() << "\n";
2930 }
2931 
2932 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2933   // TODO: This doesn't necessarily capture all potential cases.
2934   // Currently, region arguments can be shadowed when printing the main
2935   // operation. If the IR hasn't been printed, this will produce the old SSA
2936   // name and not the shadowed name.
2937   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2938                                                  os);
2939 }
2940 
2941 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
2942   // If this is a top level operation, we also print aliases.
2943   if (!getParent() && !printerFlags.shouldUseLocalScope()) {
2944     AsmState state(this, printerFlags);
2945     state.getImpl().initializeAliases(this);
2946     print(os, state);
2947     return;
2948   }
2949 
2950   // Find the operation to number from based upon the provided flags.
2951   Operation *op = this;
2952   bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
2953   do {
2954     // If we are printing local scope, stop at the first operation that is
2955     // isolated from above.
2956     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
2957       break;
2958 
2959     // Otherwise, traverse up to the next parent.
2960     Operation *parentOp = op->getParentOp();
2961     if (!parentOp)
2962       break;
2963     op = parentOp;
2964   } while (true);
2965 
2966   AsmState state(op, printerFlags);
2967   print(os, state);
2968 }
2969 void Operation::print(raw_ostream &os, AsmState &state) {
2970   OperationPrinter printer(os, state.getImpl());
2971   if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
2972     printer.printTopLevelOperation(this);
2973   else
2974     printer.print(this);
2975 }
2976 
2977 void Operation::dump() {
2978   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2979   llvm::errs() << "\n";
2980 }
2981 
2982 void Block::print(raw_ostream &os) {
2983   Operation *parentOp = getParentOp();
2984   if (!parentOp) {
2985     os << "<<UNLINKED BLOCK>>\n";
2986     return;
2987   }
2988   // Get the top-level op.
2989   while (auto *nextOp = parentOp->getParentOp())
2990     parentOp = nextOp;
2991 
2992   AsmState state(parentOp);
2993   print(os, state);
2994 }
2995 void Block::print(raw_ostream &os, AsmState &state) {
2996   OperationPrinter(os, state.getImpl()).print(this);
2997 }
2998 
2999 void Block::dump() { print(llvm::errs()); }
3000 
3001 /// Print out the name of the block without printing its body.
3002 void Block::printAsOperand(raw_ostream &os, bool printType) {
3003   Operation *parentOp = getParentOp();
3004   if (!parentOp) {
3005     os << "<<UNLINKED BLOCK>>\n";
3006     return;
3007   }
3008   AsmState state(parentOp);
3009   printAsOperand(os, state);
3010 }
3011 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
3012   OperationPrinter printer(os, state.getImpl());
3013   printer.printBlockName(this);
3014 }
3015