1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/SubElementInterfaces.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/ScopeExit.h"
33 #include "llvm/ADT/ScopedHashTable.h"
34 #include "llvm/ADT/SetVector.h"
35 #include "llvm/ADT/SmallString.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Endian.h"
41 #include "llvm/Support/Regex.h"
42 #include "llvm/Support/SaveAndRestore.h"
43 
44 #include <tuple>
45 
46 using namespace mlir;
47 using namespace mlir::detail;
48 
49 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
50 
51 void OperationName::dump() const { print(llvm::errs()); }
52 
53 //===--------------------------------------------------------------------===//
54 // AsmParser
55 //===--------------------------------------------------------------------===//
56 
57 AsmParser::~AsmParser() = default;
58 DialectAsmParser::~DialectAsmParser() = default;
59 OpAsmParser::~OpAsmParser() = default;
60 
61 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
62 
63 //===----------------------------------------------------------------------===//
64 // DialectAsmPrinter
65 //===----------------------------------------------------------------------===//
66 
67 DialectAsmPrinter::~DialectAsmPrinter() = default;
68 
69 //===----------------------------------------------------------------------===//
70 // OpAsmPrinter
71 //===----------------------------------------------------------------------===//
72 
73 OpAsmPrinter::~OpAsmPrinter() = default;
74 
75 void OpAsmPrinter::printFunctionalType(Operation *op) {
76   auto &os = getStream();
77   os << '(';
78   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
79     // Print the types of null values as <<NULL TYPE>>.
80     *this << (operand ? operand.getType() : Type());
81   });
82   os << ") -> ";
83 
84   // Print the result list.  We don't parenthesize single result types unless
85   // it is a function (avoiding a grammar ambiguity).
86   bool wrapped = op->getNumResults() != 1;
87   if (!wrapped && op->getResult(0).getType() &&
88       op->getResult(0).getType().isa<FunctionType>())
89     wrapped = true;
90 
91   if (wrapped)
92     os << '(';
93 
94   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
95     // Print the types of null values as <<NULL TYPE>>.
96     *this << (result ? result.getType() : Type());
97   });
98 
99   if (wrapped)
100     os << ')';
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // Operation OpAsm interface.
105 //===----------------------------------------------------------------------===//
106 
107 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
108 #include "mlir/IR/OpAsmInterface.cpp.inc"
109 
110 //===----------------------------------------------------------------------===//
111 // OpPrintingFlags
112 //===----------------------------------------------------------------------===//
113 
114 namespace {
115 /// This struct contains command line options that can be used to initialize
116 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
117 /// for global command line options.
118 struct AsmPrinterOptions {
119   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
120       "mlir-print-elementsattrs-with-hex-if-larger",
121       llvm::cl::desc(
122           "Print DenseElementsAttrs with a hex string that have "
123           "more elements than the given upper limit (use -1 to disable)")};
124 
125   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
126       "mlir-elide-elementsattrs-if-larger",
127       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
128                      "more elements than the given upper limit")};
129 
130   llvm::cl::opt<bool> printDebugInfoOpt{
131       "mlir-print-debuginfo", llvm::cl::init(false),
132       llvm::cl::desc("Print debug info in MLIR output")};
133 
134   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
135       "mlir-pretty-debuginfo", llvm::cl::init(false),
136       llvm::cl::desc("Print pretty debug info in MLIR output")};
137 
138   // Use the generic op output form in the operation printer even if the custom
139   // form is defined.
140   llvm::cl::opt<bool> printGenericOpFormOpt{
141       "mlir-print-op-generic", llvm::cl::init(false),
142       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
143 
144   llvm::cl::opt<bool> printLocalScopeOpt{
145       "mlir-print-local-scope", llvm::cl::init(false),
146       llvm::cl::desc("Print assuming in local scope by default"),
147       llvm::cl::Hidden};
148 };
149 } // namespace
150 
151 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
152 
153 /// Register a set of useful command-line options that can be used to configure
154 /// various flags within the AsmPrinter.
155 void mlir::registerAsmPrinterCLOptions() {
156   // Make sure that the options struct has been initialized.
157   *clOptions;
158 }
159 
160 /// Initialize the printing flags with default supplied by the cl::opts above.
161 OpPrintingFlags::OpPrintingFlags()
162     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
163       printGenericOpFormFlag(false), printLocalScope(false) {
164   // Initialize based upon command line options, if they are available.
165   if (!clOptions.isConstructed())
166     return;
167   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
168     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
169   printDebugInfoFlag = clOptions->printDebugInfoOpt;
170   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
171   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
172   printLocalScope = clOptions->printLocalScopeOpt;
173 }
174 
175 /// Enable the elision of large elements attributes, by printing a '...'
176 /// instead of the element data, when the number of elements is greater than
177 /// `largeElementLimit`. Note: The IR generated with this option is not
178 /// parsable.
179 OpPrintingFlags &
180 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
181   elementsAttrElementLimit = largeElementLimit;
182   return *this;
183 }
184 
185 /// Enable printing of debug information. If 'prettyForm' is set to true,
186 /// debug information is printed in a more readable 'pretty' form.
187 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
188   printDebugInfoFlag = true;
189   printDebugInfoPrettyFormFlag = prettyForm;
190   return *this;
191 }
192 
193 /// Always print operations in the generic form.
194 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
195   printGenericOpFormFlag = true;
196   return *this;
197 }
198 
199 /// Use local scope when printing the operation. This allows for using the
200 /// printer in a more localized and thread-safe setting, but may not necessarily
201 /// be identical of what the IR will look like when dumping the full module.
202 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
203   printLocalScope = true;
204   return *this;
205 }
206 
207 /// Return if the given ElementsAttr should be elided.
208 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
209   return elementsAttrElementLimit.hasValue() &&
210          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
211          !attr.isa<SplatElementsAttr>();
212 }
213 
214 /// Return the size limit for printing large ElementsAttr.
215 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
216   return elementsAttrElementLimit;
217 }
218 
219 /// Return if debug information should be printed.
220 bool OpPrintingFlags::shouldPrintDebugInfo() const {
221   return printDebugInfoFlag;
222 }
223 
224 /// Return if debug information should be printed in the pretty form.
225 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
226   return printDebugInfoPrettyFormFlag;
227 }
228 
229 /// Return if operations should be printed in the generic form.
230 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
231   return printGenericOpFormFlag;
232 }
233 
234 /// Return if the printer should use local scope when dumping the IR.
235 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
236 
237 /// Returns true if an ElementsAttr with the given number of elements should be
238 /// printed with hex.
239 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
240   // Check to see if a command line option was provided for the limit.
241   if (clOptions.isConstructed()) {
242     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
243       // -1 is used to disable hex printing.
244       if (clOptions->printElementsAttrWithHexIfLarger == -1)
245         return false;
246       return numElements > clOptions->printElementsAttrWithHexIfLarger;
247     }
248   }
249 
250   // Otherwise, default to printing with hex if the number of elements is >100.
251   return numElements > 100;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // NewLineCounter
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 /// This class is a simple formatter that emits a new line when inputted into a
260 /// stream, that enables counting the number of newlines emitted. This class
261 /// should be used whenever emitting newlines in the printer.
262 struct NewLineCounter {
263   unsigned curLine = 1;
264 };
265 
266 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
267   ++newLine.curLine;
268   return os << '\n';
269 }
270 } // namespace
271 
272 //===----------------------------------------------------------------------===//
273 // AliasInitializer
274 //===----------------------------------------------------------------------===//
275 
276 namespace {
277 /// This class represents a specific instance of a symbol Alias.
278 class SymbolAlias {
279 public:
280   SymbolAlias(StringRef name, bool isDeferrable)
281       : name(name), suffixIndex(0), hasSuffixIndex(false),
282         isDeferrable(isDeferrable) {}
283   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
284       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
285         isDeferrable(isDeferrable) {}
286 
287   /// Print this alias to the given stream.
288   void print(raw_ostream &os) const {
289     os << name;
290     if (hasSuffixIndex)
291       os << suffixIndex;
292   }
293 
294   /// Returns true if this alias supports deferred resolution when parsing.
295   bool canBeDeferred() const { return isDeferrable; }
296 
297 private:
298   /// The main name of the alias.
299   StringRef name;
300   /// The optional suffix index of the alias, if multiple aliases had the same
301   /// name.
302   uint32_t suffixIndex : 30;
303   /// A flag indicating whether this alias has a suffix or not.
304   bool hasSuffixIndex : 1;
305   /// A flag indicating whether this alias may be deferred or not.
306   bool isDeferrable : 1;
307 };
308 
309 /// This class represents a utility that initializes the set of attribute and
310 /// type aliases, without the need to store the extra information within the
311 /// main AliasState class or pass it around via function arguments.
312 class AliasInitializer {
313 public:
314   AliasInitializer(
315       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
316       llvm::BumpPtrAllocator &aliasAllocator)
317       : interfaces(interfaces), aliasAllocator(aliasAllocator),
318         aliasOS(aliasBuffer) {}
319 
320   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
321                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
322                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
323 
324   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
325   /// set to true if the originator of this attribute can resolve the alias
326   /// after parsing has completed (e.g. in the case of operation locations).
327   void visit(Attribute attr, bool canBeDeferred = false);
328 
329   /// Visit the given type to see if it has an alias.
330   void visit(Type type);
331 
332 private:
333   /// Try to generate an alias for the provided symbol. If an alias is
334   /// generated, the provided alias mapping and reverse mapping are updated.
335   /// Returns success if an alias was generated, failure otherwise.
336   template <typename T>
337   LogicalResult
338   generateAlias(T symbol,
339                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
340 
341   /// The set of asm interfaces within the context.
342   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
343 
344   /// Mapping between an alias and the set of symbols mapped to it.
345   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
346   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
347 
348   /// An allocator used for alias names.
349   llvm::BumpPtrAllocator &aliasAllocator;
350 
351   /// The set of visited attributes.
352   DenseSet<Attribute> visitedAttributes;
353 
354   /// The set of attributes that have aliases *and* can be deferred.
355   DenseSet<Attribute> deferrableAttributes;
356 
357   /// The set of visited types.
358   DenseSet<Type> visitedTypes;
359 
360   /// Storage and stream used when generating an alias.
361   SmallString<32> aliasBuffer;
362   llvm::raw_svector_ostream aliasOS;
363 };
364 
365 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
366 /// and merely collects the attributes and types that *would* be printed in a
367 /// normal print invocation so that we can generate proper aliases. This allows
368 /// for us to generate aliases only for the attributes and types that would be
369 /// in the output, and trims down unnecessary output.
370 class DummyAliasOperationPrinter : private OpAsmPrinter {
371 public:
372   explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
373                                       AliasInitializer &initializer)
374       : printerFlags(printerFlags), initializer(initializer) {}
375 
376   /// Print the given operation.
377   void print(Operation *op) {
378     // Visit the operation location.
379     if (printerFlags.shouldPrintDebugInfo())
380       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
381 
382     // If requested, always print the generic form.
383     if (!printerFlags.shouldPrintGenericOpForm()) {
384       // Check to see if this is a known operation.  If so, use the registered
385       // custom printer hook.
386       if (auto opInfo = op->getRegisteredInfo()) {
387         opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
388         return;
389       }
390     }
391 
392     // Otherwise print with the generic assembly form.
393     printGenericOp(op);
394   }
395 
396 private:
397   /// Print the given operation in the generic form.
398   void printGenericOp(Operation *op, bool printOpName = true) override {
399     // Consider nested operations for aliases.
400     if (op->getNumRegions() != 0) {
401       for (Region &region : op->getRegions())
402         printRegion(region, /*printEntryBlockArgs=*/true,
403                     /*printBlockTerminators=*/true);
404     }
405 
406     // Visit all the types used in the operation.
407     for (Type type : op->getOperandTypes())
408       printType(type);
409     for (Type type : op->getResultTypes())
410       printType(type);
411 
412     // Consider the attributes of the operation for aliases.
413     for (const NamedAttribute &attr : op->getAttrs())
414       printAttribute(attr.getValue());
415   }
416 
417   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
418   /// block are not printed. If 'printBlockTerminator' is false, the terminator
419   /// operation of the block is not printed.
420   void print(Block *block, bool printBlockArgs = true,
421              bool printBlockTerminator = true) {
422     // Consider the types of the block arguments for aliases if 'printBlockArgs'
423     // is set to true.
424     if (printBlockArgs) {
425       for (BlockArgument arg : block->getArguments()) {
426         printType(arg.getType());
427 
428         // Visit the argument location.
429         if (printerFlags.shouldPrintDebugInfo())
430           // TODO: Allow deferring argument locations.
431           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
432       }
433     }
434 
435     // Consider the operations within this block, ignoring the terminator if
436     // requested.
437     bool hasTerminator =
438         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
439     auto range = llvm::make_range(
440         block->begin(),
441         std::prev(block->end(),
442                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
443     for (Operation &op : range)
444       print(&op);
445   }
446 
447   /// Print the given region.
448   void printRegion(Region &region, bool printEntryBlockArgs,
449                    bool printBlockTerminators,
450                    bool printEmptyBlock = false) override {
451     if (region.empty())
452       return;
453 
454     auto *entryBlock = &region.front();
455     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
456     for (Block &b : llvm::drop_begin(region, 1))
457       print(&b);
458   }
459 
460   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
461                            bool omitType) override {
462     printType(arg.getType());
463     // Visit the argument location.
464     if (printerFlags.shouldPrintDebugInfo())
465       // TODO: Allow deferring argument locations.
466       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
467   }
468 
469   /// Consider the given type to be printed for an alias.
470   void printType(Type type) override { initializer.visit(type); }
471 
472   /// Consider the given attribute to be printed for an alias.
473   void printAttribute(Attribute attr) override { initializer.visit(attr); }
474   void printAttributeWithoutType(Attribute attr) override {
475     printAttribute(attr);
476   }
477   LogicalResult printAlias(Attribute attr) override {
478     initializer.visit(attr);
479     return success();
480   }
481   LogicalResult printAlias(Type type) override {
482     initializer.visit(type);
483     return success();
484   }
485 
486   /// Print the given set of attributes with names not included within
487   /// 'elidedAttrs'.
488   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
489                              ArrayRef<StringRef> elidedAttrs = {}) override {
490     if (attrs.empty())
491       return;
492     if (elidedAttrs.empty()) {
493       for (const NamedAttribute &attr : attrs)
494         printAttribute(attr.getValue());
495       return;
496     }
497     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
498                                                   elidedAttrs.end());
499     for (const NamedAttribute &attr : attrs)
500       if (!elidedAttrsSet.contains(attr.getName().strref()))
501         printAttribute(attr.getValue());
502   }
503   void printOptionalAttrDictWithKeyword(
504       ArrayRef<NamedAttribute> attrs,
505       ArrayRef<StringRef> elidedAttrs = {}) override {
506     printOptionalAttrDict(attrs, elidedAttrs);
507   }
508 
509   /// Return a null stream as the output stream, this will ignore any data fed
510   /// to it.
511   raw_ostream &getStream() const override { return os; }
512 
513   /// The following are hooks of `OpAsmPrinter` that are not necessary for
514   /// determining potential aliases.
515   void printFloat(const APFloat &value) override {}
516   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
517   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
518   void printNewline() override {}
519   void printOperand(Value) override {}
520   void printOperand(Value, raw_ostream &os) override {
521     // Users expect the output string to have at least the prefixed % to signal
522     // a value name. To maintain this invariant, emit a name even if it is
523     // guaranteed to go unused.
524     os << "%";
525   }
526   void printKeywordOrString(StringRef) override {}
527   void printSymbolName(StringRef) override {}
528   void printSuccessor(Block *) override {}
529   void printSuccessorAndUseList(Block *, ValueRange) override {}
530   void shadowRegionArgs(Region &, ValueRange) override {}
531 
532   /// The printer flags to use when determining potential aliases.
533   const OpPrintingFlags &printerFlags;
534 
535   /// The initializer to use when identifying aliases.
536   AliasInitializer &initializer;
537 
538   /// A dummy output stream.
539   mutable llvm::raw_null_ostream os;
540 };
541 } // namespace
542 
543 /// Sanitize the given name such that it can be used as a valid identifier. If
544 /// the string needs to be modified in any way, the provided buffer is used to
545 /// store the new copy,
546 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
547                                     StringRef allowedPunctChars = "$._-",
548                                     bool allowTrailingDigit = true) {
549   assert(!name.empty() && "Shouldn't have an empty name here");
550 
551   auto copyNameToBuffer = [&] {
552     for (char ch : name) {
553       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
554         buffer.push_back(ch);
555       else if (ch == ' ')
556         buffer.push_back('_');
557       else
558         buffer.append(llvm::utohexstr((unsigned char)ch));
559     }
560   };
561 
562   // Check to see if this name is valid. If it starts with a digit, then it
563   // could conflict with the autogenerated numeric ID's, so add an underscore
564   // prefix to avoid problems.
565   if (isdigit(name[0])) {
566     buffer.push_back('_');
567     copyNameToBuffer();
568     return buffer;
569   }
570 
571   // If the name ends with a trailing digit, add a '_' to avoid potential
572   // conflicts with autogenerated ID's.
573   if (!allowTrailingDigit && isdigit(name.back())) {
574     copyNameToBuffer();
575     buffer.push_back('_');
576     return buffer;
577   }
578 
579   // Check to see that the name consists of only valid identifier characters.
580   for (char ch : name) {
581     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
582       copyNameToBuffer();
583       return buffer;
584     }
585   }
586 
587   // If there are no invalid characters, return the original name.
588   return name;
589 }
590 
591 /// Given a collection of aliases and symbols, initialize a mapping from a
592 /// symbol to a given alias.
593 template <typename T>
594 static void
595 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
596                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
597                   DenseSet<T> *deferrableAliases = nullptr) {
598   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
599       aliasToSymbol.takeVector();
600   llvm::array_pod_sort(aliases.begin(), aliases.end(),
601                        [](const auto *lhs, const auto *rhs) {
602                          return lhs->first.compare(rhs->first);
603                        });
604 
605   for (auto &it : aliases) {
606     // If there is only one instance for this alias, use the name directly.
607     if (it.second.size() == 1) {
608       T symbol = it.second.front();
609       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
610       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
611       continue;
612     }
613     // Otherwise, add the index to the name.
614     for (int i = 0, e = it.second.size(); i < e; ++i) {
615       T symbol = it.second[i];
616       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
617       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
618     }
619   }
620 }
621 
622 void AliasInitializer::initialize(
623     Operation *op, const OpPrintingFlags &printerFlags,
624     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
625     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
626   // Use a dummy printer when walking the IR so that we can collect the
627   // attributes/types that will actually be used during printing when
628   // considering aliases.
629   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
630   aliasPrinter.print(op);
631 
632   // Initialize the aliases sorted by name.
633   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
634   initializeAliases(aliasToType, typeToAlias);
635 }
636 
637 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
638   if (!visitedAttributes.insert(attr).second) {
639     // If this attribute already has an alias and this instance can't be
640     // deferred, make sure that the alias isn't deferred.
641     if (!canBeDeferred)
642       deferrableAttributes.erase(attr);
643     return;
644   }
645 
646   // Try to generate an alias for this attribute.
647   if (succeeded(generateAlias(attr, aliasToAttr))) {
648     if (canBeDeferred)
649       deferrableAttributes.insert(attr);
650     return;
651   }
652 
653   // Check for any sub elements.
654   if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
655     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
656                                         [&](Type type) { visit(type); });
657   }
658 }
659 
660 void AliasInitializer::visit(Type type) {
661   if (!visitedTypes.insert(type).second)
662     return;
663 
664   // Try to generate an alias for this type.
665   if (succeeded(generateAlias(type, aliasToType)))
666     return;
667 
668   // Check for any sub elements.
669   if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
670     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
671                                         [&](Type type) { visit(type); });
672   }
673 }
674 
675 template <typename T>
676 LogicalResult AliasInitializer::generateAlias(
677     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
678   SmallString<32> nameBuffer;
679   for (const auto &interface : interfaces) {
680     OpAsmDialectInterface::AliasResult result =
681         interface.getAlias(symbol, aliasOS);
682     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
683       continue;
684     nameBuffer = std::move(aliasBuffer);
685     assert(!nameBuffer.empty() && "expected valid alias name");
686     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
687       break;
688   }
689 
690   if (nameBuffer.empty())
691     return failure();
692 
693   SmallString<16> tempBuffer;
694   StringRef name =
695       sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
696                          /*allowTrailingDigit=*/false);
697   name = name.copy(aliasAllocator);
698   aliasToSymbol[name].push_back(symbol);
699   return success();
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // AliasState
704 //===----------------------------------------------------------------------===//
705 
706 namespace {
707 /// This class manages the state for type and attribute aliases.
708 class AliasState {
709 public:
710   // Initialize the internal aliases.
711   void
712   initialize(Operation *op, const OpPrintingFlags &printerFlags,
713              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
714 
715   /// Get an alias for the given attribute if it has one and print it in `os`.
716   /// Returns success if an alias was printed, failure otherwise.
717   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
718 
719   /// Get an alias for the given type if it has one and print it in `os`.
720   /// Returns success if an alias was printed, failure otherwise.
721   LogicalResult getAlias(Type ty, raw_ostream &os) const;
722 
723   /// Print all of the referenced aliases that can not be resolved in a deferred
724   /// manner.
725   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
726     printAliases(os, newLine, /*isDeferred=*/false);
727   }
728 
729   /// Print all of the referenced aliases that support deferred resolution.
730   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
731     printAliases(os, newLine, /*isDeferred=*/true);
732   }
733 
734 private:
735   /// Print all of the referenced aliases that support the provided resolution
736   /// behavior.
737   void printAliases(raw_ostream &os, NewLineCounter &newLine,
738                     bool isDeferred) const;
739 
740   /// Mapping between attribute and alias.
741   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
742   /// Mapping between type and alias.
743   llvm::MapVector<Type, SymbolAlias> typeToAlias;
744 
745   /// An allocator used for alias names.
746   llvm::BumpPtrAllocator aliasAllocator;
747 };
748 } // namespace
749 
750 void AliasState::initialize(
751     Operation *op, const OpPrintingFlags &printerFlags,
752     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
753   AliasInitializer initializer(interfaces, aliasAllocator);
754   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
755 }
756 
757 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
758   auto it = attrToAlias.find(attr);
759   if (it == attrToAlias.end())
760     return failure();
761   it->second.print(os << '#');
762   return success();
763 }
764 
765 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
766   auto it = typeToAlias.find(ty);
767   if (it == typeToAlias.end())
768     return failure();
769 
770   it->second.print(os << '!');
771   return success();
772 }
773 
774 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
775                               bool isDeferred) const {
776   auto filterFn = [=](const auto &aliasIt) {
777     return aliasIt.second.canBeDeferred() == isDeferred;
778   };
779   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
780     it.second.print(os << '#');
781     os << " = " << it.first << newLine;
782   }
783   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
784     it.second.print(os << '!');
785     os << " = type " << it.first << newLine;
786   }
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // SSANameState
791 //===----------------------------------------------------------------------===//
792 
793 namespace {
794 /// This class manages the state of SSA value names.
795 class SSANameState {
796 public:
797   /// A sentinel value used for values with names set.
798   enum : unsigned { NameSentinel = ~0U };
799 
800   SSANameState(Operation *op, const OpPrintingFlags &printerFlags,
801                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
802 
803   /// Print the SSA identifier for the given value to 'stream'. If
804   /// 'printResultNo' is true, it also presents the result number ('#' number)
805   /// of this value.
806   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
807 
808   /// Return the result indices for each of the result groups registered by this
809   /// operation, or empty if none exist.
810   ArrayRef<int> getOpResultGroups(Operation *op);
811 
812   /// Get the ID for the given block.
813   unsigned getBlockID(Block *block);
814 
815   /// Renumber the arguments for the specified region to the same names as the
816   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
817   /// details.
818   void shadowRegionArgs(Region &region, ValueRange namesToUse);
819 
820 private:
821   /// Number the SSA values within the given IR unit.
822   void numberValuesInRegion(Region &region);
823   void numberValuesInBlock(Block &block);
824   void numberValuesInOp(Operation &op);
825 
826   /// Given a result of an operation 'result', find the result group head
827   /// 'lookupValue' and the result of 'result' within that group in
828   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
829   /// has more than 1 result.
830   void getResultIDAndNumber(OpResult result, Value &lookupValue,
831                             Optional<int> &lookupResultNo) const;
832 
833   /// Set a special value name for the given value.
834   void setValueName(Value value, StringRef name);
835 
836   /// Uniques the given value name within the printer. If the given name
837   /// conflicts, it is automatically renamed.
838   StringRef uniqueValueName(StringRef name);
839 
840   /// This is the value ID for each SSA value. If this returns NameSentinel,
841   /// then the valueID has an entry in valueNames.
842   DenseMap<Value, unsigned> valueIDs;
843   DenseMap<Value, StringRef> valueNames;
844 
845   /// This is a map of operations that contain multiple named result groups,
846   /// i.e. there may be multiple names for the results of the operation. The
847   /// value of this map are the result numbers that start a result group.
848   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
849 
850   /// This is the block ID for each block in the current.
851   DenseMap<Block *, unsigned> blockIDs;
852 
853   /// This keeps track of all of the non-numeric names that are in flight,
854   /// allowing us to check for duplicates.
855   /// Note: the value of the map is unused.
856   llvm::ScopedHashTable<StringRef, char> usedNames;
857   llvm::BumpPtrAllocator usedNameAllocator;
858 
859   /// This is the next value ID to assign in numbering.
860   unsigned nextValueID = 0;
861   /// This is the next ID to assign to a region entry block argument.
862   unsigned nextArgumentID = 0;
863   /// This is the next ID to assign when a name conflict is detected.
864   unsigned nextConflictID = 0;
865 
866   /// These are the printing flags.  They control, eg., whether to print in
867   /// generic form.
868   OpPrintingFlags printerFlags;
869 
870   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
871 };
872 } // namespace
873 
874 SSANameState::SSANameState(
875     Operation *op, const OpPrintingFlags &printerFlags,
876     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces)
877     : printerFlags(printerFlags), interfaces(interfaces) {
878   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
879   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
880   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
881 
882   // The naming context includes `nextValueID`, `nextArgumentID`,
883   // `nextConflictID` and `usedNames` scoped HashTable. This information is
884   // carried from the parent region.
885   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
886   using NamingContext =
887       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
888 
889   // Allocator for UsedNamesScopeTy
890   llvm::BumpPtrAllocator allocator;
891 
892   // Add a scope for the top level operation.
893   auto *topLevelNamesScope =
894       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
895 
896   SmallVector<NamingContext, 8> nameContext;
897   for (Region &region : op->getRegions())
898     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
899                                           nextConflictID, topLevelNamesScope));
900 
901   numberValuesInOp(*op);
902 
903   while (!nameContext.empty()) {
904     Region *region;
905     UsedNamesScopeTy *parentScope;
906     std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
907         nameContext.pop_back_val();
908 
909     // When we switch from one subtree to another, pop the scopes(needless)
910     // until the parent scope.
911     while (usedNames.getCurScope() != parentScope) {
912       usedNames.getCurScope()->~UsedNamesScopeTy();
913       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
914              "top level parentScope must be a nullptr");
915     }
916 
917     // Add a scope for the current region.
918     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
919         UsedNamesScopeTy(usedNames);
920 
921     numberValuesInRegion(*region);
922 
923     for (Operation &op : region->getOps())
924       for (Region &region : op.getRegions())
925         nameContext.push_back(std::make_tuple(&region, nextValueID,
926                                               nextArgumentID, nextConflictID,
927                                               curNamesScope));
928   }
929 
930   // Manually remove all the scopes.
931   while (usedNames.getCurScope() != nullptr)
932     usedNames.getCurScope()->~UsedNamesScopeTy();
933 }
934 
935 void SSANameState::printValueID(Value value, bool printResultNo,
936                                 raw_ostream &stream) const {
937   if (!value) {
938     stream << "<<NULL VALUE>>";
939     return;
940   }
941 
942   Optional<int> resultNo;
943   auto lookupValue = value;
944 
945   // If this is an operation result, collect the head lookup value of the result
946   // group and the result number of 'result' within that group.
947   if (OpResult result = value.dyn_cast<OpResult>())
948     getResultIDAndNumber(result, lookupValue, resultNo);
949 
950   auto it = valueIDs.find(lookupValue);
951   if (it == valueIDs.end()) {
952     stream << "<<UNKNOWN SSA VALUE>>";
953     return;
954   }
955 
956   stream << '%';
957   if (it->second != NameSentinel) {
958     stream << it->second;
959   } else {
960     auto nameIt = valueNames.find(lookupValue);
961     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
962     stream << nameIt->second;
963   }
964 
965   if (resultNo.hasValue() && printResultNo)
966     stream << '#' << resultNo;
967 }
968 
969 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
970   auto it = opResultGroups.find(op);
971   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
972 }
973 
974 unsigned SSANameState::getBlockID(Block *block) {
975   auto it = blockIDs.find(block);
976   return it != blockIDs.end() ? it->second : NameSentinel;
977 }
978 
979 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
980   assert(!region.empty() && "cannot shadow arguments of an empty region");
981   assert(region.getNumArguments() == namesToUse.size() &&
982          "incorrect number of names passed in");
983   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
984          "only KnownIsolatedFromAbove ops can shadow names");
985 
986   SmallVector<char, 16> nameStr;
987   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
988     auto nameToUse = namesToUse[i];
989     if (nameToUse == nullptr)
990       continue;
991     auto nameToReplace = region.getArgument(i);
992 
993     nameStr.clear();
994     llvm::raw_svector_ostream nameStream(nameStr);
995     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
996 
997     // Entry block arguments should already have a pretty "arg" name.
998     assert(valueIDs[nameToReplace] == NameSentinel);
999 
1000     // Use the name without the leading %.
1001     auto name = StringRef(nameStream.str()).drop_front();
1002 
1003     // Overwrite the name.
1004     valueNames[nameToReplace] = name.copy(usedNameAllocator);
1005   }
1006 }
1007 
1008 void SSANameState::numberValuesInRegion(Region &region) {
1009   auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1010     assert(!valueIDs.count(arg) && "arg numbered multiple times");
1011     assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
1012            "arg not defined in current region");
1013     setValueName(arg, name);
1014   };
1015 
1016   if (!printerFlags.shouldPrintGenericOpForm()) {
1017     if (Operation *op = region.getParentOp()) {
1018       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1019         asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1020     }
1021   }
1022 
1023   // Number the values within this region in a breadth-first order.
1024   unsigned nextBlockID = 0;
1025   for (auto &block : region) {
1026     // Each block gets a unique ID, and all of the operations within it get
1027     // numbered as well.
1028     blockIDs[&block] = nextBlockID++;
1029     numberValuesInBlock(block);
1030   }
1031 }
1032 
1033 void SSANameState::numberValuesInBlock(Block &block) {
1034   // Number the block arguments. We give entry block arguments a special name
1035   // 'arg'.
1036   bool isEntryBlock = block.isEntryBlock();
1037   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1038   llvm::raw_svector_ostream specialName(specialNameBuffer);
1039   for (auto arg : block.getArguments()) {
1040     if (valueIDs.count(arg))
1041       continue;
1042     if (isEntryBlock) {
1043       specialNameBuffer.resize(strlen("arg"));
1044       specialName << nextArgumentID++;
1045     }
1046     setValueName(arg, specialName.str());
1047   }
1048 
1049   // Number the operations in this block.
1050   for (auto &op : block)
1051     numberValuesInOp(op);
1052 }
1053 
1054 void SSANameState::numberValuesInOp(Operation &op) {
1055   unsigned numResults = op.getNumResults();
1056   if (numResults == 0)
1057     return;
1058   Value resultBegin = op.getResult(0);
1059 
1060   // Function used to set the special result names for the operation.
1061   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1062   auto setResultNameFn = [&](Value result, StringRef name) {
1063     assert(!valueIDs.count(result) && "result numbered multiple times");
1064     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1065     setValueName(result, name);
1066 
1067     // Record the result number for groups not anchored at 0.
1068     if (int resultNo = result.cast<OpResult>().getResultNumber())
1069       resultGroups.push_back(resultNo);
1070   };
1071   if (!printerFlags.shouldPrintGenericOpForm()) {
1072     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
1073       asmInterface.getAsmResultNames(setResultNameFn);
1074     else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
1075       asmInterface->getAsmResultNames(&op, setResultNameFn);
1076   }
1077 
1078   // If the first result wasn't numbered, give it a default number.
1079   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1080     ++nextValueID;
1081 
1082   // If this operation has multiple result groups, mark it.
1083   if (resultGroups.size() != 1) {
1084     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1085     opResultGroups.try_emplace(&op, std::move(resultGroups));
1086   }
1087 }
1088 
1089 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1090                                         Optional<int> &lookupResultNo) const {
1091   Operation *owner = result.getOwner();
1092   if (owner->getNumResults() == 1)
1093     return;
1094   int resultNo = result.getResultNumber();
1095 
1096   // If this operation has multiple result groups, we will need to find the
1097   // one corresponding to this result.
1098   auto resultGroupIt = opResultGroups.find(owner);
1099   if (resultGroupIt == opResultGroups.end()) {
1100     // If not, just use the first result.
1101     lookupResultNo = resultNo;
1102     lookupValue = owner->getResult(0);
1103     return;
1104   }
1105 
1106   // Find the correct index using a binary search, as the groups are ordered.
1107   ArrayRef<int> resultGroups = resultGroupIt->second;
1108   const auto *it = llvm::upper_bound(resultGroups, resultNo);
1109   int groupResultNo = 0, groupSize = 0;
1110 
1111   // If there are no smaller elements, the last result group is the lookup.
1112   if (it == resultGroups.end()) {
1113     groupResultNo = resultGroups.back();
1114     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1115   } else {
1116     // Otherwise, the previous element is the lookup.
1117     groupResultNo = *std::prev(it);
1118     groupSize = *it - groupResultNo;
1119   }
1120 
1121   // We only record the result number for a group of size greater than 1.
1122   if (groupSize != 1)
1123     lookupResultNo = resultNo - groupResultNo;
1124   lookupValue = owner->getResult(groupResultNo);
1125 }
1126 
1127 void SSANameState::setValueName(Value value, StringRef name) {
1128   // If the name is empty, the value uses the default numbering.
1129   if (name.empty()) {
1130     valueIDs[value] = nextValueID++;
1131     return;
1132   }
1133 
1134   valueIDs[value] = NameSentinel;
1135   valueNames[value] = uniqueValueName(name);
1136 }
1137 
1138 StringRef SSANameState::uniqueValueName(StringRef name) {
1139   SmallString<16> tmpBuffer;
1140   name = sanitizeIdentifier(name, tmpBuffer);
1141 
1142   // Check to see if this name is already unique.
1143   if (!usedNames.count(name)) {
1144     name = name.copy(usedNameAllocator);
1145   } else {
1146     // Otherwise, we had a conflict - probe until we find a unique name. This
1147     // is guaranteed to terminate (and usually in a single iteration) because it
1148     // generates new names by incrementing nextConflictID.
1149     SmallString<64> probeName(name);
1150     probeName.push_back('_');
1151     while (true) {
1152       probeName += llvm::utostr(nextConflictID++);
1153       if (!usedNames.count(probeName)) {
1154         name = probeName.str().copy(usedNameAllocator);
1155         break;
1156       }
1157       probeName.resize(name.size() + 1);
1158     }
1159   }
1160 
1161   usedNames.insert(name, char());
1162   return name;
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // AsmState
1167 //===----------------------------------------------------------------------===//
1168 
1169 namespace mlir {
1170 namespace detail {
1171 class AsmStateImpl {
1172 public:
1173   explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1174                         AsmState::LocationMap *locationMap)
1175       : interfaces(op->getContext()), nameState(op, printerFlags, interfaces),
1176         printerFlags(printerFlags), locationMap(locationMap) {}
1177 
1178   /// Initialize the alias state to enable the printing of aliases.
1179   void initializeAliases(Operation *op) {
1180     aliasState.initialize(op, printerFlags, interfaces);
1181   }
1182 
1183   /// Get an instance of the OpAsmDialectInterface for the given dialect, or
1184   /// null if one wasn't registered.
1185   const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
1186     return interfaces.getInterfaceFor(dialect);
1187   }
1188 
1189   /// Get the state used for aliases.
1190   AliasState &getAliasState() { return aliasState; }
1191 
1192   /// Get the state used for SSA names.
1193   SSANameState &getSSANameState() { return nameState; }
1194 
1195   /// Register the location, line and column, within the buffer that the given
1196   /// operation was printed at.
1197   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1198     if (locationMap)
1199       (*locationMap)[op] = std::make_pair(line, col);
1200   }
1201 
1202 private:
1203   /// Collection of OpAsm interfaces implemented in the context.
1204   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1205 
1206   /// The state used for attribute and type aliases.
1207   AliasState aliasState;
1208 
1209   /// The state used for SSA value names.
1210   SSANameState nameState;
1211 
1212   /// Flags that control op output.
1213   OpPrintingFlags printerFlags;
1214 
1215   /// An optional location map to be populated.
1216   AsmState::LocationMap *locationMap;
1217 };
1218 } // namespace detail
1219 } // namespace mlir
1220 
1221 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1222                    LocationMap *locationMap)
1223     : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
1224 AsmState::~AsmState() = default;
1225 
1226 //===----------------------------------------------------------------------===//
1227 // AsmPrinter::Impl
1228 //===----------------------------------------------------------------------===//
1229 
1230 namespace mlir {
1231 class AsmPrinter::Impl {
1232 public:
1233   Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1234        AsmStateImpl *state = nullptr)
1235       : os(os), printerFlags(flags), state(state) {}
1236   explicit Impl(Impl &other)
1237       : Impl(other.os, other.printerFlags, other.state) {}
1238 
1239   /// Returns the output stream of the printer.
1240   raw_ostream &getStream() { return os; }
1241 
1242   template <typename Container, typename UnaryFunctor>
1243   inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
1244     llvm::interleaveComma(c, os, eachFn);
1245   }
1246 
1247   /// This enum describes the different kinds of elision for the type of an
1248   /// attribute when printing it.
1249   enum class AttrTypeElision {
1250     /// The type must not be elided,
1251     Never,
1252     /// The type may be elided when it matches the default used in the parser
1253     /// (for example i64 is the default for integer attributes).
1254     May,
1255     /// The type must be elided.
1256     Must
1257   };
1258 
1259   /// Print the given attribute.
1260   void printAttribute(Attribute attr,
1261                       AttrTypeElision typeElision = AttrTypeElision::Never);
1262 
1263   /// Print the alias for the given attribute, return failure if no alias could
1264   /// be printed.
1265   LogicalResult printAlias(Attribute attr);
1266 
1267   void printType(Type type);
1268 
1269   /// Print the alias for the given type, return failure if no alias could
1270   /// be printed.
1271   LogicalResult printAlias(Type type);
1272 
1273   /// Print the given location to the stream. If `allowAlias` is true, this
1274   /// allows for the internal location to use an attribute alias.
1275   void printLocation(LocationAttr loc, bool allowAlias = false);
1276 
1277   void printAffineMap(AffineMap map);
1278   void
1279   printAffineExpr(AffineExpr expr,
1280                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1281   void printAffineConstraint(AffineExpr expr, bool isEq);
1282   void printIntegerSet(IntegerSet set);
1283 
1284 protected:
1285   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1286                              ArrayRef<StringRef> elidedAttrs = {},
1287                              bool withKeyword = false);
1288   void printNamedAttribute(NamedAttribute attr);
1289   void printTrailingLocation(Location loc, bool allowAlias = true);
1290   void printLocationInternal(LocationAttr loc, bool pretty = false);
1291 
1292   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1293   /// used instead of individual elements when the elements attr is large.
1294   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1295 
1296   /// Print a dense string elements attribute.
1297   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1298 
1299   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1300   /// used instead of individual elements when the elements attr is large.
1301   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1302                                      bool allowHex);
1303 
1304   void printDialectAttribute(Attribute attr);
1305   void printDialectType(Type type);
1306 
1307   /// This enum is used to represent the binding strength of the enclosing
1308   /// context that an AffineExprStorage is being printed in, so we can
1309   /// intelligently produce parens.
1310   enum class BindingStrength {
1311     Weak,   // + and -
1312     Strong, // All other binary operators.
1313   };
1314   void printAffineExprInternal(
1315       AffineExpr expr, BindingStrength enclosingTightness,
1316       function_ref<void(unsigned, bool)> printValueName = nullptr);
1317 
1318   /// The output stream for the printer.
1319   raw_ostream &os;
1320 
1321   /// A set of flags to control the printer's behavior.
1322   OpPrintingFlags printerFlags;
1323 
1324   /// An optional printer state for the module.
1325   AsmStateImpl *state;
1326 
1327   /// A tracker for the number of new lines emitted during printing.
1328   NewLineCounter newLine;
1329 };
1330 } // namespace mlir
1331 
1332 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1333   // Check to see if we are printing debug information.
1334   if (!printerFlags.shouldPrintDebugInfo())
1335     return;
1336 
1337   os << " ";
1338   printLocation(loc, /*allowAlias=*/allowAlias);
1339 }
1340 
1341 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
1342   TypeSwitch<LocationAttr>(loc)
1343       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1344         printLocationInternal(loc.getFallbackLocation(), pretty);
1345       })
1346       .Case<UnknownLoc>([&](UnknownLoc loc) {
1347         if (pretty)
1348           os << "[unknown]";
1349         else
1350           os << "unknown";
1351       })
1352       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1353         if (pretty) {
1354           os << loc.getFilename().getValue();
1355         } else {
1356           os << "\"";
1357           printEscapedString(loc.getFilename(), os);
1358           os << "\"";
1359         }
1360         os << ':' << loc.getLine() << ':' << loc.getColumn();
1361       })
1362       .Case<NameLoc>([&](NameLoc loc) {
1363         os << '\"';
1364         printEscapedString(loc.getName(), os);
1365         os << '\"';
1366 
1367         // Print the child if it isn't unknown.
1368         auto childLoc = loc.getChildLoc();
1369         if (!childLoc.isa<UnknownLoc>()) {
1370           os << '(';
1371           printLocationInternal(childLoc, pretty);
1372           os << ')';
1373         }
1374       })
1375       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1376         Location caller = loc.getCaller();
1377         Location callee = loc.getCallee();
1378         if (!pretty)
1379           os << "callsite(";
1380         printLocationInternal(callee, pretty);
1381         if (pretty) {
1382           if (callee.isa<NameLoc>()) {
1383             if (caller.isa<FileLineColLoc>()) {
1384               os << " at ";
1385             } else {
1386               os << newLine << " at ";
1387             }
1388           } else {
1389             os << newLine << " at ";
1390           }
1391         } else {
1392           os << " at ";
1393         }
1394         printLocationInternal(caller, pretty);
1395         if (!pretty)
1396           os << ")";
1397       })
1398       .Case<FusedLoc>([&](FusedLoc loc) {
1399         if (!pretty)
1400           os << "fused";
1401         if (Attribute metadata = loc.getMetadata())
1402           os << '<' << metadata << '>';
1403         os << '[';
1404         interleave(
1405             loc.getLocations(),
1406             [&](Location loc) { printLocationInternal(loc, pretty); },
1407             [&]() { os << ", "; });
1408         os << ']';
1409       });
1410 }
1411 
1412 /// Print a floating point value in a way that the parser will be able to
1413 /// round-trip losslessly.
1414 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1415   // We would like to output the FP constant value in exponential notation,
1416   // but we cannot do this if doing so will lose precision.  Check here to
1417   // make sure that we only output it in exponential format if we can parse
1418   // the value back and get the same value.
1419   bool isInf = apValue.isInfinity();
1420   bool isNaN = apValue.isNaN();
1421   if (!isInf && !isNaN) {
1422     SmallString<128> strValue;
1423     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1424                      /*TruncateZero=*/false);
1425 
1426     // Check to make sure that the stringized number is not some string like
1427     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1428     // that the string matches the "[-+]?[0-9]" regex.
1429     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1430             ((strValue[0] == '-' || strValue[0] == '+') &&
1431              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1432            "[-+]?[0-9] regex does not match!");
1433 
1434     // Parse back the stringized version and check that the value is equal
1435     // (i.e., there is no precision loss).
1436     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1437       os << strValue;
1438       return;
1439     }
1440 
1441     // If it is not, use the default format of APFloat instead of the
1442     // exponential notation.
1443     strValue.clear();
1444     apValue.toString(strValue);
1445 
1446     // Make sure that we can parse the default form as a float.
1447     if (strValue.str().contains('.')) {
1448       os << strValue;
1449       return;
1450     }
1451   }
1452 
1453   // Print special values in hexadecimal format. The sign bit should be included
1454   // in the literal.
1455   SmallVector<char, 16> str;
1456   APInt apInt = apValue.bitcastToAPInt();
1457   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1458                  /*formatAsCLiteral=*/true);
1459   os << str;
1460 }
1461 
1462 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
1463   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1464     return printLocationInternal(loc, /*pretty=*/true);
1465 
1466   os << "loc(";
1467   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1468     printLocationInternal(loc);
1469   os << ')';
1470 }
1471 
1472 /// Returns true if the given dialect symbol data is simple enough to print in
1473 /// the pretty form, i.e. without the enclosing "".
1474 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1475   // The name must start with an identifier.
1476   if (symName.empty() || !isalpha(symName.front()))
1477     return false;
1478 
1479   // Ignore all the characters that are valid in an identifier in the symbol
1480   // name.
1481   symName = symName.drop_while(
1482       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1483   if (symName.empty())
1484     return true;
1485 
1486   // If we got to an unexpected character, then it must be a <>.  Check those
1487   // recursively.
1488   if (symName.front() != '<' || symName.back() != '>')
1489     return false;
1490 
1491   SmallVector<char, 8> nestedPunctuation;
1492   do {
1493     // If we ran out of characters, then we had a punctuation mismatch.
1494     if (symName.empty())
1495       return false;
1496 
1497     auto c = symName.front();
1498     symName = symName.drop_front();
1499 
1500     switch (c) {
1501     // We never allow null characters. This is an EOF indicator for the lexer
1502     // which we could handle, but isn't important for any known dialect.
1503     case '\0':
1504       return false;
1505     case '<':
1506     case '[':
1507     case '(':
1508     case '{':
1509       nestedPunctuation.push_back(c);
1510       continue;
1511     case '-':
1512       // Treat `->` as a special token.
1513       if (!symName.empty() && symName.front() == '>') {
1514         symName = symName.drop_front();
1515         continue;
1516       }
1517       break;
1518     // Reject types with mismatched brackets.
1519     case '>':
1520       if (nestedPunctuation.pop_back_val() != '<')
1521         return false;
1522       break;
1523     case ']':
1524       if (nestedPunctuation.pop_back_val() != '[')
1525         return false;
1526       break;
1527     case ')':
1528       if (nestedPunctuation.pop_back_val() != '(')
1529         return false;
1530       break;
1531     case '}':
1532       if (nestedPunctuation.pop_back_val() != '{')
1533         return false;
1534       break;
1535     default:
1536       continue;
1537     }
1538 
1539     // We're done when the punctuation is fully matched.
1540   } while (!nestedPunctuation.empty());
1541 
1542   // If there were extra characters, then we failed.
1543   return symName.empty();
1544 }
1545 
1546 /// Print the given dialect symbol to the stream.
1547 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1548                                StringRef dialectName, StringRef symString) {
1549   os << symPrefix << dialectName;
1550 
1551   // If this symbol name is simple enough, print it directly in pretty form,
1552   // otherwise, we print it as an escaped string.
1553   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1554     os << '.' << symString;
1555     return;
1556   }
1557 
1558   os << "<\"";
1559   llvm::printEscapedString(symString, os);
1560   os << "\">";
1561 }
1562 
1563 /// Returns true if the given string can be represented as a bare identifier.
1564 static bool isBareIdentifier(StringRef name) {
1565   // By making this unsigned, the value passed in to isalnum will always be
1566   // in the range 0-255. This is important when building with MSVC because
1567   // its implementation will assert. This situation can arise when dealing
1568   // with UTF-8 multibyte characters.
1569   if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
1570     return false;
1571   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1572     return isalnum(c) || c == '_' || c == '$' || c == '.';
1573   });
1574 }
1575 
1576 /// Print the given string as a keyword, or a quoted and escaped string if it
1577 /// has any special or non-printable characters in it.
1578 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
1579   // If it can be represented as a bare identifier, write it directly.
1580   if (isBareIdentifier(keyword)) {
1581     os << keyword;
1582     return;
1583   }
1584 
1585   // Otherwise, output the keyword wrapped in quotes with proper escaping.
1586   os << "\"";
1587   printEscapedString(keyword, os);
1588   os << '"';
1589 }
1590 
1591 /// Print the given string as a symbol reference. A symbol reference is
1592 /// represented as a string prefixed with '@'. The reference is surrounded with
1593 /// ""'s and escaped if it has any special or non-printable characters in it.
1594 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1595   assert(!symbolRef.empty() && "expected valid symbol reference");
1596   os << '@';
1597   printKeywordOrString(symbolRef, os);
1598 }
1599 
1600 // Print out a valid ElementsAttr that is succinct and can represent any
1601 // potential shape/type, for use when eliding a large ElementsAttr.
1602 //
1603 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1604 // hopefully alert readers to the fact that this has been elided.
1605 //
1606 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1607 // accept the string "elided". The first string must be a registered dialect
1608 // name and the latter must be a hex constant.
1609 static void printElidedElementsAttr(raw_ostream &os) {
1610   os << R"(opaque<"_", "0xDEADBEEF">)";
1611 }
1612 
1613 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
1614   return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
1615 }
1616 
1617 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
1618   return success(state && succeeded(state->getAliasState().getAlias(type, os)));
1619 }
1620 
1621 void AsmPrinter::Impl::printAttribute(Attribute attr,
1622                                       AttrTypeElision typeElision) {
1623   if (!attr) {
1624     os << "<<NULL ATTRIBUTE>>";
1625     return;
1626   }
1627 
1628   // Try to print an alias for this attribute.
1629   if (succeeded(printAlias(attr)))
1630     return;
1631 
1632   if (!isa<BuiltinDialect>(attr.getDialect()))
1633     return printDialectAttribute(attr);
1634 
1635   auto attrType = attr.getType();
1636   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1637     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1638                        opaqueAttr.getAttrData());
1639   } else if (attr.isa<UnitAttr>()) {
1640     os << "unit";
1641     return;
1642   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1643     os << '{';
1644     interleaveComma(dictAttr.getValue(),
1645                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1646     os << '}';
1647 
1648   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1649     if (attrType.isSignlessInteger(1)) {
1650       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1651 
1652       // Boolean integer attributes always elides the type.
1653       return;
1654     }
1655 
1656     // Only print attributes as unsigned if they are explicitly unsigned or are
1657     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1658     // values print as signed.
1659     bool isUnsigned =
1660         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1661     intAttr.getValue().print(os, !isUnsigned);
1662 
1663     // IntegerAttr elides the type if I64.
1664     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1665       return;
1666 
1667   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1668     printFloatValue(floatAttr.getValue(), os);
1669 
1670     // FloatAttr elides the type if F64.
1671     if (typeElision == AttrTypeElision::May && attrType.isF64())
1672       return;
1673 
1674   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1675     os << '"';
1676     printEscapedString(strAttr.getValue(), os);
1677     os << '"';
1678 
1679   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1680     os << '[';
1681     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1682       printAttribute(attr, AttrTypeElision::May);
1683     });
1684     os << ']';
1685 
1686   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1687     os << "affine_map<";
1688     affineMapAttr.getValue().print(os);
1689     os << '>';
1690 
1691     // AffineMap always elides the type.
1692     return;
1693 
1694   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1695     os << "affine_set<";
1696     integerSetAttr.getValue().print(os);
1697     os << '>';
1698 
1699     // IntegerSet always elides the type.
1700     return;
1701 
1702   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1703     printType(typeAttr.getValue());
1704 
1705   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1706     printSymbolReference(refAttr.getRootReference().getValue(), os);
1707     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1708       os << "::";
1709       printSymbolReference(nestedRef.getValue(), os);
1710     }
1711 
1712   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1713     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1714       printElidedElementsAttr(os);
1715     } else {
1716       os << "opaque<" << opaqueAttr.getDialect() << ", \"0x"
1717          << llvm::toHex(opaqueAttr.getValue()) << "\">";
1718     }
1719 
1720   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1721     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1722       printElidedElementsAttr(os);
1723     } else {
1724       os << "dense<";
1725       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1726       os << '>';
1727     }
1728 
1729   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1730     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1731       printElidedElementsAttr(os);
1732     } else {
1733       os << "dense<";
1734       printDenseStringElementsAttr(strEltAttr);
1735       os << '>';
1736     }
1737 
1738   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1739     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1740         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1741       printElidedElementsAttr(os);
1742     } else {
1743       os << "sparse<";
1744       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1745       if (indices.getNumElements() != 0) {
1746         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1747         os << ", ";
1748         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1749       }
1750       os << '>';
1751     }
1752 
1753   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1754     printLocation(locAttr);
1755   }
1756   // Don't print the type if we must elide it, or if it is a None type.
1757   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1758     os << " : ";
1759     printType(attrType);
1760   }
1761 }
1762 
1763 /// Print the integer element of a DenseElementsAttr.
1764 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1765                                  bool isSigned) {
1766   if (value.getBitWidth() == 1)
1767     os << (value.getBoolValue() ? "true" : "false");
1768   else
1769     value.print(os, isSigned);
1770 }
1771 
1772 static void
1773 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1774                            function_ref<void(unsigned)> printEltFn) {
1775   // Special case for 0-d and splat tensors.
1776   if (isSplat)
1777     return printEltFn(0);
1778 
1779   // Special case for degenerate tensors.
1780   auto numElements = type.getNumElements();
1781   if (numElements == 0)
1782     return;
1783 
1784   // We use a mixed-radix counter to iterate through the shape. When we bump a
1785   // non-least-significant digit, we emit a close bracket. When we next emit an
1786   // element we re-open all closed brackets.
1787 
1788   // The mixed-radix counter, with radices in 'shape'.
1789   int64_t rank = type.getRank();
1790   SmallVector<unsigned, 4> counter(rank, 0);
1791   // The number of brackets that have been opened and not closed.
1792   unsigned openBrackets = 0;
1793 
1794   auto shape = type.getShape();
1795   auto bumpCounter = [&] {
1796     // Bump the least significant digit.
1797     ++counter[rank - 1];
1798     // Iterate backwards bubbling back the increment.
1799     for (unsigned i = rank - 1; i > 0; --i)
1800       if (counter[i] >= shape[i]) {
1801         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1802         counter[i] = 0;
1803         ++counter[i - 1];
1804         --openBrackets;
1805         os << ']';
1806       }
1807   };
1808 
1809   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1810     if (idx != 0)
1811       os << ", ";
1812     while (openBrackets++ < rank)
1813       os << '[';
1814     openBrackets = rank;
1815     printEltFn(idx);
1816     bumpCounter();
1817   }
1818   while (openBrackets-- > 0)
1819     os << ']';
1820 }
1821 
1822 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
1823                                               bool allowHex) {
1824   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1825     return printDenseStringElementsAttr(stringAttr);
1826 
1827   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1828                                 allowHex);
1829 }
1830 
1831 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
1832     DenseIntOrFPElementsAttr attr, bool allowHex) {
1833   auto type = attr.getType();
1834   auto elementType = type.getElementType();
1835 
1836   // Check to see if we should format this attribute as a hex string.
1837   auto numElements = type.getNumElements();
1838   if (!attr.isSplat() && allowHex &&
1839       shouldPrintElementsAttrWithHex(numElements)) {
1840     ArrayRef<char> rawData = attr.getRawData();
1841     if (llvm::support::endian::system_endianness() ==
1842         llvm::support::endianness::big) {
1843       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1844       // machines. It is converted here to print in LE format.
1845       SmallVector<char, 64> outDataVec(rawData.size());
1846       MutableArrayRef<char> convRawData(outDataVec);
1847       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1848           rawData, convRawData, type);
1849       os << '"' << "0x"
1850          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1851          << "\"";
1852     } else {
1853       os << '"' << "0x"
1854          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1855     }
1856 
1857     return;
1858   }
1859 
1860   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1861     Type complexElementType = complexTy.getElementType();
1862     // Note: The if and else below had a common lambda function which invoked
1863     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1864     // and hence was replaced.
1865     if (complexElementType.isa<IntegerType>()) {
1866       bool isSigned = !complexElementType.isUnsignedInteger();
1867       auto valueIt = attr.value_begin<std::complex<APInt>>();
1868       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1869         auto complexValue = *(valueIt + index);
1870         os << "(";
1871         printDenseIntElement(complexValue.real(), os, isSigned);
1872         os << ",";
1873         printDenseIntElement(complexValue.imag(), os, isSigned);
1874         os << ")";
1875       });
1876     } else {
1877       auto valueIt = attr.value_begin<std::complex<APFloat>>();
1878       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1879         auto complexValue = *(valueIt + index);
1880         os << "(";
1881         printFloatValue(complexValue.real(), os);
1882         os << ",";
1883         printFloatValue(complexValue.imag(), os);
1884         os << ")";
1885       });
1886     }
1887   } else if (elementType.isIntOrIndex()) {
1888     bool isSigned = !elementType.isUnsignedInteger();
1889     auto valueIt = attr.value_begin<APInt>();
1890     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1891       printDenseIntElement(*(valueIt + index), os, isSigned);
1892     });
1893   } else {
1894     assert(elementType.isa<FloatType>() && "unexpected element type");
1895     auto valueIt = attr.value_begin<APFloat>();
1896     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1897       printFloatValue(*(valueIt + index), os);
1898     });
1899   }
1900 }
1901 
1902 void AsmPrinter::Impl::printDenseStringElementsAttr(
1903     DenseStringElementsAttr attr) {
1904   ArrayRef<StringRef> data = attr.getRawStringData();
1905   auto printFn = [&](unsigned index) {
1906     os << "\"";
1907     printEscapedString(data[index], os);
1908     os << "\"";
1909   };
1910   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1911 }
1912 
1913 void AsmPrinter::Impl::printType(Type type) {
1914   if (!type) {
1915     os << "<<NULL TYPE>>";
1916     return;
1917   }
1918 
1919   // Try to print an alias for this type.
1920   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1921     return;
1922 
1923   TypeSwitch<Type>(type)
1924       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1925         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1926                            opaqueTy.getTypeData());
1927       })
1928       .Case<IndexType>([&](Type) { os << "index"; })
1929       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1930       .Case<Float16Type>([&](Type) { os << "f16"; })
1931       .Case<Float32Type>([&](Type) { os << "f32"; })
1932       .Case<Float64Type>([&](Type) { os << "f64"; })
1933       .Case<Float80Type>([&](Type) { os << "f80"; })
1934       .Case<Float128Type>([&](Type) { os << "f128"; })
1935       .Case<IntegerType>([&](IntegerType integerTy) {
1936         if (integerTy.isSigned())
1937           os << 's';
1938         else if (integerTy.isUnsigned())
1939           os << 'u';
1940         os << 'i' << integerTy.getWidth();
1941       })
1942       .Case<FunctionType>([&](FunctionType funcTy) {
1943         os << '(';
1944         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1945         os << ") -> ";
1946         ArrayRef<Type> results = funcTy.getResults();
1947         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1948           printType(results[0]);
1949         } else {
1950           os << '(';
1951           interleaveComma(results, [&](Type ty) { printType(ty); });
1952           os << ')';
1953         }
1954       })
1955       .Case<VectorType>([&](VectorType vectorTy) {
1956         os << "vector<";
1957         auto vShape = vectorTy.getShape();
1958         unsigned lastDim = vShape.size();
1959         unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
1960         unsigned dimIdx = 0;
1961         for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
1962           os << vShape[dimIdx] << 'x';
1963         if (vectorTy.isScalable()) {
1964           os << '[';
1965           unsigned secondToLastDim = lastDim - 1;
1966           for (; dimIdx < secondToLastDim; dimIdx++)
1967             os << vShape[dimIdx] << 'x';
1968           os << vShape[dimIdx] << "]x";
1969         }
1970         printType(vectorTy.getElementType());
1971         os << '>';
1972       })
1973       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
1974         os << "tensor<";
1975         for (int64_t dim : tensorTy.getShape()) {
1976           if (ShapedType::isDynamic(dim))
1977             os << '?';
1978           else
1979             os << dim;
1980           os << 'x';
1981         }
1982         printType(tensorTy.getElementType());
1983         // Only print the encoding attribute value if set.
1984         if (tensorTy.getEncoding()) {
1985           os << ", ";
1986           printAttribute(tensorTy.getEncoding());
1987         }
1988         os << '>';
1989       })
1990       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
1991         os << "tensor<*x";
1992         printType(tensorTy.getElementType());
1993         os << '>';
1994       })
1995       .Case<MemRefType>([&](MemRefType memrefTy) {
1996         os << "memref<";
1997         for (int64_t dim : memrefTy.getShape()) {
1998           if (ShapedType::isDynamic(dim))
1999             os << '?';
2000           else
2001             os << dim;
2002           os << 'x';
2003         }
2004         printType(memrefTy.getElementType());
2005         if (!memrefTy.getLayout().isIdentity()) {
2006           os << ", ";
2007           printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2008         }
2009         // Only print the memory space if it is the non-default one.
2010         if (memrefTy.getMemorySpace()) {
2011           os << ", ";
2012           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2013         }
2014         os << '>';
2015       })
2016       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2017         os << "memref<*x";
2018         printType(memrefTy.getElementType());
2019         // Only print the memory space if it is the non-default one.
2020         if (memrefTy.getMemorySpace()) {
2021           os << ", ";
2022           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2023         }
2024         os << '>';
2025       })
2026       .Case<ComplexType>([&](ComplexType complexTy) {
2027         os << "complex<";
2028         printType(complexTy.getElementType());
2029         os << '>';
2030       })
2031       .Case<TupleType>([&](TupleType tupleTy) {
2032         os << "tuple<";
2033         interleaveComma(tupleTy.getTypes(),
2034                         [&](Type type) { printType(type); });
2035         os << '>';
2036       })
2037       .Case<NoneType>([&](Type) { os << "none"; })
2038       .Default([&](Type type) { return printDialectType(type); });
2039 }
2040 
2041 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2042                                              ArrayRef<StringRef> elidedAttrs,
2043                                              bool withKeyword) {
2044   // If there are no attributes, then there is nothing to be done.
2045   if (attrs.empty())
2046     return;
2047 
2048   // Functor used to print a filtered attribute list.
2049   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2050     // Print the 'attributes' keyword if necessary.
2051     if (withKeyword)
2052       os << " attributes";
2053 
2054     // Otherwise, print them all out in braces.
2055     os << " {";
2056     interleaveComma(filteredAttrs,
2057                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
2058     os << '}';
2059   };
2060 
2061   // If no attributes are elided, we can directly print with no filtering.
2062   if (elidedAttrs.empty())
2063     return printFilteredAttributesFn(attrs);
2064 
2065   // Otherwise, filter out any attributes that shouldn't be included.
2066   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2067                                                 elidedAttrs.end());
2068   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2069     return !elidedAttrsSet.contains(attr.getName().strref());
2070   });
2071   if (!filteredAttrs.empty())
2072     printFilteredAttributesFn(filteredAttrs);
2073 }
2074 
2075 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2076   // Print the name without quotes if possible.
2077   ::printKeywordOrString(attr.getName().strref(), os);
2078 
2079   // Pretty printing elides the attribute value for unit attributes.
2080   if (attr.getValue().isa<UnitAttr>())
2081     return;
2082 
2083   os << " = ";
2084   printAttribute(attr.getValue());
2085 }
2086 
2087 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2088   auto &dialect = attr.getDialect();
2089 
2090   // Ask the dialect to serialize the attribute to a string.
2091   std::string attrName;
2092   {
2093     llvm::raw_string_ostream attrNameStr(attrName);
2094     Impl subPrinter(attrNameStr, printerFlags, state);
2095     DialectAsmPrinter printer(subPrinter);
2096     dialect.printAttribute(attr, printer);
2097   }
2098   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2099 }
2100 
2101 void AsmPrinter::Impl::printDialectType(Type type) {
2102   auto &dialect = type.getDialect();
2103 
2104   // Ask the dialect to serialize the type to a string.
2105   std::string typeName;
2106   {
2107     llvm::raw_string_ostream typeNameStr(typeName);
2108     Impl subPrinter(typeNameStr, printerFlags, state);
2109     DialectAsmPrinter printer(subPrinter);
2110     dialect.printType(type, printer);
2111   }
2112   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2113 }
2114 
2115 //===--------------------------------------------------------------------===//
2116 // AsmPrinter
2117 //===--------------------------------------------------------------------===//
2118 
2119 AsmPrinter::~AsmPrinter() = default;
2120 
2121 raw_ostream &AsmPrinter::getStream() const {
2122   assert(impl && "expected AsmPrinter::getStream to be overriden");
2123   return impl->getStream();
2124 }
2125 
2126 /// Print the given floating point value in a stablized form.
2127 void AsmPrinter::printFloat(const APFloat &value) {
2128   assert(impl && "expected AsmPrinter::printFloat to be overriden");
2129   printFloatValue(value, impl->getStream());
2130 }
2131 
2132 void AsmPrinter::printType(Type type) {
2133   assert(impl && "expected AsmPrinter::printType to be overriden");
2134   impl->printType(type);
2135 }
2136 
2137 void AsmPrinter::printAttribute(Attribute attr) {
2138   assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2139   impl->printAttribute(attr);
2140 }
2141 
2142 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2143   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2144   return impl->printAlias(attr);
2145 }
2146 
2147 LogicalResult AsmPrinter::printAlias(Type type) {
2148   assert(impl && "expected AsmPrinter::printAlias to be overriden");
2149   return impl->printAlias(type);
2150 }
2151 
2152 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2153   assert(impl &&
2154          "expected AsmPrinter::printAttributeWithoutType to be overriden");
2155   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2156 }
2157 
2158 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2159   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2160   ::printKeywordOrString(keyword, impl->getStream());
2161 }
2162 
2163 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2164   assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2165   ::printSymbolReference(symbolRef, impl->getStream());
2166 }
2167 
2168 //===----------------------------------------------------------------------===//
2169 // Affine expressions and maps
2170 //===----------------------------------------------------------------------===//
2171 
2172 void AsmPrinter::Impl::printAffineExpr(
2173     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2174   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2175 }
2176 
2177 void AsmPrinter::Impl::printAffineExprInternal(
2178     AffineExpr expr, BindingStrength enclosingTightness,
2179     function_ref<void(unsigned, bool)> printValueName) {
2180   const char *binopSpelling = nullptr;
2181   switch (expr.getKind()) {
2182   case AffineExprKind::SymbolId: {
2183     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2184     if (printValueName)
2185       printValueName(pos, /*isSymbol=*/true);
2186     else
2187       os << 's' << pos;
2188     return;
2189   }
2190   case AffineExprKind::DimId: {
2191     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2192     if (printValueName)
2193       printValueName(pos, /*isSymbol=*/false);
2194     else
2195       os << 'd' << pos;
2196     return;
2197   }
2198   case AffineExprKind::Constant:
2199     os << expr.cast<AffineConstantExpr>().getValue();
2200     return;
2201   case AffineExprKind::Add:
2202     binopSpelling = " + ";
2203     break;
2204   case AffineExprKind::Mul:
2205     binopSpelling = " * ";
2206     break;
2207   case AffineExprKind::FloorDiv:
2208     binopSpelling = " floordiv ";
2209     break;
2210   case AffineExprKind::CeilDiv:
2211     binopSpelling = " ceildiv ";
2212     break;
2213   case AffineExprKind::Mod:
2214     binopSpelling = " mod ";
2215     break;
2216   }
2217 
2218   auto binOp = expr.cast<AffineBinaryOpExpr>();
2219   AffineExpr lhsExpr = binOp.getLHS();
2220   AffineExpr rhsExpr = binOp.getRHS();
2221 
2222   // Handle tightly binding binary operators.
2223   if (binOp.getKind() != AffineExprKind::Add) {
2224     if (enclosingTightness == BindingStrength::Strong)
2225       os << '(';
2226 
2227     // Pretty print multiplication with -1.
2228     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2229     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2230         rhsConst.getValue() == -1) {
2231       os << "-";
2232       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2233       if (enclosingTightness == BindingStrength::Strong)
2234         os << ')';
2235       return;
2236     }
2237 
2238     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2239 
2240     os << binopSpelling;
2241     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2242 
2243     if (enclosingTightness == BindingStrength::Strong)
2244       os << ')';
2245     return;
2246   }
2247 
2248   // Print out special "pretty" forms for add.
2249   if (enclosingTightness == BindingStrength::Strong)
2250     os << '(';
2251 
2252   // Pretty print addition to a product that has a negative operand as a
2253   // subtraction.
2254   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2255     if (rhs.getKind() == AffineExprKind::Mul) {
2256       AffineExpr rrhsExpr = rhs.getRHS();
2257       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2258         if (rrhs.getValue() == -1) {
2259           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2260                                   printValueName);
2261           os << " - ";
2262           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2263             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2264                                     printValueName);
2265           } else {
2266             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2267                                     printValueName);
2268           }
2269 
2270           if (enclosingTightness == BindingStrength::Strong)
2271             os << ')';
2272           return;
2273         }
2274 
2275         if (rrhs.getValue() < -1) {
2276           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2277                                   printValueName);
2278           os << " - ";
2279           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2280                                   printValueName);
2281           os << " * " << -rrhs.getValue();
2282           if (enclosingTightness == BindingStrength::Strong)
2283             os << ')';
2284           return;
2285         }
2286       }
2287     }
2288   }
2289 
2290   // Pretty print addition to a negative number as a subtraction.
2291   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2292     if (rhsConst.getValue() < 0) {
2293       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2294       os << " - " << -rhsConst.getValue();
2295       if (enclosingTightness == BindingStrength::Strong)
2296         os << ')';
2297       return;
2298     }
2299   }
2300 
2301   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2302 
2303   os << " + ";
2304   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2305 
2306   if (enclosingTightness == BindingStrength::Strong)
2307     os << ')';
2308 }
2309 
2310 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2311   printAffineExprInternal(expr, BindingStrength::Weak);
2312   isEq ? os << " == 0" : os << " >= 0";
2313 }
2314 
2315 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2316   // Dimension identifiers.
2317   os << '(';
2318   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2319     os << 'd' << i << ", ";
2320   if (map.getNumDims() >= 1)
2321     os << 'd' << map.getNumDims() - 1;
2322   os << ')';
2323 
2324   // Symbolic identifiers.
2325   if (map.getNumSymbols() != 0) {
2326     os << '[';
2327     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2328       os << 's' << i << ", ";
2329     if (map.getNumSymbols() >= 1)
2330       os << 's' << map.getNumSymbols() - 1;
2331     os << ']';
2332   }
2333 
2334   // Result affine expressions.
2335   os << " -> (";
2336   interleaveComma(map.getResults(),
2337                   [&](AffineExpr expr) { printAffineExpr(expr); });
2338   os << ')';
2339 }
2340 
2341 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
2342   // Dimension identifiers.
2343   os << '(';
2344   for (unsigned i = 1; i < set.getNumDims(); ++i)
2345     os << 'd' << i - 1 << ", ";
2346   if (set.getNumDims() >= 1)
2347     os << 'd' << set.getNumDims() - 1;
2348   os << ')';
2349 
2350   // Symbolic identifiers.
2351   if (set.getNumSymbols() != 0) {
2352     os << '[';
2353     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2354       os << 's' << i << ", ";
2355     if (set.getNumSymbols() >= 1)
2356       os << 's' << set.getNumSymbols() - 1;
2357     os << ']';
2358   }
2359 
2360   // Print constraints.
2361   os << " : (";
2362   int numConstraints = set.getNumConstraints();
2363   for (int i = 1; i < numConstraints; ++i) {
2364     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2365     os << ", ";
2366   }
2367   if (numConstraints >= 1)
2368     printAffineConstraint(set.getConstraint(numConstraints - 1),
2369                           set.isEq(numConstraints - 1));
2370   os << ')';
2371 }
2372 
2373 //===----------------------------------------------------------------------===//
2374 // OperationPrinter
2375 //===----------------------------------------------------------------------===//
2376 
2377 namespace {
2378 /// This class contains the logic for printing operations, regions, and blocks.
2379 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
2380 public:
2381   using Impl = AsmPrinter::Impl;
2382   using Impl::printType;
2383 
2384   explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2385                             AsmStateImpl &state)
2386       : Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
2387 
2388   /// Print the given top-level operation.
2389   void printTopLevelOperation(Operation *op);
2390 
2391   /// Print the given operation with its indent and location.
2392   void print(Operation *op);
2393   /// Print the bare location, not including indentation/location/etc.
2394   void printOperation(Operation *op);
2395   /// Print the given operation in the generic form.
2396   void printGenericOp(Operation *op, bool printOpName) override;
2397 
2398   /// Print the name of the given block.
2399   void printBlockName(Block *block);
2400 
2401   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2402   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2403   /// operation of the block is not printed.
2404   void print(Block *block, bool printBlockArgs = true,
2405              bool printBlockTerminator = true);
2406 
2407   /// Print the ID of the given value, optionally with its result number.
2408   void printValueID(Value value, bool printResultNo = true,
2409                     raw_ostream *streamOverride = nullptr) const;
2410 
2411   //===--------------------------------------------------------------------===//
2412   // OpAsmPrinter methods
2413   //===--------------------------------------------------------------------===//
2414 
2415   /// Print a newline and indent the printer to the start of the current
2416   /// operation.
2417   void printNewline() override {
2418     os << newLine;
2419     os.indent(currentIndent);
2420   }
2421 
2422   /// Print a block argument in the usual format of:
2423   ///   %ssaName : type {attr1=42} loc("here")
2424   /// where location printing is controlled by the standard internal option.
2425   /// You may pass omitType=true to not print a type, and pass an empty
2426   /// attribute list if you don't care for attributes.
2427   void printRegionArgument(BlockArgument arg,
2428                            ArrayRef<NamedAttribute> argAttrs = {},
2429                            bool omitType = false) override;
2430 
2431   /// Print the ID for the given value.
2432   void printOperand(Value value) override { printValueID(value); }
2433   void printOperand(Value value, raw_ostream &os) override {
2434     printValueID(value, /*printResultNo=*/true, &os);
2435   }
2436 
2437   /// Print an optional attribute dictionary with a given set of elided values.
2438   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2439                              ArrayRef<StringRef> elidedAttrs = {}) override {
2440     Impl::printOptionalAttrDict(attrs, elidedAttrs);
2441   }
2442   void printOptionalAttrDictWithKeyword(
2443       ArrayRef<NamedAttribute> attrs,
2444       ArrayRef<StringRef> elidedAttrs = {}) override {
2445     Impl::printOptionalAttrDict(attrs, elidedAttrs,
2446                                 /*withKeyword=*/true);
2447   }
2448 
2449   /// Print the given successor.
2450   void printSuccessor(Block *successor) override;
2451 
2452   /// Print an operation successor with the operands used for the block
2453   /// arguments.
2454   void printSuccessorAndUseList(Block *successor,
2455                                 ValueRange succOperands) override;
2456 
2457   /// Print the given region.
2458   void printRegion(Region &region, bool printEntryBlockArgs,
2459                    bool printBlockTerminators, bool printEmptyBlock) override;
2460 
2461   /// Renumber the arguments for the specified region to the same names as the
2462   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2463   /// operations. If any entry in namesToUse is null, the corresponding
2464   /// argument name is left alone.
2465   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2466     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2467   }
2468 
2469   /// Print the given affine map with the symbol and dimension operands printed
2470   /// inline with the map.
2471   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2472                               ValueRange operands) override;
2473 
2474   /// Print the given affine expression with the symbol and dimension operands
2475   /// printed inline with the expression.
2476   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2477                                ValueRange symOperands) override;
2478 
2479 private:
2480   // Contains the stack of default dialects to use when printing regions.
2481   // A new dialect is pushed to the stack before parsing regions nested under an
2482   // operation implementing `OpAsmOpInterface`, and popped when done. At the
2483   // top-level we start with "builtin" as the default, so that the top-level
2484   // `module` operation prints as-is.
2485   SmallVector<StringRef> defaultDialectStack{"builtin"};
2486 
2487   /// The number of spaces used for indenting nested operations.
2488   const static unsigned indentWidth = 2;
2489 
2490   // This is the current indentation level for nested structures.
2491   unsigned currentIndent = 0;
2492 };
2493 } // namespace
2494 
2495 void OperationPrinter::printTopLevelOperation(Operation *op) {
2496   // Output the aliases at the top level that can't be deferred.
2497   state->getAliasState().printNonDeferredAliases(os, newLine);
2498 
2499   // Print the module.
2500   print(op);
2501   os << newLine;
2502 
2503   // Output the aliases at the top level that can be deferred.
2504   state->getAliasState().printDeferredAliases(os, newLine);
2505 }
2506 
2507 /// Print a block argument in the usual format of:
2508 ///   %ssaName : type {attr1=42} loc("here")
2509 /// where location printing is controlled by the standard internal option.
2510 /// You may pass omitType=true to not print a type, and pass an empty
2511 /// attribute list if you don't care for attributes.
2512 void OperationPrinter::printRegionArgument(BlockArgument arg,
2513                                            ArrayRef<NamedAttribute> argAttrs,
2514                                            bool omitType) {
2515   printOperand(arg);
2516   if (!omitType) {
2517     os << ": ";
2518     printType(arg.getType());
2519   }
2520   printOptionalAttrDict(argAttrs);
2521   // TODO: We should allow location aliases on block arguments.
2522   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2523 }
2524 
2525 void OperationPrinter::print(Operation *op) {
2526   // Track the location of this operation.
2527   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2528 
2529   os.indent(currentIndent);
2530   printOperation(op);
2531   printTrailingLocation(op->getLoc());
2532 }
2533 
2534 void OperationPrinter::printOperation(Operation *op) {
2535   if (size_t numResults = op->getNumResults()) {
2536     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2537       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2538       if (resultCount > 1)
2539         os << ':' << resultCount;
2540     };
2541 
2542     // Check to see if this operation has multiple result groups.
2543     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2544     if (!resultGroups.empty()) {
2545       // Interleave the groups excluding the last one, this one will be handled
2546       // separately.
2547       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2548         printResultGroup(resultGroups[i],
2549                          resultGroups[i + 1] - resultGroups[i]);
2550       });
2551       os << ", ";
2552       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2553 
2554     } else {
2555       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2556     }
2557 
2558     os << " = ";
2559   }
2560 
2561   // If requested, always print the generic form.
2562   if (!printerFlags.shouldPrintGenericOpForm()) {
2563     // Check to see if this is a known operation. If so, use the registered
2564     // custom printer hook.
2565     if (auto opInfo = op->getRegisteredInfo()) {
2566       opInfo->printAssembly(op, *this, defaultDialectStack.back());
2567       return;
2568     }
2569     // Otherwise try to dispatch to the dialect, if available.
2570     if (Dialect *dialect = op->getDialect()) {
2571       if (auto opPrinter = dialect->getOperationPrinter(op)) {
2572         // Print the op name first.
2573         StringRef name = op->getName().getStringRef();
2574         name.consume_front((defaultDialectStack.back() + ".").str());
2575         printEscapedString(name, os);
2576         // Print the rest of the op now.
2577         opPrinter(op, *this);
2578         return;
2579       }
2580     }
2581   }
2582 
2583   // Otherwise print with the generic assembly form.
2584   printGenericOp(op, /*printOpName=*/true);
2585 }
2586 
2587 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
2588   if (printOpName) {
2589     os << '"';
2590     printEscapedString(op->getName().getStringRef(), os);
2591     os << '"';
2592   }
2593   os << '(';
2594   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2595   os << ')';
2596 
2597   // For terminators, print the list of successors and their operands.
2598   if (op->getNumSuccessors() != 0) {
2599     os << '[';
2600     interleaveComma(op->getSuccessors(),
2601                     [&](Block *successor) { printBlockName(successor); });
2602     os << ']';
2603   }
2604 
2605   // Print regions.
2606   if (op->getNumRegions() != 0) {
2607     os << " (";
2608     interleaveComma(op->getRegions(), [&](Region &region) {
2609       printRegion(region, /*printEntryBlockArgs=*/true,
2610                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2611     });
2612     os << ')';
2613   }
2614 
2615   auto attrs = op->getAttrs();
2616   printOptionalAttrDict(attrs);
2617 
2618   // Print the type signature of the operation.
2619   os << " : ";
2620   printFunctionalType(op);
2621 }
2622 
2623 void OperationPrinter::printBlockName(Block *block) {
2624   auto id = state->getSSANameState().getBlockID(block);
2625   if (id != SSANameState::NameSentinel)
2626     os << "^bb" << id;
2627   else
2628     os << "^INVALIDBLOCK";
2629 }
2630 
2631 void OperationPrinter::print(Block *block, bool printBlockArgs,
2632                              bool printBlockTerminator) {
2633   // Print the block label and argument list if requested.
2634   if (printBlockArgs) {
2635     os.indent(currentIndent);
2636     printBlockName(block);
2637 
2638     // Print the argument list if non-empty.
2639     if (!block->args_empty()) {
2640       os << '(';
2641       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2642         printValueID(arg);
2643         os << ": ";
2644         printType(arg.getType());
2645         // TODO: We should allow location aliases on block arguments.
2646         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2647       });
2648       os << ')';
2649     }
2650     os << ':';
2651 
2652     // Print out some context information about the predecessors of this block.
2653     if (!block->getParent()) {
2654       os << "  // block is not in a region!";
2655     } else if (block->hasNoPredecessors()) {
2656       os << "  // no predecessors";
2657     } else if (auto *pred = block->getSinglePredecessor()) {
2658       os << "  // pred: ";
2659       printBlockName(pred);
2660     } else {
2661       // We want to print the predecessors in increasing numeric order, not in
2662       // whatever order the use-list is in, so gather and sort them.
2663       SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2664       for (auto *pred : block->getPredecessors())
2665         predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2666       llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2667 
2668       os << "  // " << predIDs.size() << " preds: ";
2669 
2670       interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2671         printBlockName(pred.second);
2672       });
2673     }
2674     os << newLine;
2675   }
2676 
2677   currentIndent += indentWidth;
2678   bool hasTerminator =
2679       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
2680   auto range = llvm::make_range(
2681       block->begin(),
2682       std::prev(block->end(),
2683                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
2684   for (auto &op : range) {
2685     print(&op);
2686     os << newLine;
2687   }
2688   currentIndent -= indentWidth;
2689 }
2690 
2691 void OperationPrinter::printValueID(Value value, bool printResultNo,
2692                                     raw_ostream *streamOverride) const {
2693   state->getSSANameState().printValueID(value, printResultNo,
2694                                         streamOverride ? *streamOverride : os);
2695 }
2696 
2697 void OperationPrinter::printSuccessor(Block *successor) {
2698   printBlockName(successor);
2699 }
2700 
2701 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2702                                                 ValueRange succOperands) {
2703   printBlockName(successor);
2704   if (succOperands.empty())
2705     return;
2706 
2707   os << '(';
2708   interleaveComma(succOperands,
2709                   [this](Value operand) { printValueID(operand); });
2710   os << " : ";
2711   interleaveComma(succOperands,
2712                   [this](Value operand) { printType(operand.getType()); });
2713   os << ')';
2714 }
2715 
2716 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2717                                    bool printBlockTerminators,
2718                                    bool printEmptyBlock) {
2719   os << "{" << newLine;
2720   if (!region.empty()) {
2721     auto restoreDefaultDialect =
2722         llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
2723     if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
2724       defaultDialectStack.push_back(iface.getDefaultDialect());
2725     else
2726       defaultDialectStack.push_back("");
2727 
2728     auto *entryBlock = &region.front();
2729     // Force printing the block header if printEmptyBlock is set and the block
2730     // is empty or if printEntryBlockArgs is set and there are arguments to
2731     // print.
2732     bool shouldAlwaysPrintBlockHeader =
2733         (printEmptyBlock && entryBlock->empty()) ||
2734         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
2735     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
2736     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2737       print(&b);
2738   }
2739   os.indent(currentIndent) << "}";
2740 }
2741 
2742 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2743                                               ValueRange operands) {
2744   AffineMap map = mapAttr.getValue();
2745   unsigned numDims = map.getNumDims();
2746   auto printValueName = [&](unsigned pos, bool isSymbol) {
2747     unsigned index = isSymbol ? numDims + pos : pos;
2748     assert(index < operands.size());
2749     if (isSymbol)
2750       os << "symbol(";
2751     printValueID(operands[index]);
2752     if (isSymbol)
2753       os << ')';
2754   };
2755 
2756   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2757     printAffineExpr(expr, printValueName);
2758   });
2759 }
2760 
2761 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
2762                                                ValueRange dimOperands,
2763                                                ValueRange symOperands) {
2764   auto printValueName = [&](unsigned pos, bool isSymbol) {
2765     if (!isSymbol)
2766       return printValueID(dimOperands[pos]);
2767     os << "symbol(";
2768     printValueID(symOperands[pos]);
2769     os << ')';
2770   };
2771   printAffineExpr(expr, printValueName);
2772 }
2773 
2774 //===----------------------------------------------------------------------===//
2775 // print and dump methods
2776 //===----------------------------------------------------------------------===//
2777 
2778 void Attribute::print(raw_ostream &os) const {
2779   AsmPrinter::Impl(os).printAttribute(*this);
2780 }
2781 
2782 void Attribute::dump() const {
2783   print(llvm::errs());
2784   llvm::errs() << "\n";
2785 }
2786 
2787 void Type::print(raw_ostream &os) const {
2788   AsmPrinter::Impl(os).printType(*this);
2789 }
2790 
2791 void Type::dump() const { print(llvm::errs()); }
2792 
2793 void AffineMap::dump() const {
2794   print(llvm::errs());
2795   llvm::errs() << "\n";
2796 }
2797 
2798 void IntegerSet::dump() const {
2799   print(llvm::errs());
2800   llvm::errs() << "\n";
2801 }
2802 
2803 void AffineExpr::print(raw_ostream &os) const {
2804   if (!expr) {
2805     os << "<<NULL AFFINE EXPR>>";
2806     return;
2807   }
2808   AsmPrinter::Impl(os).printAffineExpr(*this);
2809 }
2810 
2811 void AffineExpr::dump() const {
2812   print(llvm::errs());
2813   llvm::errs() << "\n";
2814 }
2815 
2816 void AffineMap::print(raw_ostream &os) const {
2817   if (!map) {
2818     os << "<<NULL AFFINE MAP>>";
2819     return;
2820   }
2821   AsmPrinter::Impl(os).printAffineMap(*this);
2822 }
2823 
2824 void IntegerSet::print(raw_ostream &os) const {
2825   AsmPrinter::Impl(os).printIntegerSet(*this);
2826 }
2827 
2828 void Value::print(raw_ostream &os) {
2829   if (!impl) {
2830     os << "<<NULL VALUE>>";
2831     return;
2832   }
2833 
2834   if (auto *op = getDefiningOp())
2835     return op->print(os);
2836   // TODO: Improve BlockArgument print'ing.
2837   BlockArgument arg = this->cast<BlockArgument>();
2838   os << "<block argument> of type '" << arg.getType()
2839      << "' at index: " << arg.getArgNumber();
2840 }
2841 void Value::print(raw_ostream &os, AsmState &state) {
2842   if (!impl) {
2843     os << "<<NULL VALUE>>";
2844     return;
2845   }
2846 
2847   if (auto *op = getDefiningOp())
2848     return op->print(os, state);
2849 
2850   // TODO: Improve BlockArgument print'ing.
2851   BlockArgument arg = this->cast<BlockArgument>();
2852   os << "<block argument> of type '" << arg.getType()
2853      << "' at index: " << arg.getArgNumber();
2854 }
2855 
2856 void Value::dump() {
2857   print(llvm::errs());
2858   llvm::errs() << "\n";
2859 }
2860 
2861 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2862   // TODO: This doesn't necessarily capture all potential cases.
2863   // Currently, region arguments can be shadowed when printing the main
2864   // operation. If the IR hasn't been printed, this will produce the old SSA
2865   // name and not the shadowed name.
2866   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2867                                                  os);
2868 }
2869 
2870 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
2871   // If this is a top level operation, we also print aliases.
2872   if (!getParent() && !printerFlags.shouldUseLocalScope()) {
2873     AsmState state(this, printerFlags);
2874     state.getImpl().initializeAliases(this);
2875     print(os, state, printerFlags);
2876     return;
2877   }
2878 
2879   // Find the operation to number from based upon the provided flags.
2880   Operation *op = this;
2881   bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
2882   do {
2883     // If we are printing local scope, stop at the first operation that is
2884     // isolated from above.
2885     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
2886       break;
2887 
2888     // Otherwise, traverse up to the next parent.
2889     Operation *parentOp = op->getParentOp();
2890     if (!parentOp)
2891       break;
2892     op = parentOp;
2893   } while (true);
2894 
2895   AsmState state(op, printerFlags);
2896   print(os, state, printerFlags);
2897 }
2898 void Operation::print(raw_ostream &os, AsmState &state,
2899                       const OpPrintingFlags &flags) {
2900   OperationPrinter printer(os, flags, state.getImpl());
2901   if (!getParent() && !flags.shouldUseLocalScope())
2902     printer.printTopLevelOperation(this);
2903   else
2904     printer.print(this);
2905 }
2906 
2907 void Operation::dump() {
2908   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2909   llvm::errs() << "\n";
2910 }
2911 
2912 void Block::print(raw_ostream &os) {
2913   Operation *parentOp = getParentOp();
2914   if (!parentOp) {
2915     os << "<<UNLINKED BLOCK>>\n";
2916     return;
2917   }
2918   // Get the top-level op.
2919   while (auto *nextOp = parentOp->getParentOp())
2920     parentOp = nextOp;
2921 
2922   AsmState state(parentOp);
2923   print(os, state);
2924 }
2925 void Block::print(raw_ostream &os, AsmState &state) {
2926   OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2927 }
2928 
2929 void Block::dump() { print(llvm::errs()); }
2930 
2931 /// Print out the name of the block without printing its body.
2932 void Block::printAsOperand(raw_ostream &os, bool printType) {
2933   Operation *parentOp = getParentOp();
2934   if (!parentOp) {
2935     os << "<<UNLINKED BLOCK>>\n";
2936     return;
2937   }
2938   AsmState state(parentOp);
2939   printAsOperand(os, state);
2940 }
2941 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2942   OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2943   printer.printBlockName(this);
2944 }
2945