1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the MLIR AsmPrinter class, which is used to implement 10 // the various print() methods on the core IR objects. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/AffineExpr.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/AsmState.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/IntegerSet.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/Operation.h" 27 #include "mlir/IR/SubElementInterfaces.h" 28 #include "mlir/IR/Verifier.h" 29 #include "llvm/ADT/APFloat.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/MapVector.h" 32 #include "llvm/ADT/STLExtras.h" 33 #include "llvm/ADT/ScopeExit.h" 34 #include "llvm/ADT/ScopedHashTable.h" 35 #include "llvm/ADT/SetVector.h" 36 #include "llvm/ADT/SmallString.h" 37 #include "llvm/ADT/StringExtras.h" 38 #include "llvm/ADT/StringSet.h" 39 #include "llvm/ADT/TypeSwitch.h" 40 #include "llvm/Support/CommandLine.h" 41 #include "llvm/Support/Endian.h" 42 #include "llvm/Support/Regex.h" 43 #include "llvm/Support/SaveAndRestore.h" 44 #include "llvm/Support/Threading.h" 45 46 #include <tuple> 47 48 using namespace mlir; 49 using namespace mlir::detail; 50 51 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 52 53 void OperationName::dump() const { print(llvm::errs()); } 54 55 //===--------------------------------------------------------------------===// 56 // AsmParser 57 //===--------------------------------------------------------------------===// 58 59 AsmParser::~AsmParser() = default; 60 DialectAsmParser::~DialectAsmParser() = default; 61 OpAsmParser::~OpAsmParser() = default; 62 63 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } 64 65 //===----------------------------------------------------------------------===// 66 // DialectAsmPrinter 67 //===----------------------------------------------------------------------===// 68 69 DialectAsmPrinter::~DialectAsmPrinter() = default; 70 71 //===----------------------------------------------------------------------===// 72 // OpAsmPrinter 73 //===----------------------------------------------------------------------===// 74 75 OpAsmPrinter::~OpAsmPrinter() = default; 76 77 void OpAsmPrinter::printFunctionalType(Operation *op) { 78 auto &os = getStream(); 79 os << '('; 80 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { 81 // Print the types of null values as <<NULL TYPE>>. 82 *this << (operand ? operand.getType() : Type()); 83 }); 84 os << ") -> "; 85 86 // Print the result list. We don't parenthesize single result types unless 87 // it is a function (avoiding a grammar ambiguity). 88 bool wrapped = op->getNumResults() != 1; 89 if (!wrapped && op->getResult(0).getType() && 90 op->getResult(0).getType().isa<FunctionType>()) 91 wrapped = true; 92 93 if (wrapped) 94 os << '('; 95 96 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { 97 // Print the types of null values as <<NULL TYPE>>. 98 *this << (result ? result.getType() : Type()); 99 }); 100 101 if (wrapped) 102 os << ')'; 103 } 104 105 //===----------------------------------------------------------------------===// 106 // Operation OpAsm interface. 107 //===----------------------------------------------------------------------===// 108 109 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 110 #include "mlir/IR/OpAsmInterface.cpp.inc" 111 112 //===----------------------------------------------------------------------===// 113 // OpPrintingFlags 114 //===----------------------------------------------------------------------===// 115 116 namespace { 117 /// This struct contains command line options that can be used to initialize 118 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need 119 /// for global command line options. 120 struct AsmPrinterOptions { 121 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ 122 "mlir-print-elementsattrs-with-hex-if-larger", 123 llvm::cl::desc( 124 "Print DenseElementsAttrs with a hex string that have " 125 "more elements than the given upper limit (use -1 to disable)")}; 126 127 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ 128 "mlir-elide-elementsattrs-if-larger", 129 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " 130 "more elements than the given upper limit")}; 131 132 llvm::cl::opt<bool> printDebugInfoOpt{ 133 "mlir-print-debuginfo", llvm::cl::init(false), 134 llvm::cl::desc("Print debug info in MLIR output")}; 135 136 llvm::cl::opt<bool> printPrettyDebugInfoOpt{ 137 "mlir-pretty-debuginfo", llvm::cl::init(false), 138 llvm::cl::desc("Print pretty debug info in MLIR output")}; 139 140 // Use the generic op output form in the operation printer even if the custom 141 // form is defined. 142 llvm::cl::opt<bool> printGenericOpFormOpt{ 143 "mlir-print-op-generic", llvm::cl::init(false), 144 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; 145 146 llvm::cl::opt<bool> assumeVerifiedOpt{ 147 "mlir-print-assume-verified", llvm::cl::init(false), 148 llvm::cl::desc("Skip op verification when using custom printers"), 149 llvm::cl::Hidden}; 150 151 llvm::cl::opt<bool> printLocalScopeOpt{ 152 "mlir-print-local-scope", llvm::cl::init(false), 153 llvm::cl::desc("Print with local scope and inline information (eliding " 154 "aliases for attributes, types, and locations")}; 155 }; 156 } // namespace 157 158 static llvm::ManagedStatic<AsmPrinterOptions> clOptions; 159 160 /// Register a set of useful command-line options that can be used to configure 161 /// various flags within the AsmPrinter. 162 void mlir::registerAsmPrinterCLOptions() { 163 // Make sure that the options struct has been initialized. 164 *clOptions; 165 } 166 167 /// Initialize the printing flags with default supplied by the cl::opts above. 168 OpPrintingFlags::OpPrintingFlags() 169 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), 170 printGenericOpFormFlag(false), assumeVerifiedFlag(false), 171 printLocalScope(false) { 172 // Initialize based upon command line options, if they are available. 173 if (!clOptions.isConstructed()) 174 return; 175 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) 176 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; 177 printDebugInfoFlag = clOptions->printDebugInfoOpt; 178 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; 179 printGenericOpFormFlag = clOptions->printGenericOpFormOpt; 180 assumeVerifiedFlag = clOptions->assumeVerifiedOpt; 181 printLocalScope = clOptions->printLocalScopeOpt; 182 } 183 184 /// Enable the elision of large elements attributes, by printing a '...' 185 /// instead of the element data, when the number of elements is greater than 186 /// `largeElementLimit`. Note: The IR generated with this option is not 187 /// parsable. 188 OpPrintingFlags & 189 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { 190 elementsAttrElementLimit = largeElementLimit; 191 return *this; 192 } 193 194 /// Enable printing of debug information. If 'prettyForm' is set to true, 195 /// debug information is printed in a more readable 'pretty' form. 196 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { 197 printDebugInfoFlag = true; 198 printDebugInfoPrettyFormFlag = prettyForm; 199 return *this; 200 } 201 202 /// Always print operations in the generic form. 203 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { 204 printGenericOpFormFlag = true; 205 return *this; 206 } 207 208 /// Do not verify the operation when using custom operation printers. 209 OpPrintingFlags &OpPrintingFlags::assumeVerified() { 210 assumeVerifiedFlag = true; 211 return *this; 212 } 213 214 /// Use local scope when printing the operation. This allows for using the 215 /// printer in a more localized and thread-safe setting, but may not necessarily 216 /// be identical of what the IR will look like when dumping the full module. 217 OpPrintingFlags &OpPrintingFlags::useLocalScope() { 218 printLocalScope = true; 219 return *this; 220 } 221 222 /// Return if the given ElementsAttr should be elided. 223 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { 224 return elementsAttrElementLimit.hasValue() && 225 *elementsAttrElementLimit < int64_t(attr.getNumElements()) && 226 !attr.isa<SplatElementsAttr>(); 227 } 228 229 /// Return the size limit for printing large ElementsAttr. 230 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { 231 return elementsAttrElementLimit; 232 } 233 234 /// Return if debug information should be printed. 235 bool OpPrintingFlags::shouldPrintDebugInfo() const { 236 return printDebugInfoFlag; 237 } 238 239 /// Return if debug information should be printed in the pretty form. 240 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { 241 return printDebugInfoPrettyFormFlag; 242 } 243 244 /// Return if operations should be printed in the generic form. 245 bool OpPrintingFlags::shouldPrintGenericOpForm() const { 246 return printGenericOpFormFlag; 247 } 248 249 /// Return if operation verification should be skipped. 250 bool OpPrintingFlags::shouldAssumeVerified() const { 251 return assumeVerifiedFlag; 252 } 253 254 /// Return if the printer should use local scope when dumping the IR. 255 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } 256 257 /// Returns true if an ElementsAttr with the given number of elements should be 258 /// printed with hex. 259 static bool shouldPrintElementsAttrWithHex(int64_t numElements) { 260 // Check to see if a command line option was provided for the limit. 261 if (clOptions.isConstructed()) { 262 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) { 263 // -1 is used to disable hex printing. 264 if (clOptions->printElementsAttrWithHexIfLarger == -1) 265 return false; 266 return numElements > clOptions->printElementsAttrWithHexIfLarger; 267 } 268 } 269 270 // Otherwise, default to printing with hex if the number of elements is >100. 271 return numElements > 100; 272 } 273 274 //===----------------------------------------------------------------------===// 275 // NewLineCounter 276 //===----------------------------------------------------------------------===// 277 278 namespace { 279 /// This class is a simple formatter that emits a new line when inputted into a 280 /// stream, that enables counting the number of newlines emitted. This class 281 /// should be used whenever emitting newlines in the printer. 282 struct NewLineCounter { 283 unsigned curLine = 1; 284 }; 285 286 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { 287 ++newLine.curLine; 288 return os << '\n'; 289 } 290 } // namespace 291 292 //===----------------------------------------------------------------------===// 293 // AliasInitializer 294 //===----------------------------------------------------------------------===// 295 296 namespace { 297 /// This class represents a specific instance of a symbol Alias. 298 class SymbolAlias { 299 public: 300 SymbolAlias(StringRef name, bool isDeferrable) 301 : name(name), suffixIndex(0), hasSuffixIndex(false), 302 isDeferrable(isDeferrable) {} 303 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) 304 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true), 305 isDeferrable(isDeferrable) {} 306 307 /// Print this alias to the given stream. 308 void print(raw_ostream &os) const { 309 os << name; 310 if (hasSuffixIndex) 311 os << suffixIndex; 312 } 313 314 /// Returns true if this alias supports deferred resolution when parsing. 315 bool canBeDeferred() const { return isDeferrable; } 316 317 private: 318 /// The main name of the alias. 319 StringRef name; 320 /// The optional suffix index of the alias, if multiple aliases had the same 321 /// name. 322 uint32_t suffixIndex : 30; 323 /// A flag indicating whether this alias has a suffix or not. 324 bool hasSuffixIndex : 1; 325 /// A flag indicating whether this alias may be deferred or not. 326 bool isDeferrable : 1; 327 }; 328 329 /// This class represents a utility that initializes the set of attribute and 330 /// type aliases, without the need to store the extra information within the 331 /// main AliasState class or pass it around via function arguments. 332 class AliasInitializer { 333 public: 334 AliasInitializer( 335 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, 336 llvm::BumpPtrAllocator &aliasAllocator) 337 : interfaces(interfaces), aliasAllocator(aliasAllocator), 338 aliasOS(aliasBuffer) {} 339 340 void initialize(Operation *op, const OpPrintingFlags &printerFlags, 341 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 342 llvm::MapVector<Type, SymbolAlias> &typeToAlias); 343 344 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is 345 /// set to true if the originator of this attribute can resolve the alias 346 /// after parsing has completed (e.g. in the case of operation locations). 347 void visit(Attribute attr, bool canBeDeferred = false); 348 349 /// Visit the given type to see if it has an alias. 350 void visit(Type type); 351 352 private: 353 /// Try to generate an alias for the provided symbol. If an alias is 354 /// generated, the provided alias mapping and reverse mapping are updated. 355 /// Returns success if an alias was generated, failure otherwise. 356 template <typename T> 357 LogicalResult 358 generateAlias(T symbol, 359 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol); 360 361 /// The set of asm interfaces within the context. 362 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 363 364 /// Mapping between an alias and the set of symbols mapped to it. 365 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr; 366 llvm::MapVector<StringRef, std::vector<Type>> aliasToType; 367 368 /// An allocator used for alias names. 369 llvm::BumpPtrAllocator &aliasAllocator; 370 371 /// The set of visited attributes. 372 DenseSet<Attribute> visitedAttributes; 373 374 /// The set of attributes that have aliases *and* can be deferred. 375 DenseSet<Attribute> deferrableAttributes; 376 377 /// The set of visited types. 378 DenseSet<Type> visitedTypes; 379 380 /// Storage and stream used when generating an alias. 381 SmallString<32> aliasBuffer; 382 llvm::raw_svector_ostream aliasOS; 383 }; 384 385 /// This class implements a dummy OpAsmPrinter that doesn't print any output, 386 /// and merely collects the attributes and types that *would* be printed in a 387 /// normal print invocation so that we can generate proper aliases. This allows 388 /// for us to generate aliases only for the attributes and types that would be 389 /// in the output, and trims down unnecessary output. 390 class DummyAliasOperationPrinter : private OpAsmPrinter { 391 public: 392 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, 393 AliasInitializer &initializer) 394 : printerFlags(printerFlags), initializer(initializer) {} 395 396 /// Print the given operation. 397 void print(Operation *op) { 398 // Visit the operation location. 399 if (printerFlags.shouldPrintDebugInfo()) 400 initializer.visit(op->getLoc(), /*canBeDeferred=*/true); 401 402 // If requested, always print the generic form. 403 if (!printerFlags.shouldPrintGenericOpForm()) { 404 // Check to see if this is a known operation. If so, use the registered 405 // custom printer hook. 406 if (auto opInfo = op->getRegisteredInfo()) { 407 opInfo->printAssembly(op, *this, /*defaultDialect=*/""); 408 return; 409 } 410 } 411 412 // Otherwise print with the generic assembly form. 413 printGenericOp(op); 414 } 415 416 private: 417 /// Print the given operation in the generic form. 418 void printGenericOp(Operation *op, bool printOpName = true) override { 419 // Consider nested operations for aliases. 420 if (op->getNumRegions() != 0) { 421 for (Region ®ion : op->getRegions()) 422 printRegion(region, /*printEntryBlockArgs=*/true, 423 /*printBlockTerminators=*/true); 424 } 425 426 // Visit all the types used in the operation. 427 for (Type type : op->getOperandTypes()) 428 printType(type); 429 for (Type type : op->getResultTypes()) 430 printType(type); 431 432 // Consider the attributes of the operation for aliases. 433 for (const NamedAttribute &attr : op->getAttrs()) 434 printAttribute(attr.getValue()); 435 } 436 437 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 438 /// block are not printed. If 'printBlockTerminator' is false, the terminator 439 /// operation of the block is not printed. 440 void print(Block *block, bool printBlockArgs = true, 441 bool printBlockTerminator = true) { 442 // Consider the types of the block arguments for aliases if 'printBlockArgs' 443 // is set to true. 444 if (printBlockArgs) { 445 for (BlockArgument arg : block->getArguments()) { 446 printType(arg.getType()); 447 448 // Visit the argument location. 449 if (printerFlags.shouldPrintDebugInfo()) 450 // TODO: Allow deferring argument locations. 451 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 452 } 453 } 454 455 // Consider the operations within this block, ignoring the terminator if 456 // requested. 457 bool hasTerminator = 458 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 459 auto range = llvm::make_range( 460 block->begin(), 461 std::prev(block->end(), 462 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 463 for (Operation &op : range) 464 print(&op); 465 } 466 467 /// Print the given region. 468 void printRegion(Region ®ion, bool printEntryBlockArgs, 469 bool printBlockTerminators, 470 bool printEmptyBlock = false) override { 471 if (region.empty()) 472 return; 473 474 auto *entryBlock = ®ion.front(); 475 print(entryBlock, printEntryBlockArgs, printBlockTerminators); 476 for (Block &b : llvm::drop_begin(region, 1)) 477 print(&b); 478 } 479 480 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, 481 bool omitType) override { 482 printType(arg.getType()); 483 // Visit the argument location. 484 if (printerFlags.shouldPrintDebugInfo()) 485 // TODO: Allow deferring argument locations. 486 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 487 } 488 489 /// Consider the given type to be printed for an alias. 490 void printType(Type type) override { initializer.visit(type); } 491 492 /// Consider the given attribute to be printed for an alias. 493 void printAttribute(Attribute attr) override { initializer.visit(attr); } 494 void printAttributeWithoutType(Attribute attr) override { 495 printAttribute(attr); 496 } 497 LogicalResult printAlias(Attribute attr) override { 498 initializer.visit(attr); 499 return success(); 500 } 501 LogicalResult printAlias(Type type) override { 502 initializer.visit(type); 503 return success(); 504 } 505 506 /// Print the given set of attributes with names not included within 507 /// 'elidedAttrs'. 508 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 509 ArrayRef<StringRef> elidedAttrs = {}) override { 510 if (attrs.empty()) 511 return; 512 if (elidedAttrs.empty()) { 513 for (const NamedAttribute &attr : attrs) 514 printAttribute(attr.getValue()); 515 return; 516 } 517 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 518 elidedAttrs.end()); 519 for (const NamedAttribute &attr : attrs) 520 if (!elidedAttrsSet.contains(attr.getName().strref())) 521 printAttribute(attr.getValue()); 522 } 523 void printOptionalAttrDictWithKeyword( 524 ArrayRef<NamedAttribute> attrs, 525 ArrayRef<StringRef> elidedAttrs = {}) override { 526 printOptionalAttrDict(attrs, elidedAttrs); 527 } 528 529 /// Return a null stream as the output stream, this will ignore any data fed 530 /// to it. 531 raw_ostream &getStream() const override { return os; } 532 533 /// The following are hooks of `OpAsmPrinter` that are not necessary for 534 /// determining potential aliases. 535 void printFloat(const APFloat &value) override {} 536 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} 537 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} 538 void printNewline() override {} 539 void printOperand(Value) override {} 540 void printOperand(Value, raw_ostream &os) override { 541 // Users expect the output string to have at least the prefixed % to signal 542 // a value name. To maintain this invariant, emit a name even if it is 543 // guaranteed to go unused. 544 os << "%"; 545 } 546 void printKeywordOrString(StringRef) override {} 547 void printSymbolName(StringRef) override {} 548 void printSuccessor(Block *) override {} 549 void printSuccessorAndUseList(Block *, ValueRange) override {} 550 void shadowRegionArgs(Region &, ValueRange) override {} 551 552 /// The printer flags to use when determining potential aliases. 553 const OpPrintingFlags &printerFlags; 554 555 /// The initializer to use when identifying aliases. 556 AliasInitializer &initializer; 557 558 /// A dummy output stream. 559 mutable llvm::raw_null_ostream os; 560 }; 561 } // namespace 562 563 /// Sanitize the given name such that it can be used as a valid identifier. If 564 /// the string needs to be modified in any way, the provided buffer is used to 565 /// store the new copy, 566 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, 567 StringRef allowedPunctChars = "$._-", 568 bool allowTrailingDigit = true) { 569 assert(!name.empty() && "Shouldn't have an empty name here"); 570 571 auto copyNameToBuffer = [&] { 572 for (char ch : name) { 573 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) 574 buffer.push_back(ch); 575 else if (ch == ' ') 576 buffer.push_back('_'); 577 else 578 buffer.append(llvm::utohexstr((unsigned char)ch)); 579 } 580 }; 581 582 // Check to see if this name is valid. If it starts with a digit, then it 583 // could conflict with the autogenerated numeric ID's, so add an underscore 584 // prefix to avoid problems. 585 if (isdigit(name[0])) { 586 buffer.push_back('_'); 587 copyNameToBuffer(); 588 return buffer; 589 } 590 591 // If the name ends with a trailing digit, add a '_' to avoid potential 592 // conflicts with autogenerated ID's. 593 if (!allowTrailingDigit && isdigit(name.back())) { 594 copyNameToBuffer(); 595 buffer.push_back('_'); 596 return buffer; 597 } 598 599 // Check to see that the name consists of only valid identifier characters. 600 for (char ch : name) { 601 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { 602 copyNameToBuffer(); 603 return buffer; 604 } 605 } 606 607 // If there are no invalid characters, return the original name. 608 return name; 609 } 610 611 /// Given a collection of aliases and symbols, initialize a mapping from a 612 /// symbol to a given alias. 613 template <typename T> 614 static void 615 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol, 616 llvm::MapVector<T, SymbolAlias> &symbolToAlias, 617 DenseSet<T> *deferrableAliases = nullptr) { 618 std::vector<std::pair<StringRef, std::vector<T>>> aliases = 619 aliasToSymbol.takeVector(); 620 llvm::array_pod_sort(aliases.begin(), aliases.end(), 621 [](const auto *lhs, const auto *rhs) { 622 return lhs->first.compare(rhs->first); 623 }); 624 625 for (auto &it : aliases) { 626 // If there is only one instance for this alias, use the name directly. 627 if (it.second.size() == 1) { 628 T symbol = it.second.front(); 629 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 630 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)}); 631 continue; 632 } 633 // Otherwise, add the index to the name. 634 for (int i = 0, e = it.second.size(); i < e; ++i) { 635 T symbol = it.second[i]; 636 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 637 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)}); 638 } 639 } 640 } 641 642 void AliasInitializer::initialize( 643 Operation *op, const OpPrintingFlags &printerFlags, 644 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 645 llvm::MapVector<Type, SymbolAlias> &typeToAlias) { 646 // Use a dummy printer when walking the IR so that we can collect the 647 // attributes/types that will actually be used during printing when 648 // considering aliases. 649 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); 650 aliasPrinter.print(op); 651 652 // Initialize the aliases sorted by name. 653 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); 654 initializeAliases(aliasToType, typeToAlias); 655 } 656 657 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { 658 if (!visitedAttributes.insert(attr).second) { 659 // If this attribute already has an alias and this instance can't be 660 // deferred, make sure that the alias isn't deferred. 661 if (!canBeDeferred) 662 deferrableAttributes.erase(attr); 663 return; 664 } 665 666 // Try to generate an alias for this attribute. 667 if (succeeded(generateAlias(attr, aliasToAttr))) { 668 if (canBeDeferred) 669 deferrableAttributes.insert(attr); 670 return; 671 } 672 673 // Check for any sub elements. 674 if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) { 675 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 676 [&](Type type) { visit(type); }); 677 } 678 } 679 680 void AliasInitializer::visit(Type type) { 681 if (!visitedTypes.insert(type).second) 682 return; 683 684 // Try to generate an alias for this type. 685 if (succeeded(generateAlias(type, aliasToType))) 686 return; 687 688 // Check for any sub elements. 689 if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) { 690 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 691 [&](Type type) { visit(type); }); 692 } 693 } 694 695 template <typename T> 696 LogicalResult AliasInitializer::generateAlias( 697 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) { 698 SmallString<32> nameBuffer; 699 for (const auto &interface : interfaces) { 700 OpAsmDialectInterface::AliasResult result = 701 interface.getAlias(symbol, aliasOS); 702 if (result == OpAsmDialectInterface::AliasResult::NoAlias) 703 continue; 704 nameBuffer = std::move(aliasBuffer); 705 assert(!nameBuffer.empty() && "expected valid alias name"); 706 if (result == OpAsmDialectInterface::AliasResult::FinalAlias) 707 break; 708 } 709 710 if (nameBuffer.empty()) 711 return failure(); 712 713 SmallString<16> tempBuffer; 714 StringRef name = 715 sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", 716 /*allowTrailingDigit=*/false); 717 name = name.copy(aliasAllocator); 718 aliasToSymbol[name].push_back(symbol); 719 return success(); 720 } 721 722 //===----------------------------------------------------------------------===// 723 // AliasState 724 //===----------------------------------------------------------------------===// 725 726 namespace { 727 /// This class manages the state for type and attribute aliases. 728 class AliasState { 729 public: 730 // Initialize the internal aliases. 731 void 732 initialize(Operation *op, const OpPrintingFlags &printerFlags, 733 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 734 735 /// Get an alias for the given attribute if it has one and print it in `os`. 736 /// Returns success if an alias was printed, failure otherwise. 737 LogicalResult getAlias(Attribute attr, raw_ostream &os) const; 738 739 /// Get an alias for the given type if it has one and print it in `os`. 740 /// Returns success if an alias was printed, failure otherwise. 741 LogicalResult getAlias(Type ty, raw_ostream &os) const; 742 743 /// Print all of the referenced aliases that can not be resolved in a deferred 744 /// manner. 745 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 746 printAliases(os, newLine, /*isDeferred=*/false); 747 } 748 749 /// Print all of the referenced aliases that support deferred resolution. 750 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 751 printAliases(os, newLine, /*isDeferred=*/true); 752 } 753 754 private: 755 /// Print all of the referenced aliases that support the provided resolution 756 /// behavior. 757 void printAliases(raw_ostream &os, NewLineCounter &newLine, 758 bool isDeferred) const; 759 760 /// Mapping between attribute and alias. 761 llvm::MapVector<Attribute, SymbolAlias> attrToAlias; 762 /// Mapping between type and alias. 763 llvm::MapVector<Type, SymbolAlias> typeToAlias; 764 765 /// An allocator used for alias names. 766 llvm::BumpPtrAllocator aliasAllocator; 767 }; 768 } // namespace 769 770 void AliasState::initialize( 771 Operation *op, const OpPrintingFlags &printerFlags, 772 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 773 AliasInitializer initializer(interfaces, aliasAllocator); 774 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); 775 } 776 777 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { 778 auto it = attrToAlias.find(attr); 779 if (it == attrToAlias.end()) 780 return failure(); 781 it->second.print(os << '#'); 782 return success(); 783 } 784 785 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { 786 auto it = typeToAlias.find(ty); 787 if (it == typeToAlias.end()) 788 return failure(); 789 790 it->second.print(os << '!'); 791 return success(); 792 } 793 794 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine, 795 bool isDeferred) const { 796 auto filterFn = [=](const auto &aliasIt) { 797 return aliasIt.second.canBeDeferred() == isDeferred; 798 }; 799 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) { 800 it.second.print(os << '#'); 801 os << " = " << it.first << newLine; 802 } 803 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) { 804 it.second.print(os << '!'); 805 os << " = type " << it.first << newLine; 806 } 807 } 808 809 //===----------------------------------------------------------------------===// 810 // SSANameState 811 //===----------------------------------------------------------------------===// 812 813 namespace { 814 /// Info about block printing: a number which is its position in the visitation 815 /// order, and a name that is used to print reference to it, e.g. ^bb42. 816 struct BlockInfo { 817 int ordering; 818 StringRef name; 819 }; 820 821 /// This class manages the state of SSA value names. 822 class SSANameState { 823 public: 824 /// A sentinel value used for values with names set. 825 enum : unsigned { NameSentinel = ~0U }; 826 827 SSANameState(Operation *op, const OpPrintingFlags &printerFlags); 828 829 /// Print the SSA identifier for the given value to 'stream'. If 830 /// 'printResultNo' is true, it also presents the result number ('#' number) 831 /// of this value. 832 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; 833 834 /// Return the result indices for each of the result groups registered by this 835 /// operation, or empty if none exist. 836 ArrayRef<int> getOpResultGroups(Operation *op); 837 838 /// Get the info for the given block. 839 BlockInfo getBlockInfo(Block *block); 840 841 /// Renumber the arguments for the specified region to the same names as the 842 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for 843 /// details. 844 void shadowRegionArgs(Region ®ion, ValueRange namesToUse); 845 846 private: 847 /// Number the SSA values within the given IR unit. 848 void numberValuesInRegion(Region ®ion); 849 void numberValuesInBlock(Block &block); 850 void numberValuesInOp(Operation &op); 851 852 /// Given a result of an operation 'result', find the result group head 853 /// 'lookupValue' and the result of 'result' within that group in 854 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group 855 /// has more than 1 result. 856 void getResultIDAndNumber(OpResult result, Value &lookupValue, 857 Optional<int> &lookupResultNo) const; 858 859 /// Set a special value name for the given value. 860 void setValueName(Value value, StringRef name); 861 862 /// Uniques the given value name within the printer. If the given name 863 /// conflicts, it is automatically renamed. 864 StringRef uniqueValueName(StringRef name); 865 866 /// This is the value ID for each SSA value. If this returns NameSentinel, 867 /// then the valueID has an entry in valueNames. 868 DenseMap<Value, unsigned> valueIDs; 869 DenseMap<Value, StringRef> valueNames; 870 871 /// This is a map of operations that contain multiple named result groups, 872 /// i.e. there may be multiple names for the results of the operation. The 873 /// value of this map are the result numbers that start a result group. 874 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; 875 876 /// This maps blocks to there visitation number in the current region as well 877 /// as the string representing their name. 878 DenseMap<Block *, BlockInfo> blockNames; 879 880 /// This keeps track of all of the non-numeric names that are in flight, 881 /// allowing us to check for duplicates. 882 /// Note: the value of the map is unused. 883 llvm::ScopedHashTable<StringRef, char> usedNames; 884 llvm::BumpPtrAllocator usedNameAllocator; 885 886 /// This is the next value ID to assign in numbering. 887 unsigned nextValueID = 0; 888 /// This is the next ID to assign to a region entry block argument. 889 unsigned nextArgumentID = 0; 890 /// This is the next ID to assign when a name conflict is detected. 891 unsigned nextConflictID = 0; 892 893 /// These are the printing flags. They control, eg., whether to print in 894 /// generic form. 895 OpPrintingFlags printerFlags; 896 }; 897 } // namespace 898 899 SSANameState::SSANameState( 900 Operation *op, const OpPrintingFlags &printerFlags) 901 : printerFlags(printerFlags) { 902 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); 903 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); 904 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); 905 906 // The naming context includes `nextValueID`, `nextArgumentID`, 907 // `nextConflictID` and `usedNames` scoped HashTable. This information is 908 // carried from the parent region. 909 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; 910 using NamingContext = 911 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; 912 913 // Allocator for UsedNamesScopeTy 914 llvm::BumpPtrAllocator allocator; 915 916 // Add a scope for the top level operation. 917 auto *topLevelNamesScope = 918 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); 919 920 SmallVector<NamingContext, 8> nameContext; 921 for (Region ®ion : op->getRegions()) 922 nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, 923 nextConflictID, topLevelNamesScope)); 924 925 numberValuesInOp(*op); 926 927 while (!nameContext.empty()) { 928 Region *region; 929 UsedNamesScopeTy *parentScope; 930 std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) = 931 nameContext.pop_back_val(); 932 933 // When we switch from one subtree to another, pop the scopes(needless) 934 // until the parent scope. 935 while (usedNames.getCurScope() != parentScope) { 936 usedNames.getCurScope()->~UsedNamesScopeTy(); 937 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && 938 "top level parentScope must be a nullptr"); 939 } 940 941 // Add a scope for the current region. 942 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) 943 UsedNamesScopeTy(usedNames); 944 945 numberValuesInRegion(*region); 946 947 for (Operation &op : region->getOps()) 948 for (Region ®ion : op.getRegions()) 949 nameContext.push_back(std::make_tuple(®ion, nextValueID, 950 nextArgumentID, nextConflictID, 951 curNamesScope)); 952 } 953 954 // Manually remove all the scopes. 955 while (usedNames.getCurScope() != nullptr) 956 usedNames.getCurScope()->~UsedNamesScopeTy(); 957 } 958 959 void SSANameState::printValueID(Value value, bool printResultNo, 960 raw_ostream &stream) const { 961 if (!value) { 962 stream << "<<NULL VALUE>>"; 963 return; 964 } 965 966 Optional<int> resultNo; 967 auto lookupValue = value; 968 969 // If this is an operation result, collect the head lookup value of the result 970 // group and the result number of 'result' within that group. 971 if (OpResult result = value.dyn_cast<OpResult>()) 972 getResultIDAndNumber(result, lookupValue, resultNo); 973 974 auto it = valueIDs.find(lookupValue); 975 if (it == valueIDs.end()) { 976 stream << "<<UNKNOWN SSA VALUE>>"; 977 return; 978 } 979 980 stream << '%'; 981 if (it->second != NameSentinel) { 982 stream << it->second; 983 } else { 984 auto nameIt = valueNames.find(lookupValue); 985 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 986 stream << nameIt->second; 987 } 988 989 if (resultNo.hasValue() && printResultNo) 990 stream << '#' << resultNo; 991 } 992 993 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { 994 auto it = opResultGroups.find(op); 995 return it == opResultGroups.end() ? ArrayRef<int>() : it->second; 996 } 997 998 BlockInfo SSANameState::getBlockInfo(Block *block) { 999 auto it = blockNames.find(block); 1000 BlockInfo invalidBlock{-1, "INVALIDBLOCK"}; 1001 return it != blockNames.end() ? it->second : invalidBlock; 1002 } 1003 1004 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { 1005 assert(!region.empty() && "cannot shadow arguments of an empty region"); 1006 assert(region.getNumArguments() == namesToUse.size() && 1007 "incorrect number of names passed in"); 1008 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1009 "only KnownIsolatedFromAbove ops can shadow names"); 1010 1011 SmallVector<char, 16> nameStr; 1012 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 1013 auto nameToUse = namesToUse[i]; 1014 if (nameToUse == nullptr) 1015 continue; 1016 auto nameToReplace = region.getArgument(i); 1017 1018 nameStr.clear(); 1019 llvm::raw_svector_ostream nameStream(nameStr); 1020 printValueID(nameToUse, /*printResultNo=*/true, nameStream); 1021 1022 // Entry block arguments should already have a pretty "arg" name. 1023 assert(valueIDs[nameToReplace] == NameSentinel); 1024 1025 // Use the name without the leading %. 1026 auto name = StringRef(nameStream.str()).drop_front(); 1027 1028 // Overwrite the name. 1029 valueNames[nameToReplace] = name.copy(usedNameAllocator); 1030 } 1031 } 1032 1033 void SSANameState::numberValuesInRegion(Region ®ion) { 1034 auto setBlockArgNameFn = [&](Value arg, StringRef name) { 1035 assert(!valueIDs.count(arg) && "arg numbered multiple times"); 1036 assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion && 1037 "arg not defined in current region"); 1038 setValueName(arg, name); 1039 }; 1040 1041 if (!printerFlags.shouldPrintGenericOpForm()) { 1042 if (Operation *op = region.getParentOp()) { 1043 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) 1044 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); 1045 } 1046 } 1047 1048 // Number the values within this region in a breadth-first order. 1049 unsigned nextBlockID = 0; 1050 for (auto &block : region) { 1051 // Each block gets a unique ID, and all of the operations within it get 1052 // numbered as well. 1053 auto blockInfoIt = blockNames.insert({&block, {-1, ""}}); 1054 if (blockInfoIt.second) { 1055 // This block hasn't been named through `getAsmBlockArgumentNames`, use 1056 // default `^bbNNN` format. 1057 std::string name; 1058 llvm::raw_string_ostream(name) << "^bb" << nextBlockID; 1059 blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator); 1060 } 1061 blockInfoIt.first->second.ordering = nextBlockID++; 1062 1063 numberValuesInBlock(block); 1064 } 1065 } 1066 1067 void SSANameState::numberValuesInBlock(Block &block) { 1068 // Number the block arguments. We give entry block arguments a special name 1069 // 'arg'. 1070 bool isEntryBlock = block.isEntryBlock(); 1071 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); 1072 llvm::raw_svector_ostream specialName(specialNameBuffer); 1073 for (auto arg : block.getArguments()) { 1074 if (valueIDs.count(arg)) 1075 continue; 1076 if (isEntryBlock) { 1077 specialNameBuffer.resize(strlen("arg")); 1078 specialName << nextArgumentID++; 1079 } 1080 setValueName(arg, specialName.str()); 1081 } 1082 1083 // Number the operations in this block. 1084 for (auto &op : block) 1085 numberValuesInOp(op); 1086 } 1087 1088 void SSANameState::numberValuesInOp(Operation &op) { 1089 // Function used to set the special result names for the operation. 1090 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); 1091 auto setResultNameFn = [&](Value result, StringRef name) { 1092 assert(!valueIDs.count(result) && "result numbered multiple times"); 1093 assert(result.getDefiningOp() == &op && "result not defined by 'op'"); 1094 setValueName(result, name); 1095 1096 // Record the result number for groups not anchored at 0. 1097 if (int resultNo = result.cast<OpResult>().getResultNumber()) 1098 resultGroups.push_back(resultNo); 1099 }; 1100 // Operations can customize the printing of block names in OpAsmOpInterface. 1101 auto setBlockNameFn = [&](Block *block, StringRef name) { 1102 assert(block->getParentOp() == &op && 1103 "getAsmBlockArgumentNames callback invoked on a block not directly " 1104 "nested under the current operation"); 1105 assert(!blockNames.count(block) && "block numbered multiple times"); 1106 SmallString<16> tmpBuffer{"^"}; 1107 name = sanitizeIdentifier(name, tmpBuffer); 1108 if (name.data() != tmpBuffer.data()) { 1109 tmpBuffer.append(name); 1110 name = tmpBuffer.str(); 1111 } 1112 name = name.copy(usedNameAllocator); 1113 blockNames[block] = {-1, name}; 1114 }; 1115 1116 if (!printerFlags.shouldPrintGenericOpForm()) { 1117 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { 1118 asmInterface.getAsmBlockNames(setBlockNameFn); 1119 asmInterface.getAsmResultNames(setResultNameFn); 1120 } 1121 } 1122 1123 unsigned numResults = op.getNumResults(); 1124 if (numResults == 0) 1125 return; 1126 Value resultBegin = op.getResult(0); 1127 1128 // If the first result wasn't numbered, give it a default number. 1129 if (valueIDs.try_emplace(resultBegin, nextValueID).second) 1130 ++nextValueID; 1131 1132 // If this operation has multiple result groups, mark it. 1133 if (resultGroups.size() != 1) { 1134 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); 1135 opResultGroups.try_emplace(&op, std::move(resultGroups)); 1136 } 1137 } 1138 1139 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, 1140 Optional<int> &lookupResultNo) const { 1141 Operation *owner = result.getOwner(); 1142 if (owner->getNumResults() == 1) 1143 return; 1144 int resultNo = result.getResultNumber(); 1145 1146 // If this operation has multiple result groups, we will need to find the 1147 // one corresponding to this result. 1148 auto resultGroupIt = opResultGroups.find(owner); 1149 if (resultGroupIt == opResultGroups.end()) { 1150 // If not, just use the first result. 1151 lookupResultNo = resultNo; 1152 lookupValue = owner->getResult(0); 1153 return; 1154 } 1155 1156 // Find the correct index using a binary search, as the groups are ordered. 1157 ArrayRef<int> resultGroups = resultGroupIt->second; 1158 const auto *it = llvm::upper_bound(resultGroups, resultNo); 1159 int groupResultNo = 0, groupSize = 0; 1160 1161 // If there are no smaller elements, the last result group is the lookup. 1162 if (it == resultGroups.end()) { 1163 groupResultNo = resultGroups.back(); 1164 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); 1165 } else { 1166 // Otherwise, the previous element is the lookup. 1167 groupResultNo = *std::prev(it); 1168 groupSize = *it - groupResultNo; 1169 } 1170 1171 // We only record the result number for a group of size greater than 1. 1172 if (groupSize != 1) 1173 lookupResultNo = resultNo - groupResultNo; 1174 lookupValue = owner->getResult(groupResultNo); 1175 } 1176 1177 void SSANameState::setValueName(Value value, StringRef name) { 1178 // If the name is empty, the value uses the default numbering. 1179 if (name.empty()) { 1180 valueIDs[value] = nextValueID++; 1181 return; 1182 } 1183 1184 valueIDs[value] = NameSentinel; 1185 valueNames[value] = uniqueValueName(name); 1186 } 1187 1188 StringRef SSANameState::uniqueValueName(StringRef name) { 1189 SmallString<16> tmpBuffer; 1190 name = sanitizeIdentifier(name, tmpBuffer); 1191 1192 // Check to see if this name is already unique. 1193 if (!usedNames.count(name)) { 1194 name = name.copy(usedNameAllocator); 1195 } else { 1196 // Otherwise, we had a conflict - probe until we find a unique name. This 1197 // is guaranteed to terminate (and usually in a single iteration) because it 1198 // generates new names by incrementing nextConflictID. 1199 SmallString<64> probeName(name); 1200 probeName.push_back('_'); 1201 while (true) { 1202 probeName += llvm::utostr(nextConflictID++); 1203 if (!usedNames.count(probeName)) { 1204 name = probeName.str().copy(usedNameAllocator); 1205 break; 1206 } 1207 probeName.resize(name.size() + 1); 1208 } 1209 } 1210 1211 usedNames.insert(name, char()); 1212 return name; 1213 } 1214 1215 //===----------------------------------------------------------------------===// 1216 // AsmState 1217 //===----------------------------------------------------------------------===// 1218 1219 namespace mlir { 1220 namespace detail { 1221 class AsmStateImpl { 1222 public: 1223 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, 1224 AsmState::LocationMap *locationMap) 1225 : interfaces(op->getContext()), nameState(op, printerFlags), 1226 printerFlags(printerFlags), locationMap(locationMap) {} 1227 1228 /// Initialize the alias state to enable the printing of aliases. 1229 void initializeAliases(Operation *op) { 1230 aliasState.initialize(op, printerFlags, interfaces); 1231 } 1232 1233 /// Get the state used for aliases. 1234 AliasState &getAliasState() { return aliasState; } 1235 1236 /// Get the state used for SSA names. 1237 SSANameState &getSSANameState() { return nameState; } 1238 1239 /// Get the printer flags. 1240 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } 1241 1242 /// Register the location, line and column, within the buffer that the given 1243 /// operation was printed at. 1244 void registerOperationLocation(Operation *op, unsigned line, unsigned col) { 1245 if (locationMap) 1246 (*locationMap)[op] = std::make_pair(line, col); 1247 } 1248 1249 private: 1250 /// Collection of OpAsm interfaces implemented in the context. 1251 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 1252 1253 /// The state used for attribute and type aliases. 1254 AliasState aliasState; 1255 1256 /// The state used for SSA value names. 1257 SSANameState nameState; 1258 1259 /// Flags that control op output. 1260 OpPrintingFlags printerFlags; 1261 1262 /// An optional location map to be populated. 1263 AsmState::LocationMap *locationMap; 1264 }; 1265 } // namespace detail 1266 } // namespace mlir 1267 1268 /// Verifies the operation and switches to generic op printing if verification 1269 /// fails. We need to do this because custom print functions may fail for 1270 /// invalid ops. 1271 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, 1272 OpPrintingFlags printerFlags) { 1273 if (printerFlags.shouldPrintGenericOpForm() || 1274 printerFlags.shouldAssumeVerified()) 1275 return printerFlags; 1276 1277 // Ignore errors emitted by the verifier. We check the thread id to avoid 1278 // consuming other threads' errors. 1279 auto parentThreadId = llvm::get_threadid(); 1280 ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &) { 1281 return success(parentThreadId == llvm::get_threadid()); 1282 }); 1283 if (failed(verify(op))) 1284 printerFlags.printGenericOpForm(); 1285 1286 return printerFlags; 1287 } 1288 1289 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, 1290 LocationMap *locationMap) 1291 : impl(std::make_unique<AsmStateImpl>( 1292 op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {} 1293 AsmState::~AsmState() = default; 1294 1295 const OpPrintingFlags &AsmState::getPrinterFlags() const { 1296 return impl->getPrinterFlags(); 1297 } 1298 1299 //===----------------------------------------------------------------------===// 1300 // AsmPrinter::Impl 1301 //===----------------------------------------------------------------------===// 1302 1303 namespace mlir { 1304 class AsmPrinter::Impl { 1305 public: 1306 Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None, 1307 AsmStateImpl *state = nullptr) 1308 : os(os), printerFlags(flags), state(state) {} 1309 explicit Impl(Impl &other) 1310 : Impl(other.os, other.printerFlags, other.state) {} 1311 1312 /// Returns the output stream of the printer. 1313 raw_ostream &getStream() { return os; } 1314 1315 template <typename Container, typename UnaryFunctor> 1316 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { 1317 llvm::interleaveComma(c, os, eachFn); 1318 } 1319 1320 /// This enum describes the different kinds of elision for the type of an 1321 /// attribute when printing it. 1322 enum class AttrTypeElision { 1323 /// The type must not be elided, 1324 Never, 1325 /// The type may be elided when it matches the default used in the parser 1326 /// (for example i64 is the default for integer attributes). 1327 May, 1328 /// The type must be elided. 1329 Must 1330 }; 1331 1332 /// Print the given attribute. 1333 void printAttribute(Attribute attr, 1334 AttrTypeElision typeElision = AttrTypeElision::Never); 1335 1336 /// Print the alias for the given attribute, return failure if no alias could 1337 /// be printed. 1338 LogicalResult printAlias(Attribute attr); 1339 1340 void printType(Type type); 1341 1342 /// Print the alias for the given type, return failure if no alias could 1343 /// be printed. 1344 LogicalResult printAlias(Type type); 1345 1346 /// Print the given location to the stream. If `allowAlias` is true, this 1347 /// allows for the internal location to use an attribute alias. 1348 void printLocation(LocationAttr loc, bool allowAlias = false); 1349 1350 void printAffineMap(AffineMap map); 1351 void 1352 printAffineExpr(AffineExpr expr, 1353 function_ref<void(unsigned, bool)> printValueName = nullptr); 1354 void printAffineConstraint(AffineExpr expr, bool isEq); 1355 void printIntegerSet(IntegerSet set); 1356 1357 protected: 1358 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1359 ArrayRef<StringRef> elidedAttrs = {}, 1360 bool withKeyword = false); 1361 void printNamedAttribute(NamedAttribute attr); 1362 void printTrailingLocation(Location loc, bool allowAlias = true); 1363 void printLocationInternal(LocationAttr loc, bool pretty = false); 1364 1365 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1366 /// used instead of individual elements when the elements attr is large. 1367 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); 1368 1369 /// Print a dense string elements attribute. 1370 void printDenseStringElementsAttr(DenseStringElementsAttr attr); 1371 1372 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1373 /// used instead of individual elements when the elements attr is large. 1374 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1375 bool allowHex); 1376 1377 void printDialectAttribute(Attribute attr); 1378 void printDialectType(Type type); 1379 1380 /// This enum is used to represent the binding strength of the enclosing 1381 /// context that an AffineExprStorage is being printed in, so we can 1382 /// intelligently produce parens. 1383 enum class BindingStrength { 1384 Weak, // + and - 1385 Strong, // All other binary operators. 1386 }; 1387 void printAffineExprInternal( 1388 AffineExpr expr, BindingStrength enclosingTightness, 1389 function_ref<void(unsigned, bool)> printValueName = nullptr); 1390 1391 /// The output stream for the printer. 1392 raw_ostream &os; 1393 1394 /// A set of flags to control the printer's behavior. 1395 OpPrintingFlags printerFlags; 1396 1397 /// An optional printer state for the module. 1398 AsmStateImpl *state; 1399 1400 /// A tracker for the number of new lines emitted during printing. 1401 NewLineCounter newLine; 1402 }; 1403 } // namespace mlir 1404 1405 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { 1406 // Check to see if we are printing debug information. 1407 if (!printerFlags.shouldPrintDebugInfo()) 1408 return; 1409 1410 os << " "; 1411 printLocation(loc, /*allowAlias=*/allowAlias); 1412 } 1413 1414 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) { 1415 TypeSwitch<LocationAttr>(loc) 1416 .Case<OpaqueLoc>([&](OpaqueLoc loc) { 1417 printLocationInternal(loc.getFallbackLocation(), pretty); 1418 }) 1419 .Case<UnknownLoc>([&](UnknownLoc loc) { 1420 if (pretty) 1421 os << "[unknown]"; 1422 else 1423 os << "unknown"; 1424 }) 1425 .Case<FileLineColLoc>([&](FileLineColLoc loc) { 1426 if (pretty) { 1427 os << loc.getFilename().getValue(); 1428 } else { 1429 os << "\""; 1430 printEscapedString(loc.getFilename(), os); 1431 os << "\""; 1432 } 1433 os << ':' << loc.getLine() << ':' << loc.getColumn(); 1434 }) 1435 .Case<NameLoc>([&](NameLoc loc) { 1436 os << '\"'; 1437 printEscapedString(loc.getName(), os); 1438 os << '\"'; 1439 1440 // Print the child if it isn't unknown. 1441 auto childLoc = loc.getChildLoc(); 1442 if (!childLoc.isa<UnknownLoc>()) { 1443 os << '('; 1444 printLocationInternal(childLoc, pretty); 1445 os << ')'; 1446 } 1447 }) 1448 .Case<CallSiteLoc>([&](CallSiteLoc loc) { 1449 Location caller = loc.getCaller(); 1450 Location callee = loc.getCallee(); 1451 if (!pretty) 1452 os << "callsite("; 1453 printLocationInternal(callee, pretty); 1454 if (pretty) { 1455 if (callee.isa<NameLoc>()) { 1456 if (caller.isa<FileLineColLoc>()) { 1457 os << " at "; 1458 } else { 1459 os << newLine << " at "; 1460 } 1461 } else { 1462 os << newLine << " at "; 1463 } 1464 } else { 1465 os << " at "; 1466 } 1467 printLocationInternal(caller, pretty); 1468 if (!pretty) 1469 os << ")"; 1470 }) 1471 .Case<FusedLoc>([&](FusedLoc loc) { 1472 if (!pretty) 1473 os << "fused"; 1474 if (Attribute metadata = loc.getMetadata()) 1475 os << '<' << metadata << '>'; 1476 os << '['; 1477 interleave( 1478 loc.getLocations(), 1479 [&](Location loc) { printLocationInternal(loc, pretty); }, 1480 [&]() { os << ", "; }); 1481 os << ']'; 1482 }); 1483 } 1484 1485 /// Print a floating point value in a way that the parser will be able to 1486 /// round-trip losslessly. 1487 static void printFloatValue(const APFloat &apValue, raw_ostream &os) { 1488 // We would like to output the FP constant value in exponential notation, 1489 // but we cannot do this if doing so will lose precision. Check here to 1490 // make sure that we only output it in exponential format if we can parse 1491 // the value back and get the same value. 1492 bool isInf = apValue.isInfinity(); 1493 bool isNaN = apValue.isNaN(); 1494 if (!isInf && !isNaN) { 1495 SmallString<128> strValue; 1496 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, 1497 /*TruncateZero=*/false); 1498 1499 // Check to make sure that the stringized number is not some string like 1500 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 1501 // that the string matches the "[-+]?[0-9]" regex. 1502 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 1503 ((strValue[0] == '-' || strValue[0] == '+') && 1504 (strValue[1] >= '0' && strValue[1] <= '9'))) && 1505 "[-+]?[0-9] regex does not match!"); 1506 1507 // Parse back the stringized version and check that the value is equal 1508 // (i.e., there is no precision loss). 1509 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 1510 os << strValue; 1511 return; 1512 } 1513 1514 // If it is not, use the default format of APFloat instead of the 1515 // exponential notation. 1516 strValue.clear(); 1517 apValue.toString(strValue); 1518 1519 // Make sure that we can parse the default form as a float. 1520 if (strValue.str().contains('.')) { 1521 os << strValue; 1522 return; 1523 } 1524 } 1525 1526 // Print special values in hexadecimal format. The sign bit should be included 1527 // in the literal. 1528 SmallVector<char, 16> str; 1529 APInt apInt = apValue.bitcastToAPInt(); 1530 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 1531 /*formatAsCLiteral=*/true); 1532 os << str; 1533 } 1534 1535 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { 1536 if (printerFlags.shouldPrintDebugInfoPrettyForm()) 1537 return printLocationInternal(loc, /*pretty=*/true); 1538 1539 os << "loc("; 1540 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) 1541 printLocationInternal(loc); 1542 os << ')'; 1543 } 1544 1545 /// Returns true if the given dialect symbol data is simple enough to print in 1546 /// the pretty form, i.e. without the enclosing "". 1547 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 1548 // The name must start with an identifier. 1549 if (symName.empty() || !isalpha(symName.front())) 1550 return false; 1551 1552 // Ignore all the characters that are valid in an identifier in the symbol 1553 // name. 1554 symName = symName.drop_while( 1555 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); 1556 if (symName.empty()) 1557 return true; 1558 1559 // If we got to an unexpected character, then it must be a <>. Check those 1560 // recursively. 1561 if (symName.front() != '<' || symName.back() != '>') 1562 return false; 1563 1564 SmallVector<char, 8> nestedPunctuation; 1565 do { 1566 // If we ran out of characters, then we had a punctuation mismatch. 1567 if (symName.empty()) 1568 return false; 1569 1570 auto c = symName.front(); 1571 symName = symName.drop_front(); 1572 1573 switch (c) { 1574 // We never allow null characters. This is an EOF indicator for the lexer 1575 // which we could handle, but isn't important for any known dialect. 1576 case '\0': 1577 return false; 1578 case '<': 1579 case '[': 1580 case '(': 1581 case '{': 1582 nestedPunctuation.push_back(c); 1583 continue; 1584 case '-': 1585 // Treat `->` as a special token. 1586 if (!symName.empty() && symName.front() == '>') { 1587 symName = symName.drop_front(); 1588 continue; 1589 } 1590 break; 1591 // Reject types with mismatched brackets. 1592 case '>': 1593 if (nestedPunctuation.pop_back_val() != '<') 1594 return false; 1595 break; 1596 case ']': 1597 if (nestedPunctuation.pop_back_val() != '[') 1598 return false; 1599 break; 1600 case ')': 1601 if (nestedPunctuation.pop_back_val() != '(') 1602 return false; 1603 break; 1604 case '}': 1605 if (nestedPunctuation.pop_back_val() != '{') 1606 return false; 1607 break; 1608 default: 1609 continue; 1610 } 1611 1612 // We're done when the punctuation is fully matched. 1613 } while (!nestedPunctuation.empty()); 1614 1615 // If there were extra characters, then we failed. 1616 return symName.empty(); 1617 } 1618 1619 /// Print the given dialect symbol to the stream. 1620 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 1621 StringRef dialectName, StringRef symString) { 1622 os << symPrefix << dialectName; 1623 1624 // If this symbol name is simple enough, print it directly in pretty form, 1625 // otherwise, we print it as an escaped string. 1626 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 1627 os << '.' << symString; 1628 return; 1629 } 1630 1631 os << "<\""; 1632 llvm::printEscapedString(symString, os); 1633 os << "\">"; 1634 } 1635 1636 /// Returns true if the given string can be represented as a bare identifier. 1637 static bool isBareIdentifier(StringRef name) { 1638 // By making this unsigned, the value passed in to isalnum will always be 1639 // in the range 0-255. This is important when building with MSVC because 1640 // its implementation will assert. This situation can arise when dealing 1641 // with UTF-8 multibyte characters. 1642 if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) 1643 return false; 1644 return llvm::all_of(name.drop_front(), [](unsigned char c) { 1645 return isalnum(c) || c == '_' || c == '$' || c == '.'; 1646 }); 1647 } 1648 1649 /// Print the given string as a keyword, or a quoted and escaped string if it 1650 /// has any special or non-printable characters in it. 1651 static void printKeywordOrString(StringRef keyword, raw_ostream &os) { 1652 // If it can be represented as a bare identifier, write it directly. 1653 if (isBareIdentifier(keyword)) { 1654 os << keyword; 1655 return; 1656 } 1657 1658 // Otherwise, output the keyword wrapped in quotes with proper escaping. 1659 os << "\""; 1660 printEscapedString(keyword, os); 1661 os << '"'; 1662 } 1663 1664 /// Print the given string as a symbol reference. A symbol reference is 1665 /// represented as a string prefixed with '@'. The reference is surrounded with 1666 /// ""'s and escaped if it has any special or non-printable characters in it. 1667 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { 1668 assert(!symbolRef.empty() && "expected valid symbol reference"); 1669 os << '@'; 1670 printKeywordOrString(symbolRef, os); 1671 } 1672 1673 // Print out a valid ElementsAttr that is succinct and can represent any 1674 // potential shape/type, for use when eliding a large ElementsAttr. 1675 // 1676 // We choose to use an opaque ElementsAttr literal with conspicuous content to 1677 // hopefully alert readers to the fact that this has been elided. 1678 // 1679 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will 1680 // accept the string "elided". The first string must be a registered dialect 1681 // name and the latter must be a hex constant. 1682 static void printElidedElementsAttr(raw_ostream &os) { 1683 os << R"(opaque<"elided_large_const", "0xDEADBEEF">)"; 1684 } 1685 1686 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { 1687 return success(state && succeeded(state->getAliasState().getAlias(attr, os))); 1688 } 1689 1690 LogicalResult AsmPrinter::Impl::printAlias(Type type) { 1691 return success(state && succeeded(state->getAliasState().getAlias(type, os))); 1692 } 1693 1694 void AsmPrinter::Impl::printAttribute(Attribute attr, 1695 AttrTypeElision typeElision) { 1696 if (!attr) { 1697 os << "<<NULL ATTRIBUTE>>"; 1698 return; 1699 } 1700 1701 // Try to print an alias for this attribute. 1702 if (succeeded(printAlias(attr))) 1703 return; 1704 1705 if (!isa<BuiltinDialect>(attr.getDialect())) 1706 return printDialectAttribute(attr); 1707 1708 auto attrType = attr.getType(); 1709 if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) { 1710 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 1711 opaqueAttr.getAttrData()); 1712 } else if (attr.isa<UnitAttr>()) { 1713 os << "unit"; 1714 return; 1715 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) { 1716 os << '{'; 1717 interleaveComma(dictAttr.getValue(), 1718 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1719 os << '}'; 1720 1721 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) { 1722 if (attrType.isSignlessInteger(1)) { 1723 os << (intAttr.getValue().getBoolValue() ? "true" : "false"); 1724 1725 // Boolean integer attributes always elides the type. 1726 return; 1727 } 1728 1729 // Only print attributes as unsigned if they are explicitly unsigned or are 1730 // signless 1-bit values. Indexes, signed values, and multi-bit signless 1731 // values print as signed. 1732 bool isUnsigned = 1733 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); 1734 intAttr.getValue().print(os, !isUnsigned); 1735 1736 // IntegerAttr elides the type if I64. 1737 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) 1738 return; 1739 1740 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) { 1741 printFloatValue(floatAttr.getValue(), os); 1742 1743 // FloatAttr elides the type if F64. 1744 if (typeElision == AttrTypeElision::May && attrType.isF64()) 1745 return; 1746 1747 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) { 1748 os << '"'; 1749 printEscapedString(strAttr.getValue(), os); 1750 os << '"'; 1751 1752 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 1753 os << '['; 1754 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { 1755 printAttribute(attr, AttrTypeElision::May); 1756 }); 1757 os << ']'; 1758 1759 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) { 1760 os << "affine_map<"; 1761 affineMapAttr.getValue().print(os); 1762 os << '>'; 1763 1764 // AffineMap always elides the type. 1765 return; 1766 1767 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) { 1768 os << "affine_set<"; 1769 integerSetAttr.getValue().print(os); 1770 os << '>'; 1771 1772 // IntegerSet always elides the type. 1773 return; 1774 1775 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 1776 printType(typeAttr.getValue()); 1777 1778 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) { 1779 printSymbolReference(refAttr.getRootReference().getValue(), os); 1780 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { 1781 os << "::"; 1782 printSymbolReference(nestedRef.getValue(), os); 1783 } 1784 1785 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) { 1786 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { 1787 printElidedElementsAttr(os); 1788 } else { 1789 os << "opaque<" << opaqueAttr.getDialect() << ", \"0x" 1790 << llvm::toHex(opaqueAttr.getValue()) << "\">"; 1791 } 1792 1793 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) { 1794 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { 1795 printElidedElementsAttr(os); 1796 } else { 1797 os << "dense<"; 1798 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); 1799 os << '>'; 1800 } 1801 1802 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) { 1803 if (printerFlags.shouldElideElementsAttr(strEltAttr)) { 1804 printElidedElementsAttr(os); 1805 } else { 1806 os << "dense<"; 1807 printDenseStringElementsAttr(strEltAttr); 1808 os << '>'; 1809 } 1810 1811 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) { 1812 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || 1813 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { 1814 printElidedElementsAttr(os); 1815 } else { 1816 os << "sparse<"; 1817 DenseIntElementsAttr indices = sparseEltAttr.getIndices(); 1818 if (indices.getNumElements() != 0) { 1819 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); 1820 os << ", "; 1821 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); 1822 } 1823 os << '>'; 1824 } 1825 1826 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) { 1827 printLocation(locAttr); 1828 } 1829 // Don't print the type if we must elide it, or if it is a None type. 1830 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { 1831 os << " : "; 1832 printType(attrType); 1833 } 1834 } 1835 1836 /// Print the integer element of a DenseElementsAttr. 1837 static void printDenseIntElement(const APInt &value, raw_ostream &os, 1838 bool isSigned) { 1839 if (value.getBitWidth() == 1) 1840 os << (value.getBoolValue() ? "true" : "false"); 1841 else 1842 value.print(os, isSigned); 1843 } 1844 1845 static void 1846 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, 1847 function_ref<void(unsigned)> printEltFn) { 1848 // Special case for 0-d and splat tensors. 1849 if (isSplat) 1850 return printEltFn(0); 1851 1852 // Special case for degenerate tensors. 1853 auto numElements = type.getNumElements(); 1854 if (numElements == 0) 1855 return; 1856 1857 // We use a mixed-radix counter to iterate through the shape. When we bump a 1858 // non-least-significant digit, we emit a close bracket. When we next emit an 1859 // element we re-open all closed brackets. 1860 1861 // The mixed-radix counter, with radices in 'shape'. 1862 int64_t rank = type.getRank(); 1863 SmallVector<unsigned, 4> counter(rank, 0); 1864 // The number of brackets that have been opened and not closed. 1865 unsigned openBrackets = 0; 1866 1867 auto shape = type.getShape(); 1868 auto bumpCounter = [&] { 1869 // Bump the least significant digit. 1870 ++counter[rank - 1]; 1871 // Iterate backwards bubbling back the increment. 1872 for (unsigned i = rank - 1; i > 0; --i) 1873 if (counter[i] >= shape[i]) { 1874 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 1875 counter[i] = 0; 1876 ++counter[i - 1]; 1877 --openBrackets; 1878 os << ']'; 1879 } 1880 }; 1881 1882 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 1883 if (idx != 0) 1884 os << ", "; 1885 while (openBrackets++ < rank) 1886 os << '['; 1887 openBrackets = rank; 1888 printEltFn(idx); 1889 bumpCounter(); 1890 } 1891 while (openBrackets-- > 0) 1892 os << ']'; 1893 } 1894 1895 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, 1896 bool allowHex) { 1897 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) 1898 return printDenseStringElementsAttr(stringAttr); 1899 1900 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(), 1901 allowHex); 1902 } 1903 1904 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( 1905 DenseIntOrFPElementsAttr attr, bool allowHex) { 1906 auto type = attr.getType(); 1907 auto elementType = type.getElementType(); 1908 1909 // Check to see if we should format this attribute as a hex string. 1910 auto numElements = type.getNumElements(); 1911 if (!attr.isSplat() && allowHex && 1912 shouldPrintElementsAttrWithHex(numElements)) { 1913 ArrayRef<char> rawData = attr.getRawData(); 1914 if (llvm::support::endian::system_endianness() == 1915 llvm::support::endianness::big) { 1916 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE 1917 // machines. It is converted here to print in LE format. 1918 SmallVector<char, 64> outDataVec(rawData.size()); 1919 MutableArrayRef<char> convRawData(outDataVec); 1920 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1921 rawData, convRawData, type); 1922 os << '"' << "0x" 1923 << llvm::toHex(StringRef(convRawData.data(), convRawData.size())) 1924 << "\""; 1925 } else { 1926 os << '"' << "0x" 1927 << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\""; 1928 } 1929 1930 return; 1931 } 1932 1933 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1934 Type complexElementType = complexTy.getElementType(); 1935 // Note: The if and else below had a common lambda function which invoked 1936 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 1937 // and hence was replaced. 1938 if (complexElementType.isa<IntegerType>()) { 1939 bool isSigned = !complexElementType.isUnsignedInteger(); 1940 auto valueIt = attr.value_begin<std::complex<APInt>>(); 1941 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1942 auto complexValue = *(valueIt + index); 1943 os << "("; 1944 printDenseIntElement(complexValue.real(), os, isSigned); 1945 os << ","; 1946 printDenseIntElement(complexValue.imag(), os, isSigned); 1947 os << ")"; 1948 }); 1949 } else { 1950 auto valueIt = attr.value_begin<std::complex<APFloat>>(); 1951 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1952 auto complexValue = *(valueIt + index); 1953 os << "("; 1954 printFloatValue(complexValue.real(), os); 1955 os << ","; 1956 printFloatValue(complexValue.imag(), os); 1957 os << ")"; 1958 }); 1959 } 1960 } else if (elementType.isIntOrIndex()) { 1961 bool isSigned = !elementType.isUnsignedInteger(); 1962 auto valueIt = attr.value_begin<APInt>(); 1963 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1964 printDenseIntElement(*(valueIt + index), os, isSigned); 1965 }); 1966 } else { 1967 assert(elementType.isa<FloatType>() && "unexpected element type"); 1968 auto valueIt = attr.value_begin<APFloat>(); 1969 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1970 printFloatValue(*(valueIt + index), os); 1971 }); 1972 } 1973 } 1974 1975 void AsmPrinter::Impl::printDenseStringElementsAttr( 1976 DenseStringElementsAttr attr) { 1977 ArrayRef<StringRef> data = attr.getRawStringData(); 1978 auto printFn = [&](unsigned index) { 1979 os << "\""; 1980 printEscapedString(data[index], os); 1981 os << "\""; 1982 }; 1983 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); 1984 } 1985 1986 void AsmPrinter::Impl::printType(Type type) { 1987 if (!type) { 1988 os << "<<NULL TYPE>>"; 1989 return; 1990 } 1991 1992 // Try to print an alias for this type. 1993 if (state && succeeded(state->getAliasState().getAlias(type, os))) 1994 return; 1995 1996 TypeSwitch<Type>(type) 1997 .Case<OpaqueType>([&](OpaqueType opaqueTy) { 1998 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 1999 opaqueTy.getTypeData()); 2000 }) 2001 .Case<IndexType>([&](Type) { os << "index"; }) 2002 .Case<BFloat16Type>([&](Type) { os << "bf16"; }) 2003 .Case<Float16Type>([&](Type) { os << "f16"; }) 2004 .Case<Float32Type>([&](Type) { os << "f32"; }) 2005 .Case<Float64Type>([&](Type) { os << "f64"; }) 2006 .Case<Float80Type>([&](Type) { os << "f80"; }) 2007 .Case<Float128Type>([&](Type) { os << "f128"; }) 2008 .Case<IntegerType>([&](IntegerType integerTy) { 2009 if (integerTy.isSigned()) 2010 os << 's'; 2011 else if (integerTy.isUnsigned()) 2012 os << 'u'; 2013 os << 'i' << integerTy.getWidth(); 2014 }) 2015 .Case<FunctionType>([&](FunctionType funcTy) { 2016 os << '('; 2017 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); 2018 os << ") -> "; 2019 ArrayRef<Type> results = funcTy.getResults(); 2020 if (results.size() == 1 && !results[0].isa<FunctionType>()) { 2021 printType(results[0]); 2022 } else { 2023 os << '('; 2024 interleaveComma(results, [&](Type ty) { printType(ty); }); 2025 os << ')'; 2026 } 2027 }) 2028 .Case<VectorType>([&](VectorType vectorTy) { 2029 os << "vector<"; 2030 auto vShape = vectorTy.getShape(); 2031 unsigned lastDim = vShape.size(); 2032 unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims(); 2033 unsigned dimIdx = 0; 2034 for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++) 2035 os << vShape[dimIdx] << 'x'; 2036 if (vectorTy.isScalable()) { 2037 os << '['; 2038 unsigned secondToLastDim = lastDim - 1; 2039 for (; dimIdx < secondToLastDim; dimIdx++) 2040 os << vShape[dimIdx] << 'x'; 2041 os << vShape[dimIdx] << "]x"; 2042 } 2043 printType(vectorTy.getElementType()); 2044 os << '>'; 2045 }) 2046 .Case<RankedTensorType>([&](RankedTensorType tensorTy) { 2047 os << "tensor<"; 2048 for (int64_t dim : tensorTy.getShape()) { 2049 if (ShapedType::isDynamic(dim)) 2050 os << '?'; 2051 else 2052 os << dim; 2053 os << 'x'; 2054 } 2055 printType(tensorTy.getElementType()); 2056 // Only print the encoding attribute value if set. 2057 if (tensorTy.getEncoding()) { 2058 os << ", "; 2059 printAttribute(tensorTy.getEncoding()); 2060 } 2061 os << '>'; 2062 }) 2063 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { 2064 os << "tensor<*x"; 2065 printType(tensorTy.getElementType()); 2066 os << '>'; 2067 }) 2068 .Case<MemRefType>([&](MemRefType memrefTy) { 2069 os << "memref<"; 2070 for (int64_t dim : memrefTy.getShape()) { 2071 if (ShapedType::isDynamic(dim)) 2072 os << '?'; 2073 else 2074 os << dim; 2075 os << 'x'; 2076 } 2077 printType(memrefTy.getElementType()); 2078 if (!memrefTy.getLayout().isIdentity()) { 2079 os << ", "; 2080 printAttribute(memrefTy.getLayout(), AttrTypeElision::May); 2081 } 2082 // Only print the memory space if it is the non-default one. 2083 if (memrefTy.getMemorySpace()) { 2084 os << ", "; 2085 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2086 } 2087 os << '>'; 2088 }) 2089 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { 2090 os << "memref<*x"; 2091 printType(memrefTy.getElementType()); 2092 // Only print the memory space if it is the non-default one. 2093 if (memrefTy.getMemorySpace()) { 2094 os << ", "; 2095 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2096 } 2097 os << '>'; 2098 }) 2099 .Case<ComplexType>([&](ComplexType complexTy) { 2100 os << "complex<"; 2101 printType(complexTy.getElementType()); 2102 os << '>'; 2103 }) 2104 .Case<TupleType>([&](TupleType tupleTy) { 2105 os << "tuple<"; 2106 interleaveComma(tupleTy.getTypes(), 2107 [&](Type type) { printType(type); }); 2108 os << '>'; 2109 }) 2110 .Case<NoneType>([&](Type) { os << "none"; }) 2111 .Default([&](Type type) { return printDialectType(type); }); 2112 } 2113 2114 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2115 ArrayRef<StringRef> elidedAttrs, 2116 bool withKeyword) { 2117 // If there are no attributes, then there is nothing to be done. 2118 if (attrs.empty()) 2119 return; 2120 2121 // Functor used to print a filtered attribute list. 2122 auto printFilteredAttributesFn = [&](auto filteredAttrs) { 2123 // Print the 'attributes' keyword if necessary. 2124 if (withKeyword) 2125 os << " attributes"; 2126 2127 // Otherwise, print them all out in braces. 2128 os << " {"; 2129 interleaveComma(filteredAttrs, 2130 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 2131 os << '}'; 2132 }; 2133 2134 // If no attributes are elided, we can directly print with no filtering. 2135 if (elidedAttrs.empty()) 2136 return printFilteredAttributesFn(attrs); 2137 2138 // Otherwise, filter out any attributes that shouldn't be included. 2139 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 2140 elidedAttrs.end()); 2141 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 2142 return !elidedAttrsSet.contains(attr.getName().strref()); 2143 }); 2144 if (!filteredAttrs.empty()) 2145 printFilteredAttributesFn(filteredAttrs); 2146 } 2147 2148 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { 2149 // Print the name without quotes if possible. 2150 ::printKeywordOrString(attr.getName().strref(), os); 2151 2152 // Pretty printing elides the attribute value for unit attributes. 2153 if (attr.getValue().isa<UnitAttr>()) 2154 return; 2155 2156 os << " = "; 2157 printAttribute(attr.getValue()); 2158 } 2159 2160 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { 2161 auto &dialect = attr.getDialect(); 2162 2163 // Ask the dialect to serialize the attribute to a string. 2164 std::string attrName; 2165 { 2166 llvm::raw_string_ostream attrNameStr(attrName); 2167 Impl subPrinter(attrNameStr, printerFlags, state); 2168 DialectAsmPrinter printer(subPrinter); 2169 dialect.printAttribute(attr, printer); 2170 } 2171 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 2172 } 2173 2174 void AsmPrinter::Impl::printDialectType(Type type) { 2175 auto &dialect = type.getDialect(); 2176 2177 // Ask the dialect to serialize the type to a string. 2178 std::string typeName; 2179 { 2180 llvm::raw_string_ostream typeNameStr(typeName); 2181 Impl subPrinter(typeNameStr, printerFlags, state); 2182 DialectAsmPrinter printer(subPrinter); 2183 dialect.printType(type, printer); 2184 } 2185 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 2186 } 2187 2188 //===--------------------------------------------------------------------===// 2189 // AsmPrinter 2190 //===--------------------------------------------------------------------===// 2191 2192 AsmPrinter::~AsmPrinter() = default; 2193 2194 raw_ostream &AsmPrinter::getStream() const { 2195 assert(impl && "expected AsmPrinter::getStream to be overriden"); 2196 return impl->getStream(); 2197 } 2198 2199 /// Print the given floating point value in a stablized form. 2200 void AsmPrinter::printFloat(const APFloat &value) { 2201 assert(impl && "expected AsmPrinter::printFloat to be overriden"); 2202 printFloatValue(value, impl->getStream()); 2203 } 2204 2205 void AsmPrinter::printType(Type type) { 2206 assert(impl && "expected AsmPrinter::printType to be overriden"); 2207 impl->printType(type); 2208 } 2209 2210 void AsmPrinter::printAttribute(Attribute attr) { 2211 assert(impl && "expected AsmPrinter::printAttribute to be overriden"); 2212 impl->printAttribute(attr); 2213 } 2214 2215 LogicalResult AsmPrinter::printAlias(Attribute attr) { 2216 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2217 return impl->printAlias(attr); 2218 } 2219 2220 LogicalResult AsmPrinter::printAlias(Type type) { 2221 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2222 return impl->printAlias(type); 2223 } 2224 2225 void AsmPrinter::printAttributeWithoutType(Attribute attr) { 2226 assert(impl && 2227 "expected AsmPrinter::printAttributeWithoutType to be overriden"); 2228 impl->printAttribute(attr, Impl::AttrTypeElision::Must); 2229 } 2230 2231 void AsmPrinter::printKeywordOrString(StringRef keyword) { 2232 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); 2233 ::printKeywordOrString(keyword, impl->getStream()); 2234 } 2235 2236 void AsmPrinter::printSymbolName(StringRef symbolRef) { 2237 assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); 2238 ::printSymbolReference(symbolRef, impl->getStream()); 2239 } 2240 2241 //===----------------------------------------------------------------------===// 2242 // Affine expressions and maps 2243 //===----------------------------------------------------------------------===// 2244 2245 void AsmPrinter::Impl::printAffineExpr( 2246 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { 2247 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 2248 } 2249 2250 void AsmPrinter::Impl::printAffineExprInternal( 2251 AffineExpr expr, BindingStrength enclosingTightness, 2252 function_ref<void(unsigned, bool)> printValueName) { 2253 const char *binopSpelling = nullptr; 2254 switch (expr.getKind()) { 2255 case AffineExprKind::SymbolId: { 2256 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); 2257 if (printValueName) 2258 printValueName(pos, /*isSymbol=*/true); 2259 else 2260 os << 's' << pos; 2261 return; 2262 } 2263 case AffineExprKind::DimId: { 2264 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 2265 if (printValueName) 2266 printValueName(pos, /*isSymbol=*/false); 2267 else 2268 os << 'd' << pos; 2269 return; 2270 } 2271 case AffineExprKind::Constant: 2272 os << expr.cast<AffineConstantExpr>().getValue(); 2273 return; 2274 case AffineExprKind::Add: 2275 binopSpelling = " + "; 2276 break; 2277 case AffineExprKind::Mul: 2278 binopSpelling = " * "; 2279 break; 2280 case AffineExprKind::FloorDiv: 2281 binopSpelling = " floordiv "; 2282 break; 2283 case AffineExprKind::CeilDiv: 2284 binopSpelling = " ceildiv "; 2285 break; 2286 case AffineExprKind::Mod: 2287 binopSpelling = " mod "; 2288 break; 2289 } 2290 2291 auto binOp = expr.cast<AffineBinaryOpExpr>(); 2292 AffineExpr lhsExpr = binOp.getLHS(); 2293 AffineExpr rhsExpr = binOp.getRHS(); 2294 2295 // Handle tightly binding binary operators. 2296 if (binOp.getKind() != AffineExprKind::Add) { 2297 if (enclosingTightness == BindingStrength::Strong) 2298 os << '('; 2299 2300 // Pretty print multiplication with -1. 2301 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); 2302 if (rhsConst && binOp.getKind() == AffineExprKind::Mul && 2303 rhsConst.getValue() == -1) { 2304 os << "-"; 2305 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2306 if (enclosingTightness == BindingStrength::Strong) 2307 os << ')'; 2308 return; 2309 } 2310 2311 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2312 2313 os << binopSpelling; 2314 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 2315 2316 if (enclosingTightness == BindingStrength::Strong) 2317 os << ')'; 2318 return; 2319 } 2320 2321 // Print out special "pretty" forms for add. 2322 if (enclosingTightness == BindingStrength::Strong) 2323 os << '('; 2324 2325 // Pretty print addition to a product that has a negative operand as a 2326 // subtraction. 2327 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { 2328 if (rhs.getKind() == AffineExprKind::Mul) { 2329 AffineExpr rrhsExpr = rhs.getRHS(); 2330 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { 2331 if (rrhs.getValue() == -1) { 2332 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2333 printValueName); 2334 os << " - "; 2335 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 2336 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2337 printValueName); 2338 } else { 2339 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 2340 printValueName); 2341 } 2342 2343 if (enclosingTightness == BindingStrength::Strong) 2344 os << ')'; 2345 return; 2346 } 2347 2348 if (rrhs.getValue() < -1) { 2349 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2350 printValueName); 2351 os << " - "; 2352 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2353 printValueName); 2354 os << " * " << -rrhs.getValue(); 2355 if (enclosingTightness == BindingStrength::Strong) 2356 os << ')'; 2357 return; 2358 } 2359 } 2360 } 2361 } 2362 2363 // Pretty print addition to a negative number as a subtraction. 2364 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { 2365 if (rhsConst.getValue() < 0) { 2366 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2367 os << " - " << -rhsConst.getValue(); 2368 if (enclosingTightness == BindingStrength::Strong) 2369 os << ')'; 2370 return; 2371 } 2372 } 2373 2374 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2375 2376 os << " + "; 2377 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 2378 2379 if (enclosingTightness == BindingStrength::Strong) 2380 os << ')'; 2381 } 2382 2383 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { 2384 printAffineExprInternal(expr, BindingStrength::Weak); 2385 isEq ? os << " == 0" : os << " >= 0"; 2386 } 2387 2388 void AsmPrinter::Impl::printAffineMap(AffineMap map) { 2389 // Dimension identifiers. 2390 os << '('; 2391 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 2392 os << 'd' << i << ", "; 2393 if (map.getNumDims() >= 1) 2394 os << 'd' << map.getNumDims() - 1; 2395 os << ')'; 2396 2397 // Symbolic identifiers. 2398 if (map.getNumSymbols() != 0) { 2399 os << '['; 2400 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 2401 os << 's' << i << ", "; 2402 if (map.getNumSymbols() >= 1) 2403 os << 's' << map.getNumSymbols() - 1; 2404 os << ']'; 2405 } 2406 2407 // Result affine expressions. 2408 os << " -> ("; 2409 interleaveComma(map.getResults(), 2410 [&](AffineExpr expr) { printAffineExpr(expr); }); 2411 os << ')'; 2412 } 2413 2414 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { 2415 // Dimension identifiers. 2416 os << '('; 2417 for (unsigned i = 1; i < set.getNumDims(); ++i) 2418 os << 'd' << i - 1 << ", "; 2419 if (set.getNumDims() >= 1) 2420 os << 'd' << set.getNumDims() - 1; 2421 os << ')'; 2422 2423 // Symbolic identifiers. 2424 if (set.getNumSymbols() != 0) { 2425 os << '['; 2426 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 2427 os << 's' << i << ", "; 2428 if (set.getNumSymbols() >= 1) 2429 os << 's' << set.getNumSymbols() - 1; 2430 os << ']'; 2431 } 2432 2433 // Print constraints. 2434 os << " : ("; 2435 int numConstraints = set.getNumConstraints(); 2436 for (int i = 1; i < numConstraints; ++i) { 2437 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 2438 os << ", "; 2439 } 2440 if (numConstraints >= 1) 2441 printAffineConstraint(set.getConstraint(numConstraints - 1), 2442 set.isEq(numConstraints - 1)); 2443 os << ')'; 2444 } 2445 2446 //===----------------------------------------------------------------------===// 2447 // OperationPrinter 2448 //===----------------------------------------------------------------------===// 2449 2450 namespace { 2451 /// This class contains the logic for printing operations, regions, and blocks. 2452 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { 2453 public: 2454 using Impl = AsmPrinter::Impl; 2455 using Impl::printType; 2456 2457 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) 2458 : Impl(os, state.getPrinterFlags(), &state), 2459 OpAsmPrinter(static_cast<Impl &>(*this)) {} 2460 2461 /// Print the given top-level operation. 2462 void printTopLevelOperation(Operation *op); 2463 2464 /// Print the given operation with its indent and location. 2465 void print(Operation *op); 2466 /// Print the bare location, not including indentation/location/etc. 2467 void printOperation(Operation *op); 2468 /// Print the given operation in the generic form. 2469 void printGenericOp(Operation *op, bool printOpName) override; 2470 2471 /// Print the name of the given block. 2472 void printBlockName(Block *block); 2473 2474 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 2475 /// block are not printed. If 'printBlockTerminator' is false, the terminator 2476 /// operation of the block is not printed. 2477 void print(Block *block, bool printBlockArgs = true, 2478 bool printBlockTerminator = true); 2479 2480 /// Print the ID of the given value, optionally with its result number. 2481 void printValueID(Value value, bool printResultNo = true, 2482 raw_ostream *streamOverride = nullptr) const; 2483 2484 //===--------------------------------------------------------------------===// 2485 // OpAsmPrinter methods 2486 //===--------------------------------------------------------------------===// 2487 2488 /// Print a newline and indent the printer to the start of the current 2489 /// operation. 2490 void printNewline() override { 2491 os << newLine; 2492 os.indent(currentIndent); 2493 } 2494 2495 /// Print a block argument in the usual format of: 2496 /// %ssaName : type {attr1=42} loc("here") 2497 /// where location printing is controlled by the standard internal option. 2498 /// You may pass omitType=true to not print a type, and pass an empty 2499 /// attribute list if you don't care for attributes. 2500 void printRegionArgument(BlockArgument arg, 2501 ArrayRef<NamedAttribute> argAttrs = {}, 2502 bool omitType = false) override; 2503 2504 /// Print the ID for the given value. 2505 void printOperand(Value value) override { printValueID(value); } 2506 void printOperand(Value value, raw_ostream &os) override { 2507 printValueID(value, /*printResultNo=*/true, &os); 2508 } 2509 2510 /// Print an optional attribute dictionary with a given set of elided values. 2511 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2512 ArrayRef<StringRef> elidedAttrs = {}) override { 2513 Impl::printOptionalAttrDict(attrs, elidedAttrs); 2514 } 2515 void printOptionalAttrDictWithKeyword( 2516 ArrayRef<NamedAttribute> attrs, 2517 ArrayRef<StringRef> elidedAttrs = {}) override { 2518 Impl::printOptionalAttrDict(attrs, elidedAttrs, 2519 /*withKeyword=*/true); 2520 } 2521 2522 /// Print the given successor. 2523 void printSuccessor(Block *successor) override; 2524 2525 /// Print an operation successor with the operands used for the block 2526 /// arguments. 2527 void printSuccessorAndUseList(Block *successor, 2528 ValueRange succOperands) override; 2529 2530 /// Print the given region. 2531 void printRegion(Region ®ion, bool printEntryBlockArgs, 2532 bool printBlockTerminators, bool printEmptyBlock) override; 2533 2534 /// Renumber the arguments for the specified region to the same names as the 2535 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 2536 /// operations. If any entry in namesToUse is null, the corresponding 2537 /// argument name is left alone. 2538 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { 2539 state->getSSANameState().shadowRegionArgs(region, namesToUse); 2540 } 2541 2542 /// Print the given affine map with the symbol and dimension operands printed 2543 /// inline with the map. 2544 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2545 ValueRange operands) override; 2546 2547 /// Print the given affine expression with the symbol and dimension operands 2548 /// printed inline with the expression. 2549 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, 2550 ValueRange symOperands) override; 2551 2552 private: 2553 // Contains the stack of default dialects to use when printing regions. 2554 // A new dialect is pushed to the stack before parsing regions nested under an 2555 // operation implementing `OpAsmOpInterface`, and popped when done. At the 2556 // top-level we start with "builtin" as the default, so that the top-level 2557 // `module` operation prints as-is. 2558 SmallVector<StringRef> defaultDialectStack{"builtin"}; 2559 2560 /// The number of spaces used for indenting nested operations. 2561 const static unsigned indentWidth = 2; 2562 2563 // This is the current indentation level for nested structures. 2564 unsigned currentIndent = 0; 2565 }; 2566 } // namespace 2567 2568 void OperationPrinter::printTopLevelOperation(Operation *op) { 2569 // Output the aliases at the top level that can't be deferred. 2570 state->getAliasState().printNonDeferredAliases(os, newLine); 2571 2572 // Print the module. 2573 print(op); 2574 os << newLine; 2575 2576 // Output the aliases at the top level that can be deferred. 2577 state->getAliasState().printDeferredAliases(os, newLine); 2578 } 2579 2580 /// Print a block argument in the usual format of: 2581 /// %ssaName : type {attr1=42} loc("here") 2582 /// where location printing is controlled by the standard internal option. 2583 /// You may pass omitType=true to not print a type, and pass an empty 2584 /// attribute list if you don't care for attributes. 2585 void OperationPrinter::printRegionArgument(BlockArgument arg, 2586 ArrayRef<NamedAttribute> argAttrs, 2587 bool omitType) { 2588 printOperand(arg); 2589 if (!omitType) { 2590 os << ": "; 2591 printType(arg.getType()); 2592 } 2593 printOptionalAttrDict(argAttrs); 2594 // TODO: We should allow location aliases on block arguments. 2595 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2596 } 2597 2598 void OperationPrinter::print(Operation *op) { 2599 // Track the location of this operation. 2600 state->registerOperationLocation(op, newLine.curLine, currentIndent); 2601 2602 os.indent(currentIndent); 2603 printOperation(op); 2604 printTrailingLocation(op->getLoc()); 2605 } 2606 2607 void OperationPrinter::printOperation(Operation *op) { 2608 if (size_t numResults = op->getNumResults()) { 2609 auto printResultGroup = [&](size_t resultNo, size_t resultCount) { 2610 printValueID(op->getResult(resultNo), /*printResultNo=*/false); 2611 if (resultCount > 1) 2612 os << ':' << resultCount; 2613 }; 2614 2615 // Check to see if this operation has multiple result groups. 2616 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); 2617 if (!resultGroups.empty()) { 2618 // Interleave the groups excluding the last one, this one will be handled 2619 // separately. 2620 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { 2621 printResultGroup(resultGroups[i], 2622 resultGroups[i + 1] - resultGroups[i]); 2623 }); 2624 os << ", "; 2625 printResultGroup(resultGroups.back(), numResults - resultGroups.back()); 2626 2627 } else { 2628 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); 2629 } 2630 2631 os << " = "; 2632 } 2633 2634 // If requested, always print the generic form. 2635 if (!printerFlags.shouldPrintGenericOpForm()) { 2636 // Check to see if this is a known operation. If so, use the registered 2637 // custom printer hook. 2638 if (auto opInfo = op->getRegisteredInfo()) { 2639 opInfo->printAssembly(op, *this, defaultDialectStack.back()); 2640 return; 2641 } 2642 // Otherwise try to dispatch to the dialect, if available. 2643 if (Dialect *dialect = op->getDialect()) { 2644 if (auto opPrinter = dialect->getOperationPrinter(op)) { 2645 // Print the op name first. 2646 StringRef name = op->getName().getStringRef(); 2647 name.consume_front((defaultDialectStack.back() + ".").str()); 2648 printEscapedString(name, os); 2649 // Print the rest of the op now. 2650 opPrinter(op, *this); 2651 return; 2652 } 2653 } 2654 } 2655 2656 // Otherwise print with the generic assembly form. 2657 printGenericOp(op, /*printOpName=*/true); 2658 } 2659 2660 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { 2661 if (printOpName) { 2662 os << '"'; 2663 printEscapedString(op->getName().getStringRef(), os); 2664 os << '"'; 2665 } 2666 os << '('; 2667 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); 2668 os << ')'; 2669 2670 // For terminators, print the list of successors and their operands. 2671 if (op->getNumSuccessors() != 0) { 2672 os << '['; 2673 interleaveComma(op->getSuccessors(), 2674 [&](Block *successor) { printBlockName(successor); }); 2675 os << ']'; 2676 } 2677 2678 // Print regions. 2679 if (op->getNumRegions() != 0) { 2680 os << " ("; 2681 interleaveComma(op->getRegions(), [&](Region ®ion) { 2682 printRegion(region, /*printEntryBlockArgs=*/true, 2683 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); 2684 }); 2685 os << ')'; 2686 } 2687 2688 auto attrs = op->getAttrs(); 2689 printOptionalAttrDict(attrs); 2690 2691 // Print the type signature of the operation. 2692 os << " : "; 2693 printFunctionalType(op); 2694 } 2695 2696 void OperationPrinter::printBlockName(Block *block) { 2697 os << state->getSSANameState().getBlockInfo(block).name; 2698 } 2699 2700 void OperationPrinter::print(Block *block, bool printBlockArgs, 2701 bool printBlockTerminator) { 2702 // Print the block label and argument list if requested. 2703 if (printBlockArgs) { 2704 os.indent(currentIndent); 2705 printBlockName(block); 2706 2707 // Print the argument list if non-empty. 2708 if (!block->args_empty()) { 2709 os << '('; 2710 interleaveComma(block->getArguments(), [&](BlockArgument arg) { 2711 printValueID(arg); 2712 os << ": "; 2713 printType(arg.getType()); 2714 // TODO: We should allow location aliases on block arguments. 2715 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2716 }); 2717 os << ')'; 2718 } 2719 os << ':'; 2720 2721 // Print out some context information about the predecessors of this block. 2722 if (!block->getParent()) { 2723 os << " // block is not in a region!"; 2724 } else if (block->hasNoPredecessors()) { 2725 if (!block->isEntryBlock()) 2726 os << " // no predecessors"; 2727 } else if (auto *pred = block->getSinglePredecessor()) { 2728 os << " // pred: "; 2729 printBlockName(pred); 2730 } else { 2731 // We want to print the predecessors in a stable order, not in 2732 // whatever order the use-list is in, so gather and sort them. 2733 SmallVector<BlockInfo, 4> predIDs; 2734 for (auto *pred : block->getPredecessors()) 2735 predIDs.push_back(state->getSSANameState().getBlockInfo(pred)); 2736 llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { 2737 return lhs.ordering < rhs.ordering; 2738 }); 2739 2740 os << " // " << predIDs.size() << " preds: "; 2741 2742 interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; }); 2743 } 2744 os << newLine; 2745 } 2746 2747 currentIndent += indentWidth; 2748 bool hasTerminator = 2749 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 2750 auto range = llvm::make_range( 2751 block->begin(), 2752 std::prev(block->end(), 2753 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 2754 for (auto &op : range) { 2755 print(&op); 2756 os << newLine; 2757 } 2758 currentIndent -= indentWidth; 2759 } 2760 2761 void OperationPrinter::printValueID(Value value, bool printResultNo, 2762 raw_ostream *streamOverride) const { 2763 state->getSSANameState().printValueID(value, printResultNo, 2764 streamOverride ? *streamOverride : os); 2765 } 2766 2767 void OperationPrinter::printSuccessor(Block *successor) { 2768 printBlockName(successor); 2769 } 2770 2771 void OperationPrinter::printSuccessorAndUseList(Block *successor, 2772 ValueRange succOperands) { 2773 printBlockName(successor); 2774 if (succOperands.empty()) 2775 return; 2776 2777 os << '('; 2778 interleaveComma(succOperands, 2779 [this](Value operand) { printValueID(operand); }); 2780 os << " : "; 2781 interleaveComma(succOperands, 2782 [this](Value operand) { printType(operand.getType()); }); 2783 os << ')'; 2784 } 2785 2786 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, 2787 bool printBlockTerminators, 2788 bool printEmptyBlock) { 2789 os << "{" << newLine; 2790 if (!region.empty()) { 2791 auto restoreDefaultDialect = 2792 llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); }); 2793 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) 2794 defaultDialectStack.push_back(iface.getDefaultDialect()); 2795 else 2796 defaultDialectStack.push_back(""); 2797 2798 auto *entryBlock = ®ion.front(); 2799 // Force printing the block header if printEmptyBlock is set and the block 2800 // is empty or if printEntryBlockArgs is set and there are arguments to 2801 // print. 2802 bool shouldAlwaysPrintBlockHeader = 2803 (printEmptyBlock && entryBlock->empty()) || 2804 (printEntryBlockArgs && entryBlock->getNumArguments() != 0); 2805 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); 2806 for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) 2807 print(&b); 2808 } 2809 os.indent(currentIndent) << "}"; 2810 } 2811 2812 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2813 ValueRange operands) { 2814 AffineMap map = mapAttr.getValue(); 2815 unsigned numDims = map.getNumDims(); 2816 auto printValueName = [&](unsigned pos, bool isSymbol) { 2817 unsigned index = isSymbol ? numDims + pos : pos; 2818 assert(index < operands.size()); 2819 if (isSymbol) 2820 os << "symbol("; 2821 printValueID(operands[index]); 2822 if (isSymbol) 2823 os << ')'; 2824 }; 2825 2826 interleaveComma(map.getResults(), [&](AffineExpr expr) { 2827 printAffineExpr(expr, printValueName); 2828 }); 2829 } 2830 2831 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, 2832 ValueRange dimOperands, 2833 ValueRange symOperands) { 2834 auto printValueName = [&](unsigned pos, bool isSymbol) { 2835 if (!isSymbol) 2836 return printValueID(dimOperands[pos]); 2837 os << "symbol("; 2838 printValueID(symOperands[pos]); 2839 os << ')'; 2840 }; 2841 printAffineExpr(expr, printValueName); 2842 } 2843 2844 //===----------------------------------------------------------------------===// 2845 // print and dump methods 2846 //===----------------------------------------------------------------------===// 2847 2848 void Attribute::print(raw_ostream &os) const { 2849 AsmPrinter::Impl(os).printAttribute(*this); 2850 } 2851 2852 void Attribute::dump() const { 2853 print(llvm::errs()); 2854 llvm::errs() << "\n"; 2855 } 2856 2857 void Type::print(raw_ostream &os) const { 2858 AsmPrinter::Impl(os).printType(*this); 2859 } 2860 2861 void Type::dump() const { print(llvm::errs()); } 2862 2863 void AffineMap::dump() const { 2864 print(llvm::errs()); 2865 llvm::errs() << "\n"; 2866 } 2867 2868 void IntegerSet::dump() const { 2869 print(llvm::errs()); 2870 llvm::errs() << "\n"; 2871 } 2872 2873 void AffineExpr::print(raw_ostream &os) const { 2874 if (!expr) { 2875 os << "<<NULL AFFINE EXPR>>"; 2876 return; 2877 } 2878 AsmPrinter::Impl(os).printAffineExpr(*this); 2879 } 2880 2881 void AffineExpr::dump() const { 2882 print(llvm::errs()); 2883 llvm::errs() << "\n"; 2884 } 2885 2886 void AffineMap::print(raw_ostream &os) const { 2887 if (!map) { 2888 os << "<<NULL AFFINE MAP>>"; 2889 return; 2890 } 2891 AsmPrinter::Impl(os).printAffineMap(*this); 2892 } 2893 2894 void IntegerSet::print(raw_ostream &os) const { 2895 AsmPrinter::Impl(os).printIntegerSet(*this); 2896 } 2897 2898 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); } 2899 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) { 2900 if (!impl) { 2901 os << "<<NULL VALUE>>"; 2902 return; 2903 } 2904 2905 if (auto *op = getDefiningOp()) 2906 return op->print(os, flags); 2907 // TODO: Improve BlockArgument print'ing. 2908 BlockArgument arg = this->cast<BlockArgument>(); 2909 os << "<block argument> of type '" << arg.getType() 2910 << "' at index: " << arg.getArgNumber(); 2911 } 2912 void Value::print(raw_ostream &os, AsmState &state) { 2913 if (!impl) { 2914 os << "<<NULL VALUE>>"; 2915 return; 2916 } 2917 2918 if (auto *op = getDefiningOp()) 2919 return op->print(os, state); 2920 2921 // TODO: Improve BlockArgument print'ing. 2922 BlockArgument arg = this->cast<BlockArgument>(); 2923 os << "<block argument> of type '" << arg.getType() 2924 << "' at index: " << arg.getArgNumber(); 2925 } 2926 2927 void Value::dump() { 2928 print(llvm::errs()); 2929 llvm::errs() << "\n"; 2930 } 2931 2932 void Value::printAsOperand(raw_ostream &os, AsmState &state) { 2933 // TODO: This doesn't necessarily capture all potential cases. 2934 // Currently, region arguments can be shadowed when printing the main 2935 // operation. If the IR hasn't been printed, this will produce the old SSA 2936 // name and not the shadowed name. 2937 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, 2938 os); 2939 } 2940 2941 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { 2942 // If this is a top level operation, we also print aliases. 2943 if (!getParent() && !printerFlags.shouldUseLocalScope()) { 2944 AsmState state(this, printerFlags); 2945 state.getImpl().initializeAliases(this); 2946 print(os, state); 2947 return; 2948 } 2949 2950 // Find the operation to number from based upon the provided flags. 2951 Operation *op = this; 2952 bool shouldUseLocalScope = printerFlags.shouldUseLocalScope(); 2953 do { 2954 // If we are printing local scope, stop at the first operation that is 2955 // isolated from above. 2956 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 2957 break; 2958 2959 // Otherwise, traverse up to the next parent. 2960 Operation *parentOp = op->getParentOp(); 2961 if (!parentOp) 2962 break; 2963 op = parentOp; 2964 } while (true); 2965 2966 AsmState state(op, printerFlags); 2967 print(os, state); 2968 } 2969 void Operation::print(raw_ostream &os, AsmState &state) { 2970 OperationPrinter printer(os, state.getImpl()); 2971 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) 2972 printer.printTopLevelOperation(this); 2973 else 2974 printer.print(this); 2975 } 2976 2977 void Operation::dump() { 2978 print(llvm::errs(), OpPrintingFlags().useLocalScope()); 2979 llvm::errs() << "\n"; 2980 } 2981 2982 void Block::print(raw_ostream &os) { 2983 Operation *parentOp = getParentOp(); 2984 if (!parentOp) { 2985 os << "<<UNLINKED BLOCK>>\n"; 2986 return; 2987 } 2988 // Get the top-level op. 2989 while (auto *nextOp = parentOp->getParentOp()) 2990 parentOp = nextOp; 2991 2992 AsmState state(parentOp); 2993 print(os, state); 2994 } 2995 void Block::print(raw_ostream &os, AsmState &state) { 2996 OperationPrinter(os, state.getImpl()).print(this); 2997 } 2998 2999 void Block::dump() { print(llvm::errs()); } 3000 3001 /// Print out the name of the block without printing its body. 3002 void Block::printAsOperand(raw_ostream &os, bool printType) { 3003 Operation *parentOp = getParentOp(); 3004 if (!parentOp) { 3005 os << "<<UNLINKED BLOCK>>\n"; 3006 return; 3007 } 3008 AsmState state(parentOp); 3009 printAsOperand(os, state); 3010 } 3011 void Block::printAsOperand(raw_ostream &os, AsmState &state) { 3012 OperationPrinter printer(os, state.getImpl()); 3013 printer.printBlockName(this); 3014 } 3015