1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/SubElementInterfaces.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/ScopeExit.h"
33 #include "llvm/ADT/ScopedHashTable.h"
34 #include "llvm/ADT/SetVector.h"
35 #include "llvm/ADT/SmallString.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Endian.h"
41 #include "llvm/Support/Regex.h"
42 #include "llvm/Support/SaveAndRestore.h"
43 
44 #include <tuple>
45 
46 using namespace mlir;
47 using namespace mlir::detail;
48 
49 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
50 
51 void OperationName::dump() const { print(llvm::errs()); }
52 
53 //===--------------------------------------------------------------------===//
54 // AsmParser
55 //===--------------------------------------------------------------------===//
56 
57 AsmParser::~AsmParser() = default;
58 DialectAsmParser::~DialectAsmParser() = default;
59 OpAsmParser::~OpAsmParser() = default;
60 
61 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
62 
63 //===----------------------------------------------------------------------===//
64 // DialectAsmPrinter
65 //===----------------------------------------------------------------------===//
66 
67 DialectAsmPrinter::~DialectAsmPrinter() = default;
68 
69 //===----------------------------------------------------------------------===//
70 // OpAsmPrinter
71 //===----------------------------------------------------------------------===//
72 
73 OpAsmPrinter::~OpAsmPrinter() = default;
74 
75 void OpAsmPrinter::printFunctionalType(Operation *op) {
76   auto &os = getStream();
77   os << '(';
78   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
79     // Print the types of null values as <<NULL TYPE>>.
80     *this << (operand ? operand.getType() : Type());
81   });
82   os << ") -> ";
83 
84   // Print the result list.  We don't parenthesize single result types unless
85   // it is a function (avoiding a grammar ambiguity).
86   bool wrapped = op->getNumResults() != 1;
87   if (!wrapped && op->getResult(0).getType() &&
88       op->getResult(0).getType().isa<FunctionType>())
89     wrapped = true;
90 
91   if (wrapped)
92     os << '(';
93 
94   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
95     // Print the types of null values as <<NULL TYPE>>.
96     *this << (result ? result.getType() : Type());
97   });
98 
99   if (wrapped)
100     os << ')';
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // Operation OpAsm interface.
105 //===----------------------------------------------------------------------===//
106 
107 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
108 #include "mlir/IR/OpAsmInterface.cpp.inc"
109 
110 //===----------------------------------------------------------------------===//
111 // OpPrintingFlags
112 //===----------------------------------------------------------------------===//
113 
114 namespace {
115 /// This struct contains command line options that can be used to initialize
116 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
117 /// for global command line options.
118 struct AsmPrinterOptions {
119   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
120       "mlir-print-elementsattrs-with-hex-if-larger",
121       llvm::cl::desc(
122           "Print DenseElementsAttrs with a hex string that have "
123           "more elements than the given upper limit (use -1 to disable)")};
124 
125   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
126       "mlir-elide-elementsattrs-if-larger",
127       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
128                      "more elements than the given upper limit")};
129 
130   llvm::cl::opt<bool> printDebugInfoOpt{
131       "mlir-print-debuginfo", llvm::cl::init(false),
132       llvm::cl::desc("Print debug info in MLIR output")};
133 
134   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
135       "mlir-pretty-debuginfo", llvm::cl::init(false),
136       llvm::cl::desc("Print pretty debug info in MLIR output")};
137 
138   // Use the generic op output form in the operation printer even if the custom
139   // form is defined.
140   llvm::cl::opt<bool> printGenericOpFormOpt{
141       "mlir-print-op-generic", llvm::cl::init(false),
142       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
143 
144   llvm::cl::opt<bool> printLocalScopeOpt{
145       "mlir-print-local-scope", llvm::cl::init(false),
146       llvm::cl::desc("Print with local scope and inline information (eliding "
147                      "aliases for attributes, types, and locations")};
148 };
149 } // namespace
150 
151 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
152 
153 /// Register a set of useful command-line options that can be used to configure
154 /// various flags within the AsmPrinter.
155 void mlir::registerAsmPrinterCLOptions() {
156   // Make sure that the options struct has been initialized.
157   *clOptions;
158 }
159 
160 /// Initialize the printing flags with default supplied by the cl::opts above.
161 OpPrintingFlags::OpPrintingFlags()
162     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
163       printGenericOpFormFlag(false), printLocalScope(false) {
164   // Initialize based upon command line options, if they are available.
165   if (!clOptions.isConstructed())
166     return;
167   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
168     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
169   printDebugInfoFlag = clOptions->printDebugInfoOpt;
170   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
171   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
172   printLocalScope = clOptions->printLocalScopeOpt;
173 }
174 
175 /// Enable the elision of large elements attributes, by printing a '...'
176 /// instead of the element data, when the number of elements is greater than
177 /// `largeElementLimit`. Note: The IR generated with this option is not
178 /// parsable.
179 OpPrintingFlags &
180 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
181   elementsAttrElementLimit = largeElementLimit;
182   return *this;
183 }
184 
185 /// Enable printing of debug information. If 'prettyForm' is set to true,
186 /// debug information is printed in a more readable 'pretty' form.
187 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
188   printDebugInfoFlag = true;
189   printDebugInfoPrettyFormFlag = prettyForm;
190   return *this;
191 }
192 
193 /// Always print operations in the generic form.
194 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
195   printGenericOpFormFlag = true;
196   return *this;
197 }
198 
199 /// Use local scope when printing the operation. This allows for using the
200 /// printer in a more localized and thread-safe setting, but may not necessarily
201 /// be identical of what the IR will look like when dumping the full module.
202 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
203   printLocalScope = true;
204   return *this;
205 }
206 
207 /// Return if the given ElementsAttr should be elided.
208 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
209   return elementsAttrElementLimit.hasValue() &&
210          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
211          !attr.isa<SplatElementsAttr>();
212 }
213 
214 /// Return the size limit for printing large ElementsAttr.
215 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
216   return elementsAttrElementLimit;
217 }
218 
219 /// Return if debug information should be printed.
220 bool OpPrintingFlags::shouldPrintDebugInfo() const {
221   return printDebugInfoFlag;
222 }
223 
224 /// Return if debug information should be printed in the pretty form.
225 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
226   return printDebugInfoPrettyFormFlag;
227 }
228 
229 /// Return if operations should be printed in the generic form.
230 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
231   return printGenericOpFormFlag;
232 }
233 
234 /// Return if the printer should use local scope when dumping the IR.
235 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
236 
237 /// Returns true if an ElementsAttr with the given number of elements should be
238 /// printed with hex.
239 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
240   // Check to see if a command line option was provided for the limit.
241   if (clOptions.isConstructed()) {
242     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
243       // -1 is used to disable hex printing.
244       if (clOptions->printElementsAttrWithHexIfLarger == -1)
245         return false;
246       return numElements > clOptions->printElementsAttrWithHexIfLarger;
247     }
248   }
249 
250   // Otherwise, default to printing with hex if the number of elements is >100.
251   return numElements > 100;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // NewLineCounter
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 /// This class is a simple formatter that emits a new line when inputted into a
260 /// stream, that enables counting the number of newlines emitted. This class
261 /// should be used whenever emitting newlines in the printer.
262 struct NewLineCounter {
263   unsigned curLine = 1;
264 };
265 
266 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
267   ++newLine.curLine;
268   return os << '\n';
269 }
270 } // namespace
271 
272 //===----------------------------------------------------------------------===//
273 // AliasInitializer
274 //===----------------------------------------------------------------------===//
275 
276 namespace {
277 /// This class represents a specific instance of a symbol Alias.
278 class SymbolAlias {
279 public:
280   SymbolAlias(StringRef name, bool isDeferrable)
281       : name(name), suffixIndex(0), hasSuffixIndex(false),
282         isDeferrable(isDeferrable) {}
283   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
284       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
285         isDeferrable(isDeferrable) {}
286 
287   /// Print this alias to the given stream.
288   void print(raw_ostream &os) const {
289     os << name;
290     if (hasSuffixIndex)
291       os << suffixIndex;
292   }
293 
294   /// Returns true if this alias supports deferred resolution when parsing.
295   bool canBeDeferred() const { return isDeferrable; }
296 
297 private:
298   /// The main name of the alias.
299   StringRef name;
300   /// The optional suffix index of the alias, if multiple aliases had the same
301   /// name.
302   uint32_t suffixIndex : 30;
303   /// A flag indicating whether this alias has a suffix or not.
304   bool hasSuffixIndex : 1;
305   /// A flag indicating whether this alias may be deferred or not.
306   bool isDeferrable : 1;
307 };
308 
309 /// This class represents a utility that initializes the set of attribute and
310 /// type aliases, without the need to store the extra information within the
311 /// main AliasState class or pass it around via function arguments.
312 class AliasInitializer {
313 public:
314   AliasInitializer(
315       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
316       llvm::BumpPtrAllocator &aliasAllocator)
317       : interfaces(interfaces), aliasAllocator(aliasAllocator),
318         aliasOS(aliasBuffer) {}
319 
320   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
321                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
322                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
323 
324   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
325   /// set to true if the originator of this attribute can resolve the alias
326   /// after parsing has completed (e.g. in the case of operation locations).
327   void visit(Attribute attr, bool canBeDeferred = false);
328 
329   /// Visit the given type to see if it has an alias.
330   void visit(Type type);
331 
332 private:
333   /// Try to generate an alias for the provided symbol. If an alias is
334   /// generated, the provided alias mapping and reverse mapping are updated.
335   /// Returns success if an alias was generated, failure otherwise.
336   template <typename T>
337   LogicalResult
338   generateAlias(T symbol,
339                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
340 
341   /// The set of asm interfaces within the context.
342   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
343 
344   /// Mapping between an alias and the set of symbols mapped to it.
345   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
346   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
347 
348   /// An allocator used for alias names.
349   llvm::BumpPtrAllocator &aliasAllocator;
350 
351   /// The set of visited attributes.
352   DenseSet<Attribute> visitedAttributes;
353 
354   /// The set of attributes that have aliases *and* can be deferred.
355   DenseSet<Attribute> deferrableAttributes;
356 
357   /// The set of visited types.
358   DenseSet<Type> visitedTypes;
359 
360   /// Storage and stream used when generating an alias.
361   SmallString<32> aliasBuffer;
362   llvm::raw_svector_ostream aliasOS;
363 };
364 
365 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
366 /// and merely collects the attributes and types that *would* be printed in a
367 /// normal print invocation so that we can generate proper aliases. This allows
368 /// for us to generate aliases only for the attributes and types that would be
369 /// in the output, and trims down unnecessary output.
370 class DummyAliasOperationPrinter : private OpAsmPrinter {
371 public:
372   explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
373                                       AliasInitializer &initializer)
374       : printerFlags(printerFlags), initializer(initializer) {}
375 
376   /// Print the given operation.
377   void print(Operation *op) {
378     // Visit the operation location.
379     if (printerFlags.shouldPrintDebugInfo())
380       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
381 
382     // If requested, always print the generic form.
383     if (!printerFlags.shouldPrintGenericOpForm()) {
384       // Check to see if this is a known operation.  If so, use the registered
385       // custom printer hook.
386       if (auto opInfo = op->getRegisteredInfo()) {
387         opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
388         return;
389       }
390     }
391 
392     // Otherwise print with the generic assembly form.
393     printGenericOp(op);
394   }
395 
396 private:
397   /// Print the given operation in the generic form.
398   void printGenericOp(Operation *op, bool printOpName = true) override {
399     // Consider nested operations for aliases.
400     if (op->getNumRegions() != 0) {
401       for (Region &region : op->getRegions())
402         printRegion(region, /*printEntryBlockArgs=*/true,
403                     /*printBlockTerminators=*/true);
404     }
405 
406     // Visit all the types used in the operation.
407     for (Type type : op->getOperandTypes())
408       printType(type);
409     for (Type type : op->getResultTypes())
410       printType(type);
411 
412     // Consider the attributes of the operation for aliases.
413     for (const NamedAttribute &attr : op->getAttrs())
414       printAttribute(attr.getValue());
415   }
416 
417   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
418   /// block are not printed. If 'printBlockTerminator' is false, the terminator
419   /// operation of the block is not printed.
420   void print(Block *block, bool printBlockArgs = true,
421              bool printBlockTerminator = true) {
422     // Consider the types of the block arguments for aliases if 'printBlockArgs'
423     // is set to true.
424     if (printBlockArgs) {
425       for (BlockArgument arg : block->getArguments()) {
426         printType(arg.getType());
427 
428         // Visit the argument location.
429         if (printerFlags.shouldPrintDebugInfo())
430           // TODO: Allow deferring argument locations.
431           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
432       }
433     }
434 
435     // Consider the operations within this block, ignoring the terminator if
436     // requested.
437     bool hasTerminator =
438         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
439     auto range = llvm::make_range(
440         block->begin(),
441         std::prev(block->end(),
442                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
443     for (Operation &op : range)
444       print(&op);
445   }
446 
447   /// Print the given region.
448   void printRegion(Region &region, bool printEntryBlockArgs,
449                    bool printBlockTerminators,
450                    bool printEmptyBlock = false) override {
451     if (region.empty())
452       return;
453 
454     auto *entryBlock = &region.front();
455     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
456     for (Block &b : llvm::drop_begin(region, 1))
457       print(&b);
458   }
459 
460   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
461                            bool omitType) override {
462     printType(arg.getType());
463     // Visit the argument location.
464     if (printerFlags.shouldPrintDebugInfo())
465       // TODO: Allow deferring argument locations.
466       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
467   }
468 
469   /// Consider the given type to be printed for an alias.
470   void printType(Type type) override { initializer.visit(type); }
471 
472   /// Consider the given attribute to be printed for an alias.
473   void printAttribute(Attribute attr) override { initializer.visit(attr); }
474   void printAttributeWithoutType(Attribute attr) override {
475     printAttribute(attr);
476   }
477   LogicalResult printAlias(Attribute attr) override {
478     initializer.visit(attr);
479     return success();
480   }
481   LogicalResult printAlias(Type type) override {
482     initializer.visit(type);
483     return success();
484   }
485 
486   /// Print the given set of attributes with names not included within
487   /// 'elidedAttrs'.
488   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
489                              ArrayRef<StringRef> elidedAttrs = {}) override {
490     if (attrs.empty())
491       return;
492     if (elidedAttrs.empty()) {
493       for (const NamedAttribute &attr : attrs)
494         printAttribute(attr.getValue());
495       return;
496     }
497     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
498                                                   elidedAttrs.end());
499     for (const NamedAttribute &attr : attrs)
500       if (!elidedAttrsSet.contains(attr.getName().strref()))
501         printAttribute(attr.getValue());
502   }
503   void printOptionalAttrDictWithKeyword(
504       ArrayRef<NamedAttribute> attrs,
505       ArrayRef<StringRef> elidedAttrs = {}) override {
506     printOptionalAttrDict(attrs, elidedAttrs);
507   }
508 
509   /// Return a null stream as the output stream, this will ignore any data fed
510   /// to it.
511   raw_ostream &getStream() const override { return os; }
512 
513   /// The following are hooks of `OpAsmPrinter` that are not necessary for
514   /// determining potential aliases.
515   void printFloat(const APFloat &value) override {}
516   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
517   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
518   void printNewline() override {}
519   void printOperand(Value) override {}
520   void printOperand(Value, raw_ostream &os) override {
521     // Users expect the output string to have at least the prefixed % to signal
522     // a value name. To maintain this invariant, emit a name even if it is
523     // guaranteed to go unused.
524     os << "%";
525   }
526   void printKeywordOrString(StringRef) override {}
527   void printSymbolName(StringRef) override {}
528   void printSuccessor(Block *) override {}
529   void printSuccessorAndUseList(Block *, ValueRange) override {}
530   void shadowRegionArgs(Region &, ValueRange) override {}
531 
532   /// The printer flags to use when determining potential aliases.
533   const OpPrintingFlags &printerFlags;
534 
535   /// The initializer to use when identifying aliases.
536   AliasInitializer &initializer;
537 
538   /// A dummy output stream.
539   mutable llvm::raw_null_ostream os;
540 };
541 } // namespace
542 
543 /// Sanitize the given name such that it can be used as a valid identifier. If
544 /// the string needs to be modified in any way, the provided buffer is used to
545 /// store the new copy,
546 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
547                                     StringRef allowedPunctChars = "$._-",
548                                     bool allowTrailingDigit = true) {
549   assert(!name.empty() && "Shouldn't have an empty name here");
550 
551   auto copyNameToBuffer = [&] {
552     for (char ch : name) {
553       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
554         buffer.push_back(ch);
555       else if (ch == ' ')
556         buffer.push_back('_');
557       else
558         buffer.append(llvm::utohexstr((unsigned char)ch));
559     }
560   };
561 
562   // Check to see if this name is valid. If it starts with a digit, then it
563   // could conflict with the autogenerated numeric ID's, so add an underscore
564   // prefix to avoid problems.
565   if (isdigit(name[0])) {
566     buffer.push_back('_');
567     copyNameToBuffer();
568     return buffer;
569   }
570 
571   // If the name ends with a trailing digit, add a '_' to avoid potential
572   // conflicts with autogenerated ID's.
573   if (!allowTrailingDigit && isdigit(name.back())) {
574     copyNameToBuffer();
575     buffer.push_back('_');
576     return buffer;
577   }
578 
579   // Check to see that the name consists of only valid identifier characters.
580   for (char ch : name) {
581     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
582       copyNameToBuffer();
583       return buffer;
584     }
585   }
586 
587   // If there are no invalid characters, return the original name.
588   return name;
589 }
590 
591 /// Given a collection of aliases and symbols, initialize a mapping from a
592 /// symbol to a given alias.
593 template <typename T>
594 static void
595 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
596                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
597                   DenseSet<T> *deferrableAliases = nullptr) {
598   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
599       aliasToSymbol.takeVector();
600   llvm::array_pod_sort(aliases.begin(), aliases.end(),
601                        [](const auto *lhs, const auto *rhs) {
602                          return lhs->first.compare(rhs->first);
603                        });
604 
605   for (auto &it : aliases) {
606     // If there is only one instance for this alias, use the name directly.
607     if (it.second.size() == 1) {
608       T symbol = it.second.front();
609       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
610       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
611       continue;
612     }
613     // Otherwise, add the index to the name.
614     for (int i = 0, e = it.second.size(); i < e; ++i) {
615       T symbol = it.second[i];
616       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
617       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
618     }
619   }
620 }
621 
622 void AliasInitializer::initialize(
623     Operation *op, const OpPrintingFlags &printerFlags,
624     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
625     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
626   // Use a dummy printer when walking the IR so that we can collect the
627   // attributes/types that will actually be used during printing when
628   // considering aliases.
629   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
630   aliasPrinter.print(op);
631 
632   // Initialize the aliases sorted by name.
633   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
634   initializeAliases(aliasToType, typeToAlias);
635 }
636 
637 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
638   if (!visitedAttributes.insert(attr).second) {
639     // If this attribute already has an alias and this instance can't be
640     // deferred, make sure that the alias isn't deferred.
641     if (!canBeDeferred)
642       deferrableAttributes.erase(attr);
643     return;
644   }
645 
646   // Try to generate an alias for this attribute.
647   if (succeeded(generateAlias(attr, aliasToAttr))) {
648     if (canBeDeferred)
649       deferrableAttributes.insert(attr);
650     return;
651   }
652 
653   // Check for any sub elements.
654   if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
655     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
656                                         [&](Type type) { visit(type); });
657   }
658 }
659 
660 void AliasInitializer::visit(Type type) {
661   if (!visitedTypes.insert(type).second)
662     return;
663 
664   // Try to generate an alias for this type.
665   if (succeeded(generateAlias(type, aliasToType)))
666     return;
667 
668   // Check for any sub elements.
669   if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
670     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
671                                         [&](Type type) { visit(type); });
672   }
673 }
674 
675 template <typename T>
676 LogicalResult AliasInitializer::generateAlias(
677     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
678   SmallString<32> nameBuffer;
679   for (const auto &interface : interfaces) {
680     OpAsmDialectInterface::AliasResult result =
681         interface.getAlias(symbol, aliasOS);
682     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
683       continue;
684     nameBuffer = std::move(aliasBuffer);
685     assert(!nameBuffer.empty() && "expected valid alias name");
686     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
687       break;
688   }
689 
690   if (nameBuffer.empty())
691     return failure();
692 
693   SmallString<16> tempBuffer;
694   StringRef name =
695       sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
696                          /*allowTrailingDigit=*/false);
697   name = name.copy(aliasAllocator);
698   aliasToSymbol[name].push_back(symbol);
699   return success();
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // AliasState
704 //===----------------------------------------------------------------------===//
705 
706 namespace {
707 /// This class manages the state for type and attribute aliases.
708 class AliasState {
709 public:
710   // Initialize the internal aliases.
711   void
712   initialize(Operation *op, const OpPrintingFlags &printerFlags,
713              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
714 
715   /// Get an alias for the given attribute if it has one and print it in `os`.
716   /// Returns success if an alias was printed, failure otherwise.
717   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
718 
719   /// Get an alias for the given type if it has one and print it in `os`.
720   /// Returns success if an alias was printed, failure otherwise.
721   LogicalResult getAlias(Type ty, raw_ostream &os) const;
722 
723   /// Print all of the referenced aliases that can not be resolved in a deferred
724   /// manner.
725   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
726     printAliases(os, newLine, /*isDeferred=*/false);
727   }
728 
729   /// Print all of the referenced aliases that support deferred resolution.
730   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
731     printAliases(os, newLine, /*isDeferred=*/true);
732   }
733 
734 private:
735   /// Print all of the referenced aliases that support the provided resolution
736   /// behavior.
737   void printAliases(raw_ostream &os, NewLineCounter &newLine,
738                     bool isDeferred) const;
739 
740   /// Mapping between attribute and alias.
741   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
742   /// Mapping between type and alias.
743   llvm::MapVector<Type, SymbolAlias> typeToAlias;
744 
745   /// An allocator used for alias names.
746   llvm::BumpPtrAllocator aliasAllocator;
747 };
748 } // namespace
749 
750 void AliasState::initialize(
751     Operation *op, const OpPrintingFlags &printerFlags,
752     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
753   AliasInitializer initializer(interfaces, aliasAllocator);
754   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
755 }
756 
757 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
758   auto it = attrToAlias.find(attr);
759   if (it == attrToAlias.end())
760     return failure();
761   it->second.print(os << '#');
762   return success();
763 }
764 
765 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
766   auto it = typeToAlias.find(ty);
767   if (it == typeToAlias.end())
768     return failure();
769 
770   it->second.print(os << '!');
771   return success();
772 }
773 
774 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
775                               bool isDeferred) const {
776   auto filterFn = [=](const auto &aliasIt) {
777     return aliasIt.second.canBeDeferred() == isDeferred;
778   };
779   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
780     it.second.print(os << '#');
781     os << " = " << it.first << newLine;
782   }
783   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
784     it.second.print(os << '!');
785     os << " = type " << it.first << newLine;
786   }
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // SSANameState
791 //===----------------------------------------------------------------------===//
792 
793 namespace {
794 /// Info about block printing: a number which is its position in the visitation
795 /// order, and a name that is used to print reference to it, e.g. ^bb42.
796 struct BlockInfo {
797   int ordering;
798   StringRef name;
799 };
800 
801 /// This class manages the state of SSA value names.
802 class SSANameState {
803 public:
804   /// A sentinel value used for values with names set.
805   enum : unsigned { NameSentinel = ~0U };
806 
807   SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
808 
809   /// Print the SSA identifier for the given value to 'stream'. If
810   /// 'printResultNo' is true, it also presents the result number ('#' number)
811   /// of this value.
812   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
813 
814   /// Return the result indices for each of the result groups registered by this
815   /// operation, or empty if none exist.
816   ArrayRef<int> getOpResultGroups(Operation *op);
817 
818   /// Get the info for the given block.
819   BlockInfo getBlockInfo(Block *block);
820 
821   /// Renumber the arguments for the specified region to the same names as the
822   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
823   /// details.
824   void shadowRegionArgs(Region &region, ValueRange namesToUse);
825 
826 private:
827   /// Number the SSA values within the given IR unit.
828   void numberValuesInRegion(Region &region);
829   void numberValuesInBlock(Block &block);
830   void numberValuesInOp(Operation &op);
831 
832   /// Given a result of an operation 'result', find the result group head
833   /// 'lookupValue' and the result of 'result' within that group in
834   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
835   /// has more than 1 result.
836   void getResultIDAndNumber(OpResult result, Value &lookupValue,
837                             Optional<int> &lookupResultNo) const;
838 
839   /// Set a special value name for the given value.
840   void setValueName(Value value, StringRef name);
841 
842   /// Uniques the given value name within the printer. If the given name
843   /// conflicts, it is automatically renamed.
844   StringRef uniqueValueName(StringRef name);
845 
846   /// This is the value ID for each SSA value. If this returns NameSentinel,
847   /// then the valueID has an entry in valueNames.
848   DenseMap<Value, unsigned> valueIDs;
849   DenseMap<Value, StringRef> valueNames;
850 
851   /// This is a map of operations that contain multiple named result groups,
852   /// i.e. there may be multiple names for the results of the operation. The
853   /// value of this map are the result numbers that start a result group.
854   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
855 
856   /// This maps blocks to there visitation number in the current region as well
857   /// as the string representing their name.
858   DenseMap<Block *, BlockInfo> blockNames;
859 
860   /// This keeps track of all of the non-numeric names that are in flight,
861   /// allowing us to check for duplicates.
862   /// Note: the value of the map is unused.
863   llvm::ScopedHashTable<StringRef, char> usedNames;
864   llvm::BumpPtrAllocator usedNameAllocator;
865 
866   /// This is the next value ID to assign in numbering.
867   unsigned nextValueID = 0;
868   /// This is the next ID to assign to a region entry block argument.
869   unsigned nextArgumentID = 0;
870   /// This is the next ID to assign when a name conflict is detected.
871   unsigned nextConflictID = 0;
872 
873   /// These are the printing flags.  They control, eg., whether to print in
874   /// generic form.
875   OpPrintingFlags printerFlags;
876 };
877 } // namespace
878 
879 SSANameState::SSANameState(
880     Operation *op, const OpPrintingFlags &printerFlags)
881     : printerFlags(printerFlags) {
882   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
883   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
884   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
885 
886   // The naming context includes `nextValueID`, `nextArgumentID`,
887   // `nextConflictID` and `usedNames` scoped HashTable. This information is
888   // carried from the parent region.
889   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
890   using NamingContext =
891       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
892 
893   // Allocator for UsedNamesScopeTy
894   llvm::BumpPtrAllocator allocator;
895 
896   // Add a scope for the top level operation.
897   auto *topLevelNamesScope =
898       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
899 
900   SmallVector<NamingContext, 8> nameContext;
901   for (Region &region : op->getRegions())
902     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
903                                           nextConflictID, topLevelNamesScope));
904 
905   numberValuesInOp(*op);
906 
907   while (!nameContext.empty()) {
908     Region *region;
909     UsedNamesScopeTy *parentScope;
910     std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
911         nameContext.pop_back_val();
912 
913     // When we switch from one subtree to another, pop the scopes(needless)
914     // until the parent scope.
915     while (usedNames.getCurScope() != parentScope) {
916       usedNames.getCurScope()->~UsedNamesScopeTy();
917       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
918              "top level parentScope must be a nullptr");
919     }
920 
921     // Add a scope for the current region.
922     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
923         UsedNamesScopeTy(usedNames);
924 
925     numberValuesInRegion(*region);
926 
927     for (Operation &op : region->getOps())
928       for (Region &region : op.getRegions())
929         nameContext.push_back(std::make_tuple(&region, nextValueID,
930                                               nextArgumentID, nextConflictID,
931                                               curNamesScope));
932   }
933 
934   // Manually remove all the scopes.
935   while (usedNames.getCurScope() != nullptr)
936     usedNames.getCurScope()->~UsedNamesScopeTy();
937 }
938 
939 void SSANameState::printValueID(Value value, bool printResultNo,
940                                 raw_ostream &stream) const {
941   if (!value) {
942     stream << "<<NULL VALUE>>";
943     return;
944   }
945 
946   Optional<int> resultNo;
947   auto lookupValue = value;
948 
949   // If this is an operation result, collect the head lookup value of the result
950   // group and the result number of 'result' within that group.
951   if (OpResult result = value.dyn_cast<OpResult>())
952     getResultIDAndNumber(result, lookupValue, resultNo);
953 
954   auto it = valueIDs.find(lookupValue);
955   if (it == valueIDs.end()) {
956     stream << "<<UNKNOWN SSA VALUE>>";
957     return;
958   }
959 
960   stream << '%';
961   if (it->second != NameSentinel) {
962     stream << it->second;
963   } else {
964     auto nameIt = valueNames.find(lookupValue);
965     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
966     stream << nameIt->second;
967   }
968 
969   if (resultNo.hasValue() && printResultNo)
970     stream << '#' << resultNo;
971 }
972 
973 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
974   auto it = opResultGroups.find(op);
975   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
976 }
977 
978 BlockInfo SSANameState::getBlockInfo(Block *block) {
979   auto it = blockNames.find(block);
980   BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
981   return it != blockNames.end() ? it->second : invalidBlock;
982 }
983 
984 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
985   assert(!region.empty() && "cannot shadow arguments of an empty region");
986   assert(region.getNumArguments() == namesToUse.size() &&
987          "incorrect number of names passed in");
988   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
989          "only KnownIsolatedFromAbove ops can shadow names");
990 
991   SmallVector<char, 16> nameStr;
992   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
993     auto nameToUse = namesToUse[i];
994     if (nameToUse == nullptr)
995       continue;
996     auto nameToReplace = region.getArgument(i);
997 
998     nameStr.clear();
999     llvm::raw_svector_ostream nameStream(nameStr);
1000     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
1001 
1002     // Entry block arguments should already have a pretty "arg" name.
1003     assert(valueIDs[nameToReplace] == NameSentinel);
1004 
1005     // Use the name without the leading %.
1006     auto name = StringRef(nameStream.str()).drop_front();
1007 
1008     // Overwrite the name.
1009     valueNames[nameToReplace] = name.copy(usedNameAllocator);
1010   }
1011 }
1012 
1013 void SSANameState::numberValuesInRegion(Region &region) {
1014   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1015     assert(!valueIDs.count(arg) && "arg numbered multiple times");
1016     assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
1017            "arg not defined in current region");
1018     setValueName(arg, name);
1019   };
1020 
1021   if (!printerFlags.shouldPrintGenericOpForm()) {
1022     if (Operation *op = region.getParentOp()) {
1023       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1024         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1025     }
1026   }
1027 
1028   // Number the values within this region in a breadth-first order.
1029   unsigned nextBlockID = 0;
1030   for (auto &block : region) {
1031     // Each block gets a unique ID, and all of the operations within it get
1032     // numbered as well.
1033     auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1034     if (blockInfoIt.second) {
1035       // This block hasn't been named through `getAsmBlockArgumentNames`, use
1036       // default `^bbNNN` format.
1037       std::string name;
1038       llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1039       blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1040     }
1041     blockInfoIt.first->second.ordering = nextBlockID++;
1042 
1043     numberValuesInBlock(block);
1044   }
1045 }
1046 
1047 void SSANameState::numberValuesInBlock(Block &block) {
1048   // Number the block arguments. We give entry block arguments a special name
1049   // 'arg'.
1050   bool isEntryBlock = block.isEntryBlock();
1051   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1052   llvm::raw_svector_ostream specialName(specialNameBuffer);
1053   for (auto arg : block.getArguments()) {
1054     if (valueIDs.count(arg))
1055       continue;
1056     if (isEntryBlock) {
1057       specialNameBuffer.resize(strlen("arg"));
1058       specialName << nextArgumentID++;
1059     }
1060     setValueName(arg, specialName.str());
1061   }
1062 
1063   // Number the operations in this block.
1064   for (auto &op : block)
1065     numberValuesInOp(op);
1066 }
1067 
1068 void SSANameState::numberValuesInOp(Operation &op) {
1069   // Function used to set the special result names for the operation.
1070   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1071   auto setResultNameFn = [&](Value result, StringRef name) {
1072     assert(!valueIDs.count(result) && "result numbered multiple times");
1073     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1074     setValueName(result, name);
1075 
1076     // Record the result number for groups not anchored at 0.
1077     if (int resultNo = result.cast<OpResult>().getResultNumber())
1078       resultGroups.push_back(resultNo);
1079   };
1080   // Operations can customize the printing of block names in OpAsmOpInterface.
1081   auto setBlockNameFn = [&](Block *block, StringRef name) {
1082     assert(block->getParentOp() == &op &&
1083            "getAsmBlockArgumentNames callback invoked on a block not directly "
1084            "nested under the current operation");
1085     assert(!blockNames.count(block) && "block numbered multiple times");
1086     SmallString<16> tmpBuffer{"^"};
1087     name = sanitizeIdentifier(name, tmpBuffer);
1088     if (name.data() != tmpBuffer.data()) {
1089       tmpBuffer.append(name);
1090       name = tmpBuffer.str();
1091     }
1092     name = name.copy(usedNameAllocator);
1093     blockNames[block] = {-1, name};
1094   };
1095 
1096   if (!printerFlags.shouldPrintGenericOpForm()) {
1097     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1098       asmInterface.getAsmBlockNames(setBlockNameFn);
1099       asmInterface.getAsmResultNames(setResultNameFn);
1100     }
1101   }
1102 
1103   unsigned numResults = op.getNumResults();
1104   if (numResults == 0)
1105     return;
1106   Value resultBegin = op.getResult(0);
1107 
1108   // If the first result wasn't numbered, give it a default number.
1109   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1110     ++nextValueID;
1111 
1112   // If this operation has multiple result groups, mark it.
1113   if (resultGroups.size() != 1) {
1114     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1115     opResultGroups.try_emplace(&op, std::move(resultGroups));
1116   }
1117 }
1118 
1119 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1120                                         Optional<int> &lookupResultNo) const {
1121   Operation *owner = result.getOwner();
1122   if (owner->getNumResults() == 1)
1123     return;
1124   int resultNo = result.getResultNumber();
1125 
1126   // If this operation has multiple result groups, we will need to find the
1127   // one corresponding to this result.
1128   auto resultGroupIt = opResultGroups.find(owner);
1129   if (resultGroupIt == opResultGroups.end()) {
1130     // If not, just use the first result.
1131     lookupResultNo = resultNo;
1132     lookupValue = owner->getResult(0);
1133     return;
1134   }
1135 
1136   // Find the correct index using a binary search, as the groups are ordered.
1137   ArrayRef<int> resultGroups = resultGroupIt->second;
1138   const auto *it = llvm::upper_bound(resultGroups, resultNo);
1139   int groupResultNo = 0, groupSize = 0;
1140 
1141   // If there are no smaller elements, the last result group is the lookup.
1142   if (it == resultGroups.end()) {
1143     groupResultNo = resultGroups.back();
1144     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1145   } else {
1146     // Otherwise, the previous element is the lookup.
1147     groupResultNo = *std::prev(it);
1148     groupSize = *it - groupResultNo;
1149   }
1150 
1151   // We only record the result number for a group of size greater than 1.
1152   if (groupSize != 1)
1153     lookupResultNo = resultNo - groupResultNo;
1154   lookupValue = owner->getResult(groupResultNo);
1155 }
1156 
1157 void SSANameState::setValueName(Value value, StringRef name) {
1158   // If the name is empty, the value uses the default numbering.
1159   if (name.empty()) {
1160     valueIDs[value] = nextValueID++;
1161     return;
1162   }
1163 
1164   valueIDs[value] = NameSentinel;
1165   valueNames[value] = uniqueValueName(name);
1166 }
1167 
1168 StringRef SSANameState::uniqueValueName(StringRef name) {
1169   SmallString<16> tmpBuffer;
1170   name = sanitizeIdentifier(name, tmpBuffer);
1171 
1172   // Check to see if this name is already unique.
1173   if (!usedNames.count(name)) {
1174     name = name.copy(usedNameAllocator);
1175   } else {
1176     // Otherwise, we had a conflict - probe until we find a unique name. This
1177     // is guaranteed to terminate (and usually in a single iteration) because it
1178     // generates new names by incrementing nextConflictID.
1179     SmallString<64> probeName(name);
1180     probeName.push_back('_');
1181     while (true) {
1182       probeName += llvm::utostr(nextConflictID++);
1183       if (!usedNames.count(probeName)) {
1184         name = probeName.str().copy(usedNameAllocator);
1185         break;
1186       }
1187       probeName.resize(name.size() + 1);
1188     }
1189   }
1190 
1191   usedNames.insert(name, char());
1192   return name;
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // AsmState
1197 //===----------------------------------------------------------------------===//
1198 
1199 namespace mlir {
1200 namespace detail {
1201 class AsmStateImpl {
1202 public:
1203   explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1204                         AsmState::LocationMap *locationMap)
1205       : interfaces(op->getContext()), nameState(op, printerFlags),
1206         printerFlags(printerFlags), locationMap(locationMap) {}
1207 
1208   /// Initialize the alias state to enable the printing of aliases.
1209   void initializeAliases(Operation *op) {
1210     aliasState.initialize(op, printerFlags, interfaces);
1211   }
1212 
1213   /// Get the state used for aliases.
1214   AliasState &getAliasState() { return aliasState; }
1215 
1216   /// Get the state used for SSA names.
1217   SSANameState &getSSANameState() { return nameState; }
1218 
1219   /// Get the printer flags.
1220   const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1221 
1222   /// Register the location, line and column, within the buffer that the given
1223   /// operation was printed at.
1224   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1225     if (locationMap)
1226       (*locationMap)[op] = std::make_pair(line, col);
1227   }
1228 
1229 private:
1230   /// Collection of OpAsm interfaces implemented in the context.
1231   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1232 
1233   /// The state used for attribute and type aliases.
1234   AliasState aliasState;
1235 
1236   /// The state used for SSA value names.
1237   SSANameState nameState;
1238 
1239   /// Flags that control op output.
1240   OpPrintingFlags printerFlags;
1241 
1242   /// An optional location map to be populated.
1243   AsmState::LocationMap *locationMap;
1244 };
1245 } // namespace detail
1246 } // namespace mlir
1247 
1248 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1249                    LocationMap *locationMap)
1250     : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
1251 AsmState::~AsmState() = default;
1252 
1253 const OpPrintingFlags &AsmState::getPrinterFlags() const {
1254   return impl->getPrinterFlags();
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 // AsmPrinter::Impl
1259 //===----------------------------------------------------------------------===//
1260 
1261 namespace mlir {
1262 class AsmPrinter::Impl {
1263 public:
1264   Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1265        AsmStateImpl *state = nullptr)
1266       : os(os), printerFlags(flags), state(state) {}
1267   explicit Impl(Impl &other)
1268       : Impl(other.os, other.printerFlags, other.state) {}
1269 
1270   /// Returns the output stream of the printer.
1271   raw_ostream &getStream() { return os; }
1272 
1273   template <typename Container, typename UnaryFunctor>
1274   inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
1275     llvm::interleaveComma(c, os, eachFn);
1276   }
1277 
1278   /// This enum describes the different kinds of elision for the type of an
1279   /// attribute when printing it.
1280   enum class AttrTypeElision {
1281     /// The type must not be elided,
1282     Never,
1283     /// The type may be elided when it matches the default used in the parser
1284     /// (for example i64 is the default for integer attributes).
1285     May,
1286     /// The type must be elided.
1287     Must
1288   };
1289 
1290   /// Print the given attribute.
1291   void printAttribute(Attribute attr,
1292                       AttrTypeElision typeElision = AttrTypeElision::Never);
1293 
1294   /// Print the alias for the given attribute, return failure if no alias could
1295   /// be printed.
1296   LogicalResult printAlias(Attribute attr);
1297 
1298   void printType(Type type);
1299 
1300   /// Print the alias for the given type, return failure if no alias could
1301   /// be printed.
1302   LogicalResult printAlias(Type type);
1303 
1304   /// Print the given location to the stream. If `allowAlias` is true, this
1305   /// allows for the internal location to use an attribute alias.
1306   void printLocation(LocationAttr loc, bool allowAlias = false);
1307 
1308   void printAffineMap(AffineMap map);
1309   void
1310   printAffineExpr(AffineExpr expr,
1311                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1312   void printAffineConstraint(AffineExpr expr, bool isEq);
1313   void printIntegerSet(IntegerSet set);
1314 
1315 protected:
1316   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1317                              ArrayRef<StringRef> elidedAttrs = {},
1318                              bool withKeyword = false);
1319   void printNamedAttribute(NamedAttribute attr);
1320   void printTrailingLocation(Location loc, bool allowAlias = true);
1321   void printLocationInternal(LocationAttr loc, bool pretty = false);
1322 
1323   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1324   /// used instead of individual elements when the elements attr is large.
1325   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1326 
1327   /// Print a dense string elements attribute.
1328   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1329 
1330   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1331   /// used instead of individual elements when the elements attr is large.
1332   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1333                                      bool allowHex);
1334 
1335   void printDialectAttribute(Attribute attr);
1336   void printDialectType(Type type);
1337 
1338   /// This enum is used to represent the binding strength of the enclosing
1339   /// context that an AffineExprStorage is being printed in, so we can
1340   /// intelligently produce parens.
1341   enum class BindingStrength {
1342     Weak,   // + and -
1343     Strong, // All other binary operators.
1344   };
1345   void printAffineExprInternal(
1346       AffineExpr expr, BindingStrength enclosingTightness,
1347       function_ref<void(unsigned, bool)> printValueName = nullptr);
1348 
1349   /// The output stream for the printer.
1350   raw_ostream &os;
1351 
1352   /// A set of flags to control the printer's behavior.
1353   OpPrintingFlags printerFlags;
1354 
1355   /// An optional printer state for the module.
1356   AsmStateImpl *state;
1357 
1358   /// A tracker for the number of new lines emitted during printing.
1359   NewLineCounter newLine;
1360 };
1361 } // namespace mlir
1362 
1363 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1364   // Check to see if we are printing debug information.
1365   if (!printerFlags.shouldPrintDebugInfo())
1366     return;
1367 
1368   os << " ";
1369   printLocation(loc, /*allowAlias=*/allowAlias);
1370 }
1371 
1372 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
1373   TypeSwitch<LocationAttr>(loc)
1374       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1375         printLocationInternal(loc.getFallbackLocation(), pretty);
1376       })
1377       .Case<UnknownLoc>([&](UnknownLoc loc) {
1378         if (pretty)
1379           os << "[unknown]";
1380         else
1381           os << "unknown";
1382       })
1383       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1384         if (pretty) {
1385           os << loc.getFilename().getValue();
1386         } else {
1387           os << "\"";
1388           printEscapedString(loc.getFilename(), os);
1389           os << "\"";
1390         }
1391         os << ':' << loc.getLine() << ':' << loc.getColumn();
1392       })
1393       .Case<NameLoc>([&](NameLoc loc) {
1394         os << '\"';
1395         printEscapedString(loc.getName(), os);
1396         os << '\"';
1397 
1398         // Print the child if it isn't unknown.
1399         auto childLoc = loc.getChildLoc();
1400         if (!childLoc.isa<UnknownLoc>()) {
1401           os << '(';
1402           printLocationInternal(childLoc, pretty);
1403           os << ')';
1404         }
1405       })
1406       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1407         Location caller = loc.getCaller();
1408         Location callee = loc.getCallee();
1409         if (!pretty)
1410           os << "callsite(";
1411         printLocationInternal(callee, pretty);
1412         if (pretty) {
1413           if (callee.isa<NameLoc>()) {
1414             if (caller.isa<FileLineColLoc>()) {
1415               os << " at ";
1416             } else {
1417               os << newLine << " at ";
1418             }
1419           } else {
1420             os << newLine << " at ";
1421           }
1422         } else {
1423           os << " at ";
1424         }
1425         printLocationInternal(caller, pretty);
1426         if (!pretty)
1427           os << ")";
1428       })
1429       .Case<FusedLoc>([&](FusedLoc loc) {
1430         if (!pretty)
1431           os << "fused";
1432         if (Attribute metadata = loc.getMetadata())
1433           os << '<' << metadata << '>';
1434         os << '[';
1435         interleave(
1436             loc.getLocations(),
1437             [&](Location loc) { printLocationInternal(loc, pretty); },
1438             [&]() { os << ", "; });
1439         os << ']';
1440       });
1441 }
1442 
1443 /// Print a floating point value in a way that the parser will be able to
1444 /// round-trip losslessly.
1445 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1446   // We would like to output the FP constant value in exponential notation,
1447   // but we cannot do this if doing so will lose precision.  Check here to
1448   // make sure that we only output it in exponential format if we can parse
1449   // the value back and get the same value.
1450   bool isInf = apValue.isInfinity();
1451   bool isNaN = apValue.isNaN();
1452   if (!isInf && !isNaN) {
1453     SmallString<128> strValue;
1454     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1455                      /*TruncateZero=*/false);
1456 
1457     // Check to make sure that the stringized number is not some string like
1458     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1459     // that the string matches the "[-+]?[0-9]" regex.
1460     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1461             ((strValue[0] == '-' || strValue[0] == '+') &&
1462              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1463            "[-+]?[0-9] regex does not match!");
1464 
1465     // Parse back the stringized version and check that the value is equal
1466     // (i.e., there is no precision loss).
1467     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1468       os << strValue;
1469       return;
1470     }
1471 
1472     // If it is not, use the default format of APFloat instead of the
1473     // exponential notation.
1474     strValue.clear();
1475     apValue.toString(strValue);
1476 
1477     // Make sure that we can parse the default form as a float.
1478     if (strValue.str().contains('.')) {
1479       os << strValue;
1480       return;
1481     }
1482   }
1483 
1484   // Print special values in hexadecimal format. The sign bit should be included
1485   // in the literal.
1486   SmallVector<char, 16> str;
1487   APInt apInt = apValue.bitcastToAPInt();
1488   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1489                  /*formatAsCLiteral=*/true);
1490   os << str;
1491 }
1492 
1493 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
1494   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1495     return printLocationInternal(loc, /*pretty=*/true);
1496 
1497   os << "loc(";
1498   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1499     printLocationInternal(loc);
1500   os << ')';
1501 }
1502 
1503 /// Returns true if the given dialect symbol data is simple enough to print in
1504 /// the pretty form, i.e. without the enclosing "".
1505 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1506   // The name must start with an identifier.
1507   if (symName.empty() || !isalpha(symName.front()))
1508     return false;
1509 
1510   // Ignore all the characters that are valid in an identifier in the symbol
1511   // name.
1512   symName = symName.drop_while(
1513       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1514   if (symName.empty())
1515     return true;
1516 
1517   // If we got to an unexpected character, then it must be a <>.  Check those
1518   // recursively.
1519   if (symName.front() != '<' || symName.back() != '>')
1520     return false;
1521 
1522   SmallVector<char, 8> nestedPunctuation;
1523   do {
1524     // If we ran out of characters, then we had a punctuation mismatch.
1525     if (symName.empty())
1526       return false;
1527 
1528     auto c = symName.front();
1529     symName = symName.drop_front();
1530 
1531     switch (c) {
1532     // We never allow null characters. This is an EOF indicator for the lexer
1533     // which we could handle, but isn't important for any known dialect.
1534     case '\0':
1535       return false;
1536     case '<':
1537     case '[':
1538     case '(':
1539     case '{':
1540       nestedPunctuation.push_back(c);
1541       continue;
1542     case '-':
1543       // Treat `->` as a special token.
1544       if (!symName.empty() && symName.front() == '>') {
1545         symName = symName.drop_front();
1546         continue;
1547       }
1548       break;
1549     // Reject types with mismatched brackets.
1550     case '>':
1551       if (nestedPunctuation.pop_back_val() != '<')
1552         return false;
1553       break;
1554     case ']':
1555       if (nestedPunctuation.pop_back_val() != '[')
1556         return false;
1557       break;
1558     case ')':
1559       if (nestedPunctuation.pop_back_val() != '(')
1560         return false;
1561       break;
1562     case '}':
1563       if (nestedPunctuation.pop_back_val() != '{')
1564         return false;
1565       break;
1566     default:
1567       continue;
1568     }
1569 
1570     // We're done when the punctuation is fully matched.
1571   } while (!nestedPunctuation.empty());
1572 
1573   // If there were extra characters, then we failed.
1574   return symName.empty();
1575 }
1576 
1577 /// Print the given dialect symbol to the stream.
1578 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1579                                StringRef dialectName, StringRef symString) {
1580   os << symPrefix << dialectName;
1581 
1582   // If this symbol name is simple enough, print it directly in pretty form,
1583   // otherwise, we print it as an escaped string.
1584   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1585     os << '.' << symString;
1586     return;
1587   }
1588 
1589   os << "<\"";
1590   llvm::printEscapedString(symString, os);
1591   os << "\">";
1592 }
1593 
1594 /// Returns true if the given string can be represented as a bare identifier.
1595 static bool isBareIdentifier(StringRef name) {
1596   // By making this unsigned, the value passed in to isalnum will always be
1597   // in the range 0-255. This is important when building with MSVC because
1598   // its implementation will assert. This situation can arise when dealing
1599   // with UTF-8 multibyte characters.
1600   if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
1601     return false;
1602   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1603     return isalnum(c) || c == '_' || c == '$' || c == '.';
1604   });
1605 }
1606 
1607 /// Print the given string as a keyword, or a quoted and escaped string if it
1608 /// has any special or non-printable characters in it.
1609 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
1610   // If it can be represented as a bare identifier, write it directly.
1611   if (isBareIdentifier(keyword)) {
1612     os << keyword;
1613     return;
1614   }
1615 
1616   // Otherwise, output the keyword wrapped in quotes with proper escaping.
1617   os << "\"";
1618   printEscapedString(keyword, os);
1619   os << '"';
1620 }
1621 
1622 /// Print the given string as a symbol reference. A symbol reference is
1623 /// represented as a string prefixed with '@'. The reference is surrounded with
1624 /// ""'s and escaped if it has any special or non-printable characters in it.
1625 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1626   assert(!symbolRef.empty() && "expected valid symbol reference");
1627   os << '@';
1628   printKeywordOrString(symbolRef, os);
1629 }
1630 
1631 // Print out a valid ElementsAttr that is succinct and can represent any
1632 // potential shape/type, for use when eliding a large ElementsAttr.
1633 //
1634 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1635 // hopefully alert readers to the fact that this has been elided.
1636 //
1637 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1638 // accept the string "elided". The first string must be a registered dialect
1639 // name and the latter must be a hex constant.
1640 static void printElidedElementsAttr(raw_ostream &os) {
1641   os << R"(opaque<"elided_large_const", "0xDEADBEEF">)";
1642 }
1643 
1644 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
1645   return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
1646 }
1647 
1648 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
1649   return success(state && succeeded(state->getAliasState().getAlias(type, os)));
1650 }
1651 
1652 void AsmPrinter::Impl::printAttribute(Attribute attr,
1653                                       AttrTypeElision typeElision) {
1654   if (!attr) {
1655     os << "<<NULL ATTRIBUTE>>";
1656     return;
1657   }
1658 
1659   // Try to print an alias for this attribute.
1660   if (succeeded(printAlias(attr)))
1661     return;
1662 
1663   if (!isa<BuiltinDialect>(attr.getDialect()))
1664     return printDialectAttribute(attr);
1665 
1666   auto attrType = attr.getType();
1667   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1668     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1669                        opaqueAttr.getAttrData());
1670   } else if (attr.isa<UnitAttr>()) {
1671     os << "unit";
1672     return;
1673   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1674     os << '{';
1675     interleaveComma(dictAttr.getValue(),
1676                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1677     os << '}';
1678 
1679   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1680     if (attrType.isSignlessInteger(1)) {
1681       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1682 
1683       // Boolean integer attributes always elides the type.
1684       return;
1685     }
1686 
1687     // Only print attributes as unsigned if they are explicitly unsigned or are
1688     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1689     // values print as signed.
1690     bool isUnsigned =
1691         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1692     intAttr.getValue().print(os, !isUnsigned);
1693 
1694     // IntegerAttr elides the type if I64.
1695     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1696       return;
1697 
1698   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1699     printFloatValue(floatAttr.getValue(), os);
1700 
1701     // FloatAttr elides the type if F64.
1702     if (typeElision == AttrTypeElision::May && attrType.isF64())
1703       return;
1704 
1705   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1706     os << '"';
1707     printEscapedString(strAttr.getValue(), os);
1708     os << '"';
1709 
1710   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1711     os << '[';
1712     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1713       printAttribute(attr, AttrTypeElision::May);
1714     });
1715     os << ']';
1716 
1717   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1718     os << "affine_map<";
1719     affineMapAttr.getValue().print(os);
1720     os << '>';
1721 
1722     // AffineMap always elides the type.
1723     return;
1724 
1725   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1726     os << "affine_set<";
1727     integerSetAttr.getValue().print(os);
1728     os << '>';
1729 
1730     // IntegerSet always elides the type.
1731     return;
1732 
1733   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1734     printType(typeAttr.getValue());
1735 
1736   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1737     printSymbolReference(refAttr.getRootReference().getValue(), os);
1738     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1739       os << "::";
1740       printSymbolReference(nestedRef.getValue(), os);
1741     }
1742 
1743   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1744     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1745       printElidedElementsAttr(os);
1746     } else {
1747       os << "opaque<" << opaqueAttr.getDialect() << ", \"0x"
1748          << llvm::toHex(opaqueAttr.getValue()) << "\">";
1749     }
1750 
1751   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1752     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1753       printElidedElementsAttr(os);
1754     } else {
1755       os << "dense<";
1756       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1757       os << '>';
1758     }
1759 
1760   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1761     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1762       printElidedElementsAttr(os);
1763     } else {
1764       os << "dense<";
1765       printDenseStringElementsAttr(strEltAttr);
1766       os << '>';
1767     }
1768 
1769   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1770     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1771         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1772       printElidedElementsAttr(os);
1773     } else {
1774       os << "sparse<";
1775       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1776       if (indices.getNumElements() != 0) {
1777         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1778         os << ", ";
1779         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1780       }
1781       os << '>';
1782     }
1783 
1784   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1785     printLocation(locAttr);
1786   }
1787   // Don't print the type if we must elide it, or if it is a None type.
1788   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1789     os << " : ";
1790     printType(attrType);
1791   }
1792 }
1793 
1794 /// Print the integer element of a DenseElementsAttr.
1795 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1796                                  bool isSigned) {
1797   if (value.getBitWidth() == 1)
1798     os << (value.getBoolValue() ? "true" : "false");
1799   else
1800     value.print(os, isSigned);
1801 }
1802 
1803 static void
1804 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1805                            function_ref<void(unsigned)> printEltFn) {
1806   // Special case for 0-d and splat tensors.
1807   if (isSplat)
1808     return printEltFn(0);
1809 
1810   // Special case for degenerate tensors.
1811   auto numElements = type.getNumElements();
1812   if (numElements == 0)
1813     return;
1814 
1815   // We use a mixed-radix counter to iterate through the shape. When we bump a
1816   // non-least-significant digit, we emit a close bracket. When we next emit an
1817   // element we re-open all closed brackets.
1818 
1819   // The mixed-radix counter, with radices in 'shape'.
1820   int64_t rank = type.getRank();
1821   SmallVector<unsigned, 4> counter(rank, 0);
1822   // The number of brackets that have been opened and not closed.
1823   unsigned openBrackets = 0;
1824 
1825   auto shape = type.getShape();
1826   auto bumpCounter = [&] {
1827     // Bump the least significant digit.
1828     ++counter[rank - 1];
1829     // Iterate backwards bubbling back the increment.
1830     for (unsigned i = rank - 1; i > 0; --i)
1831       if (counter[i] >= shape[i]) {
1832         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1833         counter[i] = 0;
1834         ++counter[i - 1];
1835         --openBrackets;
1836         os << ']';
1837       }
1838   };
1839 
1840   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1841     if (idx != 0)
1842       os << ", ";
1843     while (openBrackets++ < rank)
1844       os << '[';
1845     openBrackets = rank;
1846     printEltFn(idx);
1847     bumpCounter();
1848   }
1849   while (openBrackets-- > 0)
1850     os << ']';
1851 }
1852 
1853 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
1854                                               bool allowHex) {
1855   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1856     return printDenseStringElementsAttr(stringAttr);
1857 
1858   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1859                                 allowHex);
1860 }
1861 
1862 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
1863     DenseIntOrFPElementsAttr attr, bool allowHex) {
1864   auto type = attr.getType();
1865   auto elementType = type.getElementType();
1866 
1867   // Check to see if we should format this attribute as a hex string.
1868   auto numElements = type.getNumElements();
1869   if (!attr.isSplat() && allowHex &&
1870       shouldPrintElementsAttrWithHex(numElements)) {
1871     ArrayRef<char> rawData = attr.getRawData();
1872     if (llvm::support::endian::system_endianness() ==
1873         llvm::support::endianness::big) {
1874       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1875       // machines. It is converted here to print in LE format.
1876       SmallVector<char, 64> outDataVec(rawData.size());
1877       MutableArrayRef<char> convRawData(outDataVec);
1878       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1879           rawData, convRawData, type);
1880       os << '"' << "0x"
1881          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1882          << "\"";
1883     } else {
1884       os << '"' << "0x"
1885          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1886     }
1887 
1888     return;
1889   }
1890 
1891   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1892     Type complexElementType = complexTy.getElementType();
1893     // Note: The if and else below had a common lambda function which invoked
1894     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1895     // and hence was replaced.
1896     if (complexElementType.isa<IntegerType>()) {
1897       bool isSigned = !complexElementType.isUnsignedInteger();
1898       auto valueIt = attr.value_begin<std::complex<APInt>>();
1899       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1900         auto complexValue = *(valueIt + index);
1901         os << "(";
1902         printDenseIntElement(complexValue.real(), os, isSigned);
1903         os << ",";
1904         printDenseIntElement(complexValue.imag(), os, isSigned);
1905         os << ")";
1906       });
1907     } else {
1908       auto valueIt = attr.value_begin<std::complex<APFloat>>();
1909       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1910         auto complexValue = *(valueIt + index);
1911         os << "(";
1912         printFloatValue(complexValue.real(), os);
1913         os << ",";
1914         printFloatValue(complexValue.imag(), os);
1915         os << ")";
1916       });
1917     }
1918   } else if (elementType.isIntOrIndex()) {
1919     bool isSigned = !elementType.isUnsignedInteger();
1920     auto valueIt = attr.value_begin<APInt>();
1921     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1922       printDenseIntElement(*(valueIt + index), os, isSigned);
1923     });
1924   } else {
1925     assert(elementType.isa<FloatType>() && "unexpected element type");
1926     auto valueIt = attr.value_begin<APFloat>();
1927     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1928       printFloatValue(*(valueIt + index), os);
1929     });
1930   }
1931 }
1932 
1933 void AsmPrinter::Impl::printDenseStringElementsAttr(
1934     DenseStringElementsAttr attr) {
1935   ArrayRef<StringRef> data = attr.getRawStringData();
1936   auto printFn = [&](unsigned index) {
1937     os << "\"";
1938     printEscapedString(data[index], os);
1939     os << "\"";
1940   };
1941   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1942 }
1943 
1944 void AsmPrinter::Impl::printType(Type type) {
1945   if (!type) {
1946     os << "<<NULL TYPE>>";
1947     return;
1948   }
1949 
1950   // Try to print an alias for this type.
1951   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1952     return;
1953 
1954   TypeSwitch<Type>(type)
1955       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1956         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1957                            opaqueTy.getTypeData());
1958       })
1959       .Case<IndexType>([&](Type) { os << "index"; })
1960       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1961       .Case<Float16Type>([&](Type) { os << "f16"; })
1962       .Case<Float32Type>([&](Type) { os << "f32"; })
1963       .Case<Float64Type>([&](Type) { os << "f64"; })
1964       .Case<Float80Type>([&](Type) { os << "f80"; })
1965       .Case<Float128Type>([&](Type) { os << "f128"; })
1966       .Case<IntegerType>([&](IntegerType integerTy) {
1967         if (integerTy.isSigned())
1968           os << 's';
1969         else if (integerTy.isUnsigned())
1970           os << 'u';
1971         os << 'i' << integerTy.getWidth();
1972       })
1973       .Case<FunctionType>([&](FunctionType funcTy) {
1974         os << '(';
1975         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1976         os << ") -> ";
1977         ArrayRef<Type> results = funcTy.getResults();
1978         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1979           printType(results[0]);
1980         } else {
1981           os << '(';
1982           interleaveComma(results, [&](Type ty) { printType(ty); });
1983           os << ')';
1984         }
1985       })
1986       .Case<VectorType>([&](VectorType vectorTy) {
1987         os << "vector<";
1988         auto vShape = vectorTy.getShape();
1989         unsigned lastDim = vShape.size();
1990         unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
1991         unsigned dimIdx = 0;
1992         for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
1993           os << vShape[dimIdx] << 'x';
1994         if (vectorTy.isScalable()) {
1995           os << '[';
1996           unsigned secondToLastDim = lastDim - 1;
1997           for (; dimIdx < secondToLastDim; dimIdx++)
1998             os << vShape[dimIdx] << 'x';
1999           os << vShape[dimIdx] << "]x";
2000         }
2001         printType(vectorTy.getElementType());
2002         os << '>';
2003       })
2004       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2005         os << "tensor<";
2006         for (int64_t dim : tensorTy.getShape()) {
2007           if (ShapedType::isDynamic(dim))
2008             os << '?';
2009           else
2010             os << dim;
2011           os << 'x';
2012         }
2013         printType(tensorTy.getElementType());
2014         // Only print the encoding attribute value if set.
2015         if (tensorTy.getEncoding()) {
2016           os << ", ";
2017           printAttribute(tensorTy.getEncoding());
2018         }
2019         os << '>';
2020       })
2021       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2022         os << "tensor<*x";
2023         printType(tensorTy.getElementType());
2024         os << '>';
2025       })
2026       .Case<MemRefType>([&](MemRefType memrefTy) {
2027         os << "memref<";
2028         for (int64_t dim : memrefTy.getShape()) {
2029           if (ShapedType::isDynamic(dim))
2030             os << '?';
2031           else
2032             os << dim;
2033           os << 'x';
2034         }
2035         printType(memrefTy.getElementType());
2036         if (!memrefTy.getLayout().isIdentity()) {
2037           os << ", ";
2038           printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2039         }
2040         // Only print the memory space if it is the non-default one.
2041         if (memrefTy.getMemorySpace()) {
2042           os << ", ";
2043           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2044         }
2045         os << '>';
2046       })
2047       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2048         os << "memref<*x";
2049         printType(memrefTy.getElementType());
2050         // Only print the memory space if it is the non-default one.
2051         if (memrefTy.getMemorySpace()) {
2052           os << ", ";
2053           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2054         }
2055         os << '>';
2056       })
2057       .Case<ComplexType>([&](ComplexType complexTy) {
2058         os << "complex<";
2059         printType(complexTy.getElementType());
2060         os << '>';
2061       })
2062       .Case<TupleType>([&](TupleType tupleTy) {
2063         os << "tuple<";
2064         interleaveComma(tupleTy.getTypes(),
2065                         [&](Type type) { printType(type); });
2066         os << '>';
2067       })
2068       .Case<NoneType>([&](Type) { os << "none"; })
2069       .Default([&](Type type) { return printDialectType(type); });
2070 }
2071 
2072 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2073                                              ArrayRef<StringRef> elidedAttrs,
2074                                              bool withKeyword) {
2075   // If there are no attributes, then there is nothing to be done.
2076   if (attrs.empty())
2077     return;
2078 
2079   // Functor used to print a filtered attribute list.
2080   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2081     // Print the 'attributes' keyword if necessary.
2082     if (withKeyword)
2083       os << " attributes";
2084 
2085     // Otherwise, print them all out in braces.
2086     os << " {";
2087     interleaveComma(filteredAttrs,
2088                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
2089     os << '}';
2090   };
2091 
2092   // If no attributes are elided, we can directly print with no filtering.
2093   if (elidedAttrs.empty())
2094     return printFilteredAttributesFn(attrs);
2095 
2096   // Otherwise, filter out any attributes that shouldn't be included.
2097   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2098                                                 elidedAttrs.end());
2099   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2100     return !elidedAttrsSet.contains(attr.getName().strref());
2101   });
2102   if (!filteredAttrs.empty())
2103     printFilteredAttributesFn(filteredAttrs);
2104 }
2105 
2106 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2107   // Print the name without quotes if possible.
2108   ::printKeywordOrString(attr.getName().strref(), os);
2109 
2110   // Pretty printing elides the attribute value for unit attributes.
2111   if (attr.getValue().isa<UnitAttr>())
2112     return;
2113 
2114   os << " = ";
2115   printAttribute(attr.getValue());
2116 }
2117 
2118 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2119   auto &dialect = attr.getDialect();
2120 
2121   // Ask the dialect to serialize the attribute to a string.
2122   std::string attrName;
2123   {
2124     llvm::raw_string_ostream attrNameStr(attrName);
2125     Impl subPrinter(attrNameStr, printerFlags, state);
2126     DialectAsmPrinter printer(subPrinter);
2127     dialect.printAttribute(attr, printer);
2128   }
2129   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2130 }
2131 
2132 void AsmPrinter::Impl::printDialectType(Type type) {
2133   auto &dialect = type.getDialect();
2134 
2135   // Ask the dialect to serialize the type to a string.
2136   std::string typeName;
2137   {
2138     llvm::raw_string_ostream typeNameStr(typeName);
2139     Impl subPrinter(typeNameStr, printerFlags, state);
2140     DialectAsmPrinter printer(subPrinter);
2141     dialect.printType(type, printer);
2142   }
2143   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2144 }
2145 
2146 //===--------------------------------------------------------------------===//
2147 // AsmPrinter
2148 //===--------------------------------------------------------------------===//
2149 
2150 AsmPrinter::~AsmPrinter() = default;
2151 
2152 raw_ostream &AsmPrinter::getStream() const {
2153   assert(impl && "expected AsmPrinter::getStream to be overriden");
2154   return impl->getStream();
2155 }
2156 
2157 /// Print the given floating point value in a stablized form.
2158 void AsmPrinter::printFloat(const APFloat &value) {
2159   assert(impl && "expected AsmPrinter::printFloat to be overriden");
2160   printFloatValue(value, impl->getStream());
2161 }
2162 
2163 void AsmPrinter::printType(Type type) {
2164   assert(impl && "expected AsmPrinter::printType to be overriden");
2165   impl->printType(type);
2166 }
2167 
2168 void AsmPrinter::printAttribute(Attribute attr) {
2169   assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2170   impl->printAttribute(attr);
2171 }
2172 
2173 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2174   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2175   return impl->printAlias(attr);
2176 }
2177 
2178 LogicalResult AsmPrinter::printAlias(Type type) {
2179   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2180   return impl->printAlias(type);
2181 }
2182 
2183 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2184   assert(impl &&
2185          "expected AsmPrinter::printAttributeWithoutType to be overriden");
2186   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2187 }
2188 
2189 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2190   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2191   ::printKeywordOrString(keyword, impl->getStream());
2192 }
2193 
2194 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2195   assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2196   ::printSymbolReference(symbolRef, impl->getStream());
2197 }
2198 
2199 //===----------------------------------------------------------------------===//
2200 // Affine expressions and maps
2201 //===----------------------------------------------------------------------===//
2202 
2203 void AsmPrinter::Impl::printAffineExpr(
2204     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2205   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2206 }
2207 
2208 void AsmPrinter::Impl::printAffineExprInternal(
2209     AffineExpr expr, BindingStrength enclosingTightness,
2210     function_ref<void(unsigned, bool)> printValueName) {
2211   const char *binopSpelling = nullptr;
2212   switch (expr.getKind()) {
2213   case AffineExprKind::SymbolId: {
2214     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2215     if (printValueName)
2216       printValueName(pos, /*isSymbol=*/true);
2217     else
2218       os << 's' << pos;
2219     return;
2220   }
2221   case AffineExprKind::DimId: {
2222     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2223     if (printValueName)
2224       printValueName(pos, /*isSymbol=*/false);
2225     else
2226       os << 'd' << pos;
2227     return;
2228   }
2229   case AffineExprKind::Constant:
2230     os << expr.cast<AffineConstantExpr>().getValue();
2231     return;
2232   case AffineExprKind::Add:
2233     binopSpelling = " + ";
2234     break;
2235   case AffineExprKind::Mul:
2236     binopSpelling = " * ";
2237     break;
2238   case AffineExprKind::FloorDiv:
2239     binopSpelling = " floordiv ";
2240     break;
2241   case AffineExprKind::CeilDiv:
2242     binopSpelling = " ceildiv ";
2243     break;
2244   case AffineExprKind::Mod:
2245     binopSpelling = " mod ";
2246     break;
2247   }
2248 
2249   auto binOp = expr.cast<AffineBinaryOpExpr>();
2250   AffineExpr lhsExpr = binOp.getLHS();
2251   AffineExpr rhsExpr = binOp.getRHS();
2252 
2253   // Handle tightly binding binary operators.
2254   if (binOp.getKind() != AffineExprKind::Add) {
2255     if (enclosingTightness == BindingStrength::Strong)
2256       os << '(';
2257 
2258     // Pretty print multiplication with -1.
2259     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2260     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2261         rhsConst.getValue() == -1) {
2262       os << "-";
2263       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2264       if (enclosingTightness == BindingStrength::Strong)
2265         os << ')';
2266       return;
2267     }
2268 
2269     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2270 
2271     os << binopSpelling;
2272     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2273 
2274     if (enclosingTightness == BindingStrength::Strong)
2275       os << ')';
2276     return;
2277   }
2278 
2279   // Print out special "pretty" forms for add.
2280   if (enclosingTightness == BindingStrength::Strong)
2281     os << '(';
2282 
2283   // Pretty print addition to a product that has a negative operand as a
2284   // subtraction.
2285   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2286     if (rhs.getKind() == AffineExprKind::Mul) {
2287       AffineExpr rrhsExpr = rhs.getRHS();
2288       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2289         if (rrhs.getValue() == -1) {
2290           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2291                                   printValueName);
2292           os << " - ";
2293           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2294             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2295                                     printValueName);
2296           } else {
2297             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2298                                     printValueName);
2299           }
2300 
2301           if (enclosingTightness == BindingStrength::Strong)
2302             os << ')';
2303           return;
2304         }
2305 
2306         if (rrhs.getValue() < -1) {
2307           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2308                                   printValueName);
2309           os << " - ";
2310           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2311                                   printValueName);
2312           os << " * " << -rrhs.getValue();
2313           if (enclosingTightness == BindingStrength::Strong)
2314             os << ')';
2315           return;
2316         }
2317       }
2318     }
2319   }
2320 
2321   // Pretty print addition to a negative number as a subtraction.
2322   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2323     if (rhsConst.getValue() < 0) {
2324       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2325       os << " - " << -rhsConst.getValue();
2326       if (enclosingTightness == BindingStrength::Strong)
2327         os << ')';
2328       return;
2329     }
2330   }
2331 
2332   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2333 
2334   os << " + ";
2335   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2336 
2337   if (enclosingTightness == BindingStrength::Strong)
2338     os << ')';
2339 }
2340 
2341 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2342   printAffineExprInternal(expr, BindingStrength::Weak);
2343   isEq ? os << " == 0" : os << " >= 0";
2344 }
2345 
2346 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2347   // Dimension identifiers.
2348   os << '(';
2349   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2350     os << 'd' << i << ", ";
2351   if (map.getNumDims() >= 1)
2352     os << 'd' << map.getNumDims() - 1;
2353   os << ')';
2354 
2355   // Symbolic identifiers.
2356   if (map.getNumSymbols() != 0) {
2357     os << '[';
2358     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2359       os << 's' << i << ", ";
2360     if (map.getNumSymbols() >= 1)
2361       os << 's' << map.getNumSymbols() - 1;
2362     os << ']';
2363   }
2364 
2365   // Result affine expressions.
2366   os << " -> (";
2367   interleaveComma(map.getResults(),
2368                   [&](AffineExpr expr) { printAffineExpr(expr); });
2369   os << ')';
2370 }
2371 
2372 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
2373   // Dimension identifiers.
2374   os << '(';
2375   for (unsigned i = 1; i < set.getNumDims(); ++i)
2376     os << 'd' << i - 1 << ", ";
2377   if (set.getNumDims() >= 1)
2378     os << 'd' << set.getNumDims() - 1;
2379   os << ')';
2380 
2381   // Symbolic identifiers.
2382   if (set.getNumSymbols() != 0) {
2383     os << '[';
2384     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2385       os << 's' << i << ", ";
2386     if (set.getNumSymbols() >= 1)
2387       os << 's' << set.getNumSymbols() - 1;
2388     os << ']';
2389   }
2390 
2391   // Print constraints.
2392   os << " : (";
2393   int numConstraints = set.getNumConstraints();
2394   for (int i = 1; i < numConstraints; ++i) {
2395     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2396     os << ", ";
2397   }
2398   if (numConstraints >= 1)
2399     printAffineConstraint(set.getConstraint(numConstraints - 1),
2400                           set.isEq(numConstraints - 1));
2401   os << ')';
2402 }
2403 
2404 //===----------------------------------------------------------------------===//
2405 // OperationPrinter
2406 //===----------------------------------------------------------------------===//
2407 
2408 namespace {
2409 /// This class contains the logic for printing operations, regions, and blocks.
2410 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
2411 public:
2412   using Impl = AsmPrinter::Impl;
2413   using Impl::printType;
2414 
2415   explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
2416       : Impl(os, state.getPrinterFlags(), &state),
2417         OpAsmPrinter(static_cast<Impl &>(*this)) {}
2418 
2419   /// Print the given top-level operation.
2420   void printTopLevelOperation(Operation *op);
2421 
2422   /// Print the given operation with its indent and location.
2423   void print(Operation *op);
2424   /// Print the bare location, not including indentation/location/etc.
2425   void printOperation(Operation *op);
2426   /// Print the given operation in the generic form.
2427   void printGenericOp(Operation *op, bool printOpName) override;
2428 
2429   /// Print the name of the given block.
2430   void printBlockName(Block *block);
2431 
2432   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2433   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2434   /// operation of the block is not printed.
2435   void print(Block *block, bool printBlockArgs = true,
2436              bool printBlockTerminator = true);
2437 
2438   /// Print the ID of the given value, optionally with its result number.
2439   void printValueID(Value value, bool printResultNo = true,
2440                     raw_ostream *streamOverride = nullptr) const;
2441 
2442   //===--------------------------------------------------------------------===//
2443   // OpAsmPrinter methods
2444   //===--------------------------------------------------------------------===//
2445 
2446   /// Print a newline and indent the printer to the start of the current
2447   /// operation.
2448   void printNewline() override {
2449     os << newLine;
2450     os.indent(currentIndent);
2451   }
2452 
2453   /// Print a block argument in the usual format of:
2454   ///   %ssaName : type {attr1=42} loc("here")
2455   /// where location printing is controlled by the standard internal option.
2456   /// You may pass omitType=true to not print a type, and pass an empty
2457   /// attribute list if you don't care for attributes.
2458   void printRegionArgument(BlockArgument arg,
2459                            ArrayRef<NamedAttribute> argAttrs = {},
2460                            bool omitType = false) override;
2461 
2462   /// Print the ID for the given value.
2463   void printOperand(Value value) override { printValueID(value); }
2464   void printOperand(Value value, raw_ostream &os) override {
2465     printValueID(value, /*printResultNo=*/true, &os);
2466   }
2467 
2468   /// Print an optional attribute dictionary with a given set of elided values.
2469   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2470                              ArrayRef<StringRef> elidedAttrs = {}) override {
2471     Impl::printOptionalAttrDict(attrs, elidedAttrs);
2472   }
2473   void printOptionalAttrDictWithKeyword(
2474       ArrayRef<NamedAttribute> attrs,
2475       ArrayRef<StringRef> elidedAttrs = {}) override {
2476     Impl::printOptionalAttrDict(attrs, elidedAttrs,
2477                                 /*withKeyword=*/true);
2478   }
2479 
2480   /// Print the given successor.
2481   void printSuccessor(Block *successor) override;
2482 
2483   /// Print an operation successor with the operands used for the block
2484   /// arguments.
2485   void printSuccessorAndUseList(Block *successor,
2486                                 ValueRange succOperands) override;
2487 
2488   /// Print the given region.
2489   void printRegion(Region &region, bool printEntryBlockArgs,
2490                    bool printBlockTerminators, bool printEmptyBlock) override;
2491 
2492   /// Renumber the arguments for the specified region to the same names as the
2493   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2494   /// operations. If any entry in namesToUse is null, the corresponding
2495   /// argument name is left alone.
2496   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2497     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2498   }
2499 
2500   /// Print the given affine map with the symbol and dimension operands printed
2501   /// inline with the map.
2502   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2503                               ValueRange operands) override;
2504 
2505   /// Print the given affine expression with the symbol and dimension operands
2506   /// printed inline with the expression.
2507   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2508                                ValueRange symOperands) override;
2509 
2510 private:
2511   // Contains the stack of default dialects to use when printing regions.
2512   // A new dialect is pushed to the stack before parsing regions nested under an
2513   // operation implementing `OpAsmOpInterface`, and popped when done. At the
2514   // top-level we start with "builtin" as the default, so that the top-level
2515   // `module` operation prints as-is.
2516   SmallVector<StringRef> defaultDialectStack{"builtin"};
2517 
2518   /// The number of spaces used for indenting nested operations.
2519   const static unsigned indentWidth = 2;
2520 
2521   // This is the current indentation level for nested structures.
2522   unsigned currentIndent = 0;
2523 };
2524 } // namespace
2525 
2526 void OperationPrinter::printTopLevelOperation(Operation *op) {
2527   // Output the aliases at the top level that can't be deferred.
2528   state->getAliasState().printNonDeferredAliases(os, newLine);
2529 
2530   // Print the module.
2531   print(op);
2532   os << newLine;
2533 
2534   // Output the aliases at the top level that can be deferred.
2535   state->getAliasState().printDeferredAliases(os, newLine);
2536 }
2537 
2538 /// Print a block argument in the usual format of:
2539 ///   %ssaName : type {attr1=42} loc("here")
2540 /// where location printing is controlled by the standard internal option.
2541 /// You may pass omitType=true to not print a type, and pass an empty
2542 /// attribute list if you don't care for attributes.
2543 void OperationPrinter::printRegionArgument(BlockArgument arg,
2544                                            ArrayRef<NamedAttribute> argAttrs,
2545                                            bool omitType) {
2546   printOperand(arg);
2547   if (!omitType) {
2548     os << ": ";
2549     printType(arg.getType());
2550   }
2551   printOptionalAttrDict(argAttrs);
2552   // TODO: We should allow location aliases on block arguments.
2553   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2554 }
2555 
2556 void OperationPrinter::print(Operation *op) {
2557   // Track the location of this operation.
2558   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2559 
2560   os.indent(currentIndent);
2561   printOperation(op);
2562   printTrailingLocation(op->getLoc());
2563 }
2564 
2565 void OperationPrinter::printOperation(Operation *op) {
2566   if (size_t numResults = op->getNumResults()) {
2567     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2568       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2569       if (resultCount > 1)
2570         os << ':' << resultCount;
2571     };
2572 
2573     // Check to see if this operation has multiple result groups.
2574     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2575     if (!resultGroups.empty()) {
2576       // Interleave the groups excluding the last one, this one will be handled
2577       // separately.
2578       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2579         printResultGroup(resultGroups[i],
2580                          resultGroups[i + 1] - resultGroups[i]);
2581       });
2582       os << ", ";
2583       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2584 
2585     } else {
2586       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2587     }
2588 
2589     os << " = ";
2590   }
2591 
2592   // If requested, always print the generic form.
2593   if (!printerFlags.shouldPrintGenericOpForm()) {
2594     // Check to see if this is a known operation. If so, use the registered
2595     // custom printer hook.
2596     if (auto opInfo = op->getRegisteredInfo()) {
2597       opInfo->printAssembly(op, *this, defaultDialectStack.back());
2598       return;
2599     }
2600     // Otherwise try to dispatch to the dialect, if available.
2601     if (Dialect *dialect = op->getDialect()) {
2602       if (auto opPrinter = dialect->getOperationPrinter(op)) {
2603         // Print the op name first.
2604         StringRef name = op->getName().getStringRef();
2605         name.consume_front((defaultDialectStack.back() + ".").str());
2606         printEscapedString(name, os);
2607         // Print the rest of the op now.
2608         opPrinter(op, *this);
2609         return;
2610       }
2611     }
2612   }
2613 
2614   // Otherwise print with the generic assembly form.
2615   printGenericOp(op, /*printOpName=*/true);
2616 }
2617 
2618 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
2619   if (printOpName) {
2620     os << '"';
2621     printEscapedString(op->getName().getStringRef(), os);
2622     os << '"';
2623   }
2624   os << '(';
2625   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2626   os << ')';
2627 
2628   // For terminators, print the list of successors and their operands.
2629   if (op->getNumSuccessors() != 0) {
2630     os << '[';
2631     interleaveComma(op->getSuccessors(),
2632                     [&](Block *successor) { printBlockName(successor); });
2633     os << ']';
2634   }
2635 
2636   // Print regions.
2637   if (op->getNumRegions() != 0) {
2638     os << " (";
2639     interleaveComma(op->getRegions(), [&](Region &region) {
2640       printRegion(region, /*printEntryBlockArgs=*/true,
2641                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2642     });
2643     os << ')';
2644   }
2645 
2646   auto attrs = op->getAttrs();
2647   printOptionalAttrDict(attrs);
2648 
2649   // Print the type signature of the operation.
2650   os << " : ";
2651   printFunctionalType(op);
2652 }
2653 
2654 void OperationPrinter::printBlockName(Block *block) {
2655   os << state->getSSANameState().getBlockInfo(block).name;
2656 }
2657 
2658 void OperationPrinter::print(Block *block, bool printBlockArgs,
2659                              bool printBlockTerminator) {
2660   // Print the block label and argument list if requested.
2661   if (printBlockArgs) {
2662     os.indent(currentIndent);
2663     printBlockName(block);
2664 
2665     // Print the argument list if non-empty.
2666     if (!block->args_empty()) {
2667       os << '(';
2668       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2669         printValueID(arg);
2670         os << ": ";
2671         printType(arg.getType());
2672         // TODO: We should allow location aliases on block arguments.
2673         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2674       });
2675       os << ')';
2676     }
2677     os << ':';
2678 
2679     // Print out some context information about the predecessors of this block.
2680     if (!block->getParent()) {
2681       os << "  // block is not in a region!";
2682     } else if (block->hasNoPredecessors()) {
2683       if (!block->isEntryBlock())
2684         os << "  // no predecessors";
2685     } else if (auto *pred = block->getSinglePredecessor()) {
2686       os << "  // pred: ";
2687       printBlockName(pred);
2688     } else {
2689       // We want to print the predecessors in a stable order, not in
2690       // whatever order the use-list is in, so gather and sort them.
2691       SmallVector<BlockInfo, 4> predIDs;
2692       for (auto *pred : block->getPredecessors())
2693         predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
2694       llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
2695         return lhs.ordering < rhs.ordering;
2696       });
2697 
2698       os << "  // " << predIDs.size() << " preds: ";
2699 
2700       interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
2701     }
2702     os << newLine;
2703   }
2704 
2705   currentIndent += indentWidth;
2706   bool hasTerminator =
2707       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
2708   auto range = llvm::make_range(
2709       block->begin(),
2710       std::prev(block->end(),
2711                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
2712   for (auto &op : range) {
2713     print(&op);
2714     os << newLine;
2715   }
2716   currentIndent -= indentWidth;
2717 }
2718 
2719 void OperationPrinter::printValueID(Value value, bool printResultNo,
2720                                     raw_ostream *streamOverride) const {
2721   state->getSSANameState().printValueID(value, printResultNo,
2722                                         streamOverride ? *streamOverride : os);
2723 }
2724 
2725 void OperationPrinter::printSuccessor(Block *successor) {
2726   printBlockName(successor);
2727 }
2728 
2729 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2730                                                 ValueRange succOperands) {
2731   printBlockName(successor);
2732   if (succOperands.empty())
2733     return;
2734 
2735   os << '(';
2736   interleaveComma(succOperands,
2737                   [this](Value operand) { printValueID(operand); });
2738   os << " : ";
2739   interleaveComma(succOperands,
2740                   [this](Value operand) { printType(operand.getType()); });
2741   os << ')';
2742 }
2743 
2744 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2745                                    bool printBlockTerminators,
2746                                    bool printEmptyBlock) {
2747   os << "{" << newLine;
2748   if (!region.empty()) {
2749     auto restoreDefaultDialect =
2750         llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
2751     if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
2752       defaultDialectStack.push_back(iface.getDefaultDialect());
2753     else
2754       defaultDialectStack.push_back("");
2755 
2756     auto *entryBlock = &region.front();
2757     // Force printing the block header if printEmptyBlock is set and the block
2758     // is empty or if printEntryBlockArgs is set and there are arguments to
2759     // print.
2760     bool shouldAlwaysPrintBlockHeader =
2761         (printEmptyBlock && entryBlock->empty()) ||
2762         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
2763     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
2764     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2765       print(&b);
2766   }
2767   os.indent(currentIndent) << "}";
2768 }
2769 
2770 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2771                                               ValueRange operands) {
2772   AffineMap map = mapAttr.getValue();
2773   unsigned numDims = map.getNumDims();
2774   auto printValueName = [&](unsigned pos, bool isSymbol) {
2775     unsigned index = isSymbol ? numDims + pos : pos;
2776     assert(index < operands.size());
2777     if (isSymbol)
2778       os << "symbol(";
2779     printValueID(operands[index]);
2780     if (isSymbol)
2781       os << ')';
2782   };
2783 
2784   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2785     printAffineExpr(expr, printValueName);
2786   });
2787 }
2788 
2789 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
2790                                                ValueRange dimOperands,
2791                                                ValueRange symOperands) {
2792   auto printValueName = [&](unsigned pos, bool isSymbol) {
2793     if (!isSymbol)
2794       return printValueID(dimOperands[pos]);
2795     os << "symbol(";
2796     printValueID(symOperands[pos]);
2797     os << ')';
2798   };
2799   printAffineExpr(expr, printValueName);
2800 }
2801 
2802 //===----------------------------------------------------------------------===//
2803 // print and dump methods
2804 //===----------------------------------------------------------------------===//
2805 
2806 void Attribute::print(raw_ostream &os) const {
2807   AsmPrinter::Impl(os).printAttribute(*this);
2808 }
2809 
2810 void Attribute::dump() const {
2811   print(llvm::errs());
2812   llvm::errs() << "\n";
2813 }
2814 
2815 void Type::print(raw_ostream &os) const {
2816   AsmPrinter::Impl(os).printType(*this);
2817 }
2818 
2819 void Type::dump() const { print(llvm::errs()); }
2820 
2821 void AffineMap::dump() const {
2822   print(llvm::errs());
2823   llvm::errs() << "\n";
2824 }
2825 
2826 void IntegerSet::dump() const {
2827   print(llvm::errs());
2828   llvm::errs() << "\n";
2829 }
2830 
2831 void AffineExpr::print(raw_ostream &os) const {
2832   if (!expr) {
2833     os << "<<NULL AFFINE EXPR>>";
2834     return;
2835   }
2836   AsmPrinter::Impl(os).printAffineExpr(*this);
2837 }
2838 
2839 void AffineExpr::dump() const {
2840   print(llvm::errs());
2841   llvm::errs() << "\n";
2842 }
2843 
2844 void AffineMap::print(raw_ostream &os) const {
2845   if (!map) {
2846     os << "<<NULL AFFINE MAP>>";
2847     return;
2848   }
2849   AsmPrinter::Impl(os).printAffineMap(*this);
2850 }
2851 
2852 void IntegerSet::print(raw_ostream &os) const {
2853   AsmPrinter::Impl(os).printIntegerSet(*this);
2854 }
2855 
2856 void Value::print(raw_ostream &os) {
2857   if (!impl) {
2858     os << "<<NULL VALUE>>";
2859     return;
2860   }
2861 
2862   if (auto *op = getDefiningOp())
2863     return op->print(os);
2864   // TODO: Improve BlockArgument print'ing.
2865   BlockArgument arg = this->cast<BlockArgument>();
2866   os << "<block argument> of type '" << arg.getType()
2867      << "' at index: " << arg.getArgNumber();
2868 }
2869 void Value::print(raw_ostream &os, AsmState &state) {
2870   if (!impl) {
2871     os << "<<NULL VALUE>>";
2872     return;
2873   }
2874 
2875   if (auto *op = getDefiningOp())
2876     return op->print(os, state);
2877 
2878   // TODO: Improve BlockArgument print'ing.
2879   BlockArgument arg = this->cast<BlockArgument>();
2880   os << "<block argument> of type '" << arg.getType()
2881      << "' at index: " << arg.getArgNumber();
2882 }
2883 
2884 void Value::dump() {
2885   print(llvm::errs());
2886   llvm::errs() << "\n";
2887 }
2888 
2889 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2890   // TODO: This doesn't necessarily capture all potential cases.
2891   // Currently, region arguments can be shadowed when printing the main
2892   // operation. If the IR hasn't been printed, this will produce the old SSA
2893   // name and not the shadowed name.
2894   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2895                                                  os);
2896 }
2897 
2898 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
2899   // If this is a top level operation, we also print aliases.
2900   if (!getParent() && !printerFlags.shouldUseLocalScope()) {
2901     AsmState state(this, printerFlags);
2902     state.getImpl().initializeAliases(this);
2903     print(os, state);
2904     return;
2905   }
2906 
2907   // Find the operation to number from based upon the provided flags.
2908   Operation *op = this;
2909   bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
2910   do {
2911     // If we are printing local scope, stop at the first operation that is
2912     // isolated from above.
2913     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
2914       break;
2915 
2916     // Otherwise, traverse up to the next parent.
2917     Operation *parentOp = op->getParentOp();
2918     if (!parentOp)
2919       break;
2920     op = parentOp;
2921   } while (true);
2922 
2923   AsmState state(op, printerFlags);
2924   print(os, state);
2925 }
2926 void Operation::print(raw_ostream &os, AsmState &state) {
2927   OperationPrinter printer(os, state.getImpl());
2928   if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
2929     printer.printTopLevelOperation(this);
2930   else
2931     printer.print(this);
2932 }
2933 
2934 void Operation::dump() {
2935   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2936   llvm::errs() << "\n";
2937 }
2938 
2939 void Block::print(raw_ostream &os) {
2940   Operation *parentOp = getParentOp();
2941   if (!parentOp) {
2942     os << "<<UNLINKED BLOCK>>\n";
2943     return;
2944   }
2945   // Get the top-level op.
2946   while (auto *nextOp = parentOp->getParentOp())
2947     parentOp = nextOp;
2948 
2949   AsmState state(parentOp);
2950   print(os, state);
2951 }
2952 void Block::print(raw_ostream &os, AsmState &state) {
2953   OperationPrinter(os, state.getImpl()).print(this);
2954 }
2955 
2956 void Block::dump() { print(llvm::errs()); }
2957 
2958 /// Print out the name of the block without printing its body.
2959 void Block::printAsOperand(raw_ostream &os, bool printType) {
2960   Operation *parentOp = getParentOp();
2961   if (!parentOp) {
2962     os << "<<UNLINKED BLOCK>>\n";
2963     return;
2964   }
2965   AsmState state(parentOp);
2966   printAsOperand(os, state);
2967 }
2968 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2969   OperationPrinter printer(os, state.getImpl());
2970   printer.printBlockName(this);
2971 }
2972