1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/SubElementInterfaces.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/ScopeExit.h"
33 #include "llvm/ADT/ScopedHashTable.h"
34 #include "llvm/ADT/SetVector.h"
35 #include "llvm/ADT/SmallString.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Endian.h"
41 #include "llvm/Support/Regex.h"
42 #include "llvm/Support/SaveAndRestore.h"
43 
44 #include <tuple>
45 
46 using namespace mlir;
47 using namespace mlir::detail;
48 
49 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
50 
51 void OperationName::dump() const { print(llvm::errs()); }
52 
53 //===--------------------------------------------------------------------===//
54 // AsmParser
55 //===--------------------------------------------------------------------===//
56 
57 AsmParser::~AsmParser() {}
58 DialectAsmParser::~DialectAsmParser() {}
59 OpAsmParser::~OpAsmParser() {}
60 
61 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
62 
63 //===----------------------------------------------------------------------===//
64 // DialectAsmPrinter
65 //===----------------------------------------------------------------------===//
66 
67 DialectAsmPrinter::~DialectAsmPrinter() {}
68 
69 //===----------------------------------------------------------------------===//
70 // OpAsmPrinter
71 //===----------------------------------------------------------------------===//
72 
73 OpAsmPrinter::~OpAsmPrinter() {}
74 
75 void OpAsmPrinter::printFunctionalType(Operation *op) {
76   auto &os = getStream();
77   os << '(';
78   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
79     // Print the types of null values as <<NULL TYPE>>.
80     *this << (operand ? operand.getType() : Type());
81   });
82   os << ") -> ";
83 
84   // Print the result list.  We don't parenthesize single result types unless
85   // it is a function (avoiding a grammar ambiguity).
86   bool wrapped = op->getNumResults() != 1;
87   if (!wrapped && op->getResult(0).getType() &&
88       op->getResult(0).getType().isa<FunctionType>())
89     wrapped = true;
90 
91   if (wrapped)
92     os << '(';
93 
94   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
95     // Print the types of null values as <<NULL TYPE>>.
96     *this << (result ? result.getType() : Type());
97   });
98 
99   if (wrapped)
100     os << ')';
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // Operation OpAsm interface.
105 //===----------------------------------------------------------------------===//
106 
107 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
108 #include "mlir/IR/OpAsmInterface.cpp.inc"
109 
110 //===----------------------------------------------------------------------===//
111 // OpPrintingFlags
112 //===----------------------------------------------------------------------===//
113 
114 namespace {
115 /// This struct contains command line options that can be used to initialize
116 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
117 /// for global command line options.
118 struct AsmPrinterOptions {
119   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
120       "mlir-print-elementsattrs-with-hex-if-larger",
121       llvm::cl::desc(
122           "Print DenseElementsAttrs with a hex string that have "
123           "more elements than the given upper limit (use -1 to disable)")};
124 
125   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
126       "mlir-elide-elementsattrs-if-larger",
127       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
128                      "more elements than the given upper limit")};
129 
130   llvm::cl::opt<bool> printDebugInfoOpt{
131       "mlir-print-debuginfo", llvm::cl::init(false),
132       llvm::cl::desc("Print debug info in MLIR output")};
133 
134   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
135       "mlir-pretty-debuginfo", llvm::cl::init(false),
136       llvm::cl::desc("Print pretty debug info in MLIR output")};
137 
138   // Use the generic op output form in the operation printer even if the custom
139   // form is defined.
140   llvm::cl::opt<bool> printGenericOpFormOpt{
141       "mlir-print-op-generic", llvm::cl::init(false),
142       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
143 
144   llvm::cl::opt<bool> printLocalScopeOpt{
145       "mlir-print-local-scope", llvm::cl::init(false),
146       llvm::cl::desc("Print assuming in local scope by default"),
147       llvm::cl::Hidden};
148 };
149 } // end anonymous namespace
150 
151 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
152 
153 /// Register a set of useful command-line options that can be used to configure
154 /// various flags within the AsmPrinter.
155 void mlir::registerAsmPrinterCLOptions() {
156   // Make sure that the options struct has been initialized.
157   *clOptions;
158 }
159 
160 /// Initialize the printing flags with default supplied by the cl::opts above.
161 OpPrintingFlags::OpPrintingFlags()
162     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
163       printGenericOpFormFlag(false), printLocalScope(false) {
164   // Initialize based upon command line options, if they are available.
165   if (!clOptions.isConstructed())
166     return;
167   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
168     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
169   printDebugInfoFlag = clOptions->printDebugInfoOpt;
170   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
171   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
172   printLocalScope = clOptions->printLocalScopeOpt;
173 }
174 
175 /// Enable the elision of large elements attributes, by printing a '...'
176 /// instead of the element data, when the number of elements is greater than
177 /// `largeElementLimit`. Note: The IR generated with this option is not
178 /// parsable.
179 OpPrintingFlags &
180 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
181   elementsAttrElementLimit = largeElementLimit;
182   return *this;
183 }
184 
185 /// Enable printing of debug information. If 'prettyForm' is set to true,
186 /// debug information is printed in a more readable 'pretty' form.
187 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
188   printDebugInfoFlag = true;
189   printDebugInfoPrettyFormFlag = prettyForm;
190   return *this;
191 }
192 
193 /// Always print operations in the generic form.
194 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
195   printGenericOpFormFlag = true;
196   return *this;
197 }
198 
199 /// Use local scope when printing the operation. This allows for using the
200 /// printer in a more localized and thread-safe setting, but may not necessarily
201 /// be identical of what the IR will look like when dumping the full module.
202 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
203   printLocalScope = true;
204   return *this;
205 }
206 
207 /// Return if the given ElementsAttr should be elided.
208 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
209   return elementsAttrElementLimit.hasValue() &&
210          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
211          !attr.isa<SplatElementsAttr>();
212 }
213 
214 /// Return the size limit for printing large ElementsAttr.
215 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
216   return elementsAttrElementLimit;
217 }
218 
219 /// Return if debug information should be printed.
220 bool OpPrintingFlags::shouldPrintDebugInfo() const {
221   return printDebugInfoFlag;
222 }
223 
224 /// Return if debug information should be printed in the pretty form.
225 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
226   return printDebugInfoPrettyFormFlag;
227 }
228 
229 /// Return if operations should be printed in the generic form.
230 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
231   return printGenericOpFormFlag;
232 }
233 
234 /// Return if the printer should use local scope when dumping the IR.
235 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
236 
237 /// Returns true if an ElementsAttr with the given number of elements should be
238 /// printed with hex.
239 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
240   // Check to see if a command line option was provided for the limit.
241   if (clOptions.isConstructed()) {
242     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
243       // -1 is used to disable hex printing.
244       if (clOptions->printElementsAttrWithHexIfLarger == -1)
245         return false;
246       return numElements > clOptions->printElementsAttrWithHexIfLarger;
247     }
248   }
249 
250   // Otherwise, default to printing with hex if the number of elements is >100.
251   return numElements > 100;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // NewLineCounter
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 /// This class is a simple formatter that emits a new line when inputted into a
260 /// stream, that enables counting the number of newlines emitted. This class
261 /// should be used whenever emitting newlines in the printer.
262 struct NewLineCounter {
263   unsigned curLine = 1;
264 };
265 
266 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
267   ++newLine.curLine;
268   return os << '\n';
269 }
270 } // end anonymous namespace
271 
272 //===----------------------------------------------------------------------===//
273 // AliasInitializer
274 //===----------------------------------------------------------------------===//
275 
276 namespace {
277 /// This class represents a specific instance of a symbol Alias.
278 class SymbolAlias {
279 public:
280   SymbolAlias(StringRef name, bool isDeferrable)
281       : name(name), suffixIndex(0), hasSuffixIndex(false),
282         isDeferrable(isDeferrable) {}
283   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
284       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
285         isDeferrable(isDeferrable) {}
286 
287   /// Print this alias to the given stream.
288   void print(raw_ostream &os) const {
289     os << name;
290     if (hasSuffixIndex)
291       os << suffixIndex;
292   }
293 
294   /// Returns true if this alias supports deferred resolution when parsing.
295   bool canBeDeferred() const { return isDeferrable; }
296 
297 private:
298   /// The main name of the alias.
299   StringRef name;
300   /// The optional suffix index of the alias, if multiple aliases had the same
301   /// name.
302   uint32_t suffixIndex : 30;
303   /// A flag indicating whether this alias has a suffix or not.
304   bool hasSuffixIndex : 1;
305   /// A flag indicating whether this alias may be deferred or not.
306   bool isDeferrable : 1;
307 };
308 
309 /// This class represents a utility that initializes the set of attribute and
310 /// type aliases, without the need to store the extra information within the
311 /// main AliasState class or pass it around via function arguments.
312 class AliasInitializer {
313 public:
314   AliasInitializer(
315       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
316       llvm::BumpPtrAllocator &aliasAllocator)
317       : interfaces(interfaces), aliasAllocator(aliasAllocator),
318         aliasOS(aliasBuffer) {}
319 
320   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
321                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
322                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
323 
324   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
325   /// set to true if the originator of this attribute can resolve the alias
326   /// after parsing has completed (e.g. in the case of operation locations).
327   void visit(Attribute attr, bool canBeDeferred = false);
328 
329   /// Visit the given type to see if it has an alias.
330   void visit(Type type);
331 
332 private:
333   /// Try to generate an alias for the provided symbol. If an alias is
334   /// generated, the provided alias mapping and reverse mapping are updated.
335   /// Returns success if an alias was generated, failure otherwise.
336   template <typename T>
337   LogicalResult
338   generateAlias(T symbol,
339                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
340 
341   /// The set of asm interfaces within the context.
342   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
343 
344   /// Mapping between an alias and the set of symbols mapped to it.
345   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
346   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
347 
348   /// An allocator used for alias names.
349   llvm::BumpPtrAllocator &aliasAllocator;
350 
351   /// The set of visited attributes.
352   DenseSet<Attribute> visitedAttributes;
353 
354   /// The set of attributes that have aliases *and* can be deferred.
355   DenseSet<Attribute> deferrableAttributes;
356 
357   /// The set of visited types.
358   DenseSet<Type> visitedTypes;
359 
360   /// Storage and stream used when generating an alias.
361   SmallString<32> aliasBuffer;
362   llvm::raw_svector_ostream aliasOS;
363 };
364 
365 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
366 /// and merely collects the attributes and types that *would* be printed in a
367 /// normal print invocation so that we can generate proper aliases. This allows
368 /// for us to generate aliases only for the attributes and types that would be
369 /// in the output, and trims down unnecessary output.
370 class DummyAliasOperationPrinter : private OpAsmPrinter {
371 public:
372   explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
373                                       AliasInitializer &initializer)
374       : printerFlags(printerFlags), initializer(initializer) {}
375 
376   /// Print the given operation.
377   void print(Operation *op) {
378     // Visit the operation location.
379     if (printerFlags.shouldPrintDebugInfo())
380       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
381 
382     // If requested, always print the generic form.
383     if (!printerFlags.shouldPrintGenericOpForm()) {
384       // Check to see if this is a known operation.  If so, use the registered
385       // custom printer hook.
386       if (auto opInfo = op->getRegisteredInfo()) {
387         opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
388         return;
389       }
390     }
391 
392     // Otherwise print with the generic assembly form.
393     printGenericOp(op);
394   }
395 
396 private:
397   /// Print the given operation in the generic form.
398   void printGenericOp(Operation *op, bool printOpName = true) override {
399     // Consider nested operations for aliases.
400     if (op->getNumRegions() != 0) {
401       for (Region &region : op->getRegions())
402         printRegion(region, /*printEntryBlockArgs=*/true,
403                     /*printBlockTerminators=*/true);
404     }
405 
406     // Visit all the types used in the operation.
407     for (Type type : op->getOperandTypes())
408       printType(type);
409     for (Type type : op->getResultTypes())
410       printType(type);
411 
412     // Consider the attributes of the operation for aliases.
413     for (const NamedAttribute &attr : op->getAttrs())
414       printAttribute(attr.getValue());
415   }
416 
417   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
418   /// block are not printed. If 'printBlockTerminator' is false, the terminator
419   /// operation of the block is not printed.
420   void print(Block *block, bool printBlockArgs = true,
421              bool printBlockTerminator = true) {
422     // Consider the types of the block arguments for aliases if 'printBlockArgs'
423     // is set to true.
424     if (printBlockArgs) {
425       for (BlockArgument arg : block->getArguments()) {
426         printType(arg.getType());
427 
428         // Visit the argument location.
429         if (printerFlags.shouldPrintDebugInfo())
430           // TODO: Allow deferring argument locations.
431           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
432       }
433     }
434 
435     // Consider the operations within this block, ignoring the terminator if
436     // requested.
437     bool hasTerminator =
438         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
439     auto range = llvm::make_range(
440         block->begin(),
441         std::prev(block->end(),
442                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
443     for (Operation &op : range)
444       print(&op);
445   }
446 
447   /// Print the given region.
448   void printRegion(Region &region, bool printEntryBlockArgs,
449                    bool printBlockTerminators,
450                    bool printEmptyBlock = false) override {
451     if (region.empty())
452       return;
453 
454     auto *entryBlock = &region.front();
455     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
456     for (Block &b : llvm::drop_begin(region, 1))
457       print(&b);
458   }
459 
460   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
461                            bool omitType) override {
462     printType(arg.getType());
463     // Visit the argument location.
464     if (printerFlags.shouldPrintDebugInfo())
465       // TODO: Allow deferring argument locations.
466       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
467   }
468 
469   /// Consider the given type to be printed for an alias.
470   void printType(Type type) override { initializer.visit(type); }
471 
472   /// Consider the given attribute to be printed for an alias.
473   void printAttribute(Attribute attr) override { initializer.visit(attr); }
474   void printAttributeWithoutType(Attribute attr) override {
475     printAttribute(attr);
476   }
477 
478   /// Print the given set of attributes with names not included within
479   /// 'elidedAttrs'.
480   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
481                              ArrayRef<StringRef> elidedAttrs = {}) override {
482     if (attrs.empty())
483       return;
484     if (elidedAttrs.empty()) {
485       for (const NamedAttribute &attr : attrs)
486         printAttribute(attr.getValue());
487       return;
488     }
489     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
490                                                   elidedAttrs.end());
491     for (const NamedAttribute &attr : attrs)
492       if (!elidedAttrsSet.contains(attr.getName().strref()))
493         printAttribute(attr.getValue());
494   }
495   void printOptionalAttrDictWithKeyword(
496       ArrayRef<NamedAttribute> attrs,
497       ArrayRef<StringRef> elidedAttrs = {}) override {
498     printOptionalAttrDict(attrs, elidedAttrs);
499   }
500 
501   /// Return a null stream as the output stream, this will ignore any data fed
502   /// to it.
503   raw_ostream &getStream() const override { return os; }
504 
505   /// The following are hooks of `OpAsmPrinter` that are not necessary for
506   /// determining potential aliases.
507   void printFloat(const APFloat &value) override {}
508   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
509   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
510   void printNewline() override {}
511   void printOperand(Value) override {}
512   void printOperand(Value, raw_ostream &os) override {
513     // Users expect the output string to have at least the prefixed % to signal
514     // a value name. To maintain this invariant, emit a name even if it is
515     // guaranteed to go unused.
516     os << "%";
517   }
518   void printKeywordOrString(StringRef) override {}
519   void printSymbolName(StringRef) override {}
520   void printSuccessor(Block *) override {}
521   void printSuccessorAndUseList(Block *, ValueRange) override {}
522   void shadowRegionArgs(Region &, ValueRange) override {}
523 
524   /// The printer flags to use when determining potential aliases.
525   const OpPrintingFlags &printerFlags;
526 
527   /// The initializer to use when identifying aliases.
528   AliasInitializer &initializer;
529 
530   /// A dummy output stream.
531   mutable llvm::raw_null_ostream os;
532 };
533 } // end anonymous namespace
534 
535 /// Sanitize the given name such that it can be used as a valid identifier. If
536 /// the string needs to be modified in any way, the provided buffer is used to
537 /// store the new copy,
538 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
539                                     StringRef allowedPunctChars = "$._-",
540                                     bool allowTrailingDigit = true) {
541   assert(!name.empty() && "Shouldn't have an empty name here");
542 
543   auto copyNameToBuffer = [&] {
544     for (char ch : name) {
545       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
546         buffer.push_back(ch);
547       else if (ch == ' ')
548         buffer.push_back('_');
549       else
550         buffer.append(llvm::utohexstr((unsigned char)ch));
551     }
552   };
553 
554   // Check to see if this name is valid. If it starts with a digit, then it
555   // could conflict with the autogenerated numeric ID's, so add an underscore
556   // prefix to avoid problems.
557   if (isdigit(name[0])) {
558     buffer.push_back('_');
559     copyNameToBuffer();
560     return buffer;
561   }
562 
563   // If the name ends with a trailing digit, add a '_' to avoid potential
564   // conflicts with autogenerated ID's.
565   if (!allowTrailingDigit && isdigit(name.back())) {
566     copyNameToBuffer();
567     buffer.push_back('_');
568     return buffer;
569   }
570 
571   // Check to see that the name consists of only valid identifier characters.
572   for (char ch : name) {
573     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
574       copyNameToBuffer();
575       return buffer;
576     }
577   }
578 
579   // If there are no invalid characters, return the original name.
580   return name;
581 }
582 
583 /// Given a collection of aliases and symbols, initialize a mapping from a
584 /// symbol to a given alias.
585 template <typename T>
586 static void
587 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
588                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
589                   DenseSet<T> *deferrableAliases = nullptr) {
590   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
591       aliasToSymbol.takeVector();
592   llvm::array_pod_sort(aliases.begin(), aliases.end(),
593                        [](const auto *lhs, const auto *rhs) {
594                          return lhs->first.compare(rhs->first);
595                        });
596 
597   for (auto &it : aliases) {
598     // If there is only one instance for this alias, use the name directly.
599     if (it.second.size() == 1) {
600       T symbol = it.second.front();
601       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
602       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
603       continue;
604     }
605     // Otherwise, add the index to the name.
606     for (int i = 0, e = it.second.size(); i < e; ++i) {
607       T symbol = it.second[i];
608       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
609       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
610     }
611   }
612 }
613 
614 void AliasInitializer::initialize(
615     Operation *op, const OpPrintingFlags &printerFlags,
616     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
617     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
618   // Use a dummy printer when walking the IR so that we can collect the
619   // attributes/types that will actually be used during printing when
620   // considering aliases.
621   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
622   aliasPrinter.print(op);
623 
624   // Initialize the aliases sorted by name.
625   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
626   initializeAliases(aliasToType, typeToAlias);
627 }
628 
629 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
630   if (!visitedAttributes.insert(attr).second) {
631     // If this attribute already has an alias and this instance can't be
632     // deferred, make sure that the alias isn't deferred.
633     if (!canBeDeferred)
634       deferrableAttributes.erase(attr);
635     return;
636   }
637 
638   // Try to generate an alias for this attribute.
639   if (succeeded(generateAlias(attr, aliasToAttr))) {
640     if (canBeDeferred)
641       deferrableAttributes.insert(attr);
642     return;
643   }
644 
645   // Check for any sub elements.
646   if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
647     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
648                                         [&](Type type) { visit(type); });
649   }
650 }
651 
652 void AliasInitializer::visit(Type type) {
653   if (!visitedTypes.insert(type).second)
654     return;
655 
656   // Try to generate an alias for this type.
657   if (succeeded(generateAlias(type, aliasToType)))
658     return;
659 
660   // Check for any sub elements.
661   if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
662     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
663                                         [&](Type type) { visit(type); });
664   }
665 }
666 
667 template <typename T>
668 LogicalResult AliasInitializer::generateAlias(
669     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
670   SmallString<32> nameBuffer;
671   for (const auto &interface : interfaces) {
672     OpAsmDialectInterface::AliasResult result =
673         interface.getAlias(symbol, aliasOS);
674     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
675       continue;
676     nameBuffer = std::move(aliasBuffer);
677     assert(!nameBuffer.empty() && "expected valid alias name");
678     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
679       break;
680   }
681 
682   if (nameBuffer.empty())
683     return failure();
684 
685   SmallString<16> tempBuffer;
686   StringRef name =
687       sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
688                          /*allowTrailingDigit=*/false);
689   name = name.copy(aliasAllocator);
690   aliasToSymbol[name].push_back(symbol);
691   return success();
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // AliasState
696 //===----------------------------------------------------------------------===//
697 
698 namespace {
699 /// This class manages the state for type and attribute aliases.
700 class AliasState {
701 public:
702   // Initialize the internal aliases.
703   void
704   initialize(Operation *op, const OpPrintingFlags &printerFlags,
705              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
706 
707   /// Get an alias for the given attribute if it has one and print it in `os`.
708   /// Returns success if an alias was printed, failure otherwise.
709   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
710 
711   /// Get an alias for the given type if it has one and print it in `os`.
712   /// Returns success if an alias was printed, failure otherwise.
713   LogicalResult getAlias(Type ty, raw_ostream &os) const;
714 
715   /// Print all of the referenced aliases that can not be resolved in a deferred
716   /// manner.
717   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
718     printAliases(os, newLine, /*isDeferred=*/false);
719   }
720 
721   /// Print all of the referenced aliases that support deferred resolution.
722   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
723     printAliases(os, newLine, /*isDeferred=*/true);
724   }
725 
726 private:
727   /// Print all of the referenced aliases that support the provided resolution
728   /// behavior.
729   void printAliases(raw_ostream &os, NewLineCounter &newLine,
730                     bool isDeferred) const;
731 
732   /// Mapping between attribute and alias.
733   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
734   /// Mapping between type and alias.
735   llvm::MapVector<Type, SymbolAlias> typeToAlias;
736 
737   /// An allocator used for alias names.
738   llvm::BumpPtrAllocator aliasAllocator;
739 };
740 } // end anonymous namespace
741 
742 void AliasState::initialize(
743     Operation *op, const OpPrintingFlags &printerFlags,
744     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
745   AliasInitializer initializer(interfaces, aliasAllocator);
746   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
747 }
748 
749 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
750   auto it = attrToAlias.find(attr);
751   if (it == attrToAlias.end())
752     return failure();
753   it->second.print(os << '#');
754   return success();
755 }
756 
757 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
758   auto it = typeToAlias.find(ty);
759   if (it == typeToAlias.end())
760     return failure();
761 
762   it->second.print(os << '!');
763   return success();
764 }
765 
766 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
767                               bool isDeferred) const {
768   auto filterFn = [=](const auto &aliasIt) {
769     return aliasIt.second.canBeDeferred() == isDeferred;
770   };
771   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
772     it.second.print(os << '#');
773     os << " = " << it.first << newLine;
774   }
775   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
776     it.second.print(os << '!');
777     os << " = type " << it.first << newLine;
778   }
779 }
780 
781 //===----------------------------------------------------------------------===//
782 // SSANameState
783 //===----------------------------------------------------------------------===//
784 
785 namespace {
786 /// This class manages the state of SSA value names.
787 class SSANameState {
788 public:
789   /// A sentinel value used for values with names set.
790   enum : unsigned { NameSentinel = ~0U };
791 
792   SSANameState(Operation *op, const OpPrintingFlags &printerFlags,
793                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
794 
795   /// Print the SSA identifier for the given value to 'stream'. If
796   /// 'printResultNo' is true, it also presents the result number ('#' number)
797   /// of this value.
798   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
799 
800   /// Return the result indices for each of the result groups registered by this
801   /// operation, or empty if none exist.
802   ArrayRef<int> getOpResultGroups(Operation *op);
803 
804   /// Get the ID for the given block.
805   unsigned getBlockID(Block *block);
806 
807   /// Renumber the arguments for the specified region to the same names as the
808   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
809   /// details.
810   void shadowRegionArgs(Region &region, ValueRange namesToUse);
811 
812 private:
813   /// Number the SSA values within the given IR unit.
814   void numberValuesInRegion(Region &region);
815   void numberValuesInBlock(Block &block);
816   void numberValuesInOp(Operation &op);
817 
818   /// Given a result of an operation 'result', find the result group head
819   /// 'lookupValue' and the result of 'result' within that group in
820   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
821   /// has more than 1 result.
822   void getResultIDAndNumber(OpResult result, Value &lookupValue,
823                             Optional<int> &lookupResultNo) const;
824 
825   /// Set a special value name for the given value.
826   void setValueName(Value value, StringRef name);
827 
828   /// Uniques the given value name within the printer. If the given name
829   /// conflicts, it is automatically renamed.
830   StringRef uniqueValueName(StringRef name);
831 
832   /// This is the value ID for each SSA value. If this returns NameSentinel,
833   /// then the valueID has an entry in valueNames.
834   DenseMap<Value, unsigned> valueIDs;
835   DenseMap<Value, StringRef> valueNames;
836 
837   /// This is a map of operations that contain multiple named result groups,
838   /// i.e. there may be multiple names for the results of the operation. The
839   /// value of this map are the result numbers that start a result group.
840   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
841 
842   /// This is the block ID for each block in the current.
843   DenseMap<Block *, unsigned> blockIDs;
844 
845   /// This keeps track of all of the non-numeric names that are in flight,
846   /// allowing us to check for duplicates.
847   /// Note: the value of the map is unused.
848   llvm::ScopedHashTable<StringRef, char> usedNames;
849   llvm::BumpPtrAllocator usedNameAllocator;
850 
851   /// This is the next value ID to assign in numbering.
852   unsigned nextValueID = 0;
853   /// This is the next ID to assign to a region entry block argument.
854   unsigned nextArgumentID = 0;
855   /// This is the next ID to assign when a name conflict is detected.
856   unsigned nextConflictID = 0;
857 
858   /// These are the printing flags.  They control, eg., whether to print in
859   /// generic form.
860   OpPrintingFlags printerFlags;
861 
862   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
863 };
864 } // end anonymous namespace
865 
866 SSANameState::SSANameState(
867     Operation *op, const OpPrintingFlags &printerFlags,
868     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces)
869     : printerFlags(printerFlags), interfaces(interfaces) {
870   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
871   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
872   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
873 
874   // The naming context includes `nextValueID`, `nextArgumentID`,
875   // `nextConflictID` and `usedNames` scoped HashTable. This information is
876   // carried from the parent region.
877   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
878   using NamingContext =
879       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
880 
881   // Allocator for UsedNamesScopeTy
882   llvm::BumpPtrAllocator allocator;
883 
884   // Add a scope for the top level operation.
885   auto *topLevelNamesScope =
886       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
887 
888   SmallVector<NamingContext, 8> nameContext;
889   for (Region &region : op->getRegions())
890     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
891                                           nextConflictID, topLevelNamesScope));
892 
893   numberValuesInOp(*op);
894 
895   while (!nameContext.empty()) {
896     Region *region;
897     UsedNamesScopeTy *parentScope;
898     std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
899         nameContext.pop_back_val();
900 
901     // When we switch from one subtree to another, pop the scopes(needless)
902     // until the parent scope.
903     while (usedNames.getCurScope() != parentScope) {
904       usedNames.getCurScope()->~UsedNamesScopeTy();
905       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
906              "top level parentScope must be a nullptr");
907     }
908 
909     // Add a scope for the current region.
910     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
911         UsedNamesScopeTy(usedNames);
912 
913     numberValuesInRegion(*region);
914 
915     for (Operation &op : region->getOps())
916       for (Region &region : op.getRegions())
917         nameContext.push_back(std::make_tuple(&region, nextValueID,
918                                               nextArgumentID, nextConflictID,
919                                               curNamesScope));
920   }
921 
922   // Manually remove all the scopes.
923   while (usedNames.getCurScope() != nullptr)
924     usedNames.getCurScope()->~UsedNamesScopeTy();
925 }
926 
927 void SSANameState::printValueID(Value value, bool printResultNo,
928                                 raw_ostream &stream) const {
929   if (!value) {
930     stream << "<<NULL>>";
931     return;
932   }
933 
934   Optional<int> resultNo;
935   auto lookupValue = value;
936 
937   // If this is an operation result, collect the head lookup value of the result
938   // group and the result number of 'result' within that group.
939   if (OpResult result = value.dyn_cast<OpResult>())
940     getResultIDAndNumber(result, lookupValue, resultNo);
941 
942   auto it = valueIDs.find(lookupValue);
943   if (it == valueIDs.end()) {
944     stream << "<<UNKNOWN SSA VALUE>>";
945     return;
946   }
947 
948   stream << '%';
949   if (it->second != NameSentinel) {
950     stream << it->second;
951   } else {
952     auto nameIt = valueNames.find(lookupValue);
953     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
954     stream << nameIt->second;
955   }
956 
957   if (resultNo.hasValue() && printResultNo)
958     stream << '#' << resultNo;
959 }
960 
961 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
962   auto it = opResultGroups.find(op);
963   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
964 }
965 
966 unsigned SSANameState::getBlockID(Block *block) {
967   auto it = blockIDs.find(block);
968   return it != blockIDs.end() ? it->second : NameSentinel;
969 }
970 
971 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
972   assert(!region.empty() && "cannot shadow arguments of an empty region");
973   assert(region.getNumArguments() == namesToUse.size() &&
974          "incorrect number of names passed in");
975   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
976          "only KnownIsolatedFromAbove ops can shadow names");
977 
978   SmallVector<char, 16> nameStr;
979   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
980     auto nameToUse = namesToUse[i];
981     if (nameToUse == nullptr)
982       continue;
983     auto nameToReplace = region.getArgument(i);
984 
985     nameStr.clear();
986     llvm::raw_svector_ostream nameStream(nameStr);
987     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
988 
989     // Entry block arguments should already have a pretty "arg" name.
990     assert(valueIDs[nameToReplace] == NameSentinel);
991 
992     // Use the name without the leading %.
993     auto name = StringRef(nameStream.str()).drop_front();
994 
995     // Overwrite the name.
996     valueNames[nameToReplace] = name.copy(usedNameAllocator);
997   }
998 }
999 
1000 void SSANameState::numberValuesInRegion(Region &region) {
1001   // Number the values within this region in a breadth-first order.
1002   unsigned nextBlockID = 0;
1003   for (auto &block : region) {
1004     // Each block gets a unique ID, and all of the operations within it get
1005     // numbered as well.
1006     blockIDs[&block] = nextBlockID++;
1007     numberValuesInBlock(block);
1008   }
1009 }
1010 
1011 void SSANameState::numberValuesInBlock(Block &block) {
1012   auto setArgNameFn = [&](Value arg, StringRef name) {
1013     assert(!valueIDs.count(arg) && "arg numbered multiple times");
1014     assert(arg.cast<BlockArgument>().getOwner() == &block &&
1015            "arg not defined in 'block'");
1016     setValueName(arg, name);
1017   };
1018 
1019   bool isEntryBlock = block.isEntryBlock();
1020   if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
1021     if (auto *op = block.getParentOp()) {
1022       if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
1023         asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
1024     }
1025   }
1026 
1027   // Number the block arguments. We give entry block arguments a special name
1028   // 'arg'.
1029   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1030   llvm::raw_svector_ostream specialName(specialNameBuffer);
1031   for (auto arg : block.getArguments()) {
1032     if (valueIDs.count(arg))
1033       continue;
1034     if (isEntryBlock) {
1035       specialNameBuffer.resize(strlen("arg"));
1036       specialName << nextArgumentID++;
1037     }
1038     setValueName(arg, specialName.str());
1039   }
1040 
1041   // Number the operations in this block.
1042   for (auto &op : block)
1043     numberValuesInOp(op);
1044 }
1045 
1046 void SSANameState::numberValuesInOp(Operation &op) {
1047   unsigned numResults = op.getNumResults();
1048   if (numResults == 0)
1049     return;
1050   Value resultBegin = op.getResult(0);
1051 
1052   // Function used to set the special result names for the operation.
1053   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1054   auto setResultNameFn = [&](Value result, StringRef name) {
1055     assert(!valueIDs.count(result) && "result numbered multiple times");
1056     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1057     setValueName(result, name);
1058 
1059     // Record the result number for groups not anchored at 0.
1060     if (int resultNo = result.cast<OpResult>().getResultNumber())
1061       resultGroups.push_back(resultNo);
1062   };
1063   if (!printerFlags.shouldPrintGenericOpForm()) {
1064     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
1065       asmInterface.getAsmResultNames(setResultNameFn);
1066     else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
1067       asmInterface->getAsmResultNames(&op, setResultNameFn);
1068   }
1069 
1070   // If the first result wasn't numbered, give it a default number.
1071   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1072     ++nextValueID;
1073 
1074   // If this operation has multiple result groups, mark it.
1075   if (resultGroups.size() != 1) {
1076     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1077     opResultGroups.try_emplace(&op, std::move(resultGroups));
1078   }
1079 }
1080 
1081 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1082                                         Optional<int> &lookupResultNo) const {
1083   Operation *owner = result.getOwner();
1084   if (owner->getNumResults() == 1)
1085     return;
1086   int resultNo = result.getResultNumber();
1087 
1088   // If this operation has multiple result groups, we will need to find the
1089   // one corresponding to this result.
1090   auto resultGroupIt = opResultGroups.find(owner);
1091   if (resultGroupIt == opResultGroups.end()) {
1092     // If not, just use the first result.
1093     lookupResultNo = resultNo;
1094     lookupValue = owner->getResult(0);
1095     return;
1096   }
1097 
1098   // Find the correct index using a binary search, as the groups are ordered.
1099   ArrayRef<int> resultGroups = resultGroupIt->second;
1100   auto it = llvm::upper_bound(resultGroups, resultNo);
1101   int groupResultNo = 0, groupSize = 0;
1102 
1103   // If there are no smaller elements, the last result group is the lookup.
1104   if (it == resultGroups.end()) {
1105     groupResultNo = resultGroups.back();
1106     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1107   } else {
1108     // Otherwise, the previous element is the lookup.
1109     groupResultNo = *std::prev(it);
1110     groupSize = *it - groupResultNo;
1111   }
1112 
1113   // We only record the result number for a group of size greater than 1.
1114   if (groupSize != 1)
1115     lookupResultNo = resultNo - groupResultNo;
1116   lookupValue = owner->getResult(groupResultNo);
1117 }
1118 
1119 void SSANameState::setValueName(Value value, StringRef name) {
1120   // If the name is empty, the value uses the default numbering.
1121   if (name.empty()) {
1122     valueIDs[value] = nextValueID++;
1123     return;
1124   }
1125 
1126   valueIDs[value] = NameSentinel;
1127   valueNames[value] = uniqueValueName(name);
1128 }
1129 
1130 StringRef SSANameState::uniqueValueName(StringRef name) {
1131   SmallString<16> tmpBuffer;
1132   name = sanitizeIdentifier(name, tmpBuffer);
1133 
1134   // Check to see if this name is already unique.
1135   if (!usedNames.count(name)) {
1136     name = name.copy(usedNameAllocator);
1137   } else {
1138     // Otherwise, we had a conflict - probe until we find a unique name. This
1139     // is guaranteed to terminate (and usually in a single iteration) because it
1140     // generates new names by incrementing nextConflictID.
1141     SmallString<64> probeName(name);
1142     probeName.push_back('_');
1143     while (true) {
1144       probeName += llvm::utostr(nextConflictID++);
1145       if (!usedNames.count(probeName)) {
1146         name = probeName.str().copy(usedNameAllocator);
1147         break;
1148       }
1149       probeName.resize(name.size() + 1);
1150     }
1151   }
1152 
1153   usedNames.insert(name, char());
1154   return name;
1155 }
1156 
1157 //===----------------------------------------------------------------------===//
1158 // AsmState
1159 //===----------------------------------------------------------------------===//
1160 
1161 namespace mlir {
1162 namespace detail {
1163 class AsmStateImpl {
1164 public:
1165   explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1166                         AsmState::LocationMap *locationMap)
1167       : interfaces(op->getContext()), nameState(op, printerFlags, interfaces),
1168         printerFlags(printerFlags), locationMap(locationMap) {}
1169 
1170   /// Initialize the alias state to enable the printing of aliases.
1171   void initializeAliases(Operation *op) {
1172     aliasState.initialize(op, printerFlags, interfaces);
1173   }
1174 
1175   /// Get an instance of the OpAsmDialectInterface for the given dialect, or
1176   /// null if one wasn't registered.
1177   const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
1178     return interfaces.getInterfaceFor(dialect);
1179   }
1180 
1181   /// Get the state used for aliases.
1182   AliasState &getAliasState() { return aliasState; }
1183 
1184   /// Get the state used for SSA names.
1185   SSANameState &getSSANameState() { return nameState; }
1186 
1187   /// Register the location, line and column, within the buffer that the given
1188   /// operation was printed at.
1189   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1190     if (locationMap)
1191       (*locationMap)[op] = std::make_pair(line, col);
1192   }
1193 
1194 private:
1195   /// Collection of OpAsm interfaces implemented in the context.
1196   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1197 
1198   /// The state used for attribute and type aliases.
1199   AliasState aliasState;
1200 
1201   /// The state used for SSA value names.
1202   SSANameState nameState;
1203 
1204   /// Flags that control op output.
1205   OpPrintingFlags printerFlags;
1206 
1207   /// An optional location map to be populated.
1208   AsmState::LocationMap *locationMap;
1209 };
1210 } // end namespace detail
1211 } // end namespace mlir
1212 
1213 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1214                    LocationMap *locationMap)
1215     : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
1216 AsmState::~AsmState() {}
1217 
1218 //===----------------------------------------------------------------------===//
1219 // AsmPrinter::Impl
1220 //===----------------------------------------------------------------------===//
1221 
1222 namespace mlir {
1223 class AsmPrinter::Impl {
1224 public:
1225   Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1226        AsmStateImpl *state = nullptr)
1227       : os(os), printerFlags(flags), state(state) {}
1228   explicit Impl(Impl &other)
1229       : Impl(other.os, other.printerFlags, other.state) {}
1230 
1231   /// Returns the output stream of the printer.
1232   raw_ostream &getStream() { return os; }
1233 
1234   template <typename Container, typename UnaryFunctor>
1235   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
1236     llvm::interleaveComma(c, os, each_fn);
1237   }
1238 
1239   /// This enum describes the different kinds of elision for the type of an
1240   /// attribute when printing it.
1241   enum class AttrTypeElision {
1242     /// The type must not be elided,
1243     Never,
1244     /// The type may be elided when it matches the default used in the parser
1245     /// (for example i64 is the default for integer attributes).
1246     May,
1247     /// The type must be elided.
1248     Must
1249   };
1250 
1251   /// Print the given attribute.
1252   void printAttribute(Attribute attr,
1253                       AttrTypeElision typeElision = AttrTypeElision::Never);
1254 
1255   void printType(Type type);
1256 
1257   /// Print the given location to the stream. If `allowAlias` is true, this
1258   /// allows for the internal location to use an attribute alias.
1259   void printLocation(LocationAttr loc, bool allowAlias = false);
1260 
1261   void printAffineMap(AffineMap map);
1262   void
1263   printAffineExpr(AffineExpr expr,
1264                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1265   void printAffineConstraint(AffineExpr expr, bool isEq);
1266   void printIntegerSet(IntegerSet set);
1267 
1268 protected:
1269   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1270                              ArrayRef<StringRef> elidedAttrs = {},
1271                              bool withKeyword = false);
1272   void printNamedAttribute(NamedAttribute attr);
1273   void printTrailingLocation(Location loc, bool allowAlias = true);
1274   void printLocationInternal(LocationAttr loc, bool pretty = false);
1275 
1276   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1277   /// used instead of individual elements when the elements attr is large.
1278   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1279 
1280   /// Print a dense string elements attribute.
1281   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1282 
1283   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1284   /// used instead of individual elements when the elements attr is large.
1285   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1286                                      bool allowHex);
1287 
1288   void printDialectAttribute(Attribute attr);
1289   void printDialectType(Type type);
1290 
1291   /// This enum is used to represent the binding strength of the enclosing
1292   /// context that an AffineExprStorage is being printed in, so we can
1293   /// intelligently produce parens.
1294   enum class BindingStrength {
1295     Weak,   // + and -
1296     Strong, // All other binary operators.
1297   };
1298   void printAffineExprInternal(
1299       AffineExpr expr, BindingStrength enclosingTightness,
1300       function_ref<void(unsigned, bool)> printValueName = nullptr);
1301 
1302   /// The output stream for the printer.
1303   raw_ostream &os;
1304 
1305   /// A set of flags to control the printer's behavior.
1306   OpPrintingFlags printerFlags;
1307 
1308   /// An optional printer state for the module.
1309   AsmStateImpl *state;
1310 
1311   /// A tracker for the number of new lines emitted during printing.
1312   NewLineCounter newLine;
1313 };
1314 } // namespace mlir
1315 
1316 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1317   // Check to see if we are printing debug information.
1318   if (!printerFlags.shouldPrintDebugInfo())
1319     return;
1320 
1321   os << " ";
1322   printLocation(loc, /*allowAlias=*/allowAlias);
1323 }
1324 
1325 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
1326   TypeSwitch<LocationAttr>(loc)
1327       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1328         printLocationInternal(loc.getFallbackLocation(), pretty);
1329       })
1330       .Case<UnknownLoc>([&](UnknownLoc loc) {
1331         if (pretty)
1332           os << "[unknown]";
1333         else
1334           os << "unknown";
1335       })
1336       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1337         if (pretty) {
1338           os << loc.getFilename().getValue();
1339         } else {
1340           os << "\"";
1341           printEscapedString(loc.getFilename(), os);
1342           os << "\"";
1343         }
1344         os << ':' << loc.getLine() << ':' << loc.getColumn();
1345       })
1346       .Case<NameLoc>([&](NameLoc loc) {
1347         os << '\"';
1348         printEscapedString(loc.getName(), os);
1349         os << '\"';
1350 
1351         // Print the child if it isn't unknown.
1352         auto childLoc = loc.getChildLoc();
1353         if (!childLoc.isa<UnknownLoc>()) {
1354           os << '(';
1355           printLocationInternal(childLoc, pretty);
1356           os << ')';
1357         }
1358       })
1359       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1360         Location caller = loc.getCaller();
1361         Location callee = loc.getCallee();
1362         if (!pretty)
1363           os << "callsite(";
1364         printLocationInternal(callee, pretty);
1365         if (pretty) {
1366           if (callee.isa<NameLoc>()) {
1367             if (caller.isa<FileLineColLoc>()) {
1368               os << " at ";
1369             } else {
1370               os << newLine << " at ";
1371             }
1372           } else {
1373             os << newLine << " at ";
1374           }
1375         } else {
1376           os << " at ";
1377         }
1378         printLocationInternal(caller, pretty);
1379         if (!pretty)
1380           os << ")";
1381       })
1382       .Case<FusedLoc>([&](FusedLoc loc) {
1383         if (!pretty)
1384           os << "fused";
1385         if (Attribute metadata = loc.getMetadata())
1386           os << '<' << metadata << '>';
1387         os << '[';
1388         interleave(
1389             loc.getLocations(),
1390             [&](Location loc) { printLocationInternal(loc, pretty); },
1391             [&]() { os << ", "; });
1392         os << ']';
1393       });
1394 }
1395 
1396 /// Print a floating point value in a way that the parser will be able to
1397 /// round-trip losslessly.
1398 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1399   // We would like to output the FP constant value in exponential notation,
1400   // but we cannot do this if doing so will lose precision.  Check here to
1401   // make sure that we only output it in exponential format if we can parse
1402   // the value back and get the same value.
1403   bool isInf = apValue.isInfinity();
1404   bool isNaN = apValue.isNaN();
1405   if (!isInf && !isNaN) {
1406     SmallString<128> strValue;
1407     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1408                      /*TruncateZero=*/false);
1409 
1410     // Check to make sure that the stringized number is not some string like
1411     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1412     // that the string matches the "[-+]?[0-9]" regex.
1413     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1414             ((strValue[0] == '-' || strValue[0] == '+') &&
1415              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1416            "[-+]?[0-9] regex does not match!");
1417 
1418     // Parse back the stringized version and check that the value is equal
1419     // (i.e., there is no precision loss).
1420     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1421       os << strValue;
1422       return;
1423     }
1424 
1425     // If it is not, use the default format of APFloat instead of the
1426     // exponential notation.
1427     strValue.clear();
1428     apValue.toString(strValue);
1429 
1430     // Make sure that we can parse the default form as a float.
1431     if (strValue.str().contains('.')) {
1432       os << strValue;
1433       return;
1434     }
1435   }
1436 
1437   // Print special values in hexadecimal format. The sign bit should be included
1438   // in the literal.
1439   SmallVector<char, 16> str;
1440   APInt apInt = apValue.bitcastToAPInt();
1441   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1442                  /*formatAsCLiteral=*/true);
1443   os << str;
1444 }
1445 
1446 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
1447   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1448     return printLocationInternal(loc, /*pretty=*/true);
1449 
1450   os << "loc(";
1451   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1452     printLocationInternal(loc);
1453   os << ')';
1454 }
1455 
1456 /// Returns true if the given dialect symbol data is simple enough to print in
1457 /// the pretty form, i.e. without the enclosing "".
1458 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1459   // The name must start with an identifier.
1460   if (symName.empty() || !isalpha(symName.front()))
1461     return false;
1462 
1463   // Ignore all the characters that are valid in an identifier in the symbol
1464   // name.
1465   symName = symName.drop_while(
1466       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1467   if (symName.empty())
1468     return true;
1469 
1470   // If we got to an unexpected character, then it must be a <>.  Check those
1471   // recursively.
1472   if (symName.front() != '<' || symName.back() != '>')
1473     return false;
1474 
1475   SmallVector<char, 8> nestedPunctuation;
1476   do {
1477     // If we ran out of characters, then we had a punctuation mismatch.
1478     if (symName.empty())
1479       return false;
1480 
1481     auto c = symName.front();
1482     symName = symName.drop_front();
1483 
1484     switch (c) {
1485     // We never allow null characters. This is an EOF indicator for the lexer
1486     // which we could handle, but isn't important for any known dialect.
1487     case '\0':
1488       return false;
1489     case '<':
1490     case '[':
1491     case '(':
1492     case '{':
1493       nestedPunctuation.push_back(c);
1494       continue;
1495     case '-':
1496       // Treat `->` as a special token.
1497       if (!symName.empty() && symName.front() == '>') {
1498         symName = symName.drop_front();
1499         continue;
1500       }
1501       break;
1502     // Reject types with mismatched brackets.
1503     case '>':
1504       if (nestedPunctuation.pop_back_val() != '<')
1505         return false;
1506       break;
1507     case ']':
1508       if (nestedPunctuation.pop_back_val() != '[')
1509         return false;
1510       break;
1511     case ')':
1512       if (nestedPunctuation.pop_back_val() != '(')
1513         return false;
1514       break;
1515     case '}':
1516       if (nestedPunctuation.pop_back_val() != '{')
1517         return false;
1518       break;
1519     default:
1520       continue;
1521     }
1522 
1523     // We're done when the punctuation is fully matched.
1524   } while (!nestedPunctuation.empty());
1525 
1526   // If there were extra characters, then we failed.
1527   return symName.empty();
1528 }
1529 
1530 /// Print the given dialect symbol to the stream.
1531 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1532                                StringRef dialectName, StringRef symString) {
1533   os << symPrefix << dialectName;
1534 
1535   // If this symbol name is simple enough, print it directly in pretty form,
1536   // otherwise, we print it as an escaped string.
1537   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1538     os << '.' << symString;
1539     return;
1540   }
1541 
1542   os << "<\"";
1543   llvm::printEscapedString(symString, os);
1544   os << "\">";
1545 }
1546 
1547 /// Returns true if the given string can be represented as a bare identifier.
1548 static bool isBareIdentifier(StringRef name) {
1549   // By making this unsigned, the value passed in to isalnum will always be
1550   // in the range 0-255. This is important when building with MSVC because
1551   // its implementation will assert. This situation can arise when dealing
1552   // with UTF-8 multibyte characters.
1553   if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
1554     return false;
1555   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1556     return isalnum(c) || c == '_' || c == '$' || c == '.';
1557   });
1558 }
1559 
1560 /// Print the given string as a keyword, or a quoted and escaped string if it
1561 /// has any special or non-printable characters in it.
1562 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
1563   // If it can be represented as a bare identifier, write it directly.
1564   if (isBareIdentifier(keyword)) {
1565     os << keyword;
1566     return;
1567   }
1568 
1569   // Otherwise, output the keyword wrapped in quotes with proper escaping.
1570   os << "\"";
1571   printEscapedString(keyword, os);
1572   os << '"';
1573 }
1574 
1575 /// Print the given string as a symbol reference. A symbol reference is
1576 /// represented as a string prefixed with '@'. The reference is surrounded with
1577 /// ""'s and escaped if it has any special or non-printable characters in it.
1578 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1579   assert(!symbolRef.empty() && "expected valid symbol reference");
1580   os << '@';
1581   printKeywordOrString(symbolRef, os);
1582 }
1583 
1584 // Print out a valid ElementsAttr that is succinct and can represent any
1585 // potential shape/type, for use when eliding a large ElementsAttr.
1586 //
1587 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1588 // hopefully alert readers to the fact that this has been elided.
1589 //
1590 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1591 // accept the string "elided". The first string must be a registered dialect
1592 // name and the latter must be a hex constant.
1593 static void printElidedElementsAttr(raw_ostream &os) {
1594   os << R"(opaque<"_", "0xDEADBEEF">)";
1595 }
1596 
1597 void AsmPrinter::Impl::printAttribute(Attribute attr,
1598                                       AttrTypeElision typeElision) {
1599   if (!attr) {
1600     os << "<<NULL ATTRIBUTE>>";
1601     return;
1602   }
1603 
1604   // Try to print an alias for this attribute.
1605   if (state && succeeded(state->getAliasState().getAlias(attr, os)))
1606     return;
1607 
1608   if (!isa<BuiltinDialect>(attr.getDialect()))
1609     return printDialectAttribute(attr);
1610 
1611   auto attrType = attr.getType();
1612   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1613     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1614                        opaqueAttr.getAttrData());
1615   } else if (attr.isa<UnitAttr>()) {
1616     os << "unit";
1617     return;
1618   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1619     os << '{';
1620     interleaveComma(dictAttr.getValue(),
1621                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1622     os << '}';
1623 
1624   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1625     if (attrType.isSignlessInteger(1)) {
1626       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1627 
1628       // Boolean integer attributes always elides the type.
1629       return;
1630     }
1631 
1632     // Only print attributes as unsigned if they are explicitly unsigned or are
1633     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1634     // values print as signed.
1635     bool isUnsigned =
1636         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1637     intAttr.getValue().print(os, !isUnsigned);
1638 
1639     // IntegerAttr elides the type if I64.
1640     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1641       return;
1642 
1643   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1644     printFloatValue(floatAttr.getValue(), os);
1645 
1646     // FloatAttr elides the type if F64.
1647     if (typeElision == AttrTypeElision::May && attrType.isF64())
1648       return;
1649 
1650   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1651     os << '"';
1652     printEscapedString(strAttr.getValue(), os);
1653     os << '"';
1654 
1655   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1656     os << '[';
1657     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1658       printAttribute(attr, AttrTypeElision::May);
1659     });
1660     os << ']';
1661 
1662   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1663     os << "affine_map<";
1664     affineMapAttr.getValue().print(os);
1665     os << '>';
1666 
1667     // AffineMap always elides the type.
1668     return;
1669 
1670   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1671     os << "affine_set<";
1672     integerSetAttr.getValue().print(os);
1673     os << '>';
1674 
1675     // IntegerSet always elides the type.
1676     return;
1677 
1678   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1679     printType(typeAttr.getValue());
1680 
1681   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1682     printSymbolReference(refAttr.getRootReference().getValue(), os);
1683     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1684       os << "::";
1685       printSymbolReference(nestedRef.getValue(), os);
1686     }
1687 
1688   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1689     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1690       printElidedElementsAttr(os);
1691     } else {
1692       os << "opaque<" << opaqueAttr.getDialect() << ", \"0x"
1693          << llvm::toHex(opaqueAttr.getValue()) << "\">";
1694     }
1695 
1696   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1697     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1698       printElidedElementsAttr(os);
1699     } else {
1700       os << "dense<";
1701       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1702       os << '>';
1703     }
1704 
1705   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1706     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1707       printElidedElementsAttr(os);
1708     } else {
1709       os << "dense<";
1710       printDenseStringElementsAttr(strEltAttr);
1711       os << '>';
1712     }
1713 
1714   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1715     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1716         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1717       printElidedElementsAttr(os);
1718     } else {
1719       os << "sparse<";
1720       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1721       if (indices.getNumElements() != 0) {
1722         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1723         os << ", ";
1724         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1725       }
1726       os << '>';
1727     }
1728 
1729   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1730     printLocation(locAttr);
1731   }
1732   // Don't print the type if we must elide it, or if it is a None type.
1733   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1734     os << " : ";
1735     printType(attrType);
1736   }
1737 }
1738 
1739 /// Print the integer element of a DenseElementsAttr.
1740 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1741                                  bool isSigned) {
1742   if (value.getBitWidth() == 1)
1743     os << (value.getBoolValue() ? "true" : "false");
1744   else
1745     value.print(os, isSigned);
1746 }
1747 
1748 static void
1749 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1750                            function_ref<void(unsigned)> printEltFn) {
1751   // Special case for 0-d and splat tensors.
1752   if (isSplat)
1753     return printEltFn(0);
1754 
1755   // Special case for degenerate tensors.
1756   auto numElements = type.getNumElements();
1757   if (numElements == 0)
1758     return;
1759 
1760   // We use a mixed-radix counter to iterate through the shape. When we bump a
1761   // non-least-significant digit, we emit a close bracket. When we next emit an
1762   // element we re-open all closed brackets.
1763 
1764   // The mixed-radix counter, with radices in 'shape'.
1765   int64_t rank = type.getRank();
1766   SmallVector<unsigned, 4> counter(rank, 0);
1767   // The number of brackets that have been opened and not closed.
1768   unsigned openBrackets = 0;
1769 
1770   auto shape = type.getShape();
1771   auto bumpCounter = [&] {
1772     // Bump the least significant digit.
1773     ++counter[rank - 1];
1774     // Iterate backwards bubbling back the increment.
1775     for (unsigned i = rank - 1; i > 0; --i)
1776       if (counter[i] >= shape[i]) {
1777         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1778         counter[i] = 0;
1779         ++counter[i - 1];
1780         --openBrackets;
1781         os << ']';
1782       }
1783   };
1784 
1785   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1786     if (idx != 0)
1787       os << ", ";
1788     while (openBrackets++ < rank)
1789       os << '[';
1790     openBrackets = rank;
1791     printEltFn(idx);
1792     bumpCounter();
1793   }
1794   while (openBrackets-- > 0)
1795     os << ']';
1796 }
1797 
1798 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
1799                                               bool allowHex) {
1800   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1801     return printDenseStringElementsAttr(stringAttr);
1802 
1803   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1804                                 allowHex);
1805 }
1806 
1807 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
1808     DenseIntOrFPElementsAttr attr, bool allowHex) {
1809   auto type = attr.getType();
1810   auto elementType = type.getElementType();
1811 
1812   // Check to see if we should format this attribute as a hex string.
1813   auto numElements = type.getNumElements();
1814   if (!attr.isSplat() && allowHex &&
1815       shouldPrintElementsAttrWithHex(numElements)) {
1816     ArrayRef<char> rawData = attr.getRawData();
1817     if (llvm::support::endian::system_endianness() ==
1818         llvm::support::endianness::big) {
1819       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1820       // machines. It is converted here to print in LE format.
1821       SmallVector<char, 64> outDataVec(rawData.size());
1822       MutableArrayRef<char> convRawData(outDataVec);
1823       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1824           rawData, convRawData, type);
1825       os << '"' << "0x"
1826          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1827          << "\"";
1828     } else {
1829       os << '"' << "0x"
1830          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1831     }
1832 
1833     return;
1834   }
1835 
1836   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1837     Type complexElementType = complexTy.getElementType();
1838     // Note: The if and else below had a common lambda function which invoked
1839     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1840     // and hence was replaced.
1841     if (complexElementType.isa<IntegerType>()) {
1842       bool isSigned = !complexElementType.isUnsignedInteger();
1843       auto valueIt = attr.value_begin<std::complex<APInt>>();
1844       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1845         auto complexValue = *(valueIt + index);
1846         os << "(";
1847         printDenseIntElement(complexValue.real(), os, isSigned);
1848         os << ",";
1849         printDenseIntElement(complexValue.imag(), os, isSigned);
1850         os << ")";
1851       });
1852     } else {
1853       auto valueIt = attr.value_begin<std::complex<APFloat>>();
1854       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1855         auto complexValue = *(valueIt + index);
1856         os << "(";
1857         printFloatValue(complexValue.real(), os);
1858         os << ",";
1859         printFloatValue(complexValue.imag(), os);
1860         os << ")";
1861       });
1862     }
1863   } else if (elementType.isIntOrIndex()) {
1864     bool isSigned = !elementType.isUnsignedInteger();
1865     auto valueIt = attr.value_begin<APInt>();
1866     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1867       printDenseIntElement(*(valueIt + index), os, isSigned);
1868     });
1869   } else {
1870     assert(elementType.isa<FloatType>() && "unexpected element type");
1871     auto valueIt = attr.value_begin<APFloat>();
1872     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1873       printFloatValue(*(valueIt + index), os);
1874     });
1875   }
1876 }
1877 
1878 void AsmPrinter::Impl::printDenseStringElementsAttr(
1879     DenseStringElementsAttr attr) {
1880   ArrayRef<StringRef> data = attr.getRawStringData();
1881   auto printFn = [&](unsigned index) {
1882     os << "\"";
1883     printEscapedString(data[index], os);
1884     os << "\"";
1885   };
1886   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1887 }
1888 
1889 void AsmPrinter::Impl::printType(Type type) {
1890   if (!type) {
1891     os << "<<NULL TYPE>>";
1892     return;
1893   }
1894 
1895   // Try to print an alias for this type.
1896   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1897     return;
1898 
1899   TypeSwitch<Type>(type)
1900       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1901         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1902                            opaqueTy.getTypeData());
1903       })
1904       .Case<IndexType>([&](Type) { os << "index"; })
1905       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1906       .Case<Float16Type>([&](Type) { os << "f16"; })
1907       .Case<Float32Type>([&](Type) { os << "f32"; })
1908       .Case<Float64Type>([&](Type) { os << "f64"; })
1909       .Case<Float80Type>([&](Type) { os << "f80"; })
1910       .Case<Float128Type>([&](Type) { os << "f128"; })
1911       .Case<IntegerType>([&](IntegerType integerTy) {
1912         if (integerTy.isSigned())
1913           os << 's';
1914         else if (integerTy.isUnsigned())
1915           os << 'u';
1916         os << 'i' << integerTy.getWidth();
1917       })
1918       .Case<FunctionType>([&](FunctionType funcTy) {
1919         os << '(';
1920         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1921         os << ") -> ";
1922         ArrayRef<Type> results = funcTy.getResults();
1923         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1924           printType(results[0]);
1925         } else {
1926           os << '(';
1927           interleaveComma(results, [&](Type ty) { printType(ty); });
1928           os << ')';
1929         }
1930       })
1931       .Case<VectorType>([&](VectorType vectorTy) {
1932         os << "vector<";
1933         for (int64_t dim : vectorTy.getShape())
1934           os << dim << 'x';
1935         printType(vectorTy.getElementType());
1936         os << '>';
1937       })
1938       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
1939         os << "tensor<";
1940         for (int64_t dim : tensorTy.getShape()) {
1941           if (ShapedType::isDynamic(dim))
1942             os << '?';
1943           else
1944             os << dim;
1945           os << 'x';
1946         }
1947         printType(tensorTy.getElementType());
1948         // Only print the encoding attribute value if set.
1949         if (tensorTy.getEncoding()) {
1950           os << ", ";
1951           printAttribute(tensorTy.getEncoding());
1952         }
1953         os << '>';
1954       })
1955       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
1956         os << "tensor<*x";
1957         printType(tensorTy.getElementType());
1958         os << '>';
1959       })
1960       .Case<MemRefType>([&](MemRefType memrefTy) {
1961         os << "memref<";
1962         for (int64_t dim : memrefTy.getShape()) {
1963           if (ShapedType::isDynamic(dim))
1964             os << '?';
1965           else
1966             os << dim;
1967           os << 'x';
1968         }
1969         printType(memrefTy.getElementType());
1970         if (!memrefTy.getLayout().isIdentity()) {
1971           os << ", ";
1972           printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
1973         }
1974         // Only print the memory space if it is the non-default one.
1975         if (memrefTy.getMemorySpace()) {
1976           os << ", ";
1977           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
1978         }
1979         os << '>';
1980       })
1981       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
1982         os << "memref<*x";
1983         printType(memrefTy.getElementType());
1984         // Only print the memory space if it is the non-default one.
1985         if (memrefTy.getMemorySpace()) {
1986           os << ", ";
1987           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
1988         }
1989         os << '>';
1990       })
1991       .Case<ComplexType>([&](ComplexType complexTy) {
1992         os << "complex<";
1993         printType(complexTy.getElementType());
1994         os << '>';
1995       })
1996       .Case<TupleType>([&](TupleType tupleTy) {
1997         os << "tuple<";
1998         interleaveComma(tupleTy.getTypes(),
1999                         [&](Type type) { printType(type); });
2000         os << '>';
2001       })
2002       .Case<NoneType>([&](Type) { os << "none"; })
2003       .Default([&](Type type) { return printDialectType(type); });
2004 }
2005 
2006 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2007                                              ArrayRef<StringRef> elidedAttrs,
2008                                              bool withKeyword) {
2009   // If there are no attributes, then there is nothing to be done.
2010   if (attrs.empty())
2011     return;
2012 
2013   // Functor used to print a filtered attribute list.
2014   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2015     // Print the 'attributes' keyword if necessary.
2016     if (withKeyword)
2017       os << " attributes";
2018 
2019     // Otherwise, print them all out in braces.
2020     os << " {";
2021     interleaveComma(filteredAttrs,
2022                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
2023     os << '}';
2024   };
2025 
2026   // If no attributes are elided, we can directly print with no filtering.
2027   if (elidedAttrs.empty())
2028     return printFilteredAttributesFn(attrs);
2029 
2030   // Otherwise, filter out any attributes that shouldn't be included.
2031   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2032                                                 elidedAttrs.end());
2033   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2034     return !elidedAttrsSet.contains(attr.getName().strref());
2035   });
2036   if (!filteredAttrs.empty())
2037     printFilteredAttributesFn(filteredAttrs);
2038 }
2039 
2040 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2041   // Print the name without quotes if possible.
2042   ::printKeywordOrString(attr.getName().strref(), os);
2043 
2044   // Pretty printing elides the attribute value for unit attributes.
2045   if (attr.getValue().isa<UnitAttr>())
2046     return;
2047 
2048   os << " = ";
2049   printAttribute(attr.getValue());
2050 }
2051 
2052 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2053   auto &dialect = attr.getDialect();
2054 
2055   // Ask the dialect to serialize the attribute to a string.
2056   std::string attrName;
2057   {
2058     llvm::raw_string_ostream attrNameStr(attrName);
2059     Impl subPrinter(attrNameStr, printerFlags, state);
2060     DialectAsmPrinter printer(subPrinter);
2061     dialect.printAttribute(attr, printer);
2062   }
2063   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2064 }
2065 
2066 void AsmPrinter::Impl::printDialectType(Type type) {
2067   auto &dialect = type.getDialect();
2068 
2069   // Ask the dialect to serialize the type to a string.
2070   std::string typeName;
2071   {
2072     llvm::raw_string_ostream typeNameStr(typeName);
2073     Impl subPrinter(typeNameStr, printerFlags, state);
2074     DialectAsmPrinter printer(subPrinter);
2075     dialect.printType(type, printer);
2076   }
2077   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2078 }
2079 
2080 //===--------------------------------------------------------------------===//
2081 // AsmPrinter
2082 //===--------------------------------------------------------------------===//
2083 
2084 AsmPrinter::~AsmPrinter() {}
2085 
2086 raw_ostream &AsmPrinter::getStream() const {
2087   assert(impl && "expected AsmPrinter::getStream to be overriden");
2088   return impl->getStream();
2089 }
2090 
2091 /// Print the given floating point value in a stablized form.
2092 void AsmPrinter::printFloat(const APFloat &value) {
2093   assert(impl && "expected AsmPrinter::printFloat to be overriden");
2094   printFloatValue(value, impl->getStream());
2095 }
2096 
2097 void AsmPrinter::printType(Type type) {
2098   assert(impl && "expected AsmPrinter::printType to be overriden");
2099   impl->printType(type);
2100 }
2101 
2102 void AsmPrinter::printAttribute(Attribute attr) {
2103   assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2104   impl->printAttribute(attr);
2105 }
2106 
2107 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2108   assert(impl &&
2109          "expected AsmPrinter::printAttributeWithoutType to be overriden");
2110   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2111 }
2112 
2113 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2114   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2115   ::printKeywordOrString(keyword, impl->getStream());
2116 }
2117 
2118 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2119   assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2120   ::printSymbolReference(symbolRef, impl->getStream());
2121 }
2122 
2123 //===----------------------------------------------------------------------===//
2124 // Affine expressions and maps
2125 //===----------------------------------------------------------------------===//
2126 
2127 void AsmPrinter::Impl::printAffineExpr(
2128     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2129   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2130 }
2131 
2132 void AsmPrinter::Impl::printAffineExprInternal(
2133     AffineExpr expr, BindingStrength enclosingTightness,
2134     function_ref<void(unsigned, bool)> printValueName) {
2135   const char *binopSpelling = nullptr;
2136   switch (expr.getKind()) {
2137   case AffineExprKind::SymbolId: {
2138     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2139     if (printValueName)
2140       printValueName(pos, /*isSymbol=*/true);
2141     else
2142       os << 's' << pos;
2143     return;
2144   }
2145   case AffineExprKind::DimId: {
2146     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2147     if (printValueName)
2148       printValueName(pos, /*isSymbol=*/false);
2149     else
2150       os << 'd' << pos;
2151     return;
2152   }
2153   case AffineExprKind::Constant:
2154     os << expr.cast<AffineConstantExpr>().getValue();
2155     return;
2156   case AffineExprKind::Add:
2157     binopSpelling = " + ";
2158     break;
2159   case AffineExprKind::Mul:
2160     binopSpelling = " * ";
2161     break;
2162   case AffineExprKind::FloorDiv:
2163     binopSpelling = " floordiv ";
2164     break;
2165   case AffineExprKind::CeilDiv:
2166     binopSpelling = " ceildiv ";
2167     break;
2168   case AffineExprKind::Mod:
2169     binopSpelling = " mod ";
2170     break;
2171   }
2172 
2173   auto binOp = expr.cast<AffineBinaryOpExpr>();
2174   AffineExpr lhsExpr = binOp.getLHS();
2175   AffineExpr rhsExpr = binOp.getRHS();
2176 
2177   // Handle tightly binding binary operators.
2178   if (binOp.getKind() != AffineExprKind::Add) {
2179     if (enclosingTightness == BindingStrength::Strong)
2180       os << '(';
2181 
2182     // Pretty print multiplication with -1.
2183     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2184     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2185         rhsConst.getValue() == -1) {
2186       os << "-";
2187       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2188       if (enclosingTightness == BindingStrength::Strong)
2189         os << ')';
2190       return;
2191     }
2192 
2193     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2194 
2195     os << binopSpelling;
2196     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2197 
2198     if (enclosingTightness == BindingStrength::Strong)
2199       os << ')';
2200     return;
2201   }
2202 
2203   // Print out special "pretty" forms for add.
2204   if (enclosingTightness == BindingStrength::Strong)
2205     os << '(';
2206 
2207   // Pretty print addition to a product that has a negative operand as a
2208   // subtraction.
2209   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2210     if (rhs.getKind() == AffineExprKind::Mul) {
2211       AffineExpr rrhsExpr = rhs.getRHS();
2212       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2213         if (rrhs.getValue() == -1) {
2214           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2215                                   printValueName);
2216           os << " - ";
2217           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2218             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2219                                     printValueName);
2220           } else {
2221             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2222                                     printValueName);
2223           }
2224 
2225           if (enclosingTightness == BindingStrength::Strong)
2226             os << ')';
2227           return;
2228         }
2229 
2230         if (rrhs.getValue() < -1) {
2231           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2232                                   printValueName);
2233           os << " - ";
2234           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2235                                   printValueName);
2236           os << " * " << -rrhs.getValue();
2237           if (enclosingTightness == BindingStrength::Strong)
2238             os << ')';
2239           return;
2240         }
2241       }
2242     }
2243   }
2244 
2245   // Pretty print addition to a negative number as a subtraction.
2246   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2247     if (rhsConst.getValue() < 0) {
2248       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2249       os << " - " << -rhsConst.getValue();
2250       if (enclosingTightness == BindingStrength::Strong)
2251         os << ')';
2252       return;
2253     }
2254   }
2255 
2256   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2257 
2258   os << " + ";
2259   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2260 
2261   if (enclosingTightness == BindingStrength::Strong)
2262     os << ')';
2263 }
2264 
2265 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2266   printAffineExprInternal(expr, BindingStrength::Weak);
2267   isEq ? os << " == 0" : os << " >= 0";
2268 }
2269 
2270 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2271   // Dimension identifiers.
2272   os << '(';
2273   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2274     os << 'd' << i << ", ";
2275   if (map.getNumDims() >= 1)
2276     os << 'd' << map.getNumDims() - 1;
2277   os << ')';
2278 
2279   // Symbolic identifiers.
2280   if (map.getNumSymbols() != 0) {
2281     os << '[';
2282     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2283       os << 's' << i << ", ";
2284     if (map.getNumSymbols() >= 1)
2285       os << 's' << map.getNumSymbols() - 1;
2286     os << ']';
2287   }
2288 
2289   // Result affine expressions.
2290   os << " -> (";
2291   interleaveComma(map.getResults(),
2292                   [&](AffineExpr expr) { printAffineExpr(expr); });
2293   os << ')';
2294 }
2295 
2296 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
2297   // Dimension identifiers.
2298   os << '(';
2299   for (unsigned i = 1; i < set.getNumDims(); ++i)
2300     os << 'd' << i - 1 << ", ";
2301   if (set.getNumDims() >= 1)
2302     os << 'd' << set.getNumDims() - 1;
2303   os << ')';
2304 
2305   // Symbolic identifiers.
2306   if (set.getNumSymbols() != 0) {
2307     os << '[';
2308     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2309       os << 's' << i << ", ";
2310     if (set.getNumSymbols() >= 1)
2311       os << 's' << set.getNumSymbols() - 1;
2312     os << ']';
2313   }
2314 
2315   // Print constraints.
2316   os << " : (";
2317   int numConstraints = set.getNumConstraints();
2318   for (int i = 1; i < numConstraints; ++i) {
2319     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2320     os << ", ";
2321   }
2322   if (numConstraints >= 1)
2323     printAffineConstraint(set.getConstraint(numConstraints - 1),
2324                           set.isEq(numConstraints - 1));
2325   os << ')';
2326 }
2327 
2328 //===----------------------------------------------------------------------===//
2329 // OperationPrinter
2330 //===----------------------------------------------------------------------===//
2331 
2332 namespace {
2333 /// This class contains the logic for printing operations, regions, and blocks.
2334 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
2335 public:
2336   using Impl = AsmPrinter::Impl;
2337   using Impl::printType;
2338 
2339   explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2340                             AsmStateImpl &state)
2341       : Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
2342 
2343   /// Print the given top-level operation.
2344   void printTopLevelOperation(Operation *op);
2345 
2346   /// Print the given operation with its indent and location.
2347   void print(Operation *op);
2348   /// Print the bare location, not including indentation/location/etc.
2349   void printOperation(Operation *op);
2350   /// Print the given operation in the generic form.
2351   void printGenericOp(Operation *op, bool printOpName) override;
2352 
2353   /// Print the name of the given block.
2354   void printBlockName(Block *block);
2355 
2356   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2357   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2358   /// operation of the block is not printed.
2359   void print(Block *block, bool printBlockArgs = true,
2360              bool printBlockTerminator = true);
2361 
2362   /// Print the ID of the given value, optionally with its result number.
2363   void printValueID(Value value, bool printResultNo = true,
2364                     raw_ostream *streamOverride = nullptr) const;
2365 
2366   //===--------------------------------------------------------------------===//
2367   // OpAsmPrinter methods
2368   //===--------------------------------------------------------------------===//
2369 
2370   /// Print a newline and indent the printer to the start of the current
2371   /// operation.
2372   void printNewline() override {
2373     os << newLine;
2374     os.indent(currentIndent);
2375   }
2376 
2377   /// Print a block argument in the usual format of:
2378   ///   %ssaName : type {attr1=42} loc("here")
2379   /// where location printing is controlled by the standard internal option.
2380   /// You may pass omitType=true to not print a type, and pass an empty
2381   /// attribute list if you don't care for attributes.
2382   void printRegionArgument(BlockArgument arg,
2383                            ArrayRef<NamedAttribute> argAttrs = {},
2384                            bool omitType = false) override;
2385 
2386   /// Print the ID for the given value.
2387   void printOperand(Value value) override { printValueID(value); }
2388   void printOperand(Value value, raw_ostream &os) override {
2389     printValueID(value, /*printResultNo=*/true, &os);
2390   }
2391 
2392   /// Print an optional attribute dictionary with a given set of elided values.
2393   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2394                              ArrayRef<StringRef> elidedAttrs = {}) override {
2395     Impl::printOptionalAttrDict(attrs, elidedAttrs);
2396   }
2397   void printOptionalAttrDictWithKeyword(
2398       ArrayRef<NamedAttribute> attrs,
2399       ArrayRef<StringRef> elidedAttrs = {}) override {
2400     Impl::printOptionalAttrDict(attrs, elidedAttrs,
2401                                 /*withKeyword=*/true);
2402   }
2403 
2404   /// Print the given successor.
2405   void printSuccessor(Block *successor) override;
2406 
2407   /// Print an operation successor with the operands used for the block
2408   /// arguments.
2409   void printSuccessorAndUseList(Block *successor,
2410                                 ValueRange succOperands) override;
2411 
2412   /// Print the given region.
2413   void printRegion(Region &region, bool printEntryBlockArgs,
2414                    bool printBlockTerminators, bool printEmptyBlock) override;
2415 
2416   /// Renumber the arguments for the specified region to the same names as the
2417   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2418   /// operations. If any entry in namesToUse is null, the corresponding
2419   /// argument name is left alone.
2420   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2421     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2422   }
2423 
2424   /// Print the given affine map with the symbol and dimension operands printed
2425   /// inline with the map.
2426   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2427                               ValueRange operands) override;
2428 
2429   /// Print the given affine expression with the symbol and dimension operands
2430   /// printed inline with the expression.
2431   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2432                                ValueRange symOperands) override;
2433 
2434 private:
2435   // Contains the stack of default dialects to use when printing regions.
2436   // A new dialect is pushed to the stack before parsing regions nested under an
2437   // operation implementing `OpAsmOpInterface`, and popped when done. At the
2438   // top-level we start with "builtin" as the default, so that the top-level
2439   // `module` operation prints as-is.
2440   SmallVector<StringRef> defaultDialectStack{"builtin"};
2441 
2442   /// The number of spaces used for indenting nested operations.
2443   const static unsigned indentWidth = 2;
2444 
2445   // This is the current indentation level for nested structures.
2446   unsigned currentIndent = 0;
2447 };
2448 } // end anonymous namespace
2449 
2450 void OperationPrinter::printTopLevelOperation(Operation *op) {
2451   // Output the aliases at the top level that can't be deferred.
2452   state->getAliasState().printNonDeferredAliases(os, newLine);
2453 
2454   // Print the module.
2455   print(op);
2456   os << newLine;
2457 
2458   // Output the aliases at the top level that can be deferred.
2459   state->getAliasState().printDeferredAliases(os, newLine);
2460 }
2461 
2462 /// Print a block argument in the usual format of:
2463 ///   %ssaName : type {attr1=42} loc("here")
2464 /// where location printing is controlled by the standard internal option.
2465 /// You may pass omitType=true to not print a type, and pass an empty
2466 /// attribute list if you don't care for attributes.
2467 void OperationPrinter::printRegionArgument(BlockArgument arg,
2468                                            ArrayRef<NamedAttribute> argAttrs,
2469                                            bool omitType) {
2470   printOperand(arg);
2471   if (!omitType) {
2472     os << ": ";
2473     printType(arg.getType());
2474   }
2475   printOptionalAttrDict(argAttrs);
2476   // TODO: We should allow location aliases on block arguments.
2477   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2478 }
2479 
2480 void OperationPrinter::print(Operation *op) {
2481   // Track the location of this operation.
2482   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2483 
2484   os.indent(currentIndent);
2485   printOperation(op);
2486   printTrailingLocation(op->getLoc());
2487 }
2488 
2489 void OperationPrinter::printOperation(Operation *op) {
2490   if (size_t numResults = op->getNumResults()) {
2491     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2492       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2493       if (resultCount > 1)
2494         os << ':' << resultCount;
2495     };
2496 
2497     // Check to see if this operation has multiple result groups.
2498     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2499     if (!resultGroups.empty()) {
2500       // Interleave the groups excluding the last one, this one will be handled
2501       // separately.
2502       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2503         printResultGroup(resultGroups[i],
2504                          resultGroups[i + 1] - resultGroups[i]);
2505       });
2506       os << ", ";
2507       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2508 
2509     } else {
2510       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2511     }
2512 
2513     os << " = ";
2514   }
2515 
2516   // If requested, always print the generic form.
2517   if (!printerFlags.shouldPrintGenericOpForm()) {
2518     // Check to see if this is a known operation. If so, use the registered
2519     // custom printer hook.
2520     if (auto opInfo = op->getRegisteredInfo()) {
2521       opInfo->printAssembly(op, *this, defaultDialectStack.back());
2522       return;
2523     }
2524     // Otherwise try to dispatch to the dialect, if available.
2525     if (Dialect *dialect = op->getDialect()) {
2526       if (auto opPrinter = dialect->getOperationPrinter(op)) {
2527         // Print the op name first.
2528         StringRef name = op->getName().getStringRef();
2529         name.consume_front((defaultDialectStack.back() + ".").str());
2530         printEscapedString(name, os);
2531         // Print the rest of the op now.
2532         opPrinter(op, *this);
2533         return;
2534       }
2535     }
2536   }
2537 
2538   // Otherwise print with the generic assembly form.
2539   printGenericOp(op, /*printOpName=*/true);
2540 }
2541 
2542 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
2543   if (printOpName) {
2544     os << '"';
2545     printEscapedString(op->getName().getStringRef(), os);
2546     os << '"';
2547   }
2548   os << '(';
2549   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2550   os << ')';
2551 
2552   // For terminators, print the list of successors and their operands.
2553   if (op->getNumSuccessors() != 0) {
2554     os << '[';
2555     interleaveComma(op->getSuccessors(),
2556                     [&](Block *successor) { printBlockName(successor); });
2557     os << ']';
2558   }
2559 
2560   // Print regions.
2561   if (op->getNumRegions() != 0) {
2562     os << " (";
2563     interleaveComma(op->getRegions(), [&](Region &region) {
2564       printRegion(region, /*printEntryBlockArgs=*/true,
2565                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2566     });
2567     os << ')';
2568   }
2569 
2570   auto attrs = op->getAttrs();
2571   printOptionalAttrDict(attrs);
2572 
2573   // Print the type signature of the operation.
2574   os << " : ";
2575   printFunctionalType(op);
2576 }
2577 
2578 void OperationPrinter::printBlockName(Block *block) {
2579   auto id = state->getSSANameState().getBlockID(block);
2580   if (id != SSANameState::NameSentinel)
2581     os << "^bb" << id;
2582   else
2583     os << "^INVALIDBLOCK";
2584 }
2585 
2586 void OperationPrinter::print(Block *block, bool printBlockArgs,
2587                              bool printBlockTerminator) {
2588   // Print the block label and argument list if requested.
2589   if (printBlockArgs) {
2590     os.indent(currentIndent);
2591     printBlockName(block);
2592 
2593     // Print the argument list if non-empty.
2594     if (!block->args_empty()) {
2595       os << '(';
2596       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2597         printValueID(arg);
2598         os << ": ";
2599         printType(arg.getType());
2600         // TODO: We should allow location aliases on block arguments.
2601         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2602       });
2603       os << ')';
2604     }
2605     os << ':';
2606 
2607     // Print out some context information about the predecessors of this block.
2608     if (!block->getParent()) {
2609       os << "  // block is not in a region!";
2610     } else if (block->hasNoPredecessors()) {
2611       os << "  // no predecessors";
2612     } else if (auto *pred = block->getSinglePredecessor()) {
2613       os << "  // pred: ";
2614       printBlockName(pred);
2615     } else {
2616       // We want to print the predecessors in increasing numeric order, not in
2617       // whatever order the use-list is in, so gather and sort them.
2618       SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2619       for (auto *pred : block->getPredecessors())
2620         predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2621       llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2622 
2623       os << "  // " << predIDs.size() << " preds: ";
2624 
2625       interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2626         printBlockName(pred.second);
2627       });
2628     }
2629     os << newLine;
2630   }
2631 
2632   currentIndent += indentWidth;
2633   bool hasTerminator =
2634       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
2635   auto range = llvm::make_range(
2636       block->begin(),
2637       std::prev(block->end(),
2638                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
2639   for (auto &op : range) {
2640     print(&op);
2641     os << newLine;
2642   }
2643   currentIndent -= indentWidth;
2644 }
2645 
2646 void OperationPrinter::printValueID(Value value, bool printResultNo,
2647                                     raw_ostream *streamOverride) const {
2648   state->getSSANameState().printValueID(value, printResultNo,
2649                                         streamOverride ? *streamOverride : os);
2650 }
2651 
2652 void OperationPrinter::printSuccessor(Block *successor) {
2653   printBlockName(successor);
2654 }
2655 
2656 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2657                                                 ValueRange succOperands) {
2658   printBlockName(successor);
2659   if (succOperands.empty())
2660     return;
2661 
2662   os << '(';
2663   interleaveComma(succOperands,
2664                   [this](Value operand) { printValueID(operand); });
2665   os << " : ";
2666   interleaveComma(succOperands,
2667                   [this](Value operand) { printType(operand.getType()); });
2668   os << ')';
2669 }
2670 
2671 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2672                                    bool printBlockTerminators,
2673                                    bool printEmptyBlock) {
2674   os << " {" << newLine;
2675   if (!region.empty()) {
2676     auto restoreDefaultDialect =
2677         llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
2678     if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
2679       defaultDialectStack.push_back(iface.getDefaultDialect());
2680     else
2681       defaultDialectStack.push_back("");
2682 
2683     auto *entryBlock = &region.front();
2684     // Force printing the block header if printEmptyBlock is set and the block
2685     // is empty or if printEntryBlockArgs is set and there are arguments to
2686     // print.
2687     bool shouldAlwaysPrintBlockHeader =
2688         (printEmptyBlock && entryBlock->empty()) ||
2689         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
2690     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
2691     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2692       print(&b);
2693   }
2694   os.indent(currentIndent) << "}";
2695 }
2696 
2697 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2698                                               ValueRange operands) {
2699   AffineMap map = mapAttr.getValue();
2700   unsigned numDims = map.getNumDims();
2701   auto printValueName = [&](unsigned pos, bool isSymbol) {
2702     unsigned index = isSymbol ? numDims + pos : pos;
2703     assert(index < operands.size());
2704     if (isSymbol)
2705       os << "symbol(";
2706     printValueID(operands[index]);
2707     if (isSymbol)
2708       os << ')';
2709   };
2710 
2711   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2712     printAffineExpr(expr, printValueName);
2713   });
2714 }
2715 
2716 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
2717                                                ValueRange dimOperands,
2718                                                ValueRange symOperands) {
2719   auto printValueName = [&](unsigned pos, bool isSymbol) {
2720     if (!isSymbol)
2721       return printValueID(dimOperands[pos]);
2722     os << "symbol(";
2723     printValueID(symOperands[pos]);
2724     os << ')';
2725   };
2726   printAffineExpr(expr, printValueName);
2727 }
2728 
2729 //===----------------------------------------------------------------------===//
2730 // print and dump methods
2731 //===----------------------------------------------------------------------===//
2732 
2733 void Attribute::print(raw_ostream &os) const {
2734   AsmPrinter::Impl(os).printAttribute(*this);
2735 }
2736 
2737 void Attribute::dump() const {
2738   print(llvm::errs());
2739   llvm::errs() << "\n";
2740 }
2741 
2742 void Type::print(raw_ostream &os) const {
2743   AsmPrinter::Impl(os).printType(*this);
2744 }
2745 
2746 void Type::dump() const { print(llvm::errs()); }
2747 
2748 void AffineMap::dump() const {
2749   print(llvm::errs());
2750   llvm::errs() << "\n";
2751 }
2752 
2753 void IntegerSet::dump() const {
2754   print(llvm::errs());
2755   llvm::errs() << "\n";
2756 }
2757 
2758 void AffineExpr::print(raw_ostream &os) const {
2759   if (!expr) {
2760     os << "<<NULL AFFINE EXPR>>";
2761     return;
2762   }
2763   AsmPrinter::Impl(os).printAffineExpr(*this);
2764 }
2765 
2766 void AffineExpr::dump() const {
2767   print(llvm::errs());
2768   llvm::errs() << "\n";
2769 }
2770 
2771 void AffineMap::print(raw_ostream &os) const {
2772   if (!map) {
2773     os << "<<NULL AFFINE MAP>>";
2774     return;
2775   }
2776   AsmPrinter::Impl(os).printAffineMap(*this);
2777 }
2778 
2779 void IntegerSet::print(raw_ostream &os) const {
2780   AsmPrinter::Impl(os).printIntegerSet(*this);
2781 }
2782 
2783 void Value::print(raw_ostream &os) {
2784   if (auto *op = getDefiningOp())
2785     return op->print(os);
2786   // TODO: Improve BlockArgument print'ing.
2787   BlockArgument arg = this->cast<BlockArgument>();
2788   os << "<block argument> of type '" << arg.getType()
2789      << "' at index: " << arg.getArgNumber();
2790 }
2791 void Value::print(raw_ostream &os, AsmState &state) {
2792   if (auto *op = getDefiningOp())
2793     return op->print(os, state);
2794 
2795   // TODO: Improve BlockArgument print'ing.
2796   BlockArgument arg = this->cast<BlockArgument>();
2797   os << "<block argument> of type '" << arg.getType()
2798      << "' at index: " << arg.getArgNumber();
2799 }
2800 
2801 void Value::dump() {
2802   print(llvm::errs());
2803   llvm::errs() << "\n";
2804 }
2805 
2806 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2807   // TODO: This doesn't necessarily capture all potential cases.
2808   // Currently, region arguments can be shadowed when printing the main
2809   // operation. If the IR hasn't been printed, this will produce the old SSA
2810   // name and not the shadowed name.
2811   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2812                                                  os);
2813 }
2814 
2815 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
2816   // If this is a top level operation, we also print aliases.
2817   if (!getParent() && !printerFlags.shouldUseLocalScope()) {
2818     AsmState state(this, printerFlags);
2819     state.getImpl().initializeAliases(this);
2820     print(os, state, printerFlags);
2821     return;
2822   }
2823 
2824   // Find the operation to number from based upon the provided flags.
2825   Operation *op = this;
2826   bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
2827   do {
2828     // If we are printing local scope, stop at the first operation that is
2829     // isolated from above.
2830     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
2831       break;
2832 
2833     // Otherwise, traverse up to the next parent.
2834     Operation *parentOp = op->getParentOp();
2835     if (!parentOp)
2836       break;
2837     op = parentOp;
2838   } while (true);
2839 
2840   AsmState state(op, printerFlags);
2841   print(os, state, printerFlags);
2842 }
2843 void Operation::print(raw_ostream &os, AsmState &state,
2844                       const OpPrintingFlags &flags) {
2845   OperationPrinter printer(os, flags, state.getImpl());
2846   if (!getParent() && !flags.shouldUseLocalScope())
2847     printer.printTopLevelOperation(this);
2848   else
2849     printer.print(this);
2850 }
2851 
2852 void Operation::dump() {
2853   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2854   llvm::errs() << "\n";
2855 }
2856 
2857 void Block::print(raw_ostream &os) {
2858   Operation *parentOp = getParentOp();
2859   if (!parentOp) {
2860     os << "<<UNLINKED BLOCK>>\n";
2861     return;
2862   }
2863   // Get the top-level op.
2864   while (auto *nextOp = parentOp->getParentOp())
2865     parentOp = nextOp;
2866 
2867   AsmState state(parentOp);
2868   print(os, state);
2869 }
2870 void Block::print(raw_ostream &os, AsmState &state) {
2871   OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2872 }
2873 
2874 void Block::dump() { print(llvm::errs()); }
2875 
2876 /// Print out the name of the block without printing its body.
2877 void Block::printAsOperand(raw_ostream &os, bool printType) {
2878   Operation *parentOp = getParentOp();
2879   if (!parentOp) {
2880     os << "<<UNLINKED BLOCK>>\n";
2881     return;
2882   }
2883   AsmState state(parentOp);
2884   printAsOperand(os, state);
2885 }
2886 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2887   OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2888   printer.printBlockName(this);
2889 }
2890