1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Operation.h"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/ScopedHashTable.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSet.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Regex.h"
37 #include "llvm/Support/SaveAndRestore.h"
38 using namespace mlir;
39 using namespace mlir::detail;
40 
41 void Identifier::print(raw_ostream &os) const { os << str(); }
42 
43 void Identifier::dump() const { print(llvm::errs()); }
44 
45 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
46 
47 void OperationName::dump() const { print(llvm::errs()); }
48 
49 DialectAsmPrinter::~DialectAsmPrinter() {}
50 
51 //===--------------------------------------------------------------------===//
52 // OpAsmPrinter
53 //===--------------------------------------------------------------------===//
54 
55 OpAsmPrinter::~OpAsmPrinter() {}
56 
57 void OpAsmPrinter::printFunctionalType(Operation *op) {
58   auto &os = getStream();
59   os << '(';
60   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
61     // Print the types of null values as <<NULL TYPE>>.
62     *this << (operand ? operand.getType() : Type());
63   });
64   os << ") -> ";
65 
66   // Print the result list.  We don't parenthesize single result types unless
67   // it is a function (avoiding a grammar ambiguity).
68   bool wrapped = op->getNumResults() != 1;
69   if (!wrapped && op->getResult(0).getType() &&
70       op->getResult(0).getType().isa<FunctionType>())
71     wrapped = true;
72 
73   if (wrapped)
74     os << '(';
75 
76   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
77     // Print the types of null values as <<NULL TYPE>>.
78     *this << (result ? result.getType() : Type());
79   });
80 
81   if (wrapped)
82     os << ')';
83 }
84 
85 //===--------------------------------------------------------------------===//
86 // Operation OpAsm interface.
87 //===--------------------------------------------------------------------===//
88 
89 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
90 #include "mlir/IR/OpAsmInterface.cpp.inc"
91 
92 //===----------------------------------------------------------------------===//
93 // OpPrintingFlags
94 //===----------------------------------------------------------------------===//
95 
96 namespace {
97 /// This struct contains command line options that can be used to initialize
98 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
99 /// for global command line options.
100 struct AsmPrinterOptions {
101   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
102       "mlir-print-elementsattrs-with-hex-if-larger",
103       llvm::cl::desc(
104           "Print DenseElementsAttrs with a hex string that have "
105           "more elements than the given upper limit (use -1 to disable)")};
106 
107   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
108       "mlir-elide-elementsattrs-if-larger",
109       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
110                      "more elements than the given upper limit")};
111 
112   llvm::cl::opt<bool> printDebugInfoOpt{
113       "mlir-print-debuginfo", llvm::cl::init(false),
114       llvm::cl::desc("Print debug info in MLIR output")};
115 
116   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
117       "mlir-pretty-debuginfo", llvm::cl::init(false),
118       llvm::cl::desc("Print pretty debug info in MLIR output")};
119 
120   // Use the generic op output form in the operation printer even if the custom
121   // form is defined.
122   llvm::cl::opt<bool> printGenericOpFormOpt{
123       "mlir-print-op-generic", llvm::cl::init(false),
124       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
125 
126   llvm::cl::opt<bool> printLocalScopeOpt{
127       "mlir-print-local-scope", llvm::cl::init(false),
128       llvm::cl::desc("Print assuming in local scope by default"),
129       llvm::cl::Hidden};
130 };
131 } // end anonymous namespace
132 
133 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
134 
135 /// Register a set of useful command-line options that can be used to configure
136 /// various flags within the AsmPrinter.
137 void mlir::registerAsmPrinterCLOptions() {
138   // Make sure that the options struct has been initialized.
139   *clOptions;
140 }
141 
142 /// Initialize the printing flags with default supplied by the cl::opts above.
143 OpPrintingFlags::OpPrintingFlags()
144     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
145       printGenericOpFormFlag(false), printLocalScope(false) {
146   // Initialize based upon command line options, if they are available.
147   if (!clOptions.isConstructed())
148     return;
149   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
150     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
151   printDebugInfoFlag = clOptions->printDebugInfoOpt;
152   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
153   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
154   printLocalScope = clOptions->printLocalScopeOpt;
155 }
156 
157 /// Enable the elision of large elements attributes, by printing a '...'
158 /// instead of the element data, when the number of elements is greater than
159 /// `largeElementLimit`. Note: The IR generated with this option is not
160 /// parsable.
161 OpPrintingFlags &
162 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
163   elementsAttrElementLimit = largeElementLimit;
164   return *this;
165 }
166 
167 /// Enable printing of debug information. If 'prettyForm' is set to true,
168 /// debug information is printed in a more readable 'pretty' form.
169 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
170   printDebugInfoFlag = true;
171   printDebugInfoPrettyFormFlag = prettyForm;
172   return *this;
173 }
174 
175 /// Always print operations in the generic form.
176 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
177   printGenericOpFormFlag = true;
178   return *this;
179 }
180 
181 /// Use local scope when printing the operation. This allows for using the
182 /// printer in a more localized and thread-safe setting, but may not necessarily
183 /// be identical of what the IR will look like when dumping the full module.
184 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
185   printLocalScope = true;
186   return *this;
187 }
188 
189 /// Return if the given ElementsAttr should be elided.
190 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
191   return elementsAttrElementLimit.hasValue() &&
192          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
193          !attr.isa<SplatElementsAttr>();
194 }
195 
196 /// Return the size limit for printing large ElementsAttr.
197 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
198   return elementsAttrElementLimit;
199 }
200 
201 /// Return if debug information should be printed.
202 bool OpPrintingFlags::shouldPrintDebugInfo() const {
203   return printDebugInfoFlag;
204 }
205 
206 /// Return if debug information should be printed in the pretty form.
207 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
208   return printDebugInfoPrettyFormFlag;
209 }
210 
211 /// Return if operations should be printed in the generic form.
212 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
213   return printGenericOpFormFlag;
214 }
215 
216 /// Return if the printer should use local scope when dumping the IR.
217 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
218 
219 /// Returns true if an ElementsAttr with the given number of elements should be
220 /// printed with hex.
221 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
222   // Check to see if a command line option was provided for the limit.
223   if (clOptions.isConstructed()) {
224     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
225       // -1 is used to disable hex printing.
226       if (clOptions->printElementsAttrWithHexIfLarger == -1)
227         return false;
228       return numElements > clOptions->printElementsAttrWithHexIfLarger;
229     }
230   }
231 
232   // Otherwise, default to printing with hex if the number of elements is >100.
233   return numElements > 100;
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // NewLineCounter
238 //===----------------------------------------------------------------------===//
239 
240 namespace {
241 /// This class is a simple formatter that emits a new line when inputted into a
242 /// stream, that enables counting the number of newlines emitted. This class
243 /// should be used whenever emitting newlines in the printer.
244 struct NewLineCounter {
245   unsigned curLine = 1;
246 };
247 } // end anonymous namespace
248 
249 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
250   ++newLine.curLine;
251   return os << '\n';
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // AliasInitializer
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 /// This class represents a specific instance of a symbol Alias.
260 class SymbolAlias {
261 public:
262   SymbolAlias(StringRef name, bool isDeferrable)
263       : name(name), suffixIndex(0), hasSuffixIndex(false),
264         isDeferrable(isDeferrable) {}
265   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
266       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
267         isDeferrable(isDeferrable) {}
268 
269   /// Print this alias to the given stream.
270   void print(raw_ostream &os) const {
271     os << name;
272     if (hasSuffixIndex)
273       os << suffixIndex;
274   }
275 
276   /// Returns true if this alias supports deferred resolution when parsing.
277   bool canBeDeferred() const { return isDeferrable; }
278 
279 private:
280   /// The main name of the alias.
281   StringRef name;
282   /// The optional suffix index of the alias, if multiple aliases had the same
283   /// name.
284   uint32_t suffixIndex : 30;
285   /// A flag indicating whether this alias has a suffix or not.
286   bool hasSuffixIndex : 1;
287   /// A flag indicating whether this alias may be deferred or not.
288   bool isDeferrable : 1;
289 };
290 
291 /// This class represents a utility that initializes the set of attribute and
292 /// type aliases, without the need to store the extra information within the
293 /// main AliasState class or pass it around via function arguments.
294 class AliasInitializer {
295 public:
296   AliasInitializer(
297       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
298       llvm::BumpPtrAllocator &aliasAllocator)
299       : interfaces(interfaces), aliasAllocator(aliasAllocator),
300         aliasOS(aliasBuffer) {}
301 
302   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
303                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
304                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
305 
306   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
307   /// set to true if the originator of this attribute can resolve the alias
308   /// after parsing has completed (e.g. in the case of operation locations).
309   void visit(Attribute attr, bool canBeDeferred = false);
310 
311   /// Visit the given type to see if it has an alias.
312   void visit(Type type);
313 
314 private:
315   /// Try to generate an alias for the provided symbol. If an alias is
316   /// generated, the provided alias mapping and reverse mapping are updated.
317   /// Returns success if an alias was generated, failure otherwise.
318   template <typename T>
319   LogicalResult
320   generateAlias(T symbol,
321                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
322 
323   /// The set of asm interfaces within the context.
324   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
325 
326   /// Mapping between an alias and the set of symbols mapped to it.
327   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
328   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
329 
330   /// An allocator used for alias names.
331   llvm::BumpPtrAllocator &aliasAllocator;
332 
333   /// The set of visited attributes.
334   DenseSet<Attribute> visitedAttributes;
335 
336   /// The set of attributes that have aliases *and* can be deferred.
337   DenseSet<Attribute> deferrableAttributes;
338 
339   /// The set of visited types.
340   DenseSet<Type> visitedTypes;
341 
342   /// Storage and stream used when generating an alias.
343   SmallString<32> aliasBuffer;
344   llvm::raw_svector_ostream aliasOS;
345 };
346 
347 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
348 /// and merely collects the attributes and types that *would* be printed in a
349 /// normal print invocation so that we can generate proper aliases. This allows
350 /// for us to generate aliases only for the attributes and types that would be
351 /// in the output, and trims down unnecessary output.
352 class DummyAliasOperationPrinter : private OpAsmPrinter {
353 public:
354   explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags,
355                                       AliasInitializer &initializer)
356       : printerFlags(flags), initializer(initializer) {}
357 
358   /// Print the given operation.
359   void print(Operation *op) {
360     // Visit the operation location.
361     if (printerFlags.shouldPrintDebugInfo())
362       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
363 
364     // If requested, always print the generic form.
365     if (!printerFlags.shouldPrintGenericOpForm()) {
366       // Check to see if this is a known operation.  If so, use the registered
367       // custom printer hook.
368       if (auto *opInfo = op->getAbstractOperation()) {
369         opInfo->printAssembly(op, *this);
370         return;
371       }
372     }
373 
374     // Otherwise print with the generic assembly form.
375     printGenericOp(op);
376   }
377 
378 private:
379   /// Print the given operation in the generic form.
380   void printGenericOp(Operation *op) override {
381     // Consider nested operations for aliases.
382     if (op->getNumRegions() != 0) {
383       for (Region &region : op->getRegions())
384         printRegion(region, /*printEntryBlockArgs=*/true,
385                     /*printBlockTerminators=*/true);
386     }
387 
388     // Visit all the types used in the operation.
389     for (Type type : op->getOperandTypes())
390       printType(type);
391     for (Type type : op->getResultTypes())
392       printType(type);
393 
394     // Consider the attributes of the operation for aliases.
395     for (const NamedAttribute &attr : op->getAttrs())
396       printAttribute(attr.second);
397   }
398 
399   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
400   /// block are not printed. If 'printBlockTerminator' is false, the terminator
401   /// operation of the block is not printed.
402   void print(Block *block, bool printBlockArgs = true,
403              bool printBlockTerminator = true) {
404     // Consider the types of the block arguments for aliases if 'printBlockArgs'
405     // is set to true.
406     if (printBlockArgs) {
407       for (Type type : block->getArgumentTypes())
408         printType(type);
409     }
410 
411     // Consider the operations within this block, ignoring the terminator if
412     // requested.
413     auto range = llvm::make_range(
414         block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1));
415     for (Operation &op : range)
416       print(&op);
417   }
418 
419   /// Print the given region.
420   void printRegion(Region &region, bool printEntryBlockArgs,
421                    bool printBlockTerminators) override {
422     if (region.empty())
423       return;
424 
425     auto *entryBlock = &region.front();
426     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
427     for (Block &b : llvm::drop_begin(region, 1))
428       print(&b);
429   }
430 
431   /// Consider the given type to be printed for an alias.
432   void printType(Type type) override { initializer.visit(type); }
433 
434   /// Consider the given attribute to be printed for an alias.
435   void printAttribute(Attribute attr) override { initializer.visit(attr); }
436   void printAttributeWithoutType(Attribute attr) override {
437     printAttribute(attr);
438   }
439 
440   /// Print the given set of attributes with names not included within
441   /// 'elidedAttrs'.
442   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
443                              ArrayRef<StringRef> elidedAttrs = {}) override {
444     if (attrs.empty())
445       return;
446     if (elidedAttrs.empty()) {
447       for (const NamedAttribute &attr : attrs)
448         printAttribute(attr.second);
449       return;
450     }
451     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
452                                                   elidedAttrs.end());
453     for (const NamedAttribute &attr : attrs)
454       if (!elidedAttrsSet.contains(attr.first.strref()))
455         printAttribute(attr.second);
456   }
457   void printOptionalAttrDictWithKeyword(
458       ArrayRef<NamedAttribute> attrs,
459       ArrayRef<StringRef> elidedAttrs = {}) override {
460     printOptionalAttrDict(attrs, elidedAttrs);
461   }
462 
463   /// Return 'nulls' as the output stream, this will ignore any data fed to it.
464   raw_ostream &getStream() const override { return llvm::nulls(); }
465 
466   /// The following are hooks of `OpAsmPrinter` that are not necessary for
467   /// determining potential aliases.
468   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
469   void printNewline() override {}
470   void printOperand(Value) override {}
471   void printOperand(Value, raw_ostream &os) override {
472     // Users expect the output string to have at least the prefixed % to signal
473     // a value name. To maintain this invariant, emit a name even if it is
474     // guaranteed to go unused.
475     os << "%";
476   }
477   void printSymbolName(StringRef) override {}
478   void printSuccessor(Block *) override {}
479   void printSuccessorAndUseList(Block *, ValueRange) override {}
480   void shadowRegionArgs(Region &, ValueRange) override {}
481 
482   /// The printer flags to use when determining potential aliases.
483   const OpPrintingFlags &printerFlags;
484 
485   /// The initializer to use when identifying aliases.
486   AliasInitializer &initializer;
487 };
488 } // end anonymous namespace
489 
490 /// Sanitize the given name such that it can be used as a valid identifier. If
491 /// the string needs to be modified in any way, the provided buffer is used to
492 /// store the new copy,
493 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
494                                     StringRef allowedPunctChars = "$._-",
495                                     bool allowTrailingDigit = true) {
496   assert(!name.empty() && "Shouldn't have an empty name here");
497 
498   auto copyNameToBuffer = [&] {
499     for (char ch : name) {
500       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
501         buffer.push_back(ch);
502       else if (ch == ' ')
503         buffer.push_back('_');
504       else
505         buffer.append(llvm::utohexstr((unsigned char)ch));
506     }
507   };
508 
509   // Check to see if this name is valid. If it starts with a digit, then it
510   // could conflict with the autogenerated numeric ID's, so add an underscore
511   // prefix to avoid problems.
512   if (isdigit(name[0])) {
513     buffer.push_back('_');
514     copyNameToBuffer();
515     return buffer;
516   }
517 
518   // If the name ends with a trailing digit, add a '_' to avoid potential
519   // conflicts with autogenerated ID's.
520   if (!allowTrailingDigit && isdigit(name.back())) {
521     copyNameToBuffer();
522     buffer.push_back('_');
523     return buffer;
524   }
525 
526   // Check to see that the name consists of only valid identifier characters.
527   for (char ch : name) {
528     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
529       copyNameToBuffer();
530       return buffer;
531     }
532   }
533 
534   // If there are no invalid characters, return the original name.
535   return name;
536 }
537 
538 /// Given a collection of aliases and symbols, initialize a mapping from a
539 /// symbol to a given alias.
540 template <typename T>
541 static void
542 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
543                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
544                   DenseSet<T> *deferrableAliases = nullptr) {
545   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
546       aliasToSymbol.takeVector();
547   llvm::array_pod_sort(aliases.begin(), aliases.end(),
548                        [](const auto *lhs, const auto *rhs) {
549                          return lhs->first.compare(rhs->first);
550                        });
551 
552   for (auto &it : aliases) {
553     // If there is only one instance for this alias, use the name directly.
554     if (it.second.size() == 1) {
555       T symbol = it.second.front();
556       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
557       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
558       continue;
559     }
560     // Otherwise, add the index to the name.
561     for (int i = 0, e = it.second.size(); i < e; ++i) {
562       T symbol = it.second[i];
563       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
564       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
565     }
566   }
567 }
568 
569 void AliasInitializer::initialize(
570     Operation *op, const OpPrintingFlags &printerFlags,
571     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
572     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
573   // Use a dummy printer when walking the IR so that we can collect the
574   // attributes/types that will actually be used during printing when
575   // considering aliases.
576   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
577   aliasPrinter.print(op);
578 
579   // Initialize the aliases sorted by name.
580   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
581   initializeAliases(aliasToType, typeToAlias);
582 }
583 
584 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
585   if (!visitedAttributes.insert(attr).second) {
586     // If this attribute already has an alias and this instance can't be
587     // deferred, make sure that the alias isn't deferred.
588     if (!canBeDeferred)
589       deferrableAttributes.erase(attr);
590     return;
591   }
592 
593   // Try to generate an alias for this attribute.
594   if (succeeded(generateAlias(attr, aliasToAttr))) {
595     if (canBeDeferred)
596       deferrableAttributes.insert(attr);
597     return;
598   }
599 
600   if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
601     for (Attribute element : arrayAttr.getValue())
602       visit(element);
603   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
604     for (const NamedAttribute &attr : dictAttr)
605       visit(attr.second);
606   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
607     visit(typeAttr.getValue());
608   }
609 }
610 
611 void AliasInitializer::visit(Type type) {
612   if (!visitedTypes.insert(type).second)
613     return;
614 
615   // Try to generate an alias for this type.
616   if (succeeded(generateAlias(type, aliasToType)))
617     return;
618 
619   // Visit several subtypes that contain types or attributes.
620   if (auto funcType = type.dyn_cast<FunctionType>()) {
621     // Visit input and result types for functions.
622     for (auto input : funcType.getInputs())
623       visit(input);
624     for (auto result : funcType.getResults())
625       visit(result);
626   } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
627     visit(shapedType.getElementType());
628 
629     // Visit affine maps in memref type.
630     if (auto memref = type.dyn_cast<MemRefType>())
631       for (auto map : memref.getAffineMaps())
632         visit(AffineMapAttr::get(map));
633   }
634 }
635 
636 template <typename T>
637 LogicalResult AliasInitializer::generateAlias(
638     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
639   SmallString<16> tempBuffer;
640   for (const auto &interface : interfaces) {
641     if (failed(interface.getAlias(symbol, aliasOS)))
642       continue;
643     StringRef name = aliasOS.str();
644     assert(!name.empty() && "expected valid alias name");
645     name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
646                               /*allowTrailingDigit=*/false);
647     name = name.copy(aliasAllocator);
648 
649     aliasToSymbol[name].push_back(symbol);
650     aliasBuffer.clear();
651     return success();
652   }
653   return failure();
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // AliasState
658 //===----------------------------------------------------------------------===//
659 
660 namespace {
661 /// This class manages the state for type and attribute aliases.
662 class AliasState {
663 public:
664   // Initialize the internal aliases.
665   void
666   initialize(Operation *op, const OpPrintingFlags &printerFlags,
667              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
668 
669   /// Get an alias for the given attribute if it has one and print it in `os`.
670   /// Returns success if an alias was printed, failure otherwise.
671   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
672 
673   /// Get an alias for the given type if it has one and print it in `os`.
674   /// Returns success if an alias was printed, failure otherwise.
675   LogicalResult getAlias(Type ty, raw_ostream &os) const;
676 
677   /// Print all of the referenced aliases that can not be resolved in a deferred
678   /// manner.
679   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
680     printAliases(os, newLine, /*isDeferred=*/false);
681   }
682 
683   /// Print all of the referenced aliases that support deferred resolution.
684   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
685     printAliases(os, newLine, /*isDeferred=*/true);
686   }
687 
688 private:
689   /// Print all of the referenced aliases that support the provided resolution
690   /// behavior.
691   void printAliases(raw_ostream &os, NewLineCounter &newLine,
692                     bool isDeferred) const;
693 
694   /// Mapping between attribute and alias.
695   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
696   /// Mapping between type and alias.
697   llvm::MapVector<Type, SymbolAlias> typeToAlias;
698 
699   /// An allocator used for alias names.
700   llvm::BumpPtrAllocator aliasAllocator;
701 };
702 } // end anonymous namespace
703 
704 void AliasState::initialize(
705     Operation *op, const OpPrintingFlags &printerFlags,
706     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
707   AliasInitializer initializer(interfaces, aliasAllocator);
708   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
709 }
710 
711 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
712   auto it = attrToAlias.find(attr);
713   if (it == attrToAlias.end())
714     return failure();
715   it->second.print(os << '#');
716   return success();
717 }
718 
719 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
720   auto it = typeToAlias.find(ty);
721   if (it == typeToAlias.end())
722     return failure();
723 
724   it->second.print(os << '!');
725   return success();
726 }
727 
728 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
729                               bool isDeferred) const {
730   auto filterFn = [=](const auto &aliasIt) {
731     return aliasIt.second.canBeDeferred() == isDeferred;
732   };
733   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
734     it.second.print(os << '#');
735     os << " = " << it.first << newLine;
736   }
737   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
738     it.second.print(os << '!');
739     os << " = type " << it.first << newLine;
740   }
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // SSANameState
745 //===----------------------------------------------------------------------===//
746 
747 namespace {
748 /// This class manages the state of SSA value names.
749 class SSANameState {
750 public:
751   /// A sentinel value used for values with names set.
752   enum : unsigned { NameSentinel = ~0U };
753 
754   SSANameState(Operation *op,
755                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
756 
757   /// Print the SSA identifier for the given value to 'stream'. If
758   /// 'printResultNo' is true, it also presents the result number ('#' number)
759   /// of this value.
760   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
761 
762   /// Return the result indices for each of the result groups registered by this
763   /// operation, or empty if none exist.
764   ArrayRef<int> getOpResultGroups(Operation *op);
765 
766   /// Get the ID for the given block.
767   unsigned getBlockID(Block *block);
768 
769   /// Renumber the arguments for the specified region to the same names as the
770   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
771   /// details.
772   void shadowRegionArgs(Region &region, ValueRange namesToUse);
773 
774 private:
775   /// Number the SSA values within the given IR unit.
776   void numberValuesInRegion(
777       Region &region,
778       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
779   void numberValuesInBlock(
780       Block &block,
781       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
782   void numberValuesInOp(
783       Operation &op,
784       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
785 
786   /// Given a result of an operation 'result', find the result group head
787   /// 'lookupValue' and the result of 'result' within that group in
788   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
789   /// has more than 1 result.
790   void getResultIDAndNumber(OpResult result, Value &lookupValue,
791                             Optional<int> &lookupResultNo) const;
792 
793   /// Set a special value name for the given value.
794   void setValueName(Value value, StringRef name);
795 
796   /// Uniques the given value name within the printer. If the given name
797   /// conflicts, it is automatically renamed.
798   StringRef uniqueValueName(StringRef name);
799 
800   /// This is the value ID for each SSA value. If this returns NameSentinel,
801   /// then the valueID has an entry in valueNames.
802   DenseMap<Value, unsigned> valueIDs;
803   DenseMap<Value, StringRef> valueNames;
804 
805   /// This is a map of operations that contain multiple named result groups,
806   /// i.e. there may be multiple names for the results of the operation. The
807   /// value of this map are the result numbers that start a result group.
808   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
809 
810   /// This is the block ID for each block in the current.
811   DenseMap<Block *, unsigned> blockIDs;
812 
813   /// This keeps track of all of the non-numeric names that are in flight,
814   /// allowing us to check for duplicates.
815   /// Note: the value of the map is unused.
816   llvm::ScopedHashTable<StringRef, char> usedNames;
817   llvm::BumpPtrAllocator usedNameAllocator;
818 
819   /// This is the next value ID to assign in numbering.
820   unsigned nextValueID = 0;
821   /// This is the next ID to assign to a region entry block argument.
822   unsigned nextArgumentID = 0;
823   /// This is the next ID to assign when a name conflict is detected.
824   unsigned nextConflictID = 0;
825 };
826 } // end anonymous namespace
827 
828 SSANameState::SSANameState(
829     Operation *op,
830     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
831   llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
832   numberValuesInOp(*op, interfaces);
833 
834   for (auto &region : op->getRegions())
835     numberValuesInRegion(region, interfaces);
836 }
837 
838 void SSANameState::printValueID(Value value, bool printResultNo,
839                                 raw_ostream &stream) const {
840   if (!value) {
841     stream << "<<NULL>>";
842     return;
843   }
844 
845   Optional<int> resultNo;
846   auto lookupValue = value;
847 
848   // If this is an operation result, collect the head lookup value of the result
849   // group and the result number of 'result' within that group.
850   if (OpResult result = value.dyn_cast<OpResult>())
851     getResultIDAndNumber(result, lookupValue, resultNo);
852 
853   auto it = valueIDs.find(lookupValue);
854   if (it == valueIDs.end()) {
855     stream << "<<UNKNOWN SSA VALUE>>";
856     return;
857   }
858 
859   stream << '%';
860   if (it->second != NameSentinel) {
861     stream << it->second;
862   } else {
863     auto nameIt = valueNames.find(lookupValue);
864     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
865     stream << nameIt->second;
866   }
867 
868   if (resultNo.hasValue() && printResultNo)
869     stream << '#' << resultNo;
870 }
871 
872 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
873   auto it = opResultGroups.find(op);
874   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
875 }
876 
877 unsigned SSANameState::getBlockID(Block *block) {
878   auto it = blockIDs.find(block);
879   return it != blockIDs.end() ? it->second : NameSentinel;
880 }
881 
882 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
883   assert(!region.empty() && "cannot shadow arguments of an empty region");
884   assert(region.getNumArguments() == namesToUse.size() &&
885          "incorrect number of names passed in");
886   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
887          "only KnownIsolatedFromAbove ops can shadow names");
888 
889   SmallVector<char, 16> nameStr;
890   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
891     auto nameToUse = namesToUse[i];
892     if (nameToUse == nullptr)
893       continue;
894     auto nameToReplace = region.getArgument(i);
895 
896     nameStr.clear();
897     llvm::raw_svector_ostream nameStream(nameStr);
898     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
899 
900     // Entry block arguments should already have a pretty "arg" name.
901     assert(valueIDs[nameToReplace] == NameSentinel);
902 
903     // Use the name without the leading %.
904     auto name = StringRef(nameStream.str()).drop_front();
905 
906     // Overwrite the name.
907     valueNames[nameToReplace] = name.copy(usedNameAllocator);
908   }
909 }
910 
911 void SSANameState::numberValuesInRegion(
912     Region &region,
913     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
914   // Save the current value ids to allow for numbering values in sibling regions
915   // the same.
916   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
917   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
918   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
919 
920   // Push a new used names scope.
921   llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
922 
923   // Number the values within this region in a breadth-first order.
924   unsigned nextBlockID = 0;
925   for (auto &block : region) {
926     // Each block gets a unique ID, and all of the operations within it get
927     // numbered as well.
928     blockIDs[&block] = nextBlockID++;
929     numberValuesInBlock(block, interfaces);
930   }
931 
932   // After that we traverse the nested regions.
933   // TODO: Rework this loop to not use recursion.
934   for (auto &block : region) {
935     for (auto &op : block)
936       for (auto &nestedRegion : op.getRegions())
937         numberValuesInRegion(nestedRegion, interfaces);
938   }
939 }
940 
941 void SSANameState::numberValuesInBlock(
942     Block &block,
943     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
944   auto setArgNameFn = [&](Value arg, StringRef name) {
945     assert(!valueIDs.count(arg) && "arg numbered multiple times");
946     assert(arg.cast<BlockArgument>().getOwner() == &block &&
947            "arg not defined in 'block'");
948     setValueName(arg, name);
949   };
950 
951   bool isEntryBlock = block.isEntryBlock();
952   if (isEntryBlock) {
953     if (auto *op = block.getParentOp()) {
954       if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
955         asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
956     }
957   }
958 
959   // Number the block arguments. We give entry block arguments a special name
960   // 'arg'.
961   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
962   llvm::raw_svector_ostream specialName(specialNameBuffer);
963   for (auto arg : block.getArguments()) {
964     if (valueIDs.count(arg))
965       continue;
966     if (isEntryBlock) {
967       specialNameBuffer.resize(strlen("arg"));
968       specialName << nextArgumentID++;
969     }
970     setValueName(arg, specialName.str());
971   }
972 
973   // Number the operations in this block.
974   for (auto &op : block)
975     numberValuesInOp(op, interfaces);
976 }
977 
978 void SSANameState::numberValuesInOp(
979     Operation &op,
980     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
981   unsigned numResults = op.getNumResults();
982   if (numResults == 0)
983     return;
984   Value resultBegin = op.getResult(0);
985 
986   // Function used to set the special result names for the operation.
987   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
988   auto setResultNameFn = [&](Value result, StringRef name) {
989     assert(!valueIDs.count(result) && "result numbered multiple times");
990     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
991     setValueName(result, name);
992 
993     // Record the result number for groups not anchored at 0.
994     if (int resultNo = result.cast<OpResult>().getResultNumber())
995       resultGroups.push_back(resultNo);
996   };
997   if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
998     asmInterface.getAsmResultNames(setResultNameFn);
999   else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
1000     asmInterface->getAsmResultNames(&op, setResultNameFn);
1001 
1002   // If the first result wasn't numbered, give it a default number.
1003   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1004     ++nextValueID;
1005 
1006   // If this operation has multiple result groups, mark it.
1007   if (resultGroups.size() != 1) {
1008     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1009     opResultGroups.try_emplace(&op, std::move(resultGroups));
1010   }
1011 }
1012 
1013 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1014                                         Optional<int> &lookupResultNo) const {
1015   Operation *owner = result.getOwner();
1016   if (owner->getNumResults() == 1)
1017     return;
1018   int resultNo = result.getResultNumber();
1019 
1020   // If this operation has multiple result groups, we will need to find the
1021   // one corresponding to this result.
1022   auto resultGroupIt = opResultGroups.find(owner);
1023   if (resultGroupIt == opResultGroups.end()) {
1024     // If not, just use the first result.
1025     lookupResultNo = resultNo;
1026     lookupValue = owner->getResult(0);
1027     return;
1028   }
1029 
1030   // Find the correct index using a binary search, as the groups are ordered.
1031   ArrayRef<int> resultGroups = resultGroupIt->second;
1032   auto it = llvm::upper_bound(resultGroups, resultNo);
1033   int groupResultNo = 0, groupSize = 0;
1034 
1035   // If there are no smaller elements, the last result group is the lookup.
1036   if (it == resultGroups.end()) {
1037     groupResultNo = resultGroups.back();
1038     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1039   } else {
1040     // Otherwise, the previous element is the lookup.
1041     groupResultNo = *std::prev(it);
1042     groupSize = *it - groupResultNo;
1043   }
1044 
1045   // We only record the result number for a group of size greater than 1.
1046   if (groupSize != 1)
1047     lookupResultNo = resultNo - groupResultNo;
1048   lookupValue = owner->getResult(groupResultNo);
1049 }
1050 
1051 void SSANameState::setValueName(Value value, StringRef name) {
1052   // If the name is empty, the value uses the default numbering.
1053   if (name.empty()) {
1054     valueIDs[value] = nextValueID++;
1055     return;
1056   }
1057 
1058   valueIDs[value] = NameSentinel;
1059   valueNames[value] = uniqueValueName(name);
1060 }
1061 
1062 StringRef SSANameState::uniqueValueName(StringRef name) {
1063   SmallString<16> tmpBuffer;
1064   name = sanitizeIdentifier(name, tmpBuffer);
1065 
1066   // Check to see if this name is already unique.
1067   if (!usedNames.count(name)) {
1068     name = name.copy(usedNameAllocator);
1069   } else {
1070     // Otherwise, we had a conflict - probe until we find a unique name. This
1071     // is guaranteed to terminate (and usually in a single iteration) because it
1072     // generates new names by incrementing nextConflictID.
1073     SmallString<64> probeName(name);
1074     probeName.push_back('_');
1075     while (true) {
1076       probeName += llvm::utostr(nextConflictID++);
1077       if (!usedNames.count(probeName)) {
1078         name = StringRef(probeName).copy(usedNameAllocator);
1079         break;
1080       }
1081       probeName.resize(name.size() + 1);
1082     }
1083   }
1084 
1085   usedNames.insert(name, char());
1086   return name;
1087 }
1088 
1089 //===----------------------------------------------------------------------===//
1090 // AsmState
1091 //===----------------------------------------------------------------------===//
1092 
1093 namespace mlir {
1094 namespace detail {
1095 class AsmStateImpl {
1096 public:
1097   explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap)
1098       : interfaces(op->getContext()), nameState(op, interfaces),
1099         locationMap(locationMap) {}
1100 
1101   /// Initialize the alias state to enable the printing of aliases.
1102   void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) {
1103     aliasState.initialize(op, printerFlags, interfaces);
1104   }
1105 
1106   /// Get an instance of the OpAsmDialectInterface for the given dialect, or
1107   /// null if one wasn't registered.
1108   const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
1109     return interfaces.getInterfaceFor(dialect);
1110   }
1111 
1112   /// Get the state used for aliases.
1113   AliasState &getAliasState() { return aliasState; }
1114 
1115   /// Get the state used for SSA names.
1116   SSANameState &getSSANameState() { return nameState; }
1117 
1118   /// Register the location, line and column, within the buffer that the given
1119   /// operation was printed at.
1120   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1121     if (locationMap)
1122       (*locationMap)[op] = std::make_pair(line, col);
1123   }
1124 
1125 private:
1126   /// Collection of OpAsm interfaces implemented in the context.
1127   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1128 
1129   /// The state used for attribute and type aliases.
1130   AliasState aliasState;
1131 
1132   /// The state used for SSA value names.
1133   SSANameState nameState;
1134 
1135   /// An optional location map to be populated.
1136   AsmState::LocationMap *locationMap;
1137 };
1138 } // end namespace detail
1139 } // end namespace mlir
1140 
1141 AsmState::AsmState(Operation *op, LocationMap *locationMap)
1142     : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {}
1143 AsmState::~AsmState() {}
1144 
1145 //===----------------------------------------------------------------------===//
1146 // ModulePrinter
1147 //===----------------------------------------------------------------------===//
1148 
1149 namespace {
1150 class ModulePrinter {
1151 public:
1152   ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1153                 AsmStateImpl *state = nullptr)
1154       : os(os), printerFlags(flags), state(state) {}
1155   explicit ModulePrinter(ModulePrinter &printer)
1156       : os(printer.os), printerFlags(printer.printerFlags),
1157         state(printer.state) {}
1158 
1159   /// Returns the output stream of the printer.
1160   raw_ostream &getStream() { return os; }
1161 
1162   template <typename Container, typename UnaryFunctor>
1163   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
1164     llvm::interleaveComma(c, os, each_fn);
1165   }
1166 
1167   /// This enum describes the different kinds of elision for the type of an
1168   /// attribute when printing it.
1169   enum class AttrTypeElision {
1170     /// The type must not be elided,
1171     Never,
1172     /// The type may be elided when it matches the default used in the parser
1173     /// (for example i64 is the default for integer attributes).
1174     May,
1175     /// The type must be elided.
1176     Must
1177   };
1178 
1179   /// Print the given attribute.
1180   void printAttribute(Attribute attr,
1181                       AttrTypeElision typeElision = AttrTypeElision::Never);
1182 
1183   void printType(Type type);
1184 
1185   /// Print the given location to the stream. If `allowAlias` is true, this
1186   /// allows for the internal location to use an attribute alias.
1187   void printLocation(LocationAttr loc, bool allowAlias = false);
1188 
1189   void printAffineMap(AffineMap map);
1190   void
1191   printAffineExpr(AffineExpr expr,
1192                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1193   void printAffineConstraint(AffineExpr expr, bool isEq);
1194   void printIntegerSet(IntegerSet set);
1195 
1196 protected:
1197   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1198                              ArrayRef<StringRef> elidedAttrs = {},
1199                              bool withKeyword = false);
1200   void printNamedAttribute(NamedAttribute attr);
1201   void printTrailingLocation(Location loc);
1202   void printLocationInternal(LocationAttr loc, bool pretty = false);
1203 
1204   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1205   /// used instead of individual elements when the elements attr is large.
1206   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1207 
1208   /// Print a dense string elements attribute.
1209   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1210 
1211   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1212   /// used instead of individual elements when the elements attr is large.
1213   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1214                                      bool allowHex);
1215 
1216   void printDialectAttribute(Attribute attr);
1217   void printDialectType(Type type);
1218 
1219   /// This enum is used to represent the binding strength of the enclosing
1220   /// context that an AffineExprStorage is being printed in, so we can
1221   /// intelligently produce parens.
1222   enum class BindingStrength {
1223     Weak,   // + and -
1224     Strong, // All other binary operators.
1225   };
1226   void printAffineExprInternal(
1227       AffineExpr expr, BindingStrength enclosingTightness,
1228       function_ref<void(unsigned, bool)> printValueName = nullptr);
1229 
1230   /// The output stream for the printer.
1231   raw_ostream &os;
1232 
1233   /// A set of flags to control the printer's behavior.
1234   OpPrintingFlags printerFlags;
1235 
1236   /// An optional printer state for the module.
1237   AsmStateImpl *state;
1238 
1239   /// A tracker for the number of new lines emitted during printing.
1240   NewLineCounter newLine;
1241 };
1242 } // end anonymous namespace
1243 
1244 void ModulePrinter::printTrailingLocation(Location loc) {
1245   // Check to see if we are printing debug information.
1246   if (!printerFlags.shouldPrintDebugInfo())
1247     return;
1248 
1249   os << " ";
1250   printLocation(loc, /*allowAlias=*/true);
1251 }
1252 
1253 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
1254   TypeSwitch<LocationAttr>(loc)
1255       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1256         printLocationInternal(loc.getFallbackLocation(), pretty);
1257       })
1258       .Case<UnknownLoc>([&](UnknownLoc loc) {
1259         if (pretty)
1260           os << "[unknown]";
1261         else
1262           os << "unknown";
1263       })
1264       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1265         if (pretty) {
1266           os << loc.getFilename();
1267         } else {
1268           os << "\"";
1269           printEscapedString(loc.getFilename(), os);
1270           os << "\"";
1271         }
1272         os << ':' << loc.getLine() << ':' << loc.getColumn();
1273       })
1274       .Case<NameLoc>([&](NameLoc loc) {
1275         os << '\"';
1276         printEscapedString(loc.getName(), os);
1277         os << '\"';
1278 
1279         // Print the child if it isn't unknown.
1280         auto childLoc = loc.getChildLoc();
1281         if (!childLoc.isa<UnknownLoc>()) {
1282           os << '(';
1283           printLocationInternal(childLoc, pretty);
1284           os << ')';
1285         }
1286       })
1287       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1288         Location caller = loc.getCaller();
1289         Location callee = loc.getCallee();
1290         if (!pretty)
1291           os << "callsite(";
1292         printLocationInternal(callee, pretty);
1293         if (pretty) {
1294           if (callee.isa<NameLoc>()) {
1295             if (caller.isa<FileLineColLoc>()) {
1296               os << " at ";
1297             } else {
1298               os << newLine << " at ";
1299             }
1300           } else {
1301             os << newLine << " at ";
1302           }
1303         } else {
1304           os << " at ";
1305         }
1306         printLocationInternal(caller, pretty);
1307         if (!pretty)
1308           os << ")";
1309       })
1310       .Case<FusedLoc>([&](FusedLoc loc) {
1311         if (!pretty)
1312           os << "fused";
1313         if (Attribute metadata = loc.getMetadata())
1314           os << '<' << metadata << '>';
1315         os << '[';
1316         interleave(
1317             loc.getLocations(),
1318             [&](Location loc) { printLocationInternal(loc, pretty); },
1319             [&]() { os << ", "; });
1320         os << ']';
1321       });
1322 }
1323 
1324 /// Print a floating point value in a way that the parser will be able to
1325 /// round-trip losslessly.
1326 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1327   // We would like to output the FP constant value in exponential notation,
1328   // but we cannot do this if doing so will lose precision.  Check here to
1329   // make sure that we only output it in exponential format if we can parse
1330   // the value back and get the same value.
1331   bool isInf = apValue.isInfinity();
1332   bool isNaN = apValue.isNaN();
1333   if (!isInf && !isNaN) {
1334     SmallString<128> strValue;
1335     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1336                      /*TruncateZero=*/false);
1337 
1338     // Check to make sure that the stringized number is not some string like
1339     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1340     // that the string matches the "[-+]?[0-9]" regex.
1341     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1342             ((strValue[0] == '-' || strValue[0] == '+') &&
1343              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1344            "[-+]?[0-9] regex does not match!");
1345 
1346     // Parse back the stringized version and check that the value is equal
1347     // (i.e., there is no precision loss).
1348     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1349       os << strValue;
1350       return;
1351     }
1352 
1353     // If it is not, use the default format of APFloat instead of the
1354     // exponential notation.
1355     strValue.clear();
1356     apValue.toString(strValue);
1357 
1358     // Make sure that we can parse the default form as a float.
1359     if (StringRef(strValue).contains('.')) {
1360       os << strValue;
1361       return;
1362     }
1363   }
1364 
1365   // Print special values in hexadecimal format. The sign bit should be included
1366   // in the literal.
1367   SmallVector<char, 16> str;
1368   APInt apInt = apValue.bitcastToAPInt();
1369   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1370                  /*formatAsCLiteral=*/true);
1371   os << str;
1372 }
1373 
1374 void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) {
1375   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1376     return printLocationInternal(loc, /*pretty=*/true);
1377 
1378   os << "loc(";
1379   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1380     printLocationInternal(loc);
1381   os << ')';
1382 }
1383 
1384 /// Returns true if the given dialect symbol data is simple enough to print in
1385 /// the pretty form, i.e. without the enclosing "".
1386 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1387   // The name must start with an identifier.
1388   if (symName.empty() || !isalpha(symName.front()))
1389     return false;
1390 
1391   // Ignore all the characters that are valid in an identifier in the symbol
1392   // name.
1393   symName = symName.drop_while(
1394       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1395   if (symName.empty())
1396     return true;
1397 
1398   // If we got to an unexpected character, then it must be a <>.  Check those
1399   // recursively.
1400   if (symName.front() != '<' || symName.back() != '>')
1401     return false;
1402 
1403   SmallVector<char, 8> nestedPunctuation;
1404   do {
1405     // If we ran out of characters, then we had a punctuation mismatch.
1406     if (symName.empty())
1407       return false;
1408 
1409     auto c = symName.front();
1410     symName = symName.drop_front();
1411 
1412     switch (c) {
1413     // We never allow null characters. This is an EOF indicator for the lexer
1414     // which we could handle, but isn't important for any known dialect.
1415     case '\0':
1416       return false;
1417     case '<':
1418     case '[':
1419     case '(':
1420     case '{':
1421       nestedPunctuation.push_back(c);
1422       continue;
1423     case '-':
1424       // Treat `->` as a special token.
1425       if (!symName.empty() && symName.front() == '>') {
1426         symName = symName.drop_front();
1427         continue;
1428       }
1429       break;
1430     // Reject types with mismatched brackets.
1431     case '>':
1432       if (nestedPunctuation.pop_back_val() != '<')
1433         return false;
1434       break;
1435     case ']':
1436       if (nestedPunctuation.pop_back_val() != '[')
1437         return false;
1438       break;
1439     case ')':
1440       if (nestedPunctuation.pop_back_val() != '(')
1441         return false;
1442       break;
1443     case '}':
1444       if (nestedPunctuation.pop_back_val() != '{')
1445         return false;
1446       break;
1447     default:
1448       continue;
1449     }
1450 
1451     // We're done when the punctuation is fully matched.
1452   } while (!nestedPunctuation.empty());
1453 
1454   // If there were extra characters, then we failed.
1455   return symName.empty();
1456 }
1457 
1458 /// Print the given dialect symbol to the stream.
1459 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1460                                StringRef dialectName, StringRef symString) {
1461   os << symPrefix << dialectName;
1462 
1463   // If this symbol name is simple enough, print it directly in pretty form,
1464   // otherwise, we print it as an escaped string.
1465   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1466     os << '.' << symString;
1467     return;
1468   }
1469 
1470   // TODO: escape the symbol name, it could contain " characters.
1471   os << "<\"" << symString << "\">";
1472 }
1473 
1474 /// Returns true if the given string can be represented as a bare identifier.
1475 static bool isBareIdentifier(StringRef name) {
1476   assert(!name.empty() && "invalid name");
1477 
1478   // By making this unsigned, the value passed in to isalnum will always be
1479   // in the range 0-255. This is important when building with MSVC because
1480   // its implementation will assert. This situation can arise when dealing
1481   // with UTF-8 multibyte characters.
1482   unsigned char firstChar = static_cast<unsigned char>(name[0]);
1483   if (!isalpha(firstChar) && firstChar != '_')
1484     return false;
1485   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1486     return isalnum(c) || c == '_' || c == '$' || c == '.';
1487   });
1488 }
1489 
1490 /// Print the given string as a symbol reference. A symbol reference is
1491 /// represented as a string prefixed with '@'. The reference is surrounded with
1492 /// ""'s and escaped if it has any special or non-printable characters in it.
1493 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1494   assert(!symbolRef.empty() && "expected valid symbol reference");
1495 
1496   // If the symbol can be represented as a bare identifier, write it directly.
1497   if (isBareIdentifier(symbolRef)) {
1498     os << '@' << symbolRef;
1499     return;
1500   }
1501 
1502   // Otherwise, output the reference wrapped in quotes with proper escaping.
1503   os << "@\"";
1504   printEscapedString(symbolRef, os);
1505   os << '"';
1506 }
1507 
1508 // Print out a valid ElementsAttr that is succinct and can represent any
1509 // potential shape/type, for use when eliding a large ElementsAttr.
1510 //
1511 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1512 // hopefully alert readers to the fact that this has been elided.
1513 //
1514 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1515 // accept the string "elided". The first string must be a registered dialect
1516 // name and the latter must be a hex constant.
1517 static void printElidedElementsAttr(raw_ostream &os) {
1518   os << R"(opaque<"", "0xDEADBEEF">)";
1519 }
1520 
1521 void ModulePrinter::printAttribute(Attribute attr,
1522                                    AttrTypeElision typeElision) {
1523   if (!attr) {
1524     os << "<<NULL ATTRIBUTE>>";
1525     return;
1526   }
1527 
1528   // Try to print an alias for this attribute.
1529   if (state && succeeded(state->getAliasState().getAlias(attr, os)))
1530     return;
1531 
1532   auto attrType = attr.getType();
1533   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1534     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1535                        opaqueAttr.getAttrData());
1536   } else if (attr.isa<UnitAttr>()) {
1537     os << "unit";
1538     return;
1539   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1540     os << '{';
1541     interleaveComma(dictAttr.getValue(),
1542                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1543     os << '}';
1544 
1545   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1546     if (attrType.isSignlessInteger(1)) {
1547       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1548 
1549       // Boolean integer attributes always elides the type.
1550       return;
1551     }
1552 
1553     // Only print attributes as unsigned if they are explicitly unsigned or are
1554     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1555     // values print as signed.
1556     bool isUnsigned =
1557         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1558     intAttr.getValue().print(os, !isUnsigned);
1559 
1560     // IntegerAttr elides the type if I64.
1561     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1562       return;
1563 
1564   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1565     printFloatValue(floatAttr.getValue(), os);
1566 
1567     // FloatAttr elides the type if F64.
1568     if (typeElision == AttrTypeElision::May && attrType.isF64())
1569       return;
1570 
1571   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1572     os << '"';
1573     printEscapedString(strAttr.getValue(), os);
1574     os << '"';
1575 
1576   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1577     os << '[';
1578     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1579       printAttribute(attr, AttrTypeElision::May);
1580     });
1581     os << ']';
1582 
1583   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1584     os << "affine_map<";
1585     affineMapAttr.getValue().print(os);
1586     os << '>';
1587 
1588     // AffineMap always elides the type.
1589     return;
1590 
1591   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1592     os << "affine_set<";
1593     integerSetAttr.getValue().print(os);
1594     os << '>';
1595 
1596     // IntegerSet always elides the type.
1597     return;
1598 
1599   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1600     printType(typeAttr.getValue());
1601 
1602   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1603     printSymbolReference(refAttr.getRootReference(), os);
1604     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1605       os << "::";
1606       printSymbolReference(nestedRef.getValue(), os);
1607     }
1608 
1609   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1610     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1611       printElidedElementsAttr(os);
1612     } else {
1613       os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
1614       os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
1615     }
1616 
1617   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1618     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1619       printElidedElementsAttr(os);
1620     } else {
1621       os << "dense<";
1622       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1623       os << '>';
1624     }
1625 
1626   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1627     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1628       printElidedElementsAttr(os);
1629     } else {
1630       os << "dense<";
1631       printDenseStringElementsAttr(strEltAttr);
1632       os << '>';
1633     }
1634 
1635   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1636     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1637         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1638       printElidedElementsAttr(os);
1639     } else {
1640       os << "sparse<";
1641       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1642       if (indices.getNumElements() != 0) {
1643         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1644         os << ", ";
1645         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1646       }
1647       os << '>';
1648     }
1649 
1650   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1651     printLocation(locAttr);
1652 
1653   } else {
1654     return printDialectAttribute(attr);
1655   }
1656 
1657   // Don't print the type if we must elide it, or if it is a None type.
1658   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1659     os << " : ";
1660     printType(attrType);
1661   }
1662 }
1663 
1664 /// Print the integer element of a DenseElementsAttr.
1665 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1666                                  bool isSigned) {
1667   if (value.getBitWidth() == 1)
1668     os << (value.getBoolValue() ? "true" : "false");
1669   else
1670     value.print(os, isSigned);
1671 }
1672 
1673 static void
1674 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1675                            function_ref<void(unsigned)> printEltFn) {
1676   // Special case for 0-d and splat tensors.
1677   if (isSplat)
1678     return printEltFn(0);
1679 
1680   // Special case for degenerate tensors.
1681   auto numElements = type.getNumElements();
1682   if (numElements == 0)
1683     return;
1684 
1685   // We use a mixed-radix counter to iterate through the shape. When we bump a
1686   // non-least-significant digit, we emit a close bracket. When we next emit an
1687   // element we re-open all closed brackets.
1688 
1689   // The mixed-radix counter, with radices in 'shape'.
1690   int64_t rank = type.getRank();
1691   SmallVector<unsigned, 4> counter(rank, 0);
1692   // The number of brackets that have been opened and not closed.
1693   unsigned openBrackets = 0;
1694 
1695   auto shape = type.getShape();
1696   auto bumpCounter = [&] {
1697     // Bump the least significant digit.
1698     ++counter[rank - 1];
1699     // Iterate backwards bubbling back the increment.
1700     for (unsigned i = rank - 1; i > 0; --i)
1701       if (counter[i] >= shape[i]) {
1702         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1703         counter[i] = 0;
1704         ++counter[i - 1];
1705         --openBrackets;
1706         os << ']';
1707       }
1708   };
1709 
1710   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1711     if (idx != 0)
1712       os << ", ";
1713     while (openBrackets++ < rank)
1714       os << '[';
1715     openBrackets = rank;
1716     printEltFn(idx);
1717     bumpCounter();
1718   }
1719   while (openBrackets-- > 0)
1720     os << ']';
1721 }
1722 
1723 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
1724                                            bool allowHex) {
1725   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1726     return printDenseStringElementsAttr(stringAttr);
1727 
1728   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1729                                 allowHex);
1730 }
1731 
1732 void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1733                                                   bool allowHex) {
1734   auto type = attr.getType();
1735   auto elementType = type.getElementType();
1736 
1737   // Check to see if we should format this attribute as a hex string.
1738   auto numElements = type.getNumElements();
1739   if (!attr.isSplat() && allowHex &&
1740       shouldPrintElementsAttrWithHex(numElements)) {
1741     ArrayRef<char> rawData = attr.getRawData();
1742     if (llvm::support::endian::system_endianness() ==
1743         llvm::support::endianness::big) {
1744       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1745       // machines. It is converted here to print in LE format.
1746       SmallVector<char, 64> outDataVec(rawData.size());
1747       MutableArrayRef<char> convRawData(outDataVec);
1748       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1749           rawData, convRawData, type);
1750       os << '"' << "0x"
1751          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1752          << "\"";
1753     } else {
1754       os << '"' << "0x"
1755          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1756     }
1757 
1758     return;
1759   }
1760 
1761   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1762     Type complexElementType = complexTy.getElementType();
1763     // Note: The if and else below had a common lambda function which invoked
1764     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1765     // and hence was replaced.
1766     if (complexElementType.isa<IntegerType>()) {
1767       bool isSigned = !complexElementType.isUnsignedInteger();
1768       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1769         auto complexValue = *(attr.getComplexIntValues().begin() + index);
1770         os << "(";
1771         printDenseIntElement(complexValue.real(), os, isSigned);
1772         os << ",";
1773         printDenseIntElement(complexValue.imag(), os, isSigned);
1774         os << ")";
1775       });
1776     } else {
1777       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1778         auto complexValue = *(attr.getComplexFloatValues().begin() + index);
1779         os << "(";
1780         printFloatValue(complexValue.real(), os);
1781         os << ",";
1782         printFloatValue(complexValue.imag(), os);
1783         os << ")";
1784       });
1785     }
1786   } else if (elementType.isIntOrIndex()) {
1787     bool isSigned = !elementType.isUnsignedInteger();
1788     auto intValues = attr.getIntValues();
1789     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1790       printDenseIntElement(*(intValues.begin() + index), os, isSigned);
1791     });
1792   } else {
1793     assert(elementType.isa<FloatType>() && "unexpected element type");
1794     auto floatValues = attr.getFloatValues();
1795     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1796       printFloatValue(*(floatValues.begin() + index), os);
1797     });
1798   }
1799 }
1800 
1801 void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
1802   ArrayRef<StringRef> data = attr.getRawStringData();
1803   auto printFn = [&](unsigned index) {
1804     os << "\"";
1805     printEscapedString(data[index], os);
1806     os << "\"";
1807   };
1808   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1809 }
1810 
1811 void ModulePrinter::printType(Type type) {
1812   if (!type) {
1813     os << "<<NULL TYPE>>";
1814     return;
1815   }
1816 
1817   // Try to print an alias for this type.
1818   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1819     return;
1820 
1821   TypeSwitch<Type>(type)
1822       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1823         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1824                            opaqueTy.getTypeData());
1825       })
1826       .Case<IndexType>([&](Type) { os << "index"; })
1827       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1828       .Case<Float16Type>([&](Type) { os << "f16"; })
1829       .Case<Float32Type>([&](Type) { os << "f32"; })
1830       .Case<Float64Type>([&](Type) { os << "f64"; })
1831       .Case<Float80Type>([&](Type) { os << "f80"; })
1832       .Case<Float128Type>([&](Type) { os << "f128"; })
1833       .Case<IntegerType>([&](IntegerType integerTy) {
1834         if (integerTy.isSigned())
1835           os << 's';
1836         else if (integerTy.isUnsigned())
1837           os << 'u';
1838         os << 'i' << integerTy.getWidth();
1839       })
1840       .Case<FunctionType>([&](FunctionType funcTy) {
1841         os << '(';
1842         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1843         os << ") -> ";
1844         ArrayRef<Type> results = funcTy.getResults();
1845         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1846           os << results[0];
1847         } else {
1848           os << '(';
1849           interleaveComma(results, [&](Type ty) { printType(ty); });
1850           os << ')';
1851         }
1852       })
1853       .Case<VectorType>([&](VectorType vectorTy) {
1854         os << "vector<";
1855         for (int64_t dim : vectorTy.getShape())
1856           os << dim << 'x';
1857         os << vectorTy.getElementType() << '>';
1858       })
1859       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
1860         os << "tensor<";
1861         for (int64_t dim : tensorTy.getShape()) {
1862           if (ShapedType::isDynamic(dim))
1863             os << '?';
1864           else
1865             os << dim;
1866           os << 'x';
1867         }
1868         os << tensorTy.getElementType() << '>';
1869       })
1870       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
1871         os << "tensor<*x";
1872         printType(tensorTy.getElementType());
1873         os << '>';
1874       })
1875       .Case<MemRefType>([&](MemRefType memrefTy) {
1876         os << "memref<";
1877         for (int64_t dim : memrefTy.getShape()) {
1878           if (ShapedType::isDynamic(dim))
1879             os << '?';
1880           else
1881             os << dim;
1882           os << 'x';
1883         }
1884         printType(memrefTy.getElementType());
1885         for (auto map : memrefTy.getAffineMaps()) {
1886           os << ", ";
1887           printAttribute(AffineMapAttr::get(map));
1888         }
1889         // Only print the memory space if it is the non-default one.
1890         if (memrefTy.getMemorySpaceAsInt())
1891           os << ", " << memrefTy.getMemorySpaceAsInt();
1892         os << '>';
1893       })
1894       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
1895         os << "memref<*x";
1896         printType(memrefTy.getElementType());
1897         // Only print the memory space if it is the non-default one.
1898         if (memrefTy.getMemorySpaceAsInt())
1899           os << ", " << memrefTy.getMemorySpaceAsInt();
1900         os << '>';
1901       })
1902       .Case<ComplexType>([&](ComplexType complexTy) {
1903         os << "complex<";
1904         printType(complexTy.getElementType());
1905         os << '>';
1906       })
1907       .Case<TupleType>([&](TupleType tupleTy) {
1908         os << "tuple<";
1909         interleaveComma(tupleTy.getTypes(),
1910                         [&](Type type) { printType(type); });
1911         os << '>';
1912       })
1913       .Case<NoneType>([&](Type) { os << "none"; })
1914       .Default([&](Type type) { return printDialectType(type); });
1915 }
1916 
1917 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1918                                           ArrayRef<StringRef> elidedAttrs,
1919                                           bool withKeyword) {
1920   // If there are no attributes, then there is nothing to be done.
1921   if (attrs.empty())
1922     return;
1923 
1924   // Functor used to print a filtered attribute list.
1925   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
1926     // Print the 'attributes' keyword if necessary.
1927     if (withKeyword)
1928       os << " attributes";
1929 
1930     // Otherwise, print them all out in braces.
1931     os << " {";
1932     interleaveComma(filteredAttrs,
1933                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1934     os << '}';
1935   };
1936 
1937   // If no attributes are elided, we can directly print with no filtering.
1938   if (elidedAttrs.empty())
1939     return printFilteredAttributesFn(attrs);
1940 
1941   // Otherwise, filter out any attributes that shouldn't be included.
1942   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
1943                                                 elidedAttrs.end());
1944   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
1945     return !elidedAttrsSet.contains(attr.first.strref());
1946   });
1947   if (!filteredAttrs.empty())
1948     printFilteredAttributesFn(filteredAttrs);
1949 }
1950 
1951 void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
1952   if (isBareIdentifier(attr.first)) {
1953     os << attr.first;
1954   } else {
1955     os << '"';
1956     printEscapedString(attr.first.strref(), os);
1957     os << '"';
1958   }
1959 
1960   // Pretty printing elides the attribute value for unit attributes.
1961   if (attr.second.isa<UnitAttr>())
1962     return;
1963 
1964   os << " = ";
1965   printAttribute(attr.second);
1966 }
1967 
1968 //===----------------------------------------------------------------------===//
1969 // CustomDialectAsmPrinter
1970 //===----------------------------------------------------------------------===//
1971 
1972 namespace {
1973 /// This class provides the main specialization of the DialectAsmPrinter that is
1974 /// used to provide support for print attributes and types. This hooks allows
1975 /// for dialects to hook into the main ModulePrinter.
1976 struct CustomDialectAsmPrinter : public DialectAsmPrinter {
1977 public:
1978   CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
1979   ~CustomDialectAsmPrinter() override {}
1980 
1981   raw_ostream &getStream() const override { return printer.getStream(); }
1982 
1983   /// Print the given attribute to the stream.
1984   void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
1985 
1986   /// Print the given floating point value in a stablized form.
1987   void printFloat(const APFloat &value) override {
1988     printFloatValue(value, getStream());
1989   }
1990 
1991   /// Print the given type to the stream.
1992   void printType(Type type) override { printer.printType(type); }
1993 
1994   /// The main module printer.
1995   ModulePrinter &printer;
1996 };
1997 } // end anonymous namespace
1998 
1999 void ModulePrinter::printDialectAttribute(Attribute attr) {
2000   auto &dialect = attr.getDialect();
2001 
2002   // Ask the dialect to serialize the attribute to a string.
2003   std::string attrName;
2004   {
2005     llvm::raw_string_ostream attrNameStr(attrName);
2006     ModulePrinter subPrinter(attrNameStr, printerFlags, state);
2007     CustomDialectAsmPrinter printer(subPrinter);
2008     dialect.printAttribute(attr, printer);
2009   }
2010   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2011 }
2012 
2013 void ModulePrinter::printDialectType(Type type) {
2014   auto &dialect = type.getDialect();
2015 
2016   // Ask the dialect to serialize the type to a string.
2017   std::string typeName;
2018   {
2019     llvm::raw_string_ostream typeNameStr(typeName);
2020     ModulePrinter subPrinter(typeNameStr, printerFlags, state);
2021     CustomDialectAsmPrinter printer(subPrinter);
2022     dialect.printType(type, printer);
2023   }
2024   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2025 }
2026 
2027 //===----------------------------------------------------------------------===//
2028 // Affine expressions and maps
2029 //===----------------------------------------------------------------------===//
2030 
2031 void ModulePrinter::printAffineExpr(
2032     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2033   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2034 }
2035 
2036 void ModulePrinter::printAffineExprInternal(
2037     AffineExpr expr, BindingStrength enclosingTightness,
2038     function_ref<void(unsigned, bool)> printValueName) {
2039   const char *binopSpelling = nullptr;
2040   switch (expr.getKind()) {
2041   case AffineExprKind::SymbolId: {
2042     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2043     if (printValueName)
2044       printValueName(pos, /*isSymbol=*/true);
2045     else
2046       os << 's' << pos;
2047     return;
2048   }
2049   case AffineExprKind::DimId: {
2050     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2051     if (printValueName)
2052       printValueName(pos, /*isSymbol=*/false);
2053     else
2054       os << 'd' << pos;
2055     return;
2056   }
2057   case AffineExprKind::Constant:
2058     os << expr.cast<AffineConstantExpr>().getValue();
2059     return;
2060   case AffineExprKind::Add:
2061     binopSpelling = " + ";
2062     break;
2063   case AffineExprKind::Mul:
2064     binopSpelling = " * ";
2065     break;
2066   case AffineExprKind::FloorDiv:
2067     binopSpelling = " floordiv ";
2068     break;
2069   case AffineExprKind::CeilDiv:
2070     binopSpelling = " ceildiv ";
2071     break;
2072   case AffineExprKind::Mod:
2073     binopSpelling = " mod ";
2074     break;
2075   }
2076 
2077   auto binOp = expr.cast<AffineBinaryOpExpr>();
2078   AffineExpr lhsExpr = binOp.getLHS();
2079   AffineExpr rhsExpr = binOp.getRHS();
2080 
2081   // Handle tightly binding binary operators.
2082   if (binOp.getKind() != AffineExprKind::Add) {
2083     if (enclosingTightness == BindingStrength::Strong)
2084       os << '(';
2085 
2086     // Pretty print multiplication with -1.
2087     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2088     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2089         rhsConst.getValue() == -1) {
2090       os << "-";
2091       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2092       if (enclosingTightness == BindingStrength::Strong)
2093         os << ')';
2094       return;
2095     }
2096 
2097     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2098 
2099     os << binopSpelling;
2100     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2101 
2102     if (enclosingTightness == BindingStrength::Strong)
2103       os << ')';
2104     return;
2105   }
2106 
2107   // Print out special "pretty" forms for add.
2108   if (enclosingTightness == BindingStrength::Strong)
2109     os << '(';
2110 
2111   // Pretty print addition to a product that has a negative operand as a
2112   // subtraction.
2113   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2114     if (rhs.getKind() == AffineExprKind::Mul) {
2115       AffineExpr rrhsExpr = rhs.getRHS();
2116       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2117         if (rrhs.getValue() == -1) {
2118           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2119                                   printValueName);
2120           os << " - ";
2121           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2122             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2123                                     printValueName);
2124           } else {
2125             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2126                                     printValueName);
2127           }
2128 
2129           if (enclosingTightness == BindingStrength::Strong)
2130             os << ')';
2131           return;
2132         }
2133 
2134         if (rrhs.getValue() < -1) {
2135           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2136                                   printValueName);
2137           os << " - ";
2138           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2139                                   printValueName);
2140           os << " * " << -rrhs.getValue();
2141           if (enclosingTightness == BindingStrength::Strong)
2142             os << ')';
2143           return;
2144         }
2145       }
2146     }
2147   }
2148 
2149   // Pretty print addition to a negative number as a subtraction.
2150   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2151     if (rhsConst.getValue() < 0) {
2152       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2153       os << " - " << -rhsConst.getValue();
2154       if (enclosingTightness == BindingStrength::Strong)
2155         os << ')';
2156       return;
2157     }
2158   }
2159 
2160   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2161 
2162   os << " + ";
2163   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2164 
2165   if (enclosingTightness == BindingStrength::Strong)
2166     os << ')';
2167 }
2168 
2169 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
2170   printAffineExprInternal(expr, BindingStrength::Weak);
2171   isEq ? os << " == 0" : os << " >= 0";
2172 }
2173 
2174 void ModulePrinter::printAffineMap(AffineMap map) {
2175   // Dimension identifiers.
2176   os << '(';
2177   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2178     os << 'd' << i << ", ";
2179   if (map.getNumDims() >= 1)
2180     os << 'd' << map.getNumDims() - 1;
2181   os << ')';
2182 
2183   // Symbolic identifiers.
2184   if (map.getNumSymbols() != 0) {
2185     os << '[';
2186     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2187       os << 's' << i << ", ";
2188     if (map.getNumSymbols() >= 1)
2189       os << 's' << map.getNumSymbols() - 1;
2190     os << ']';
2191   }
2192 
2193   // Result affine expressions.
2194   os << " -> (";
2195   interleaveComma(map.getResults(),
2196                   [&](AffineExpr expr) { printAffineExpr(expr); });
2197   os << ')';
2198 }
2199 
2200 void ModulePrinter::printIntegerSet(IntegerSet set) {
2201   // Dimension identifiers.
2202   os << '(';
2203   for (unsigned i = 1; i < set.getNumDims(); ++i)
2204     os << 'd' << i - 1 << ", ";
2205   if (set.getNumDims() >= 1)
2206     os << 'd' << set.getNumDims() - 1;
2207   os << ')';
2208 
2209   // Symbolic identifiers.
2210   if (set.getNumSymbols() != 0) {
2211     os << '[';
2212     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2213       os << 's' << i << ", ";
2214     if (set.getNumSymbols() >= 1)
2215       os << 's' << set.getNumSymbols() - 1;
2216     os << ']';
2217   }
2218 
2219   // Print constraints.
2220   os << " : (";
2221   int numConstraints = set.getNumConstraints();
2222   for (int i = 1; i < numConstraints; ++i) {
2223     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2224     os << ", ";
2225   }
2226   if (numConstraints >= 1)
2227     printAffineConstraint(set.getConstraint(numConstraints - 1),
2228                           set.isEq(numConstraints - 1));
2229   os << ')';
2230 }
2231 
2232 //===----------------------------------------------------------------------===//
2233 // OperationPrinter
2234 //===----------------------------------------------------------------------===//
2235 
2236 namespace {
2237 /// This class contains the logic for printing operations, regions, and blocks.
2238 class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
2239 public:
2240   explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2241                             AsmStateImpl &state)
2242       : ModulePrinter(os, flags, &state) {}
2243 
2244   /// Print the given top-level operation.
2245   void printTopLevelOperation(Operation *op);
2246 
2247   /// Print the given operation with its indent and location.
2248   void print(Operation *op);
2249   /// Print the bare location, not including indentation/location/etc.
2250   void printOperation(Operation *op);
2251   /// Print the given operation in the generic form.
2252   void printGenericOp(Operation *op) override;
2253 
2254   /// Print the name of the given block.
2255   void printBlockName(Block *block);
2256 
2257   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2258   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2259   /// operation of the block is not printed.
2260   void print(Block *block, bool printBlockArgs = true,
2261              bool printBlockTerminator = true);
2262 
2263   /// Print the ID of the given value, optionally with its result number.
2264   void printValueID(Value value, bool printResultNo = true,
2265                     raw_ostream *streamOverride = nullptr) const;
2266 
2267   //===--------------------------------------------------------------------===//
2268   // OpAsmPrinter methods
2269   //===--------------------------------------------------------------------===//
2270 
2271   /// Return the current stream of the printer.
2272   raw_ostream &getStream() const override { return os; }
2273 
2274   /// Print a newline and indent the printer to the start of the current
2275   /// operation.
2276   void printNewline() override {
2277     os << newLine;
2278     os.indent(currentIndent);
2279   }
2280 
2281   /// Print the given type.
2282   void printType(Type type) override { ModulePrinter::printType(type); }
2283 
2284   /// Print the given attribute.
2285   void printAttribute(Attribute attr) override {
2286     ModulePrinter::printAttribute(attr);
2287   }
2288 
2289   /// Print the given attribute without its type. The corresponding parser must
2290   /// provide a valid type for the attribute.
2291   void printAttributeWithoutType(Attribute attr) override {
2292     ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
2293   }
2294 
2295   /// Print the ID for the given value.
2296   void printOperand(Value value) override { printValueID(value); }
2297   void printOperand(Value value, raw_ostream &os) override {
2298     printValueID(value, /*printResultNo=*/true, &os);
2299   }
2300 
2301   /// Print an optional attribute dictionary with a given set of elided values.
2302   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2303                              ArrayRef<StringRef> elidedAttrs = {}) override {
2304     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
2305   }
2306   void printOptionalAttrDictWithKeyword(
2307       ArrayRef<NamedAttribute> attrs,
2308       ArrayRef<StringRef> elidedAttrs = {}) override {
2309     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
2310                                          /*withKeyword=*/true);
2311   }
2312 
2313   /// Print the given successor.
2314   void printSuccessor(Block *successor) override;
2315 
2316   /// Print an operation successor with the operands used for the block
2317   /// arguments.
2318   void printSuccessorAndUseList(Block *successor,
2319                                 ValueRange succOperands) override;
2320 
2321   /// Print the given region.
2322   void printRegion(Region &region, bool printEntryBlockArgs,
2323                    bool printBlockTerminators) override;
2324 
2325   /// Renumber the arguments for the specified region to the same names as the
2326   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2327   /// operations. If any entry in namesToUse is null, the corresponding
2328   /// argument name is left alone.
2329   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2330     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2331   }
2332 
2333   /// Print the given affine map with the symbol and dimension operands printed
2334   /// inline with the map.
2335   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2336                               ValueRange operands) override;
2337 
2338   /// Print the given string as a symbol reference.
2339   void printSymbolName(StringRef symbolRef) override {
2340     ::printSymbolReference(symbolRef, os);
2341   }
2342 
2343 private:
2344   /// The number of spaces used for indenting nested operations.
2345   const static unsigned indentWidth = 2;
2346 
2347   // This is the current indentation level for nested structures.
2348   unsigned currentIndent = 0;
2349 };
2350 } // end anonymous namespace
2351 
2352 void OperationPrinter::printTopLevelOperation(Operation *op) {
2353   // Output the aliases at the top level that can't be deferred.
2354   state->getAliasState().printNonDeferredAliases(os, newLine);
2355 
2356   // Print the module.
2357   print(op);
2358   os << newLine;
2359 
2360   // Output the aliases at the top level that can be deferred.
2361   state->getAliasState().printDeferredAliases(os, newLine);
2362 }
2363 
2364 void OperationPrinter::print(Operation *op) {
2365   // Track the location of this operation.
2366   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2367 
2368   os.indent(currentIndent);
2369   printOperation(op);
2370   printTrailingLocation(op->getLoc());
2371 }
2372 
2373 void OperationPrinter::printOperation(Operation *op) {
2374   if (size_t numResults = op->getNumResults()) {
2375     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2376       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2377       if (resultCount > 1)
2378         os << ':' << resultCount;
2379     };
2380 
2381     // Check to see if this operation has multiple result groups.
2382     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2383     if (!resultGroups.empty()) {
2384       // Interleave the groups excluding the last one, this one will be handled
2385       // separately.
2386       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2387         printResultGroup(resultGroups[i],
2388                          resultGroups[i + 1] - resultGroups[i]);
2389       });
2390       os << ", ";
2391       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2392 
2393     } else {
2394       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2395     }
2396 
2397     os << " = ";
2398   }
2399 
2400   // If requested, always print the generic form.
2401   if (!printerFlags.shouldPrintGenericOpForm()) {
2402     // Check to see if this is a known operation.  If so, use the registered
2403     // custom printer hook.
2404     if (auto *opInfo = op->getAbstractOperation()) {
2405       opInfo->printAssembly(op, *this);
2406       return;
2407     }
2408   }
2409 
2410   // Otherwise print with the generic assembly form.
2411   printGenericOp(op);
2412 }
2413 
2414 void OperationPrinter::printGenericOp(Operation *op) {
2415   os << '"';
2416   printEscapedString(op->getName().getStringRef(), os);
2417   os << "\"(";
2418   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2419   os << ')';
2420 
2421   // For terminators, print the list of successors and their operands.
2422   if (op->getNumSuccessors() != 0) {
2423     os << '[';
2424     interleaveComma(op->getSuccessors(),
2425                     [&](Block *successor) { printBlockName(successor); });
2426     os << ']';
2427   }
2428 
2429   // Print regions.
2430   if (op->getNumRegions() != 0) {
2431     os << " (";
2432     interleaveComma(op->getRegions(), [&](Region &region) {
2433       printRegion(region, /*printEntryBlockArgs=*/true,
2434                   /*printBlockTerminators=*/true);
2435     });
2436     os << ')';
2437   }
2438 
2439   auto attrs = op->getAttrs();
2440   printOptionalAttrDict(attrs);
2441 
2442   // Print the type signature of the operation.
2443   os << " : ";
2444   printFunctionalType(op);
2445 }
2446 
2447 void OperationPrinter::printBlockName(Block *block) {
2448   auto id = state->getSSANameState().getBlockID(block);
2449   if (id != SSANameState::NameSentinel)
2450     os << "^bb" << id;
2451   else
2452     os << "^INVALIDBLOCK";
2453 }
2454 
2455 void OperationPrinter::print(Block *block, bool printBlockArgs,
2456                              bool printBlockTerminator) {
2457   // Print the block label and argument list if requested.
2458   if (printBlockArgs) {
2459     os.indent(currentIndent);
2460     printBlockName(block);
2461 
2462     // Print the argument list if non-empty.
2463     if (!block->args_empty()) {
2464       os << '(';
2465       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2466         printValueID(arg);
2467         os << ": ";
2468         printType(arg.getType());
2469       });
2470       os << ')';
2471     }
2472     os << ':';
2473 
2474     // Print out some context information about the predecessors of this block.
2475     if (!block->getParent()) {
2476       os << "  // block is not in a region!";
2477     } else if (block->hasNoPredecessors()) {
2478       os << "  // no predecessors";
2479     } else if (auto *pred = block->getSinglePredecessor()) {
2480       os << "  // pred: ";
2481       printBlockName(pred);
2482     } else {
2483       // We want to print the predecessors in increasing numeric order, not in
2484       // whatever order the use-list is in, so gather and sort them.
2485       SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2486       for (auto *pred : block->getPredecessors())
2487         predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2488       llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2489 
2490       os << "  // " << predIDs.size() << " preds: ";
2491 
2492       interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2493         printBlockName(pred.second);
2494       });
2495     }
2496     os << newLine;
2497   }
2498 
2499   currentIndent += indentWidth;
2500   auto range = llvm::make_range(
2501       block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1));
2502   for (auto &op : range) {
2503     print(&op);
2504     os << newLine;
2505   }
2506   currentIndent -= indentWidth;
2507 }
2508 
2509 void OperationPrinter::printValueID(Value value, bool printResultNo,
2510                                     raw_ostream *streamOverride) const {
2511   state->getSSANameState().printValueID(value, printResultNo,
2512                                         streamOverride ? *streamOverride : os);
2513 }
2514 
2515 void OperationPrinter::printSuccessor(Block *successor) {
2516   printBlockName(successor);
2517 }
2518 
2519 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2520                                                 ValueRange succOperands) {
2521   printBlockName(successor);
2522   if (succOperands.empty())
2523     return;
2524 
2525   os << '(';
2526   interleaveComma(succOperands,
2527                   [this](Value operand) { printValueID(operand); });
2528   os << " : ";
2529   interleaveComma(succOperands,
2530                   [this](Value operand) { printType(operand.getType()); });
2531   os << ')';
2532 }
2533 
2534 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2535                                    bool printBlockTerminators) {
2536   os << " {" << newLine;
2537   if (!region.empty()) {
2538     auto *entryBlock = &region.front();
2539     print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
2540           printBlockTerminators);
2541     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2542       print(&b);
2543   }
2544   os.indent(currentIndent) << "}";
2545 }
2546 
2547 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2548                                               ValueRange operands) {
2549   AffineMap map = mapAttr.getValue();
2550   unsigned numDims = map.getNumDims();
2551   auto printValueName = [&](unsigned pos, bool isSymbol) {
2552     unsigned index = isSymbol ? numDims + pos : pos;
2553     assert(index < operands.size());
2554     if (isSymbol)
2555       os << "symbol(";
2556     printValueID(operands[index]);
2557     if (isSymbol)
2558       os << ')';
2559   };
2560 
2561   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2562     printAffineExpr(expr, printValueName);
2563   });
2564 }
2565 
2566 //===----------------------------------------------------------------------===//
2567 // print and dump methods
2568 //===----------------------------------------------------------------------===//
2569 
2570 void Attribute::print(raw_ostream &os) const {
2571   ModulePrinter(os).printAttribute(*this);
2572 }
2573 
2574 void Attribute::dump() const {
2575   print(llvm::errs());
2576   llvm::errs() << "\n";
2577 }
2578 
2579 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
2580 
2581 void Type::dump() { print(llvm::errs()); }
2582 
2583 void AffineMap::dump() const {
2584   print(llvm::errs());
2585   llvm::errs() << "\n";
2586 }
2587 
2588 void IntegerSet::dump() const {
2589   print(llvm::errs());
2590   llvm::errs() << "\n";
2591 }
2592 
2593 void AffineExpr::print(raw_ostream &os) const {
2594   if (!expr) {
2595     os << "<<NULL AFFINE EXPR>>";
2596     return;
2597   }
2598   ModulePrinter(os).printAffineExpr(*this);
2599 }
2600 
2601 void AffineExpr::dump() const {
2602   print(llvm::errs());
2603   llvm::errs() << "\n";
2604 }
2605 
2606 void AffineMap::print(raw_ostream &os) const {
2607   if (!map) {
2608     os << "<<NULL AFFINE MAP>>";
2609     return;
2610   }
2611   ModulePrinter(os).printAffineMap(*this);
2612 }
2613 
2614 void IntegerSet::print(raw_ostream &os) const {
2615   ModulePrinter(os).printIntegerSet(*this);
2616 }
2617 
2618 void Value::print(raw_ostream &os) {
2619   if (auto *op = getDefiningOp())
2620     return op->print(os);
2621   // TODO: Improve this.
2622   BlockArgument arg = this->cast<BlockArgument>();
2623   os << "<block argument> of type '" << arg.getType()
2624      << "' at index: " << arg.getArgNumber() << '\n';
2625 }
2626 void Value::print(raw_ostream &os, AsmState &state) {
2627   if (auto *op = getDefiningOp())
2628     return op->print(os, state);
2629 
2630   // TODO: Improve this.
2631   BlockArgument arg = this->cast<BlockArgument>();
2632   os << "<block argument> of type '" << arg.getType()
2633      << "' at index: " << arg.getArgNumber() << '\n';
2634 }
2635 
2636 void Value::dump() {
2637   print(llvm::errs());
2638   llvm::errs() << "\n";
2639 }
2640 
2641 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2642   // TODO: This doesn't necessarily capture all potential cases.
2643   // Currently, region arguments can be shadowed when printing the main
2644   // operation. If the IR hasn't been printed, this will produce the old SSA
2645   // name and not the shadowed name.
2646   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2647                                                  os);
2648 }
2649 
2650 void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
2651   // If this is a top level operation, we also print aliases.
2652   if (!getParent() && !flags.shouldUseLocalScope()) {
2653     AsmState state(this);
2654     state.getImpl().initializeAliases(this, flags);
2655     print(os, state, flags);
2656     return;
2657   }
2658 
2659   // Find the operation to number from based upon the provided flags.
2660   Operation *op = this;
2661   bool shouldUseLocalScope = flags.shouldUseLocalScope();
2662   do {
2663     // If we are printing local scope, stop at the first operation that is
2664     // isolated from above.
2665     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
2666       break;
2667 
2668     // Otherwise, traverse up to the next parent.
2669     Operation *parentOp = op->getParentOp();
2670     if (!parentOp)
2671       break;
2672     op = parentOp;
2673   } while (true);
2674 
2675   AsmState state(op);
2676   print(os, state, flags);
2677 }
2678 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2679   OperationPrinter printer(os, flags, state.getImpl());
2680   if (!getParent() && !flags.shouldUseLocalScope())
2681     printer.printTopLevelOperation(this);
2682   else
2683     printer.print(this);
2684 }
2685 
2686 void Operation::dump() {
2687   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2688   llvm::errs() << "\n";
2689 }
2690 
2691 void Block::print(raw_ostream &os) {
2692   Operation *parentOp = getParentOp();
2693   if (!parentOp) {
2694     os << "<<UNLINKED BLOCK>>\n";
2695     return;
2696   }
2697   // Get the top-level op.
2698   while (auto *nextOp = parentOp->getParentOp())
2699     parentOp = nextOp;
2700 
2701   AsmState state(parentOp);
2702   print(os, state);
2703 }
2704 void Block::print(raw_ostream &os, AsmState &state) {
2705   OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2706 }
2707 
2708 void Block::dump() { print(llvm::errs()); }
2709 
2710 /// Print out the name of the block without printing its body.
2711 void Block::printAsOperand(raw_ostream &os, bool printType) {
2712   Operation *parentOp = getParentOp();
2713   if (!parentOp) {
2714     os << "<<UNLINKED BLOCK>>\n";
2715     return;
2716   }
2717   AsmState state(parentOp);
2718   printAsOperand(os, state);
2719 }
2720 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2721   OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2722   printer.printBlockName(this);
2723 }
2724