1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/SubElementInterfaces.h"
28 #include "mlir/IR/Verifier.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/MapVector.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/ScopedHashTable.h"
35 #include "llvm/ADT/SetVector.h"
36 #include "llvm/ADT/SmallString.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringSet.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/Endian.h"
43 #include "llvm/Support/Regex.h"
44 #include "llvm/Support/SaveAndRestore.h"
45 #include "llvm/Support/Threading.h"
46
47 #include <tuple>
48
49 using namespace mlir;
50 using namespace mlir::detail;
51
52 #define DEBUG_TYPE "mlir-asm-printer"
53
print(raw_ostream & os) const54 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
55
dump() const56 void OperationName::dump() const { print(llvm::errs()); }
57
58 //===--------------------------------------------------------------------===//
59 // AsmParser
60 //===--------------------------------------------------------------------===//
61
62 AsmParser::~AsmParser() = default;
63 DialectAsmParser::~DialectAsmParser() = default;
64 OpAsmParser::~OpAsmParser() = default;
65
getContext() const66 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
67
68 //===----------------------------------------------------------------------===//
69 // DialectAsmPrinter
70 //===----------------------------------------------------------------------===//
71
72 DialectAsmPrinter::~DialectAsmPrinter() = default;
73
74 //===----------------------------------------------------------------------===//
75 // OpAsmPrinter
76 //===----------------------------------------------------------------------===//
77
78 OpAsmPrinter::~OpAsmPrinter() = default;
79
printFunctionalType(Operation * op)80 void OpAsmPrinter::printFunctionalType(Operation *op) {
81 auto &os = getStream();
82 os << '(';
83 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
84 // Print the types of null values as <<NULL TYPE>>.
85 *this << (operand ? operand.getType() : Type());
86 });
87 os << ") -> ";
88
89 // Print the result list. We don't parenthesize single result types unless
90 // it is a function (avoiding a grammar ambiguity).
91 bool wrapped = op->getNumResults() != 1;
92 if (!wrapped && op->getResult(0).getType() &&
93 op->getResult(0).getType().isa<FunctionType>())
94 wrapped = true;
95
96 if (wrapped)
97 os << '(';
98
99 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
100 // Print the types of null values as <<NULL TYPE>>.
101 *this << (result ? result.getType() : Type());
102 });
103
104 if (wrapped)
105 os << ')';
106 }
107
108 //===----------------------------------------------------------------------===//
109 // Operation OpAsm interface.
110 //===----------------------------------------------------------------------===//
111
112 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
113 #include "mlir/IR/OpAsmInterface.cpp.inc"
114
115 LogicalResult
parseResource(AsmParsedResourceEntry & entry) const116 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
117 return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
118 << "' for dialect '" << getDialect()->getNamespace()
119 << "'";
120 }
121
122 //===----------------------------------------------------------------------===//
123 // OpPrintingFlags
124 //===----------------------------------------------------------------------===//
125
126 namespace {
127 /// This struct contains command line options that can be used to initialize
128 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
129 /// for global command line options.
130 struct AsmPrinterOptions {
131 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
132 "mlir-print-elementsattrs-with-hex-if-larger",
133 llvm::cl::desc(
134 "Print DenseElementsAttrs with a hex string that have "
135 "more elements than the given upper limit (use -1 to disable)")};
136
137 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
138 "mlir-elide-elementsattrs-if-larger",
139 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
140 "more elements than the given upper limit")};
141
142 llvm::cl::opt<bool> printDebugInfoOpt{
143 "mlir-print-debuginfo", llvm::cl::init(false),
144 llvm::cl::desc("Print debug info in MLIR output")};
145
146 llvm::cl::opt<bool> printPrettyDebugInfoOpt{
147 "mlir-pretty-debuginfo", llvm::cl::init(false),
148 llvm::cl::desc("Print pretty debug info in MLIR output")};
149
150 // Use the generic op output form in the operation printer even if the custom
151 // form is defined.
152 llvm::cl::opt<bool> printGenericOpFormOpt{
153 "mlir-print-op-generic", llvm::cl::init(false),
154 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
155
156 llvm::cl::opt<bool> assumeVerifiedOpt{
157 "mlir-print-assume-verified", llvm::cl::init(false),
158 llvm::cl::desc("Skip op verification when using custom printers"),
159 llvm::cl::Hidden};
160
161 llvm::cl::opt<bool> printLocalScopeOpt{
162 "mlir-print-local-scope", llvm::cl::init(false),
163 llvm::cl::desc("Print with local scope and inline information (eliding "
164 "aliases for attributes, types, and locations")};
165
166 llvm::cl::opt<bool> printValueUsers{
167 "mlir-print-value-users", llvm::cl::init(false),
168 llvm::cl::desc(
169 "Print users of operation results and block arguments as a comment")};
170 };
171 } // namespace
172
173 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
174
175 /// Register a set of useful command-line options that can be used to configure
176 /// various flags within the AsmPrinter.
registerAsmPrinterCLOptions()177 void mlir::registerAsmPrinterCLOptions() {
178 // Make sure that the options struct has been initialized.
179 *clOptions;
180 }
181
182 /// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags()183 OpPrintingFlags::OpPrintingFlags()
184 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
185 printGenericOpFormFlag(false), assumeVerifiedFlag(false),
186 printLocalScope(false), printValueUsersFlag(false) {
187 // Initialize based upon command line options, if they are available.
188 if (!clOptions.isConstructed())
189 return;
190 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
191 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
192 printDebugInfoFlag = clOptions->printDebugInfoOpt;
193 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
194 printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
195 assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
196 printLocalScope = clOptions->printLocalScopeOpt;
197 printValueUsersFlag = clOptions->printValueUsers;
198 }
199
200 /// Enable the elision of large elements attributes, by printing a '...'
201 /// instead of the element data, when the number of elements is greater than
202 /// `largeElementLimit`. Note: The IR generated with this option is not
203 /// parsable.
204 OpPrintingFlags &
elideLargeElementsAttrs(int64_t largeElementLimit)205 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
206 elementsAttrElementLimit = largeElementLimit;
207 return *this;
208 }
209
210 /// Enable printing of debug information. If 'prettyForm' is set to true,
211 /// debug information is printed in a more readable 'pretty' form.
enableDebugInfo(bool prettyForm)212 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
213 printDebugInfoFlag = true;
214 printDebugInfoPrettyFormFlag = prettyForm;
215 return *this;
216 }
217
218 /// Always print operations in the generic form.
printGenericOpForm()219 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
220 printGenericOpFormFlag = true;
221 return *this;
222 }
223
224 /// Do not verify the operation when using custom operation printers.
assumeVerified()225 OpPrintingFlags &OpPrintingFlags::assumeVerified() {
226 assumeVerifiedFlag = true;
227 return *this;
228 }
229
230 /// Use local scope when printing the operation. This allows for using the
231 /// printer in a more localized and thread-safe setting, but may not necessarily
232 /// be identical of what the IR will look like when dumping the full module.
useLocalScope()233 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
234 printLocalScope = true;
235 return *this;
236 }
237
238 /// Print users of values as comments.
printValueUsers()239 OpPrintingFlags &OpPrintingFlags::printValueUsers() {
240 printValueUsersFlag = true;
241 return *this;
242 }
243
244 /// Return if the given ElementsAttr should be elided.
shouldElideElementsAttr(ElementsAttr attr) const245 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
246 return elementsAttrElementLimit &&
247 *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
248 !attr.isa<SplatElementsAttr>();
249 }
250
251 /// Return the size limit for printing large ElementsAttr.
getLargeElementsAttrLimit() const252 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
253 return elementsAttrElementLimit;
254 }
255
256 /// Return if debug information should be printed.
shouldPrintDebugInfo() const257 bool OpPrintingFlags::shouldPrintDebugInfo() const {
258 return printDebugInfoFlag;
259 }
260
261 /// Return if debug information should be printed in the pretty form.
shouldPrintDebugInfoPrettyForm() const262 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
263 return printDebugInfoPrettyFormFlag;
264 }
265
266 /// Return if operations should be printed in the generic form.
shouldPrintGenericOpForm() const267 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
268 return printGenericOpFormFlag;
269 }
270
271 /// Return if operation verification should be skipped.
shouldAssumeVerified() const272 bool OpPrintingFlags::shouldAssumeVerified() const {
273 return assumeVerifiedFlag;
274 }
275
276 /// Return if the printer should use local scope when dumping the IR.
shouldUseLocalScope() const277 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
278
279 /// Return if the printer should print users of values.
shouldPrintValueUsers() const280 bool OpPrintingFlags::shouldPrintValueUsers() const {
281 return printValueUsersFlag;
282 }
283
284 /// Returns true if an ElementsAttr with the given number of elements should be
285 /// printed with hex.
shouldPrintElementsAttrWithHex(int64_t numElements)286 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
287 // Check to see if a command line option was provided for the limit.
288 if (clOptions.isConstructed()) {
289 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
290 // -1 is used to disable hex printing.
291 if (clOptions->printElementsAttrWithHexIfLarger == -1)
292 return false;
293 return numElements > clOptions->printElementsAttrWithHexIfLarger;
294 }
295 }
296
297 // Otherwise, default to printing with hex if the number of elements is >100.
298 return numElements > 100;
299 }
300
301 //===----------------------------------------------------------------------===//
302 // NewLineCounter
303 //===----------------------------------------------------------------------===//
304
305 namespace {
306 /// This class is a simple formatter that emits a new line when inputted into a
307 /// stream, that enables counting the number of newlines emitted. This class
308 /// should be used whenever emitting newlines in the printer.
309 struct NewLineCounter {
310 unsigned curLine = 1;
311 };
312
operator <<(raw_ostream & os,NewLineCounter & newLine)313 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
314 ++newLine.curLine;
315 return os << '\n';
316 }
317 } // namespace
318
319 //===----------------------------------------------------------------------===//
320 // AliasInitializer
321 //===----------------------------------------------------------------------===//
322
323 namespace {
324 /// This class represents a specific instance of a symbol Alias.
325 class SymbolAlias {
326 public:
SymbolAlias(StringRef name,bool isDeferrable)327 SymbolAlias(StringRef name, bool isDeferrable)
328 : name(name), suffixIndex(0), hasSuffixIndex(false),
329 isDeferrable(isDeferrable) {}
SymbolAlias(StringRef name,uint32_t suffixIndex,bool isDeferrable)330 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
331 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
332 isDeferrable(isDeferrable) {}
333
334 /// Print this alias to the given stream.
print(raw_ostream & os) const335 void print(raw_ostream &os) const {
336 os << name;
337 if (hasSuffixIndex)
338 os << suffixIndex;
339 }
340
341 /// Returns true if this alias supports deferred resolution when parsing.
canBeDeferred() const342 bool canBeDeferred() const { return isDeferrable; }
343
344 private:
345 /// The main name of the alias.
346 StringRef name;
347 /// The optional suffix index of the alias, if multiple aliases had the same
348 /// name.
349 uint32_t suffixIndex : 30;
350 /// A flag indicating whether this alias has a suffix or not.
351 bool hasSuffixIndex : 1;
352 /// A flag indicating whether this alias may be deferred or not.
353 bool isDeferrable : 1;
354 };
355
356 /// This class represents a utility that initializes the set of attribute and
357 /// type aliases, without the need to store the extra information within the
358 /// main AliasState class or pass it around via function arguments.
359 class AliasInitializer {
360 public:
AliasInitializer(DialectInterfaceCollection<OpAsmDialectInterface> & interfaces,llvm::BumpPtrAllocator & aliasAllocator)361 AliasInitializer(
362 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
363 llvm::BumpPtrAllocator &aliasAllocator)
364 : interfaces(interfaces), aliasAllocator(aliasAllocator),
365 aliasOS(aliasBuffer) {}
366
367 void initialize(Operation *op, const OpPrintingFlags &printerFlags,
368 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
369 llvm::MapVector<Type, SymbolAlias> &typeToAlias);
370
371 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
372 /// set to true if the originator of this attribute can resolve the alias
373 /// after parsing has completed (e.g. in the case of operation locations).
374 void visit(Attribute attr, bool canBeDeferred = false);
375
376 /// Visit the given type to see if it has an alias.
377 void visit(Type type);
378
379 private:
380 /// Try to generate an alias for the provided symbol. If an alias is
381 /// generated, the provided alias mapping and reverse mapping are updated.
382 /// Returns success if an alias was generated, failure otherwise.
383 template <typename T>
384 LogicalResult
385 generateAlias(T symbol,
386 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
387
388 /// The set of asm interfaces within the context.
389 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
390
391 /// Mapping between an alias and the set of symbols mapped to it.
392 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
393 llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
394
395 /// An allocator used for alias names.
396 llvm::BumpPtrAllocator &aliasAllocator;
397
398 /// The set of visited attributes.
399 DenseSet<Attribute> visitedAttributes;
400
401 /// The set of attributes that have aliases *and* can be deferred.
402 DenseSet<Attribute> deferrableAttributes;
403
404 /// The set of visited types.
405 DenseSet<Type> visitedTypes;
406
407 /// Storage and stream used when generating an alias.
408 SmallString<32> aliasBuffer;
409 llvm::raw_svector_ostream aliasOS;
410 };
411
412 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
413 /// and merely collects the attributes and types that *would* be printed in a
414 /// normal print invocation so that we can generate proper aliases. This allows
415 /// for us to generate aliases only for the attributes and types that would be
416 /// in the output, and trims down unnecessary output.
417 class DummyAliasOperationPrinter : private OpAsmPrinter {
418 public:
DummyAliasOperationPrinter(const OpPrintingFlags & printerFlags,AliasInitializer & initializer)419 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
420 AliasInitializer &initializer)
421 : printerFlags(printerFlags), initializer(initializer) {}
422
423 /// Print the given operation.
print(Operation * op)424 void print(Operation *op) {
425 // Visit the operation location.
426 if (printerFlags.shouldPrintDebugInfo())
427 initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
428
429 // If requested, always print the generic form.
430 if (!printerFlags.shouldPrintGenericOpForm()) {
431 // Check to see if this is a known operation. If so, use the registered
432 // custom printer hook.
433 if (auto opInfo = op->getRegisteredInfo()) {
434 opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
435 return;
436 }
437 }
438
439 // Otherwise print with the generic assembly form.
440 printGenericOp(op);
441 }
442
443 private:
444 /// Print the given operation in the generic form.
printGenericOp(Operation * op,bool printOpName=true)445 void printGenericOp(Operation *op, bool printOpName = true) override {
446 // Consider nested operations for aliases.
447 if (op->getNumRegions() != 0) {
448 for (Region ®ion : op->getRegions())
449 printRegion(region, /*printEntryBlockArgs=*/true,
450 /*printBlockTerminators=*/true);
451 }
452
453 // Visit all the types used in the operation.
454 for (Type type : op->getOperandTypes())
455 printType(type);
456 for (Type type : op->getResultTypes())
457 printType(type);
458
459 // Consider the attributes of the operation for aliases.
460 for (const NamedAttribute &attr : op->getAttrs())
461 printAttribute(attr.getValue());
462 }
463
464 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
465 /// block are not printed. If 'printBlockTerminator' is false, the terminator
466 /// operation of the block is not printed.
print(Block * block,bool printBlockArgs=true,bool printBlockTerminator=true)467 void print(Block *block, bool printBlockArgs = true,
468 bool printBlockTerminator = true) {
469 // Consider the types of the block arguments for aliases if 'printBlockArgs'
470 // is set to true.
471 if (printBlockArgs) {
472 for (BlockArgument arg : block->getArguments()) {
473 printType(arg.getType());
474
475 // Visit the argument location.
476 if (printerFlags.shouldPrintDebugInfo())
477 // TODO: Allow deferring argument locations.
478 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
479 }
480 }
481
482 // Consider the operations within this block, ignoring the terminator if
483 // requested.
484 bool hasTerminator =
485 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
486 auto range = llvm::make_range(
487 block->begin(),
488 std::prev(block->end(),
489 (!hasTerminator || printBlockTerminator) ? 0 : 1));
490 for (Operation &op : range)
491 print(&op);
492 }
493
494 /// Print the given region.
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock=false)495 void printRegion(Region ®ion, bool printEntryBlockArgs,
496 bool printBlockTerminators,
497 bool printEmptyBlock = false) override {
498 if (region.empty())
499 return;
500
501 auto *entryBlock = ®ion.front();
502 print(entryBlock, printEntryBlockArgs, printBlockTerminators);
503 for (Block &b : llvm::drop_begin(region, 1))
504 print(&b);
505 }
506
printRegionArgument(BlockArgument arg,ArrayRef<NamedAttribute> argAttrs,bool omitType)507 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
508 bool omitType) override {
509 printType(arg.getType());
510 // Visit the argument location.
511 if (printerFlags.shouldPrintDebugInfo())
512 // TODO: Allow deferring argument locations.
513 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
514 }
515
516 /// Consider the given type to be printed for an alias.
printType(Type type)517 void printType(Type type) override { initializer.visit(type); }
518
519 /// Consider the given attribute to be printed for an alias.
printAttribute(Attribute attr)520 void printAttribute(Attribute attr) override { initializer.visit(attr); }
printAttributeWithoutType(Attribute attr)521 void printAttributeWithoutType(Attribute attr) override {
522 printAttribute(attr);
523 }
printAlias(Attribute attr)524 LogicalResult printAlias(Attribute attr) override {
525 initializer.visit(attr);
526 return success();
527 }
printAlias(Type type)528 LogicalResult printAlias(Type type) override {
529 initializer.visit(type);
530 return success();
531 }
532
533 /// Print the given set of attributes with names not included within
534 /// 'elidedAttrs'.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})535 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
536 ArrayRef<StringRef> elidedAttrs = {}) override {
537 if (attrs.empty())
538 return;
539 if (elidedAttrs.empty()) {
540 for (const NamedAttribute &attr : attrs)
541 printAttribute(attr.getValue());
542 return;
543 }
544 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
545 elidedAttrs.end());
546 for (const NamedAttribute &attr : attrs)
547 if (!elidedAttrsSet.contains(attr.getName().strref()))
548 printAttribute(attr.getValue());
549 }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})550 void printOptionalAttrDictWithKeyword(
551 ArrayRef<NamedAttribute> attrs,
552 ArrayRef<StringRef> elidedAttrs = {}) override {
553 printOptionalAttrDict(attrs, elidedAttrs);
554 }
555
556 /// Return a null stream as the output stream, this will ignore any data fed
557 /// to it.
getStream() const558 raw_ostream &getStream() const override { return os; }
559
560 /// The following are hooks of `OpAsmPrinter` that are not necessary for
561 /// determining potential aliases.
printFloat(const APFloat & value)562 void printFloat(const APFloat &value) override {}
printAffineMapOfSSAIds(AffineMapAttr,ValueRange)563 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
printAffineExprOfSSAIds(AffineExpr,ValueRange,ValueRange)564 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
printNewline()565 void printNewline() override {}
printOperand(Value)566 void printOperand(Value) override {}
printOperand(Value,raw_ostream & os)567 void printOperand(Value, raw_ostream &os) override {
568 // Users expect the output string to have at least the prefixed % to signal
569 // a value name. To maintain this invariant, emit a name even if it is
570 // guaranteed to go unused.
571 os << "%";
572 }
printKeywordOrString(StringRef)573 void printKeywordOrString(StringRef) override {}
printSymbolName(StringRef)574 void printSymbolName(StringRef) override {}
printSuccessor(Block *)575 void printSuccessor(Block *) override {}
printSuccessorAndUseList(Block *,ValueRange)576 void printSuccessorAndUseList(Block *, ValueRange) override {}
shadowRegionArgs(Region &,ValueRange)577 void shadowRegionArgs(Region &, ValueRange) override {}
578
579 /// The printer flags to use when determining potential aliases.
580 const OpPrintingFlags &printerFlags;
581
582 /// The initializer to use when identifying aliases.
583 AliasInitializer &initializer;
584
585 /// A dummy output stream.
586 mutable llvm::raw_null_ostream os;
587 };
588 } // namespace
589
590 /// Sanitize the given name such that it can be used as a valid identifier. If
591 /// the string needs to be modified in any way, the provided buffer is used to
592 /// store the new copy,
sanitizeIdentifier(StringRef name,SmallString<16> & buffer,StringRef allowedPunctChars="$._-",bool allowTrailingDigit=true)593 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
594 StringRef allowedPunctChars = "$._-",
595 bool allowTrailingDigit = true) {
596 assert(!name.empty() && "Shouldn't have an empty name here");
597
598 auto copyNameToBuffer = [&] {
599 for (char ch : name) {
600 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
601 buffer.push_back(ch);
602 else if (ch == ' ')
603 buffer.push_back('_');
604 else
605 buffer.append(llvm::utohexstr((unsigned char)ch));
606 }
607 };
608
609 // Check to see if this name is valid. If it starts with a digit, then it
610 // could conflict with the autogenerated numeric ID's, so add an underscore
611 // prefix to avoid problems.
612 if (isdigit(name[0])) {
613 buffer.push_back('_');
614 copyNameToBuffer();
615 return buffer;
616 }
617
618 // If the name ends with a trailing digit, add a '_' to avoid potential
619 // conflicts with autogenerated ID's.
620 if (!allowTrailingDigit && isdigit(name.back())) {
621 copyNameToBuffer();
622 buffer.push_back('_');
623 return buffer;
624 }
625
626 // Check to see that the name consists of only valid identifier characters.
627 for (char ch : name) {
628 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
629 copyNameToBuffer();
630 return buffer;
631 }
632 }
633
634 // If there are no invalid characters, return the original name.
635 return name;
636 }
637
638 /// Given a collection of aliases and symbols, initialize a mapping from a
639 /// symbol to a given alias.
640 template <typename T>
641 static void
initializeAliases(llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol,llvm::MapVector<T,SymbolAlias> & symbolToAlias,DenseSet<T> * deferrableAliases=nullptr)642 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
643 llvm::MapVector<T, SymbolAlias> &symbolToAlias,
644 DenseSet<T> *deferrableAliases = nullptr) {
645 std::vector<std::pair<StringRef, std::vector<T>>> aliases =
646 aliasToSymbol.takeVector();
647 llvm::array_pod_sort(aliases.begin(), aliases.end(),
648 [](const auto *lhs, const auto *rhs) {
649 return lhs->first.compare(rhs->first);
650 });
651
652 for (auto &it : aliases) {
653 // If there is only one instance for this alias, use the name directly.
654 if (it.second.size() == 1) {
655 T symbol = it.second.front();
656 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
657 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
658 continue;
659 }
660 // Otherwise, add the index to the name.
661 for (int i = 0, e = it.second.size(); i < e; ++i) {
662 T symbol = it.second[i];
663 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
664 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
665 }
666 }
667 }
668
initialize(Operation * op,const OpPrintingFlags & printerFlags,llvm::MapVector<Attribute,SymbolAlias> & attrToAlias,llvm::MapVector<Type,SymbolAlias> & typeToAlias)669 void AliasInitializer::initialize(
670 Operation *op, const OpPrintingFlags &printerFlags,
671 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
672 llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
673 // Use a dummy printer when walking the IR so that we can collect the
674 // attributes/types that will actually be used during printing when
675 // considering aliases.
676 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
677 aliasPrinter.print(op);
678
679 // Initialize the aliases sorted by name.
680 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
681 initializeAliases(aliasToType, typeToAlias);
682 }
683
visit(Attribute attr,bool canBeDeferred)684 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
685 if (!visitedAttributes.insert(attr).second) {
686 // If this attribute already has an alias and this instance can't be
687 // deferred, make sure that the alias isn't deferred.
688 if (!canBeDeferred)
689 deferrableAttributes.erase(attr);
690 return;
691 }
692
693 // Try to generate an alias for this attribute.
694 if (succeeded(generateAlias(attr, aliasToAttr))) {
695 if (canBeDeferred)
696 deferrableAttributes.insert(attr);
697 return;
698 }
699
700 // Check for any sub elements.
701 if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
702 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
703 [&](Type type) { visit(type); });
704 }
705 }
706
visit(Type type)707 void AliasInitializer::visit(Type type) {
708 if (!visitedTypes.insert(type).second)
709 return;
710
711 // Try to generate an alias for this type.
712 if (succeeded(generateAlias(type, aliasToType)))
713 return;
714
715 // Check for any sub elements.
716 if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
717 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
718 [&](Type type) { visit(type); });
719 }
720 }
721
722 template <typename T>
generateAlias(T symbol,llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol)723 LogicalResult AliasInitializer::generateAlias(
724 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
725 SmallString<32> nameBuffer;
726 for (const auto &interface : interfaces) {
727 OpAsmDialectInterface::AliasResult result =
728 interface.getAlias(symbol, aliasOS);
729 if (result == OpAsmDialectInterface::AliasResult::NoAlias)
730 continue;
731 nameBuffer = std::move(aliasBuffer);
732 assert(!nameBuffer.empty() && "expected valid alias name");
733 if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
734 break;
735 }
736
737 if (nameBuffer.empty())
738 return failure();
739
740 SmallString<16> tempBuffer;
741 StringRef name =
742 sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
743 /*allowTrailingDigit=*/false);
744 name = name.copy(aliasAllocator);
745 aliasToSymbol[name].push_back(symbol);
746 return success();
747 }
748
749 //===----------------------------------------------------------------------===//
750 // AliasState
751 //===----------------------------------------------------------------------===//
752
753 namespace {
754 /// This class manages the state for type and attribute aliases.
755 class AliasState {
756 public:
757 // Initialize the internal aliases.
758 void
759 initialize(Operation *op, const OpPrintingFlags &printerFlags,
760 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
761
762 /// Get an alias for the given attribute if it has one and print it in `os`.
763 /// Returns success if an alias was printed, failure otherwise.
764 LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
765
766 /// Get an alias for the given type if it has one and print it in `os`.
767 /// Returns success if an alias was printed, failure otherwise.
768 LogicalResult getAlias(Type ty, raw_ostream &os) const;
769
770 /// Print all of the referenced aliases that can not be resolved in a deferred
771 /// manner.
printNonDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const772 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
773 printAliases(os, newLine, /*isDeferred=*/false);
774 }
775
776 /// Print all of the referenced aliases that support deferred resolution.
printDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const777 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
778 printAliases(os, newLine, /*isDeferred=*/true);
779 }
780
781 private:
782 /// Print all of the referenced aliases that support the provided resolution
783 /// behavior.
784 void printAliases(raw_ostream &os, NewLineCounter &newLine,
785 bool isDeferred) const;
786
787 /// Mapping between attribute and alias.
788 llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
789 /// Mapping between type and alias.
790 llvm::MapVector<Type, SymbolAlias> typeToAlias;
791
792 /// An allocator used for alias names.
793 llvm::BumpPtrAllocator aliasAllocator;
794 };
795 } // namespace
796
initialize(Operation * op,const OpPrintingFlags & printerFlags,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)797 void AliasState::initialize(
798 Operation *op, const OpPrintingFlags &printerFlags,
799 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
800 AliasInitializer initializer(interfaces, aliasAllocator);
801 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
802 }
803
getAlias(Attribute attr,raw_ostream & os) const804 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
805 auto it = attrToAlias.find(attr);
806 if (it == attrToAlias.end())
807 return failure();
808 it->second.print(os << '#');
809 return success();
810 }
811
getAlias(Type ty,raw_ostream & os) const812 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
813 auto it = typeToAlias.find(ty);
814 if (it == typeToAlias.end())
815 return failure();
816
817 it->second.print(os << '!');
818 return success();
819 }
820
printAliases(raw_ostream & os,NewLineCounter & newLine,bool isDeferred) const821 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
822 bool isDeferred) const {
823 auto filterFn = [=](const auto &aliasIt) {
824 return aliasIt.second.canBeDeferred() == isDeferred;
825 };
826 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
827 it.second.print(os << '#');
828 os << " = " << it.first << newLine;
829 }
830 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
831 it.second.print(os << '!');
832 os << " = " << it.first << newLine;
833 }
834 }
835
836 //===----------------------------------------------------------------------===//
837 // SSANameState
838 //===----------------------------------------------------------------------===//
839
840 namespace {
841 /// Info about block printing: a number which is its position in the visitation
842 /// order, and a name that is used to print reference to it, e.g. ^bb42.
843 struct BlockInfo {
844 int ordering;
845 StringRef name;
846 };
847
848 /// This class manages the state of SSA value names.
849 class SSANameState {
850 public:
851 /// A sentinel value used for values with names set.
852 enum : unsigned { NameSentinel = ~0U };
853
854 SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
855
856 /// Print the SSA identifier for the given value to 'stream'. If
857 /// 'printResultNo' is true, it also presents the result number ('#' number)
858 /// of this value.
859 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
860
861 /// Print the operation identifier.
862 void printOperationID(Operation *op, raw_ostream &stream) const;
863
864 /// Return the result indices for each of the result groups registered by this
865 /// operation, or empty if none exist.
866 ArrayRef<int> getOpResultGroups(Operation *op);
867
868 /// Get the info for the given block.
869 BlockInfo getBlockInfo(Block *block);
870
871 /// Renumber the arguments for the specified region to the same names as the
872 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
873 /// details.
874 void shadowRegionArgs(Region ®ion, ValueRange namesToUse);
875
876 private:
877 /// Number the SSA values within the given IR unit.
878 void numberValuesInRegion(Region ®ion);
879 void numberValuesInBlock(Block &block);
880 void numberValuesInOp(Operation &op);
881
882 /// Given a result of an operation 'result', find the result group head
883 /// 'lookupValue' and the result of 'result' within that group in
884 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
885 /// has more than 1 result.
886 void getResultIDAndNumber(OpResult result, Value &lookupValue,
887 Optional<int> &lookupResultNo) const;
888
889 /// Set a special value name for the given value.
890 void setValueName(Value value, StringRef name);
891
892 /// Uniques the given value name within the printer. If the given name
893 /// conflicts, it is automatically renamed.
894 StringRef uniqueValueName(StringRef name);
895
896 /// This is the value ID for each SSA value. If this returns NameSentinel,
897 /// then the valueID has an entry in valueNames.
898 DenseMap<Value, unsigned> valueIDs;
899 DenseMap<Value, StringRef> valueNames;
900
901 /// When printing users of values, an operation without a result might
902 /// be the user. This map holds ids for such operations.
903 DenseMap<Operation *, unsigned> operationIDs;
904
905 /// This is a map of operations that contain multiple named result groups,
906 /// i.e. there may be multiple names for the results of the operation. The
907 /// value of this map are the result numbers that start a result group.
908 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
909
910 /// This maps blocks to there visitation number in the current region as well
911 /// as the string representing their name.
912 DenseMap<Block *, BlockInfo> blockNames;
913
914 /// This keeps track of all of the non-numeric names that are in flight,
915 /// allowing us to check for duplicates.
916 /// Note: the value of the map is unused.
917 llvm::ScopedHashTable<StringRef, char> usedNames;
918 llvm::BumpPtrAllocator usedNameAllocator;
919
920 /// This is the next value ID to assign in numbering.
921 unsigned nextValueID = 0;
922 /// This is the next ID to assign to a region entry block argument.
923 unsigned nextArgumentID = 0;
924 /// This is the next ID to assign when a name conflict is detected.
925 unsigned nextConflictID = 0;
926
927 /// These are the printing flags. They control, eg., whether to print in
928 /// generic form.
929 OpPrintingFlags printerFlags;
930 };
931 } // namespace
932
SSANameState(Operation * op,const OpPrintingFlags & printerFlags)933 SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
934 : printerFlags(printerFlags) {
935 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
936 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
937 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
938
939 // The naming context includes `nextValueID`, `nextArgumentID`,
940 // `nextConflictID` and `usedNames` scoped HashTable. This information is
941 // carried from the parent region.
942 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
943 using NamingContext =
944 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
945
946 // Allocator for UsedNamesScopeTy
947 llvm::BumpPtrAllocator allocator;
948
949 // Add a scope for the top level operation.
950 auto *topLevelNamesScope =
951 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
952
953 SmallVector<NamingContext, 8> nameContext;
954 for (Region ®ion : op->getRegions())
955 nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID,
956 nextConflictID, topLevelNamesScope));
957
958 numberValuesInOp(*op);
959
960 while (!nameContext.empty()) {
961 Region *region;
962 UsedNamesScopeTy *parentScope;
963 std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
964 nameContext.pop_back_val();
965
966 // When we switch from one subtree to another, pop the scopes(needless)
967 // until the parent scope.
968 while (usedNames.getCurScope() != parentScope) {
969 usedNames.getCurScope()->~UsedNamesScopeTy();
970 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
971 "top level parentScope must be a nullptr");
972 }
973
974 // Add a scope for the current region.
975 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
976 UsedNamesScopeTy(usedNames);
977
978 numberValuesInRegion(*region);
979
980 for (Operation &op : region->getOps())
981 for (Region ®ion : op.getRegions())
982 nameContext.push_back(std::make_tuple(®ion, nextValueID,
983 nextArgumentID, nextConflictID,
984 curNamesScope));
985 }
986
987 // Manually remove all the scopes.
988 while (usedNames.getCurScope() != nullptr)
989 usedNames.getCurScope()->~UsedNamesScopeTy();
990 }
991
printValueID(Value value,bool printResultNo,raw_ostream & stream) const992 void SSANameState::printValueID(Value value, bool printResultNo,
993 raw_ostream &stream) const {
994 if (!value) {
995 stream << "<<NULL VALUE>>";
996 return;
997 }
998
999 Optional<int> resultNo;
1000 auto lookupValue = value;
1001
1002 // If this is an operation result, collect the head lookup value of the result
1003 // group and the result number of 'result' within that group.
1004 if (OpResult result = value.dyn_cast<OpResult>())
1005 getResultIDAndNumber(result, lookupValue, resultNo);
1006
1007 auto it = valueIDs.find(lookupValue);
1008 if (it == valueIDs.end()) {
1009 stream << "<<UNKNOWN SSA VALUE>>";
1010 return;
1011 }
1012
1013 stream << '%';
1014 if (it->second != NameSentinel) {
1015 stream << it->second;
1016 } else {
1017 auto nameIt = valueNames.find(lookupValue);
1018 assert(nameIt != valueNames.end() && "Didn't have a name entry?");
1019 stream << nameIt->second;
1020 }
1021
1022 if (resultNo && printResultNo)
1023 stream << '#' << resultNo;
1024 }
1025
printOperationID(Operation * op,raw_ostream & stream) const1026 void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
1027 auto it = operationIDs.find(op);
1028 if (it == operationIDs.end()) {
1029 stream << "<<UNKOWN OPERATION>>";
1030 } else {
1031 stream << '%' << it->second;
1032 }
1033 }
1034
getOpResultGroups(Operation * op)1035 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
1036 auto it = opResultGroups.find(op);
1037 return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
1038 }
1039
getBlockInfo(Block * block)1040 BlockInfo SSANameState::getBlockInfo(Block *block) {
1041 auto it = blockNames.find(block);
1042 BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
1043 return it != blockNames.end() ? it->second : invalidBlock;
1044 }
1045
shadowRegionArgs(Region & region,ValueRange namesToUse)1046 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
1047 assert(!region.empty() && "cannot shadow arguments of an empty region");
1048 assert(region.getNumArguments() == namesToUse.size() &&
1049 "incorrect number of names passed in");
1050 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1051 "only KnownIsolatedFromAbove ops can shadow names");
1052
1053 SmallVector<char, 16> nameStr;
1054 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1055 auto nameToUse = namesToUse[i];
1056 if (nameToUse == nullptr)
1057 continue;
1058 auto nameToReplace = region.getArgument(i);
1059
1060 nameStr.clear();
1061 llvm::raw_svector_ostream nameStream(nameStr);
1062 printValueID(nameToUse, /*printResultNo=*/true, nameStream);
1063
1064 // Entry block arguments should already have a pretty "arg" name.
1065 assert(valueIDs[nameToReplace] == NameSentinel);
1066
1067 // Use the name without the leading %.
1068 auto name = StringRef(nameStream.str()).drop_front();
1069
1070 // Overwrite the name.
1071 valueNames[nameToReplace] = name.copy(usedNameAllocator);
1072 }
1073 }
1074
numberValuesInRegion(Region & region)1075 void SSANameState::numberValuesInRegion(Region ®ion) {
1076 auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1077 assert(!valueIDs.count(arg) && "arg numbered multiple times");
1078 assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion &&
1079 "arg not defined in current region");
1080 setValueName(arg, name);
1081 };
1082
1083 if (!printerFlags.shouldPrintGenericOpForm()) {
1084 if (Operation *op = region.getParentOp()) {
1085 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1086 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1087 }
1088 }
1089
1090 // Number the values within this region in a breadth-first order.
1091 unsigned nextBlockID = 0;
1092 for (auto &block : region) {
1093 // Each block gets a unique ID, and all of the operations within it get
1094 // numbered as well.
1095 auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1096 if (blockInfoIt.second) {
1097 // This block hasn't been named through `getAsmBlockArgumentNames`, use
1098 // default `^bbNNN` format.
1099 std::string name;
1100 llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1101 blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1102 }
1103 blockInfoIt.first->second.ordering = nextBlockID++;
1104
1105 numberValuesInBlock(block);
1106 }
1107 }
1108
numberValuesInBlock(Block & block)1109 void SSANameState::numberValuesInBlock(Block &block) {
1110 // Number the block arguments. We give entry block arguments a special name
1111 // 'arg'.
1112 bool isEntryBlock = block.isEntryBlock();
1113 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1114 llvm::raw_svector_ostream specialName(specialNameBuffer);
1115 for (auto arg : block.getArguments()) {
1116 if (valueIDs.count(arg))
1117 continue;
1118 if (isEntryBlock) {
1119 specialNameBuffer.resize(strlen("arg"));
1120 specialName << nextArgumentID++;
1121 }
1122 setValueName(arg, specialName.str());
1123 }
1124
1125 // Number the operations in this block.
1126 for (auto &op : block)
1127 numberValuesInOp(op);
1128 }
1129
numberValuesInOp(Operation & op)1130 void SSANameState::numberValuesInOp(Operation &op) {
1131 // Function used to set the special result names for the operation.
1132 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1133 auto setResultNameFn = [&](Value result, StringRef name) {
1134 assert(!valueIDs.count(result) && "result numbered multiple times");
1135 assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1136 setValueName(result, name);
1137
1138 // Record the result number for groups not anchored at 0.
1139 if (int resultNo = result.cast<OpResult>().getResultNumber())
1140 resultGroups.push_back(resultNo);
1141 };
1142 // Operations can customize the printing of block names in OpAsmOpInterface.
1143 auto setBlockNameFn = [&](Block *block, StringRef name) {
1144 assert(block->getParentOp() == &op &&
1145 "getAsmBlockArgumentNames callback invoked on a block not directly "
1146 "nested under the current operation");
1147 assert(!blockNames.count(block) && "block numbered multiple times");
1148 SmallString<16> tmpBuffer{"^"};
1149 name = sanitizeIdentifier(name, tmpBuffer);
1150 if (name.data() != tmpBuffer.data()) {
1151 tmpBuffer.append(name);
1152 name = tmpBuffer.str();
1153 }
1154 name = name.copy(usedNameAllocator);
1155 blockNames[block] = {-1, name};
1156 };
1157
1158 if (!printerFlags.shouldPrintGenericOpForm()) {
1159 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1160 asmInterface.getAsmBlockNames(setBlockNameFn);
1161 asmInterface.getAsmResultNames(setResultNameFn);
1162 }
1163 }
1164
1165 unsigned numResults = op.getNumResults();
1166 if (numResults == 0) {
1167 // If value users should be printed, operations with no result need an id.
1168 if (printerFlags.shouldPrintValueUsers()) {
1169 if (operationIDs.try_emplace(&op, nextValueID).second)
1170 ++nextValueID;
1171 }
1172 return;
1173 }
1174 Value resultBegin = op.getResult(0);
1175
1176 // If the first result wasn't numbered, give it a default number.
1177 if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1178 ++nextValueID;
1179
1180 // If this operation has multiple result groups, mark it.
1181 if (resultGroups.size() != 1) {
1182 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1183 opResultGroups.try_emplace(&op, std::move(resultGroups));
1184 }
1185 }
1186
getResultIDAndNumber(OpResult result,Value & lookupValue,Optional<int> & lookupResultNo) const1187 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1188 Optional<int> &lookupResultNo) const {
1189 Operation *owner = result.getOwner();
1190 if (owner->getNumResults() == 1)
1191 return;
1192 int resultNo = result.getResultNumber();
1193
1194 // If this operation has multiple result groups, we will need to find the
1195 // one corresponding to this result.
1196 auto resultGroupIt = opResultGroups.find(owner);
1197 if (resultGroupIt == opResultGroups.end()) {
1198 // If not, just use the first result.
1199 lookupResultNo = resultNo;
1200 lookupValue = owner->getResult(0);
1201 return;
1202 }
1203
1204 // Find the correct index using a binary search, as the groups are ordered.
1205 ArrayRef<int> resultGroups = resultGroupIt->second;
1206 const auto *it = llvm::upper_bound(resultGroups, resultNo);
1207 int groupResultNo = 0, groupSize = 0;
1208
1209 // If there are no smaller elements, the last result group is the lookup.
1210 if (it == resultGroups.end()) {
1211 groupResultNo = resultGroups.back();
1212 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1213 } else {
1214 // Otherwise, the previous element is the lookup.
1215 groupResultNo = *std::prev(it);
1216 groupSize = *it - groupResultNo;
1217 }
1218
1219 // We only record the result number for a group of size greater than 1.
1220 if (groupSize != 1)
1221 lookupResultNo = resultNo - groupResultNo;
1222 lookupValue = owner->getResult(groupResultNo);
1223 }
1224
setValueName(Value value,StringRef name)1225 void SSANameState::setValueName(Value value, StringRef name) {
1226 // If the name is empty, the value uses the default numbering.
1227 if (name.empty()) {
1228 valueIDs[value] = nextValueID++;
1229 return;
1230 }
1231
1232 valueIDs[value] = NameSentinel;
1233 valueNames[value] = uniqueValueName(name);
1234 }
1235
uniqueValueName(StringRef name)1236 StringRef SSANameState::uniqueValueName(StringRef name) {
1237 SmallString<16> tmpBuffer;
1238 name = sanitizeIdentifier(name, tmpBuffer);
1239
1240 // Check to see if this name is already unique.
1241 if (!usedNames.count(name)) {
1242 name = name.copy(usedNameAllocator);
1243 } else {
1244 // Otherwise, we had a conflict - probe until we find a unique name. This
1245 // is guaranteed to terminate (and usually in a single iteration) because it
1246 // generates new names by incrementing nextConflictID.
1247 SmallString<64> probeName(name);
1248 probeName.push_back('_');
1249 while (true) {
1250 probeName += llvm::utostr(nextConflictID++);
1251 if (!usedNames.count(probeName)) {
1252 name = probeName.str().copy(usedNameAllocator);
1253 break;
1254 }
1255 probeName.resize(name.size() + 1);
1256 }
1257 }
1258
1259 usedNames.insert(name, char());
1260 return name;
1261 }
1262
1263 //===----------------------------------------------------------------------===//
1264 // Resources
1265 //===----------------------------------------------------------------------===//
1266
1267 AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
1268 AsmResourceBuilder::~AsmResourceBuilder() = default;
1269 AsmResourceParser::~AsmResourceParser() = default;
1270 AsmResourcePrinter::~AsmResourcePrinter() = default;
1271
1272 //===----------------------------------------------------------------------===//
1273 // AsmState
1274 //===----------------------------------------------------------------------===//
1275
1276 namespace mlir {
1277 namespace detail {
1278 class AsmStateImpl {
1279 public:
AsmStateImpl(Operation * op,const OpPrintingFlags & printerFlags,AsmState::LocationMap * locationMap)1280 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1281 AsmState::LocationMap *locationMap)
1282 : interfaces(op->getContext()), nameState(op, printerFlags),
1283 printerFlags(printerFlags), locationMap(locationMap) {}
1284
1285 /// Initialize the alias state to enable the printing of aliases.
initializeAliases(Operation * op)1286 void initializeAliases(Operation *op) {
1287 aliasState.initialize(op, printerFlags, interfaces);
1288 }
1289
1290 /// Get the state used for aliases.
getAliasState()1291 AliasState &getAliasState() { return aliasState; }
1292
1293 /// Get the state used for SSA names.
getSSANameState()1294 SSANameState &getSSANameState() { return nameState; }
1295
1296 /// Return the dialects within the context that implement
1297 /// OpAsmDialectInterface.
getDialectInterfaces()1298 DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
1299 return interfaces;
1300 }
1301
1302 /// Return the non-dialect resource printers.
getResourcePrinters()1303 auto getResourcePrinters() {
1304 return llvm::make_pointee_range(externalResourcePrinters);
1305 }
1306
1307 /// Get the printer flags.
getPrinterFlags() const1308 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1309
1310 /// Register the location, line and column, within the buffer that the given
1311 /// operation was printed at.
registerOperationLocation(Operation * op,unsigned line,unsigned col)1312 void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1313 if (locationMap)
1314 (*locationMap)[op] = std::make_pair(line, col);
1315 }
1316
1317 private:
1318 /// Collection of OpAsm interfaces implemented in the context.
1319 DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1320
1321 /// A collection of non-dialect resource printers.
1322 SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
1323
1324 /// The state used for attribute and type aliases.
1325 AliasState aliasState;
1326
1327 /// The state used for SSA value names.
1328 SSANameState nameState;
1329
1330 /// Flags that control op output.
1331 OpPrintingFlags printerFlags;
1332
1333 /// An optional location map to be populated.
1334 AsmState::LocationMap *locationMap;
1335
1336 // Allow direct access to the impl fields.
1337 friend AsmState;
1338 };
1339 } // namespace detail
1340 } // namespace mlir
1341
1342 /// Verifies the operation and switches to generic op printing if verification
1343 /// fails. We need to do this because custom print functions may fail for
1344 /// invalid ops.
verifyOpAndAdjustFlags(Operation * op,OpPrintingFlags printerFlags)1345 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
1346 OpPrintingFlags printerFlags) {
1347 if (printerFlags.shouldPrintGenericOpForm() ||
1348 printerFlags.shouldAssumeVerified())
1349 return printerFlags;
1350
1351 LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: "
1352 << op->getName() << "\n");
1353
1354 // Ignore errors emitted by the verifier. We check the thread id to avoid
1355 // consuming other threads' errors.
1356 auto parentThreadId = llvm::get_threadid();
1357 ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
1358 if (parentThreadId == llvm::get_threadid()) {
1359 LLVM_DEBUG({
1360 diag.print(llvm::dbgs());
1361 llvm::dbgs() << "\n";
1362 });
1363 return success();
1364 }
1365 return failure();
1366 });
1367 if (failed(verify(op))) {
1368 LLVM_DEBUG(llvm::dbgs()
1369 << DEBUG_TYPE << ": '" << op->getName()
1370 << "' failed to verify and will be printed in generic form\n");
1371 printerFlags.printGenericOpForm();
1372 }
1373
1374 return printerFlags;
1375 }
1376
AsmState(Operation * op,const OpPrintingFlags & printerFlags,LocationMap * locationMap)1377 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1378 LocationMap *locationMap)
1379 : impl(std::make_unique<AsmStateImpl>(
1380 op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
1381 AsmState::~AsmState() = default;
1382
getPrinterFlags() const1383 const OpPrintingFlags &AsmState::getPrinterFlags() const {
1384 return impl->getPrinterFlags();
1385 }
1386
attachResourcePrinter(std::unique_ptr<AsmResourcePrinter> printer)1387 void AsmState::attachResourcePrinter(
1388 std::unique_ptr<AsmResourcePrinter> printer) {
1389 impl->externalResourcePrinters.emplace_back(std::move(printer));
1390 }
1391
1392 //===----------------------------------------------------------------------===//
1393 // AsmPrinter::Impl
1394 //===----------------------------------------------------------------------===//
1395
1396 namespace mlir {
1397 class AsmPrinter::Impl {
1398 public:
Impl(raw_ostream & os,OpPrintingFlags flags=llvm::None,AsmStateImpl * state=nullptr)1399 Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1400 AsmStateImpl *state = nullptr)
1401 : os(os), printerFlags(flags), state(state) {}
Impl(Impl & other)1402 explicit Impl(Impl &other)
1403 : Impl(other.os, other.printerFlags, other.state) {}
1404
1405 /// Returns the output stream of the printer.
getStream()1406 raw_ostream &getStream() { return os; }
1407
1408 template <typename Container, typename UnaryFunctor>
interleaveComma(const Container & c,UnaryFunctor eachFn) const1409 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
1410 llvm::interleaveComma(c, os, eachFn);
1411 }
1412
1413 /// This enum describes the different kinds of elision for the type of an
1414 /// attribute when printing it.
1415 enum class AttrTypeElision {
1416 /// The type must not be elided,
1417 Never,
1418 /// The type may be elided when it matches the default used in the parser
1419 /// (for example i64 is the default for integer attributes).
1420 May,
1421 /// The type must be elided.
1422 Must
1423 };
1424
1425 /// Print the given attribute.
1426 void printAttribute(Attribute attr,
1427 AttrTypeElision typeElision = AttrTypeElision::Never);
1428
1429 /// Print the alias for the given attribute, return failure if no alias could
1430 /// be printed.
1431 LogicalResult printAlias(Attribute attr);
1432
1433 void printType(Type type);
1434
1435 /// Print the alias for the given type, return failure if no alias could
1436 /// be printed.
1437 LogicalResult printAlias(Type type);
1438
1439 /// Print the given location to the stream. If `allowAlias` is true, this
1440 /// allows for the internal location to use an attribute alias.
1441 void printLocation(LocationAttr loc, bool allowAlias = false);
1442
1443 /// Print a reference to the given resource that is owned by the given
1444 /// dialect.
printResourceHandle(const AsmDialectResourceHandle & resource)1445 void printResourceHandle(const AsmDialectResourceHandle &resource) {
1446 auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
1447 os << interface->getResourceKey(resource);
1448 dialectResources[resource.getDialect()].insert(resource);
1449 }
1450
1451 void printAffineMap(AffineMap map);
1452 void
1453 printAffineExpr(AffineExpr expr,
1454 function_ref<void(unsigned, bool)> printValueName = nullptr);
1455 void printAffineConstraint(AffineExpr expr, bool isEq);
1456 void printIntegerSet(IntegerSet set);
1457
1458 protected:
1459 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1460 ArrayRef<StringRef> elidedAttrs = {},
1461 bool withKeyword = false);
1462 void printNamedAttribute(NamedAttribute attr);
1463 void printTrailingLocation(Location loc, bool allowAlias = true);
1464 void printLocationInternal(LocationAttr loc, bool pretty = false);
1465
1466 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1467 /// used instead of individual elements when the elements attr is large.
1468 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1469
1470 /// Print a dense string elements attribute.
1471 void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1472
1473 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1474 /// used instead of individual elements when the elements attr is large.
1475 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1476 bool allowHex);
1477
1478 void printDialectAttribute(Attribute attr);
1479 void printDialectType(Type type);
1480
1481 /// Print an escaped string, wrapped with "".
1482 void printEscapedString(StringRef str);
1483
1484 /// Print a hex string, wrapped with "".
1485 void printHexString(StringRef str);
1486 void printHexString(ArrayRef<char> data);
1487
1488 /// This enum is used to represent the binding strength of the enclosing
1489 /// context that an AffineExprStorage is being printed in, so we can
1490 /// intelligently produce parens.
1491 enum class BindingStrength {
1492 Weak, // + and -
1493 Strong, // All other binary operators.
1494 };
1495 void printAffineExprInternal(
1496 AffineExpr expr, BindingStrength enclosingTightness,
1497 function_ref<void(unsigned, bool)> printValueName = nullptr);
1498
1499 /// The output stream for the printer.
1500 raw_ostream &os;
1501
1502 /// A set of flags to control the printer's behavior.
1503 OpPrintingFlags printerFlags;
1504
1505 /// An optional printer state for the module.
1506 AsmStateImpl *state;
1507
1508 /// A tracker for the number of new lines emitted during printing.
1509 NewLineCounter newLine;
1510
1511 /// A set of dialect resources that were referenced during printing.
1512 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
1513 };
1514 } // namespace mlir
1515
printTrailingLocation(Location loc,bool allowAlias)1516 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1517 // Check to see if we are printing debug information.
1518 if (!printerFlags.shouldPrintDebugInfo())
1519 return;
1520
1521 os << " ";
1522 printLocation(loc, /*allowAlias=*/allowAlias);
1523 }
1524
printLocationInternal(LocationAttr loc,bool pretty)1525 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
1526 TypeSwitch<LocationAttr>(loc)
1527 .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1528 printLocationInternal(loc.getFallbackLocation(), pretty);
1529 })
1530 .Case<UnknownLoc>([&](UnknownLoc loc) {
1531 if (pretty)
1532 os << "[unknown]";
1533 else
1534 os << "unknown";
1535 })
1536 .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1537 if (pretty)
1538 os << loc.getFilename().getValue();
1539 else
1540 printEscapedString(loc.getFilename());
1541 os << ':' << loc.getLine() << ':' << loc.getColumn();
1542 })
1543 .Case<NameLoc>([&](NameLoc loc) {
1544 printEscapedString(loc.getName());
1545
1546 // Print the child if it isn't unknown.
1547 auto childLoc = loc.getChildLoc();
1548 if (!childLoc.isa<UnknownLoc>()) {
1549 os << '(';
1550 printLocationInternal(childLoc, pretty);
1551 os << ')';
1552 }
1553 })
1554 .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1555 Location caller = loc.getCaller();
1556 Location callee = loc.getCallee();
1557 if (!pretty)
1558 os << "callsite(";
1559 printLocationInternal(callee, pretty);
1560 if (pretty) {
1561 if (callee.isa<NameLoc>()) {
1562 if (caller.isa<FileLineColLoc>()) {
1563 os << " at ";
1564 } else {
1565 os << newLine << " at ";
1566 }
1567 } else {
1568 os << newLine << " at ";
1569 }
1570 } else {
1571 os << " at ";
1572 }
1573 printLocationInternal(caller, pretty);
1574 if (!pretty)
1575 os << ")";
1576 })
1577 .Case<FusedLoc>([&](FusedLoc loc) {
1578 if (!pretty)
1579 os << "fused";
1580 if (Attribute metadata = loc.getMetadata())
1581 os << '<' << metadata << '>';
1582 os << '[';
1583 interleave(
1584 loc.getLocations(),
1585 [&](Location loc) { printLocationInternal(loc, pretty); },
1586 [&]() { os << ", "; });
1587 os << ']';
1588 });
1589 }
1590
1591 /// Print a floating point value in a way that the parser will be able to
1592 /// round-trip losslessly.
printFloatValue(const APFloat & apValue,raw_ostream & os)1593 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1594 // We would like to output the FP constant value in exponential notation,
1595 // but we cannot do this if doing so will lose precision. Check here to
1596 // make sure that we only output it in exponential format if we can parse
1597 // the value back and get the same value.
1598 bool isInf = apValue.isInfinity();
1599 bool isNaN = apValue.isNaN();
1600 if (!isInf && !isNaN) {
1601 SmallString<128> strValue;
1602 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1603 /*TruncateZero=*/false);
1604
1605 // Check to make sure that the stringized number is not some string like
1606 // "Inf" or NaN, that atof will accept, but the lexer will not. Check
1607 // that the string matches the "[-+]?[0-9]" regex.
1608 assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1609 ((strValue[0] == '-' || strValue[0] == '+') &&
1610 (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1611 "[-+]?[0-9] regex does not match!");
1612
1613 // Parse back the stringized version and check that the value is equal
1614 // (i.e., there is no precision loss).
1615 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1616 os << strValue;
1617 return;
1618 }
1619
1620 // If it is not, use the default format of APFloat instead of the
1621 // exponential notation.
1622 strValue.clear();
1623 apValue.toString(strValue);
1624
1625 // Make sure that we can parse the default form as a float.
1626 if (strValue.str().contains('.')) {
1627 os << strValue;
1628 return;
1629 }
1630 }
1631
1632 // Print special values in hexadecimal format. The sign bit should be included
1633 // in the literal.
1634 SmallVector<char, 16> str;
1635 APInt apInt = apValue.bitcastToAPInt();
1636 apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1637 /*formatAsCLiteral=*/true);
1638 os << str;
1639 }
1640
printLocation(LocationAttr loc,bool allowAlias)1641 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
1642 if (printerFlags.shouldPrintDebugInfoPrettyForm())
1643 return printLocationInternal(loc, /*pretty=*/true);
1644
1645 os << "loc(";
1646 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1647 printLocationInternal(loc);
1648 os << ')';
1649 }
1650
1651 /// Returns true if the given dialect symbol data is simple enough to print in
1652 /// the pretty form. This is essentially when the symbol takes the form:
1653 /// identifier (`<` body `>`)?
isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)1654 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1655 // The name must start with an identifier.
1656 if (symName.empty() || !isalpha(symName.front()))
1657 return false;
1658
1659 // Ignore all the characters that are valid in an identifier in the symbol
1660 // name.
1661 symName = symName.drop_while(
1662 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1663 if (symName.empty())
1664 return true;
1665
1666 // If we got to an unexpected character, then it must be a <>. Check that the
1667 // rest of the symbol is wrapped within <>.
1668 return symName.front() == '<' && symName.back() == '>';
1669 }
1670
1671 /// Print the given dialect symbol to the stream.
printDialectSymbol(raw_ostream & os,StringRef symPrefix,StringRef dialectName,StringRef symString)1672 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1673 StringRef dialectName, StringRef symString) {
1674 os << symPrefix << dialectName;
1675
1676 // If this symbol name is simple enough, print it directly in pretty form,
1677 // otherwise, we print it as an escaped string.
1678 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1679 os << '.' << symString;
1680 return;
1681 }
1682
1683 os << '<' << symString << '>';
1684 }
1685
1686 /// Returns true if the given string can be represented as a bare identifier.
isBareIdentifier(StringRef name)1687 static bool isBareIdentifier(StringRef name) {
1688 // By making this unsigned, the value passed in to isalnum will always be
1689 // in the range 0-255. This is important when building with MSVC because
1690 // its implementation will assert. This situation can arise when dealing
1691 // with UTF-8 multibyte characters.
1692 if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
1693 return false;
1694 return llvm::all_of(name.drop_front(), [](unsigned char c) {
1695 return isalnum(c) || c == '_' || c == '$' || c == '.';
1696 });
1697 }
1698
1699 /// Print the given string as a keyword, or a quoted and escaped string if it
1700 /// has any special or non-printable characters in it.
printKeywordOrString(StringRef keyword,raw_ostream & os)1701 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
1702 // If it can be represented as a bare identifier, write it directly.
1703 if (isBareIdentifier(keyword)) {
1704 os << keyword;
1705 return;
1706 }
1707
1708 // Otherwise, output the keyword wrapped in quotes with proper escaping.
1709 os << "\"";
1710 printEscapedString(keyword, os);
1711 os << '"';
1712 }
1713
1714 /// Print the given string as a symbol reference. A symbol reference is
1715 /// represented as a string prefixed with '@'. The reference is surrounded with
1716 /// ""'s and escaped if it has any special or non-printable characters in it.
printSymbolReference(StringRef symbolRef,raw_ostream & os)1717 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1718 assert(!symbolRef.empty() && "expected valid symbol reference");
1719 os << '@';
1720 printKeywordOrString(symbolRef, os);
1721 }
1722
1723 // Print out a valid ElementsAttr that is succinct and can represent any
1724 // potential shape/type, for use when eliding a large ElementsAttr.
1725 //
1726 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1727 // hopefully alert readers to the fact that this has been elided.
1728 //
1729 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1730 // accept the string "elided". The first string must be a registered dialect
1731 // name and the latter must be a hex constant.
printElidedElementsAttr(raw_ostream & os)1732 static void printElidedElementsAttr(raw_ostream &os) {
1733 os << R"(opaque<"elided_large_const", "0xDEADBEEF">)";
1734 }
1735
printAlias(Attribute attr)1736 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
1737 return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
1738 }
1739
printAlias(Type type)1740 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
1741 return success(state && succeeded(state->getAliasState().getAlias(type, os)));
1742 }
1743
printAttribute(Attribute attr,AttrTypeElision typeElision)1744 void AsmPrinter::Impl::printAttribute(Attribute attr,
1745 AttrTypeElision typeElision) {
1746 if (!attr) {
1747 os << "<<NULL ATTRIBUTE>>";
1748 return;
1749 }
1750
1751 // Try to print an alias for this attribute.
1752 if (succeeded(printAlias(attr)))
1753 return;
1754
1755 auto attrType = attr.getType();
1756 if (!isa<BuiltinDialect>(attr.getDialect())) {
1757 printDialectAttribute(attr);
1758 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1759 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1760 opaqueAttr.getAttrData());
1761 } else if (attr.isa<UnitAttr>()) {
1762 os << "unit";
1763 return;
1764 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1765 os << '{';
1766 interleaveComma(dictAttr.getValue(),
1767 [&](NamedAttribute attr) { printNamedAttribute(attr); });
1768 os << '}';
1769
1770 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1771 if (attrType.isSignlessInteger(1)) {
1772 os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1773
1774 // Boolean integer attributes always elides the type.
1775 return;
1776 }
1777
1778 // Only print attributes as unsigned if they are explicitly unsigned or are
1779 // signless 1-bit values. Indexes, signed values, and multi-bit signless
1780 // values print as signed.
1781 bool isUnsigned =
1782 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1783 intAttr.getValue().print(os, !isUnsigned);
1784
1785 // IntegerAttr elides the type if I64.
1786 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1787 return;
1788
1789 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1790 printFloatValue(floatAttr.getValue(), os);
1791
1792 // FloatAttr elides the type if F64.
1793 if (typeElision == AttrTypeElision::May && attrType.isF64())
1794 return;
1795
1796 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1797 printEscapedString(strAttr.getValue());
1798
1799 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1800 os << '[';
1801 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1802 printAttribute(attr, AttrTypeElision::May);
1803 });
1804 os << ']';
1805
1806 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1807 os << "affine_map<";
1808 affineMapAttr.getValue().print(os);
1809 os << '>';
1810
1811 // AffineMap always elides the type.
1812 return;
1813
1814 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1815 os << "affine_set<";
1816 integerSetAttr.getValue().print(os);
1817 os << '>';
1818
1819 // IntegerSet always elides the type.
1820 return;
1821
1822 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1823 printType(typeAttr.getValue());
1824
1825 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1826 printSymbolReference(refAttr.getRootReference().getValue(), os);
1827 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1828 os << "::";
1829 printSymbolReference(nestedRef.getValue(), os);
1830 }
1831
1832 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1833 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1834 printElidedElementsAttr(os);
1835 } else {
1836 os << "opaque<" << opaqueAttr.getDialect() << ", ";
1837 printHexString(opaqueAttr.getValue());
1838 os << ">";
1839 }
1840
1841 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1842 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1843 printElidedElementsAttr(os);
1844 } else {
1845 os << "dense<";
1846 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1847 os << '>';
1848 }
1849
1850 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1851 if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1852 printElidedElementsAttr(os);
1853 } else {
1854 os << "dense<";
1855 printDenseStringElementsAttr(strEltAttr);
1856 os << '>';
1857 }
1858
1859 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1860 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1861 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1862 printElidedElementsAttr(os);
1863 } else {
1864 os << "sparse<";
1865 DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1866 if (indices.getNumElements() != 0) {
1867 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1868 os << ", ";
1869 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1870 }
1871 os << '>';
1872 }
1873 } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
1874 typeElision = AttrTypeElision::Must;
1875 switch (denseArrayAttr.getElementType()) {
1876 case DenseArrayBaseAttr::EltType::I8:
1877 os << "[:i8";
1878 break;
1879 case DenseArrayBaseAttr::EltType::I16:
1880 os << "[:i16";
1881 break;
1882 case DenseArrayBaseAttr::EltType::I32:
1883 os << "[:i32";
1884 break;
1885 case DenseArrayBaseAttr::EltType::I64:
1886 os << "[:i64";
1887 break;
1888 case DenseArrayBaseAttr::EltType::F32:
1889 os << "[:f32";
1890 break;
1891 case DenseArrayBaseAttr::EltType::F64:
1892 os << "[:f64";
1893 break;
1894 }
1895 if (denseArrayAttr.getType().cast<ShapedType>().getRank())
1896 os << " ";
1897 denseArrayAttr.printWithoutBraces(os);
1898 os << "]";
1899 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1900 printLocation(locAttr);
1901 } else {
1902 llvm::report_fatal_error("Unknown builtin attribute");
1903 }
1904 // Don't print the type if we must elide it, or if it is a None type.
1905 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1906 os << " : ";
1907 printType(attrType);
1908 }
1909 }
1910
1911 /// Print the integer element of a DenseElementsAttr.
printDenseIntElement(const APInt & value,raw_ostream & os,bool isSigned)1912 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1913 bool isSigned) {
1914 if (value.getBitWidth() == 1)
1915 os << (value.getBoolValue() ? "true" : "false");
1916 else
1917 value.print(os, isSigned);
1918 }
1919
1920 static void
printDenseElementsAttrImpl(bool isSplat,ShapedType type,raw_ostream & os,function_ref<void (unsigned)> printEltFn)1921 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1922 function_ref<void(unsigned)> printEltFn) {
1923 // Special case for 0-d and splat tensors.
1924 if (isSplat)
1925 return printEltFn(0);
1926
1927 // Special case for degenerate tensors.
1928 auto numElements = type.getNumElements();
1929 if (numElements == 0)
1930 return;
1931
1932 // We use a mixed-radix counter to iterate through the shape. When we bump a
1933 // non-least-significant digit, we emit a close bracket. When we next emit an
1934 // element we re-open all closed brackets.
1935
1936 // The mixed-radix counter, with radices in 'shape'.
1937 int64_t rank = type.getRank();
1938 SmallVector<unsigned, 4> counter(rank, 0);
1939 // The number of brackets that have been opened and not closed.
1940 unsigned openBrackets = 0;
1941
1942 auto shape = type.getShape();
1943 auto bumpCounter = [&] {
1944 // Bump the least significant digit.
1945 ++counter[rank - 1];
1946 // Iterate backwards bubbling back the increment.
1947 for (unsigned i = rank - 1; i > 0; --i)
1948 if (counter[i] >= shape[i]) {
1949 // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1950 counter[i] = 0;
1951 ++counter[i - 1];
1952 --openBrackets;
1953 os << ']';
1954 }
1955 };
1956
1957 for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1958 if (idx != 0)
1959 os << ", ";
1960 while (openBrackets++ < rank)
1961 os << '[';
1962 openBrackets = rank;
1963 printEltFn(idx);
1964 bumpCounter();
1965 }
1966 while (openBrackets-- > 0)
1967 os << ']';
1968 }
1969
printDenseElementsAttr(DenseElementsAttr attr,bool allowHex)1970 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
1971 bool allowHex) {
1972 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1973 return printDenseStringElementsAttr(stringAttr);
1974
1975 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1976 allowHex);
1977 }
1978
printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,bool allowHex)1979 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
1980 DenseIntOrFPElementsAttr attr, bool allowHex) {
1981 auto type = attr.getType();
1982 auto elementType = type.getElementType();
1983
1984 // Check to see if we should format this attribute as a hex string.
1985 auto numElements = type.getNumElements();
1986 if (!attr.isSplat() && allowHex &&
1987 shouldPrintElementsAttrWithHex(numElements)) {
1988 ArrayRef<char> rawData = attr.getRawData();
1989 if (llvm::support::endian::system_endianness() ==
1990 llvm::support::endianness::big) {
1991 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1992 // machines. It is converted here to print in LE format.
1993 SmallVector<char, 64> outDataVec(rawData.size());
1994 MutableArrayRef<char> convRawData(outDataVec);
1995 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1996 rawData, convRawData, type);
1997 printHexString(convRawData);
1998 } else {
1999 printHexString(rawData);
2000 }
2001
2002 return;
2003 }
2004
2005 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
2006 Type complexElementType = complexTy.getElementType();
2007 // Note: The if and else below had a common lambda function which invoked
2008 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
2009 // and hence was replaced.
2010 if (complexElementType.isa<IntegerType>()) {
2011 bool isSigned = !complexElementType.isUnsignedInteger();
2012 auto valueIt = attr.value_begin<std::complex<APInt>>();
2013 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2014 auto complexValue = *(valueIt + index);
2015 os << "(";
2016 printDenseIntElement(complexValue.real(), os, isSigned);
2017 os << ",";
2018 printDenseIntElement(complexValue.imag(), os, isSigned);
2019 os << ")";
2020 });
2021 } else {
2022 auto valueIt = attr.value_begin<std::complex<APFloat>>();
2023 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2024 auto complexValue = *(valueIt + index);
2025 os << "(";
2026 printFloatValue(complexValue.real(), os);
2027 os << ",";
2028 printFloatValue(complexValue.imag(), os);
2029 os << ")";
2030 });
2031 }
2032 } else if (elementType.isIntOrIndex()) {
2033 bool isSigned = !elementType.isUnsignedInteger();
2034 auto valueIt = attr.value_begin<APInt>();
2035 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2036 printDenseIntElement(*(valueIt + index), os, isSigned);
2037 });
2038 } else {
2039 assert(elementType.isa<FloatType>() && "unexpected element type");
2040 auto valueIt = attr.value_begin<APFloat>();
2041 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2042 printFloatValue(*(valueIt + index), os);
2043 });
2044 }
2045 }
2046
printDenseStringElementsAttr(DenseStringElementsAttr attr)2047 void AsmPrinter::Impl::printDenseStringElementsAttr(
2048 DenseStringElementsAttr attr) {
2049 ArrayRef<StringRef> data = attr.getRawStringData();
2050 auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
2051 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
2052 }
2053
printType(Type type)2054 void AsmPrinter::Impl::printType(Type type) {
2055 if (!type) {
2056 os << "<<NULL TYPE>>";
2057 return;
2058 }
2059
2060 // Try to print an alias for this type.
2061 if (state && succeeded(state->getAliasState().getAlias(type, os)))
2062 return;
2063
2064 TypeSwitch<Type>(type)
2065 .Case<OpaqueType>([&](OpaqueType opaqueTy) {
2066 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
2067 opaqueTy.getTypeData());
2068 })
2069 .Case<IndexType>([&](Type) { os << "index"; })
2070 .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2071 .Case<Float16Type>([&](Type) { os << "f16"; })
2072 .Case<Float32Type>([&](Type) { os << "f32"; })
2073 .Case<Float64Type>([&](Type) { os << "f64"; })
2074 .Case<Float80Type>([&](Type) { os << "f80"; })
2075 .Case<Float128Type>([&](Type) { os << "f128"; })
2076 .Case<IntegerType>([&](IntegerType integerTy) {
2077 if (integerTy.isSigned())
2078 os << 's';
2079 else if (integerTy.isUnsigned())
2080 os << 'u';
2081 os << 'i' << integerTy.getWidth();
2082 })
2083 .Case<FunctionType>([&](FunctionType funcTy) {
2084 os << '(';
2085 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2086 os << ") -> ";
2087 ArrayRef<Type> results = funcTy.getResults();
2088 if (results.size() == 1 && !results[0].isa<FunctionType>()) {
2089 printType(results[0]);
2090 } else {
2091 os << '(';
2092 interleaveComma(results, [&](Type ty) { printType(ty); });
2093 os << ')';
2094 }
2095 })
2096 .Case<VectorType>([&](VectorType vectorTy) {
2097 os << "vector<";
2098 auto vShape = vectorTy.getShape();
2099 unsigned lastDim = vShape.size();
2100 unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
2101 unsigned dimIdx = 0;
2102 for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
2103 os << vShape[dimIdx] << 'x';
2104 if (vectorTy.isScalable()) {
2105 os << '[';
2106 unsigned secondToLastDim = lastDim - 1;
2107 for (; dimIdx < secondToLastDim; dimIdx++)
2108 os << vShape[dimIdx] << 'x';
2109 os << vShape[dimIdx] << "]x";
2110 }
2111 printType(vectorTy.getElementType());
2112 os << '>';
2113 })
2114 .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2115 os << "tensor<";
2116 for (int64_t dim : tensorTy.getShape()) {
2117 if (ShapedType::isDynamic(dim))
2118 os << '?';
2119 else
2120 os << dim;
2121 os << 'x';
2122 }
2123 printType(tensorTy.getElementType());
2124 // Only print the encoding attribute value if set.
2125 if (tensorTy.getEncoding()) {
2126 os << ", ";
2127 printAttribute(tensorTy.getEncoding());
2128 }
2129 os << '>';
2130 })
2131 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2132 os << "tensor<*x";
2133 printType(tensorTy.getElementType());
2134 os << '>';
2135 })
2136 .Case<MemRefType>([&](MemRefType memrefTy) {
2137 os << "memref<";
2138 for (int64_t dim : memrefTy.getShape()) {
2139 if (ShapedType::isDynamic(dim))
2140 os << '?';
2141 else
2142 os << dim;
2143 os << 'x';
2144 }
2145 printType(memrefTy.getElementType());
2146 if (!memrefTy.getLayout().isIdentity()) {
2147 os << ", ";
2148 printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2149 }
2150 // Only print the memory space if it is the non-default one.
2151 if (memrefTy.getMemorySpace()) {
2152 os << ", ";
2153 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2154 }
2155 os << '>';
2156 })
2157 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2158 os << "memref<*x";
2159 printType(memrefTy.getElementType());
2160 // Only print the memory space if it is the non-default one.
2161 if (memrefTy.getMemorySpace()) {
2162 os << ", ";
2163 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2164 }
2165 os << '>';
2166 })
2167 .Case<ComplexType>([&](ComplexType complexTy) {
2168 os << "complex<";
2169 printType(complexTy.getElementType());
2170 os << '>';
2171 })
2172 .Case<TupleType>([&](TupleType tupleTy) {
2173 os << "tuple<";
2174 interleaveComma(tupleTy.getTypes(),
2175 [&](Type type) { printType(type); });
2176 os << '>';
2177 })
2178 .Case<NoneType>([&](Type) { os << "none"; })
2179 .Default([&](Type type) { return printDialectType(type); });
2180 }
2181
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs,bool withKeyword)2182 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2183 ArrayRef<StringRef> elidedAttrs,
2184 bool withKeyword) {
2185 // If there are no attributes, then there is nothing to be done.
2186 if (attrs.empty())
2187 return;
2188
2189 // Functor used to print a filtered attribute list.
2190 auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2191 // Print the 'attributes' keyword if necessary.
2192 if (withKeyword)
2193 os << " attributes";
2194
2195 // Otherwise, print them all out in braces.
2196 os << " {";
2197 interleaveComma(filteredAttrs,
2198 [&](NamedAttribute attr) { printNamedAttribute(attr); });
2199 os << '}';
2200 };
2201
2202 // If no attributes are elided, we can directly print with no filtering.
2203 if (elidedAttrs.empty())
2204 return printFilteredAttributesFn(attrs);
2205
2206 // Otherwise, filter out any attributes that shouldn't be included.
2207 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2208 elidedAttrs.end());
2209 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2210 return !elidedAttrsSet.contains(attr.getName().strref());
2211 });
2212 if (!filteredAttrs.empty())
2213 printFilteredAttributesFn(filteredAttrs);
2214 }
2215
printNamedAttribute(NamedAttribute attr)2216 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2217 // Print the name without quotes if possible.
2218 ::printKeywordOrString(attr.getName().strref(), os);
2219
2220 // Pretty printing elides the attribute value for unit attributes.
2221 if (attr.getValue().isa<UnitAttr>())
2222 return;
2223
2224 os << " = ";
2225 printAttribute(attr.getValue());
2226 }
2227
printDialectAttribute(Attribute attr)2228 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2229 auto &dialect = attr.getDialect();
2230
2231 // Ask the dialect to serialize the attribute to a string.
2232 std::string attrName;
2233 {
2234 llvm::raw_string_ostream attrNameStr(attrName);
2235 Impl subPrinter(attrNameStr, printerFlags, state);
2236 DialectAsmPrinter printer(subPrinter);
2237 dialect.printAttribute(attr, printer);
2238
2239 // FIXME: Delete this when we no longer require a nested printer.
2240 for (auto &it : subPrinter.dialectResources)
2241 for (const auto &resource : it.second)
2242 dialectResources[it.first].insert(resource);
2243 }
2244 printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2245 }
2246
printDialectType(Type type)2247 void AsmPrinter::Impl::printDialectType(Type type) {
2248 auto &dialect = type.getDialect();
2249
2250 // Ask the dialect to serialize the type to a string.
2251 std::string typeName;
2252 {
2253 llvm::raw_string_ostream typeNameStr(typeName);
2254 Impl subPrinter(typeNameStr, printerFlags, state);
2255 DialectAsmPrinter printer(subPrinter);
2256 dialect.printType(type, printer);
2257
2258 // FIXME: Delete this when we no longer require a nested printer.
2259 for (auto &it : subPrinter.dialectResources)
2260 for (const auto &resource : it.second)
2261 dialectResources[it.first].insert(resource);
2262 }
2263 printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2264 }
2265
printEscapedString(StringRef str)2266 void AsmPrinter::Impl::printEscapedString(StringRef str) {
2267 os << "\"";
2268 llvm::printEscapedString(str, os);
2269 os << "\"";
2270 }
2271
printHexString(StringRef str)2272 void AsmPrinter::Impl::printHexString(StringRef str) {
2273 os << "\"0x" << llvm::toHex(str) << "\"";
2274 }
printHexString(ArrayRef<char> data)2275 void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
2276 printHexString(StringRef(data.data(), data.size()));
2277 }
2278
2279 //===--------------------------------------------------------------------===//
2280 // AsmPrinter
2281 //===--------------------------------------------------------------------===//
2282
2283 AsmPrinter::~AsmPrinter() = default;
2284
getStream() const2285 raw_ostream &AsmPrinter::getStream() const {
2286 assert(impl && "expected AsmPrinter::getStream to be overriden");
2287 return impl->getStream();
2288 }
2289
2290 /// Print the given floating point value in a stablized form.
printFloat(const APFloat & value)2291 void AsmPrinter::printFloat(const APFloat &value) {
2292 assert(impl && "expected AsmPrinter::printFloat to be overriden");
2293 printFloatValue(value, impl->getStream());
2294 }
2295
printType(Type type)2296 void AsmPrinter::printType(Type type) {
2297 assert(impl && "expected AsmPrinter::printType to be overriden");
2298 impl->printType(type);
2299 }
2300
printAttribute(Attribute attr)2301 void AsmPrinter::printAttribute(Attribute attr) {
2302 assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2303 impl->printAttribute(attr);
2304 }
2305
printAlias(Attribute attr)2306 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2307 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2308 return impl->printAlias(attr);
2309 }
2310
printAlias(Type type)2311 LogicalResult AsmPrinter::printAlias(Type type) {
2312 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2313 return impl->printAlias(type);
2314 }
2315
printAttributeWithoutType(Attribute attr)2316 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2317 assert(impl &&
2318 "expected AsmPrinter::printAttributeWithoutType to be overriden");
2319 impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2320 }
2321
printKeywordOrString(StringRef keyword)2322 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2323 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2324 ::printKeywordOrString(keyword, impl->getStream());
2325 }
2326
printSymbolName(StringRef symbolRef)2327 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2328 assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2329 ::printSymbolReference(symbolRef, impl->getStream());
2330 }
2331
printResourceHandle(const AsmDialectResourceHandle & resource)2332 void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
2333 assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
2334 impl->printResourceHandle(resource);
2335 }
2336
2337 //===----------------------------------------------------------------------===//
2338 // Affine expressions and maps
2339 //===----------------------------------------------------------------------===//
2340
printAffineExpr(AffineExpr expr,function_ref<void (unsigned,bool)> printValueName)2341 void AsmPrinter::Impl::printAffineExpr(
2342 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2343 printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2344 }
2345
printAffineExprInternal(AffineExpr expr,BindingStrength enclosingTightness,function_ref<void (unsigned,bool)> printValueName)2346 void AsmPrinter::Impl::printAffineExprInternal(
2347 AffineExpr expr, BindingStrength enclosingTightness,
2348 function_ref<void(unsigned, bool)> printValueName) {
2349 const char *binopSpelling = nullptr;
2350 switch (expr.getKind()) {
2351 case AffineExprKind::SymbolId: {
2352 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2353 if (printValueName)
2354 printValueName(pos, /*isSymbol=*/true);
2355 else
2356 os << 's' << pos;
2357 return;
2358 }
2359 case AffineExprKind::DimId: {
2360 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2361 if (printValueName)
2362 printValueName(pos, /*isSymbol=*/false);
2363 else
2364 os << 'd' << pos;
2365 return;
2366 }
2367 case AffineExprKind::Constant:
2368 os << expr.cast<AffineConstantExpr>().getValue();
2369 return;
2370 case AffineExprKind::Add:
2371 binopSpelling = " + ";
2372 break;
2373 case AffineExprKind::Mul:
2374 binopSpelling = " * ";
2375 break;
2376 case AffineExprKind::FloorDiv:
2377 binopSpelling = " floordiv ";
2378 break;
2379 case AffineExprKind::CeilDiv:
2380 binopSpelling = " ceildiv ";
2381 break;
2382 case AffineExprKind::Mod:
2383 binopSpelling = " mod ";
2384 break;
2385 }
2386
2387 auto binOp = expr.cast<AffineBinaryOpExpr>();
2388 AffineExpr lhsExpr = binOp.getLHS();
2389 AffineExpr rhsExpr = binOp.getRHS();
2390
2391 // Handle tightly binding binary operators.
2392 if (binOp.getKind() != AffineExprKind::Add) {
2393 if (enclosingTightness == BindingStrength::Strong)
2394 os << '(';
2395
2396 // Pretty print multiplication with -1.
2397 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2398 if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2399 rhsConst.getValue() == -1) {
2400 os << "-";
2401 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2402 if (enclosingTightness == BindingStrength::Strong)
2403 os << ')';
2404 return;
2405 }
2406
2407 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2408
2409 os << binopSpelling;
2410 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2411
2412 if (enclosingTightness == BindingStrength::Strong)
2413 os << ')';
2414 return;
2415 }
2416
2417 // Print out special "pretty" forms for add.
2418 if (enclosingTightness == BindingStrength::Strong)
2419 os << '(';
2420
2421 // Pretty print addition to a product that has a negative operand as a
2422 // subtraction.
2423 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2424 if (rhs.getKind() == AffineExprKind::Mul) {
2425 AffineExpr rrhsExpr = rhs.getRHS();
2426 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2427 if (rrhs.getValue() == -1) {
2428 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2429 printValueName);
2430 os << " - ";
2431 if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2432 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2433 printValueName);
2434 } else {
2435 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2436 printValueName);
2437 }
2438
2439 if (enclosingTightness == BindingStrength::Strong)
2440 os << ')';
2441 return;
2442 }
2443
2444 if (rrhs.getValue() < -1) {
2445 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2446 printValueName);
2447 os << " - ";
2448 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2449 printValueName);
2450 os << " * " << -rrhs.getValue();
2451 if (enclosingTightness == BindingStrength::Strong)
2452 os << ')';
2453 return;
2454 }
2455 }
2456 }
2457 }
2458
2459 // Pretty print addition to a negative number as a subtraction.
2460 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2461 if (rhsConst.getValue() < 0) {
2462 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2463 os << " - " << -rhsConst.getValue();
2464 if (enclosingTightness == BindingStrength::Strong)
2465 os << ')';
2466 return;
2467 }
2468 }
2469
2470 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2471
2472 os << " + ";
2473 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2474
2475 if (enclosingTightness == BindingStrength::Strong)
2476 os << ')';
2477 }
2478
printAffineConstraint(AffineExpr expr,bool isEq)2479 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2480 printAffineExprInternal(expr, BindingStrength::Weak);
2481 isEq ? os << " == 0" : os << " >= 0";
2482 }
2483
printAffineMap(AffineMap map)2484 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2485 // Dimension identifiers.
2486 os << '(';
2487 for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2488 os << 'd' << i << ", ";
2489 if (map.getNumDims() >= 1)
2490 os << 'd' << map.getNumDims() - 1;
2491 os << ')';
2492
2493 // Symbolic identifiers.
2494 if (map.getNumSymbols() != 0) {
2495 os << '[';
2496 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2497 os << 's' << i << ", ";
2498 if (map.getNumSymbols() >= 1)
2499 os << 's' << map.getNumSymbols() - 1;
2500 os << ']';
2501 }
2502
2503 // Result affine expressions.
2504 os << " -> (";
2505 interleaveComma(map.getResults(),
2506 [&](AffineExpr expr) { printAffineExpr(expr); });
2507 os << ')';
2508 }
2509
printIntegerSet(IntegerSet set)2510 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
2511 // Dimension identifiers.
2512 os << '(';
2513 for (unsigned i = 1; i < set.getNumDims(); ++i)
2514 os << 'd' << i - 1 << ", ";
2515 if (set.getNumDims() >= 1)
2516 os << 'd' << set.getNumDims() - 1;
2517 os << ')';
2518
2519 // Symbolic identifiers.
2520 if (set.getNumSymbols() != 0) {
2521 os << '[';
2522 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2523 os << 's' << i << ", ";
2524 if (set.getNumSymbols() >= 1)
2525 os << 's' << set.getNumSymbols() - 1;
2526 os << ']';
2527 }
2528
2529 // Print constraints.
2530 os << " : (";
2531 int numConstraints = set.getNumConstraints();
2532 for (int i = 1; i < numConstraints; ++i) {
2533 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2534 os << ", ";
2535 }
2536 if (numConstraints >= 1)
2537 printAffineConstraint(set.getConstraint(numConstraints - 1),
2538 set.isEq(numConstraints - 1));
2539 os << ')';
2540 }
2541
2542 //===----------------------------------------------------------------------===//
2543 // OperationPrinter
2544 //===----------------------------------------------------------------------===//
2545
2546 namespace {
2547 /// This class contains the logic for printing operations, regions, and blocks.
2548 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
2549 public:
2550 using Impl = AsmPrinter::Impl;
2551 using Impl::printType;
2552
OperationPrinter(raw_ostream & os,AsmStateImpl & state)2553 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
2554 : Impl(os, state.getPrinterFlags(), &state),
2555 OpAsmPrinter(static_cast<Impl &>(*this)) {}
2556
2557 /// Print the given top-level operation.
2558 void printTopLevelOperation(Operation *op);
2559
2560 /// Print the given operation with its indent and location.
2561 void print(Operation *op);
2562 /// Print the bare location, not including indentation/location/etc.
2563 void printOperation(Operation *op);
2564 /// Print the given operation in the generic form.
2565 void printGenericOp(Operation *op, bool printOpName) override;
2566
2567 /// Print the name of the given block.
2568 void printBlockName(Block *block);
2569
2570 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2571 /// block are not printed. If 'printBlockTerminator' is false, the terminator
2572 /// operation of the block is not printed.
2573 void print(Block *block, bool printBlockArgs = true,
2574 bool printBlockTerminator = true);
2575
2576 /// Print the ID of the given value, optionally with its result number.
2577 void printValueID(Value value, bool printResultNo = true,
2578 raw_ostream *streamOverride = nullptr) const;
2579
2580 /// Print the ID of the given operation.
2581 void printOperationID(Operation *op,
2582 raw_ostream *streamOverride = nullptr) const;
2583
2584 //===--------------------------------------------------------------------===//
2585 // OpAsmPrinter methods
2586 //===--------------------------------------------------------------------===//
2587
2588 /// Print a newline and indent the printer to the start of the current
2589 /// operation.
printNewline()2590 void printNewline() override {
2591 os << newLine;
2592 os.indent(currentIndent);
2593 }
2594
2595 /// Print a block argument in the usual format of:
2596 /// %ssaName : type {attr1=42} loc("here")
2597 /// where location printing is controlled by the standard internal option.
2598 /// You may pass omitType=true to not print a type, and pass an empty
2599 /// attribute list if you don't care for attributes.
2600 void printRegionArgument(BlockArgument arg,
2601 ArrayRef<NamedAttribute> argAttrs = {},
2602 bool omitType = false) override;
2603
2604 /// Print the ID for the given value.
printOperand(Value value)2605 void printOperand(Value value) override { printValueID(value); }
printOperand(Value value,raw_ostream & os)2606 void printOperand(Value value, raw_ostream &os) override {
2607 printValueID(value, /*printResultNo=*/true, &os);
2608 }
2609
2610 /// Print an optional attribute dictionary with a given set of elided values.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2611 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2612 ArrayRef<StringRef> elidedAttrs = {}) override {
2613 Impl::printOptionalAttrDict(attrs, elidedAttrs);
2614 }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2615 void printOptionalAttrDictWithKeyword(
2616 ArrayRef<NamedAttribute> attrs,
2617 ArrayRef<StringRef> elidedAttrs = {}) override {
2618 Impl::printOptionalAttrDict(attrs, elidedAttrs,
2619 /*withKeyword=*/true);
2620 }
2621
2622 /// Print the given successor.
2623 void printSuccessor(Block *successor) override;
2624
2625 /// Print an operation successor with the operands used for the block
2626 /// arguments.
2627 void printSuccessorAndUseList(Block *successor,
2628 ValueRange succOperands) override;
2629
2630 /// Print the given region.
2631 void printRegion(Region ®ion, bool printEntryBlockArgs,
2632 bool printBlockTerminators, bool printEmptyBlock) override;
2633
2634 /// Renumber the arguments for the specified region to the same names as the
2635 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2636 /// operations. If any entry in namesToUse is null, the corresponding
2637 /// argument name is left alone.
shadowRegionArgs(Region & region,ValueRange namesToUse)2638 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override {
2639 state->getSSANameState().shadowRegionArgs(region, namesToUse);
2640 }
2641
2642 /// Print the given affine map with the symbol and dimension operands printed
2643 /// inline with the map.
2644 void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2645 ValueRange operands) override;
2646
2647 /// Print the given affine expression with the symbol and dimension operands
2648 /// printed inline with the expression.
2649 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2650 ValueRange symOperands) override;
2651
2652 /// Print users of this operation or id of this operation if it has no result.
2653 void printUsersComment(Operation *op);
2654
2655 /// Print users of this block arg.
2656 void printUsersComment(BlockArgument arg);
2657
2658 /// Print the users of a value.
2659 void printValueUsers(Value value);
2660
2661 /// Print either the ids of the result values or the id of the operation if
2662 /// the operation has no results.
2663 void printUserIDs(Operation *user, bool prefixComma = false);
2664
2665 private:
2666 /// This class represents a resource builder implementation for the MLIR
2667 /// textual assembly format.
2668 class ResourceBuilder : public AsmResourceBuilder {
2669 public:
2670 using ValueFn = function_ref<void(raw_ostream &)>;
2671 using PrintFn = function_ref<void(StringRef, ValueFn)>;
2672
ResourceBuilder(OperationPrinter & p,PrintFn printFn)2673 ResourceBuilder(OperationPrinter &p, PrintFn printFn)
2674 : p(p), printFn(printFn) {}
2675 ~ResourceBuilder() override = default;
2676
buildBool(StringRef key,bool data)2677 void buildBool(StringRef key, bool data) final {
2678 printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); });
2679 }
2680
buildString(StringRef key,StringRef data)2681 void buildString(StringRef key, StringRef data) final {
2682 printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); });
2683 }
2684
buildBlob(StringRef key,ArrayRef<char> data,uint32_t dataAlignment)2685 void buildBlob(StringRef key, ArrayRef<char> data,
2686 uint32_t dataAlignment) final {
2687 printFn(key, [&](raw_ostream &os) {
2688 // Store the blob in a hex string containing the alignment and the data.
2689 llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
2690 os << "\"0x"
2691 << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
2692 sizeof(dataAlignment)))
2693 << llvm::toHex(StringRef(data.data(), data.size())) << "\"";
2694 });
2695 }
2696
2697 private:
2698 OperationPrinter &p;
2699 PrintFn printFn;
2700 };
2701
2702 /// Print the metadata dictionary for the file, eliding it if it is empty.
2703 void printFileMetadataDictionary(Operation *op);
2704
2705 /// Print the resource sections for the file metadata dictionary.
2706 /// `checkAddMetadataDict` is used to indicate that metadata is going to be
2707 /// added, and the file metadata dictionary should be started if it hasn't
2708 /// yet.
2709 void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
2710 Operation *op);
2711
2712 // Contains the stack of default dialects to use when printing regions.
2713 // A new dialect is pushed to the stack before parsing regions nested under an
2714 // operation implementing `OpAsmOpInterface`, and popped when done. At the
2715 // top-level we start with "builtin" as the default, so that the top-level
2716 // `module` operation prints as-is.
2717 SmallVector<StringRef> defaultDialectStack{"builtin"};
2718
2719 /// The number of spaces used for indenting nested operations.
2720 const static unsigned indentWidth = 2;
2721
2722 // This is the current indentation level for nested structures.
2723 unsigned currentIndent = 0;
2724 };
2725 } // namespace
2726
printTopLevelOperation(Operation * op)2727 void OperationPrinter::printTopLevelOperation(Operation *op) {
2728 // Output the aliases at the top level that can't be deferred.
2729 state->getAliasState().printNonDeferredAliases(os, newLine);
2730
2731 // Print the module.
2732 print(op);
2733 os << newLine;
2734
2735 // Output the aliases at the top level that can be deferred.
2736 state->getAliasState().printDeferredAliases(os, newLine);
2737
2738 // Output any file level metadata.
2739 printFileMetadataDictionary(op);
2740 }
2741
printFileMetadataDictionary(Operation * op)2742 void OperationPrinter::printFileMetadataDictionary(Operation *op) {
2743 bool sawMetadataEntry = false;
2744 auto checkAddMetadataDict = [&] {
2745 if (!std::exchange(sawMetadataEntry, true))
2746 os << newLine << "{-#" << newLine;
2747 };
2748
2749 // Add the various types of metadata.
2750 printResourceFileMetadata(checkAddMetadataDict, op);
2751
2752 // If the file dictionary exists, close it.
2753 if (sawMetadataEntry)
2754 os << newLine << "#-}" << newLine;
2755 }
2756
printResourceFileMetadata(function_ref<void ()> checkAddMetadataDict,Operation * op)2757 void OperationPrinter::printResourceFileMetadata(
2758 function_ref<void()> checkAddMetadataDict, Operation *op) {
2759 // Functor used to add data entries to the file metadata dictionary.
2760 bool hadResource = false;
2761 auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
2762 auto &&...providerArgs) {
2763 bool hadEntry = false;
2764 auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
2765 checkAddMetadataDict();
2766
2767 // Emit the top-level resource entry if we haven't yet.
2768 if (!std::exchange(hadResource, true))
2769 os << " " << dictName << "_resources: {" << newLine;
2770 // Emit the parent resource entry if we haven't yet.
2771 if (!std::exchange(hadEntry, true))
2772 os << " " << name << ": {" << newLine;
2773 else
2774 os << "," << newLine;
2775
2776 os << " " << key << ": ";
2777 valueFn(os);
2778 };
2779 ResourceBuilder entryBuilder(*this, printFn);
2780 provider.buildResources(op, providerArgs..., entryBuilder);
2781
2782 if (hadEntry)
2783 os << newLine << " }";
2784 };
2785
2786 // Print the `dialect_resources` section if we have any dialects with
2787 // resources.
2788 for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) {
2789 StringRef name = interface.getDialect()->getNamespace();
2790 auto it = dialectResources.find(interface.getDialect());
2791 if (it != dialectResources.end())
2792 processProvider("dialect", name, interface, it->second);
2793 else
2794 processProvider("dialect", name, interface,
2795 SetVector<AsmDialectResourceHandle>());
2796 }
2797 if (hadResource)
2798 os << newLine << " }";
2799
2800 // Print the `external_resources` section if we have any external clients with
2801 // resources.
2802 hadResource = false;
2803 for (const auto &printer : state->getResourcePrinters())
2804 processProvider("external", printer.getName(), printer);
2805 if (hadResource)
2806 os << newLine << " }";
2807 }
2808
2809 /// Print a block argument in the usual format of:
2810 /// %ssaName : type {attr1=42} loc("here")
2811 /// where location printing is controlled by the standard internal option.
2812 /// You may pass omitType=true to not print a type, and pass an empty
2813 /// attribute list if you don't care for attributes.
printRegionArgument(BlockArgument arg,ArrayRef<NamedAttribute> argAttrs,bool omitType)2814 void OperationPrinter::printRegionArgument(BlockArgument arg,
2815 ArrayRef<NamedAttribute> argAttrs,
2816 bool omitType) {
2817 printOperand(arg);
2818 if (!omitType) {
2819 os << ": ";
2820 printType(arg.getType());
2821 }
2822 printOptionalAttrDict(argAttrs);
2823 // TODO: We should allow location aliases on block arguments.
2824 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2825 }
2826
print(Operation * op)2827 void OperationPrinter::print(Operation *op) {
2828 // Track the location of this operation.
2829 state->registerOperationLocation(op, newLine.curLine, currentIndent);
2830
2831 os.indent(currentIndent);
2832 printOperation(op);
2833 printTrailingLocation(op->getLoc());
2834 if (printerFlags.shouldPrintValueUsers())
2835 printUsersComment(op);
2836 }
2837
printOperation(Operation * op)2838 void OperationPrinter::printOperation(Operation *op) {
2839 if (size_t numResults = op->getNumResults()) {
2840 auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2841 printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2842 if (resultCount > 1)
2843 os << ':' << resultCount;
2844 };
2845
2846 // Check to see if this operation has multiple result groups.
2847 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2848 if (!resultGroups.empty()) {
2849 // Interleave the groups excluding the last one, this one will be handled
2850 // separately.
2851 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2852 printResultGroup(resultGroups[i],
2853 resultGroups[i + 1] - resultGroups[i]);
2854 });
2855 os << ", ";
2856 printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2857
2858 } else {
2859 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2860 }
2861
2862 os << " = ";
2863 }
2864
2865 // If requested, always print the generic form.
2866 if (!printerFlags.shouldPrintGenericOpForm()) {
2867 // Check to see if this is a known operation. If so, use the registered
2868 // custom printer hook.
2869 if (auto opInfo = op->getRegisteredInfo()) {
2870 opInfo->printAssembly(op, *this, defaultDialectStack.back());
2871 return;
2872 }
2873 // Otherwise try to dispatch to the dialect, if available.
2874 if (Dialect *dialect = op->getDialect()) {
2875 if (auto opPrinter = dialect->getOperationPrinter(op)) {
2876 // Print the op name first.
2877 StringRef name = op->getName().getStringRef();
2878 // Only drop the default dialect prefix when it cannot lead to
2879 // ambiguities.
2880 if (name.count('.') == 1)
2881 name.consume_front((defaultDialectStack.back() + ".").str());
2882 os << name;
2883
2884 // Print the rest of the op now.
2885 opPrinter(op, *this);
2886 return;
2887 }
2888 }
2889 }
2890
2891 // Otherwise print with the generic assembly form.
2892 printGenericOp(op, /*printOpName=*/true);
2893 }
2894
printUsersComment(Operation * op)2895 void OperationPrinter::printUsersComment(Operation *op) {
2896 unsigned numResults = op->getNumResults();
2897 if (!numResults && op->getNumOperands()) {
2898 os << " // id: ";
2899 printOperationID(op);
2900 } else if (numResults && op->use_empty()) {
2901 os << " // unused";
2902 } else if (numResults && !op->use_empty()) {
2903 // Print "user" if the operation has one result used to compute one other
2904 // result, or is used in one operation with no result.
2905 unsigned usedInNResults = 0;
2906 unsigned usedInNOperations = 0;
2907 SmallPtrSet<Operation *, 1> userSet;
2908 for (Operation *user : op->getUsers()) {
2909 if (userSet.insert(user).second) {
2910 ++usedInNOperations;
2911 usedInNResults += user->getNumResults();
2912 }
2913 }
2914
2915 // We already know that users is not empty.
2916 bool exactlyOneUniqueUse =
2917 usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
2918 os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
2919 bool shouldPrintBrackets = numResults > 1;
2920 auto printOpResult = [&](OpResult opResult) {
2921 if (shouldPrintBrackets)
2922 os << "(";
2923 printValueUsers(opResult);
2924 if (shouldPrintBrackets)
2925 os << ")";
2926 };
2927
2928 interleaveComma(op->getResults(), printOpResult);
2929 }
2930 }
2931
printUsersComment(BlockArgument arg)2932 void OperationPrinter::printUsersComment(BlockArgument arg) {
2933 os << "// ";
2934 printValueID(arg);
2935 if (arg.use_empty()) {
2936 os << " is unused";
2937 } else {
2938 os << " is used by ";
2939 printValueUsers(arg);
2940 }
2941 os << newLine;
2942 }
2943
printValueUsers(Value value)2944 void OperationPrinter::printValueUsers(Value value) {
2945 if (value.use_empty())
2946 os << "unused";
2947
2948 // One value might be used as the operand of an operation more than once.
2949 // Only print the operations results once in that case.
2950 SmallPtrSet<Operation *, 1> userSet;
2951 for (auto &indexedUser : enumerate(value.getUsers())) {
2952 if (userSet.insert(indexedUser.value()).second)
2953 printUserIDs(indexedUser.value(), indexedUser.index());
2954 }
2955 }
2956
printUserIDs(Operation * user,bool prefixComma)2957 void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
2958 if (prefixComma)
2959 os << ", ";
2960
2961 if (!user->getNumResults()) {
2962 printOperationID(user);
2963 } else {
2964 interleaveComma(user->getResults(),
2965 [this](Value result) { printValueID(result); });
2966 }
2967 }
2968
printGenericOp(Operation * op,bool printOpName)2969 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
2970 if (printOpName)
2971 printEscapedString(op->getName().getStringRef());
2972 os << '(';
2973 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2974 os << ')';
2975
2976 // For terminators, print the list of successors and their operands.
2977 if (op->getNumSuccessors() != 0) {
2978 os << '[';
2979 interleaveComma(op->getSuccessors(),
2980 [&](Block *successor) { printBlockName(successor); });
2981 os << ']';
2982 }
2983
2984 // Print regions.
2985 if (op->getNumRegions() != 0) {
2986 os << " (";
2987 interleaveComma(op->getRegions(), [&](Region ®ion) {
2988 printRegion(region, /*printEntryBlockArgs=*/true,
2989 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2990 });
2991 os << ')';
2992 }
2993
2994 auto attrs = op->getAttrs();
2995 printOptionalAttrDict(attrs);
2996
2997 // Print the type signature of the operation.
2998 os << " : ";
2999 printFunctionalType(op);
3000 }
3001
printBlockName(Block * block)3002 void OperationPrinter::printBlockName(Block *block) {
3003 os << state->getSSANameState().getBlockInfo(block).name;
3004 }
3005
print(Block * block,bool printBlockArgs,bool printBlockTerminator)3006 void OperationPrinter::print(Block *block, bool printBlockArgs,
3007 bool printBlockTerminator) {
3008 // Print the block label and argument list if requested.
3009 if (printBlockArgs) {
3010 os.indent(currentIndent);
3011 printBlockName(block);
3012
3013 // Print the argument list if non-empty.
3014 if (!block->args_empty()) {
3015 os << '(';
3016 interleaveComma(block->getArguments(), [&](BlockArgument arg) {
3017 printValueID(arg);
3018 os << ": ";
3019 printType(arg.getType());
3020 // TODO: We should allow location aliases on block arguments.
3021 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3022 });
3023 os << ')';
3024 }
3025 os << ':';
3026
3027 // Print out some context information about the predecessors of this block.
3028 if (!block->getParent()) {
3029 os << " // block is not in a region!";
3030 } else if (block->hasNoPredecessors()) {
3031 if (!block->isEntryBlock())
3032 os << " // no predecessors";
3033 } else if (auto *pred = block->getSinglePredecessor()) {
3034 os << " // pred: ";
3035 printBlockName(pred);
3036 } else {
3037 // We want to print the predecessors in a stable order, not in
3038 // whatever order the use-list is in, so gather and sort them.
3039 SmallVector<BlockInfo, 4> predIDs;
3040 for (auto *pred : block->getPredecessors())
3041 predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
3042 llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
3043 return lhs.ordering < rhs.ordering;
3044 });
3045
3046 os << " // " << predIDs.size() << " preds: ";
3047
3048 interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
3049 }
3050 os << newLine;
3051 }
3052
3053 currentIndent += indentWidth;
3054
3055 if (printerFlags.shouldPrintValueUsers()) {
3056 for (BlockArgument arg : block->getArguments()) {
3057 os.indent(currentIndent);
3058 printUsersComment(arg);
3059 }
3060 }
3061
3062 bool hasTerminator =
3063 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
3064 auto range = llvm::make_range(
3065 block->begin(),
3066 std::prev(block->end(),
3067 (!hasTerminator || printBlockTerminator) ? 0 : 1));
3068 for (auto &op : range) {
3069 print(&op);
3070 os << newLine;
3071 }
3072 currentIndent -= indentWidth;
3073 }
3074
printValueID(Value value,bool printResultNo,raw_ostream * streamOverride) const3075 void OperationPrinter::printValueID(Value value, bool printResultNo,
3076 raw_ostream *streamOverride) const {
3077 state->getSSANameState().printValueID(value, printResultNo,
3078 streamOverride ? *streamOverride : os);
3079 }
3080
printOperationID(Operation * op,raw_ostream * streamOverride) const3081 void OperationPrinter::printOperationID(Operation *op,
3082 raw_ostream *streamOverride) const {
3083 state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
3084 : os);
3085 }
3086
printSuccessor(Block * successor)3087 void OperationPrinter::printSuccessor(Block *successor) {
3088 printBlockName(successor);
3089 }
3090
printSuccessorAndUseList(Block * successor,ValueRange succOperands)3091 void OperationPrinter::printSuccessorAndUseList(Block *successor,
3092 ValueRange succOperands) {
3093 printBlockName(successor);
3094 if (succOperands.empty())
3095 return;
3096
3097 os << '(';
3098 interleaveComma(succOperands,
3099 [this](Value operand) { printValueID(operand); });
3100 os << " : ";
3101 interleaveComma(succOperands,
3102 [this](Value operand) { printType(operand.getType()); });
3103 os << ')';
3104 }
3105
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock)3106 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs,
3107 bool printBlockTerminators,
3108 bool printEmptyBlock) {
3109 os << "{" << newLine;
3110 if (!region.empty()) {
3111 auto restoreDefaultDialect =
3112 llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
3113 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
3114 defaultDialectStack.push_back(iface.getDefaultDialect());
3115 else
3116 defaultDialectStack.push_back("");
3117
3118 auto *entryBlock = ®ion.front();
3119 // Force printing the block header if printEmptyBlock is set and the block
3120 // is empty or if printEntryBlockArgs is set and there are arguments to
3121 // print.
3122 bool shouldAlwaysPrintBlockHeader =
3123 (printEmptyBlock && entryBlock->empty()) ||
3124 (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
3125 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
3126 for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
3127 print(&b);
3128 }
3129 os.indent(currentIndent) << "}";
3130 }
3131
printAffineMapOfSSAIds(AffineMapAttr mapAttr,ValueRange operands)3132 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3133 ValueRange operands) {
3134 AffineMap map = mapAttr.getValue();
3135 unsigned numDims = map.getNumDims();
3136 auto printValueName = [&](unsigned pos, bool isSymbol) {
3137 unsigned index = isSymbol ? numDims + pos : pos;
3138 assert(index < operands.size());
3139 if (isSymbol)
3140 os << "symbol(";
3141 printValueID(operands[index]);
3142 if (isSymbol)
3143 os << ')';
3144 };
3145
3146 interleaveComma(map.getResults(), [&](AffineExpr expr) {
3147 printAffineExpr(expr, printValueName);
3148 });
3149 }
3150
printAffineExprOfSSAIds(AffineExpr expr,ValueRange dimOperands,ValueRange symOperands)3151 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
3152 ValueRange dimOperands,
3153 ValueRange symOperands) {
3154 auto printValueName = [&](unsigned pos, bool isSymbol) {
3155 if (!isSymbol)
3156 return printValueID(dimOperands[pos]);
3157 os << "symbol(";
3158 printValueID(symOperands[pos]);
3159 os << ')';
3160 };
3161 printAffineExpr(expr, printValueName);
3162 }
3163
3164 //===----------------------------------------------------------------------===//
3165 // print and dump methods
3166 //===----------------------------------------------------------------------===//
3167
print(raw_ostream & os) const3168 void Attribute::print(raw_ostream &os) const {
3169 AsmPrinter::Impl(os).printAttribute(*this);
3170 }
3171
dump() const3172 void Attribute::dump() const {
3173 print(llvm::errs());
3174 llvm::errs() << "\n";
3175 }
3176
print(raw_ostream & os) const3177 void Type::print(raw_ostream &os) const {
3178 AsmPrinter::Impl(os).printType(*this);
3179 }
3180
dump() const3181 void Type::dump() const { print(llvm::errs()); }
3182
dump() const3183 void AffineMap::dump() const {
3184 print(llvm::errs());
3185 llvm::errs() << "\n";
3186 }
3187
dump() const3188 void IntegerSet::dump() const {
3189 print(llvm::errs());
3190 llvm::errs() << "\n";
3191 }
3192
print(raw_ostream & os) const3193 void AffineExpr::print(raw_ostream &os) const {
3194 if (!expr) {
3195 os << "<<NULL AFFINE EXPR>>";
3196 return;
3197 }
3198 AsmPrinter::Impl(os).printAffineExpr(*this);
3199 }
3200
dump() const3201 void AffineExpr::dump() const {
3202 print(llvm::errs());
3203 llvm::errs() << "\n";
3204 }
3205
print(raw_ostream & os) const3206 void AffineMap::print(raw_ostream &os) const {
3207 if (!map) {
3208 os << "<<NULL AFFINE MAP>>";
3209 return;
3210 }
3211 AsmPrinter::Impl(os).printAffineMap(*this);
3212 }
3213
print(raw_ostream & os) const3214 void IntegerSet::print(raw_ostream &os) const {
3215 AsmPrinter::Impl(os).printIntegerSet(*this);
3216 }
3217
print(raw_ostream & os)3218 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
print(raw_ostream & os,const OpPrintingFlags & flags)3219 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
3220 if (!impl) {
3221 os << "<<NULL VALUE>>";
3222 return;
3223 }
3224
3225 if (auto *op = getDefiningOp())
3226 return op->print(os, flags);
3227 // TODO: Improve BlockArgument print'ing.
3228 BlockArgument arg = this->cast<BlockArgument>();
3229 os << "<block argument> of type '" << arg.getType()
3230 << "' at index: " << arg.getArgNumber();
3231 }
print(raw_ostream & os,AsmState & state)3232 void Value::print(raw_ostream &os, AsmState &state) {
3233 if (!impl) {
3234 os << "<<NULL VALUE>>";
3235 return;
3236 }
3237
3238 if (auto *op = getDefiningOp())
3239 return op->print(os, state);
3240
3241 // TODO: Improve BlockArgument print'ing.
3242 BlockArgument arg = this->cast<BlockArgument>();
3243 os << "<block argument> of type '" << arg.getType()
3244 << "' at index: " << arg.getArgNumber();
3245 }
3246
dump()3247 void Value::dump() {
3248 print(llvm::errs());
3249 llvm::errs() << "\n";
3250 }
3251
printAsOperand(raw_ostream & os,AsmState & state)3252 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
3253 // TODO: This doesn't necessarily capture all potential cases.
3254 // Currently, region arguments can be shadowed when printing the main
3255 // operation. If the IR hasn't been printed, this will produce the old SSA
3256 // name and not the shadowed name.
3257 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
3258 os);
3259 }
3260
print(raw_ostream & os,const OpPrintingFlags & printerFlags)3261 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
3262 // If this is a top level operation, we also print aliases.
3263 if (!getParent() && !printerFlags.shouldUseLocalScope()) {
3264 AsmState state(this, printerFlags);
3265 state.getImpl().initializeAliases(this);
3266 print(os, state);
3267 return;
3268 }
3269
3270 // Find the operation to number from based upon the provided flags.
3271 Operation *op = this;
3272 bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
3273 do {
3274 // If we are printing local scope, stop at the first operation that is
3275 // isolated from above.
3276 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
3277 break;
3278
3279 // Otherwise, traverse up to the next parent.
3280 Operation *parentOp = op->getParentOp();
3281 if (!parentOp)
3282 break;
3283 op = parentOp;
3284 } while (true);
3285
3286 AsmState state(op, printerFlags);
3287 print(os, state);
3288 }
print(raw_ostream & os,AsmState & state)3289 void Operation::print(raw_ostream &os, AsmState &state) {
3290 OperationPrinter printer(os, state.getImpl());
3291 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
3292 printer.printTopLevelOperation(this);
3293 else
3294 printer.print(this);
3295 }
3296
dump()3297 void Operation::dump() {
3298 print(llvm::errs(), OpPrintingFlags().useLocalScope());
3299 llvm::errs() << "\n";
3300 }
3301
print(raw_ostream & os)3302 void Block::print(raw_ostream &os) {
3303 Operation *parentOp = getParentOp();
3304 if (!parentOp) {
3305 os << "<<UNLINKED BLOCK>>\n";
3306 return;
3307 }
3308 // Get the top-level op.
3309 while (auto *nextOp = parentOp->getParentOp())
3310 parentOp = nextOp;
3311
3312 AsmState state(parentOp);
3313 print(os, state);
3314 }
print(raw_ostream & os,AsmState & state)3315 void Block::print(raw_ostream &os, AsmState &state) {
3316 OperationPrinter(os, state.getImpl()).print(this);
3317 }
3318
dump()3319 void Block::dump() { print(llvm::errs()); }
3320
3321 /// Print out the name of the block without printing its body.
printAsOperand(raw_ostream & os,bool printType)3322 void Block::printAsOperand(raw_ostream &os, bool printType) {
3323 Operation *parentOp = getParentOp();
3324 if (!parentOp) {
3325 os << "<<UNLINKED BLOCK>>\n";
3326 return;
3327 }
3328 AsmState state(parentOp);
3329 printAsOperand(os, state);
3330 }
printAsOperand(raw_ostream & os,AsmState & state)3331 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
3332 OperationPrinter printer(os, state.getImpl());
3333 printer.printBlockName(this);
3334 }
3335