1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the MLIR AsmPrinter class, which is used to implement 10 // the various print() methods on the core IR objects. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/AffineExpr.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/AsmState.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/IntegerSet.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/Operation.h" 27 #include "mlir/IR/SubElementInterfaces.h" 28 #include "mlir/IR/Verifier.h" 29 #include "llvm/ADT/APFloat.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/MapVector.h" 32 #include "llvm/ADT/STLExtras.h" 33 #include "llvm/ADT/ScopeExit.h" 34 #include "llvm/ADT/ScopedHashTable.h" 35 #include "llvm/ADT/SetVector.h" 36 #include "llvm/ADT/SmallString.h" 37 #include "llvm/ADT/StringExtras.h" 38 #include "llvm/ADT/StringSet.h" 39 #include "llvm/ADT/TypeSwitch.h" 40 #include "llvm/Support/CommandLine.h" 41 #include "llvm/Support/Debug.h" 42 #include "llvm/Support/Endian.h" 43 #include "llvm/Support/Regex.h" 44 #include "llvm/Support/SaveAndRestore.h" 45 #include "llvm/Support/Threading.h" 46 47 #include <tuple> 48 49 using namespace mlir; 50 using namespace mlir::detail; 51 52 #define DEBUG_TYPE "mlir-asm-printer" 53 54 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 55 56 void OperationName::dump() const { print(llvm::errs()); } 57 58 //===--------------------------------------------------------------------===// 59 // AsmParser 60 //===--------------------------------------------------------------------===// 61 62 AsmParser::~AsmParser() = default; 63 DialectAsmParser::~DialectAsmParser() = default; 64 OpAsmParser::~OpAsmParser() = default; 65 66 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } 67 68 //===----------------------------------------------------------------------===// 69 // DialectAsmPrinter 70 //===----------------------------------------------------------------------===// 71 72 DialectAsmPrinter::~DialectAsmPrinter() = default; 73 74 //===----------------------------------------------------------------------===// 75 // OpAsmPrinter 76 //===----------------------------------------------------------------------===// 77 78 OpAsmPrinter::~OpAsmPrinter() = default; 79 80 void OpAsmPrinter::printFunctionalType(Operation *op) { 81 auto &os = getStream(); 82 os << '('; 83 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { 84 // Print the types of null values as <<NULL TYPE>>. 85 *this << (operand ? operand.getType() : Type()); 86 }); 87 os << ") -> "; 88 89 // Print the result list. We don't parenthesize single result types unless 90 // it is a function (avoiding a grammar ambiguity). 91 bool wrapped = op->getNumResults() != 1; 92 if (!wrapped && op->getResult(0).getType() && 93 op->getResult(0).getType().isa<FunctionType>()) 94 wrapped = true; 95 96 if (wrapped) 97 os << '('; 98 99 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { 100 // Print the types of null values as <<NULL TYPE>>. 101 *this << (result ? result.getType() : Type()); 102 }); 103 104 if (wrapped) 105 os << ')'; 106 } 107 108 //===----------------------------------------------------------------------===// 109 // Operation OpAsm interface. 110 //===----------------------------------------------------------------------===// 111 112 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 113 #include "mlir/IR/OpAsmInterface.cpp.inc" 114 115 LogicalResult 116 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { 117 return entry.emitError() << "unknown 'resource' key '" << entry.getKey() 118 << "' for dialect '" << getDialect()->getNamespace() 119 << "'"; 120 } 121 122 //===----------------------------------------------------------------------===// 123 // OpPrintingFlags 124 //===----------------------------------------------------------------------===// 125 126 namespace { 127 /// This struct contains command line options that can be used to initialize 128 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need 129 /// for global command line options. 130 struct AsmPrinterOptions { 131 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ 132 "mlir-print-elementsattrs-with-hex-if-larger", 133 llvm::cl::desc( 134 "Print DenseElementsAttrs with a hex string that have " 135 "more elements than the given upper limit (use -1 to disable)")}; 136 137 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ 138 "mlir-elide-elementsattrs-if-larger", 139 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " 140 "more elements than the given upper limit")}; 141 142 llvm::cl::opt<bool> printDebugInfoOpt{ 143 "mlir-print-debuginfo", llvm::cl::init(false), 144 llvm::cl::desc("Print debug info in MLIR output")}; 145 146 llvm::cl::opt<bool> printPrettyDebugInfoOpt{ 147 "mlir-pretty-debuginfo", llvm::cl::init(false), 148 llvm::cl::desc("Print pretty debug info in MLIR output")}; 149 150 // Use the generic op output form in the operation printer even if the custom 151 // form is defined. 152 llvm::cl::opt<bool> printGenericOpFormOpt{ 153 "mlir-print-op-generic", llvm::cl::init(false), 154 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; 155 156 llvm::cl::opt<bool> assumeVerifiedOpt{ 157 "mlir-print-assume-verified", llvm::cl::init(false), 158 llvm::cl::desc("Skip op verification when using custom printers"), 159 llvm::cl::Hidden}; 160 161 llvm::cl::opt<bool> printLocalScopeOpt{ 162 "mlir-print-local-scope", llvm::cl::init(false), 163 llvm::cl::desc("Print with local scope and inline information (eliding " 164 "aliases for attributes, types, and locations")}; 165 166 llvm::cl::opt<bool> printValueUsers{ 167 "mlir-print-value-users", llvm::cl::init(false), 168 llvm::cl::desc( 169 "Print users of operation results and block arguments as a comment")}; 170 }; 171 } // namespace 172 173 static llvm::ManagedStatic<AsmPrinterOptions> clOptions; 174 175 /// Register a set of useful command-line options that can be used to configure 176 /// various flags within the AsmPrinter. 177 void mlir::registerAsmPrinterCLOptions() { 178 // Make sure that the options struct has been initialized. 179 *clOptions; 180 } 181 182 /// Initialize the printing flags with default supplied by the cl::opts above. 183 OpPrintingFlags::OpPrintingFlags() 184 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), 185 printGenericOpFormFlag(false), assumeVerifiedFlag(false), 186 printLocalScope(false), printValueUsersFlag(false) { 187 // Initialize based upon command line options, if they are available. 188 if (!clOptions.isConstructed()) 189 return; 190 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) 191 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; 192 printDebugInfoFlag = clOptions->printDebugInfoOpt; 193 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; 194 printGenericOpFormFlag = clOptions->printGenericOpFormOpt; 195 assumeVerifiedFlag = clOptions->assumeVerifiedOpt; 196 printLocalScope = clOptions->printLocalScopeOpt; 197 printValueUsersFlag = clOptions->printValueUsers; 198 } 199 200 /// Enable the elision of large elements attributes, by printing a '...' 201 /// instead of the element data, when the number of elements is greater than 202 /// `largeElementLimit`. Note: The IR generated with this option is not 203 /// parsable. 204 OpPrintingFlags & 205 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { 206 elementsAttrElementLimit = largeElementLimit; 207 return *this; 208 } 209 210 /// Enable printing of debug information. If 'prettyForm' is set to true, 211 /// debug information is printed in a more readable 'pretty' form. 212 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { 213 printDebugInfoFlag = true; 214 printDebugInfoPrettyFormFlag = prettyForm; 215 return *this; 216 } 217 218 /// Always print operations in the generic form. 219 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { 220 printGenericOpFormFlag = true; 221 return *this; 222 } 223 224 /// Do not verify the operation when using custom operation printers. 225 OpPrintingFlags &OpPrintingFlags::assumeVerified() { 226 assumeVerifiedFlag = true; 227 return *this; 228 } 229 230 /// Use local scope when printing the operation. This allows for using the 231 /// printer in a more localized and thread-safe setting, but may not necessarily 232 /// be identical of what the IR will look like when dumping the full module. 233 OpPrintingFlags &OpPrintingFlags::useLocalScope() { 234 printLocalScope = true; 235 return *this; 236 } 237 238 /// Print users of values as comments. 239 OpPrintingFlags &OpPrintingFlags::printValueUsers() { 240 printValueUsersFlag = true; 241 return *this; 242 } 243 244 /// Return if the given ElementsAttr should be elided. 245 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { 246 return elementsAttrElementLimit && 247 *elementsAttrElementLimit < int64_t(attr.getNumElements()) && 248 !attr.isa<SplatElementsAttr>(); 249 } 250 251 /// Return the size limit for printing large ElementsAttr. 252 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { 253 return elementsAttrElementLimit; 254 } 255 256 /// Return if debug information should be printed. 257 bool OpPrintingFlags::shouldPrintDebugInfo() const { 258 return printDebugInfoFlag; 259 } 260 261 /// Return if debug information should be printed in the pretty form. 262 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { 263 return printDebugInfoPrettyFormFlag; 264 } 265 266 /// Return if operations should be printed in the generic form. 267 bool OpPrintingFlags::shouldPrintGenericOpForm() const { 268 return printGenericOpFormFlag; 269 } 270 271 /// Return if operation verification should be skipped. 272 bool OpPrintingFlags::shouldAssumeVerified() const { 273 return assumeVerifiedFlag; 274 } 275 276 /// Return if the printer should use local scope when dumping the IR. 277 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } 278 279 /// Return if the printer should print users of values. 280 bool OpPrintingFlags::shouldPrintValueUsers() const { 281 return printValueUsersFlag; 282 } 283 284 /// Returns true if an ElementsAttr with the given number of elements should be 285 /// printed with hex. 286 static bool shouldPrintElementsAttrWithHex(int64_t numElements) { 287 // Check to see if a command line option was provided for the limit. 288 if (clOptions.isConstructed()) { 289 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) { 290 // -1 is used to disable hex printing. 291 if (clOptions->printElementsAttrWithHexIfLarger == -1) 292 return false; 293 return numElements > clOptions->printElementsAttrWithHexIfLarger; 294 } 295 } 296 297 // Otherwise, default to printing with hex if the number of elements is >100. 298 return numElements > 100; 299 } 300 301 //===----------------------------------------------------------------------===// 302 // NewLineCounter 303 //===----------------------------------------------------------------------===// 304 305 namespace { 306 /// This class is a simple formatter that emits a new line when inputted into a 307 /// stream, that enables counting the number of newlines emitted. This class 308 /// should be used whenever emitting newlines in the printer. 309 struct NewLineCounter { 310 unsigned curLine = 1; 311 }; 312 313 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { 314 ++newLine.curLine; 315 return os << '\n'; 316 } 317 } // namespace 318 319 //===----------------------------------------------------------------------===// 320 // AliasInitializer 321 //===----------------------------------------------------------------------===// 322 323 namespace { 324 /// This class represents a specific instance of a symbol Alias. 325 class SymbolAlias { 326 public: 327 SymbolAlias(StringRef name, bool isDeferrable) 328 : name(name), suffixIndex(0), hasSuffixIndex(false), 329 isDeferrable(isDeferrable) {} 330 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) 331 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true), 332 isDeferrable(isDeferrable) {} 333 334 /// Print this alias to the given stream. 335 void print(raw_ostream &os) const { 336 os << name; 337 if (hasSuffixIndex) 338 os << suffixIndex; 339 } 340 341 /// Returns true if this alias supports deferred resolution when parsing. 342 bool canBeDeferred() const { return isDeferrable; } 343 344 private: 345 /// The main name of the alias. 346 StringRef name; 347 /// The optional suffix index of the alias, if multiple aliases had the same 348 /// name. 349 uint32_t suffixIndex : 30; 350 /// A flag indicating whether this alias has a suffix or not. 351 bool hasSuffixIndex : 1; 352 /// A flag indicating whether this alias may be deferred or not. 353 bool isDeferrable : 1; 354 }; 355 356 /// This class represents a utility that initializes the set of attribute and 357 /// type aliases, without the need to store the extra information within the 358 /// main AliasState class or pass it around via function arguments. 359 class AliasInitializer { 360 public: 361 AliasInitializer( 362 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, 363 llvm::BumpPtrAllocator &aliasAllocator) 364 : interfaces(interfaces), aliasAllocator(aliasAllocator), 365 aliasOS(aliasBuffer) {} 366 367 void initialize(Operation *op, const OpPrintingFlags &printerFlags, 368 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 369 llvm::MapVector<Type, SymbolAlias> &typeToAlias); 370 371 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is 372 /// set to true if the originator of this attribute can resolve the alias 373 /// after parsing has completed (e.g. in the case of operation locations). 374 void visit(Attribute attr, bool canBeDeferred = false); 375 376 /// Visit the given type to see if it has an alias. 377 void visit(Type type); 378 379 private: 380 /// Try to generate an alias for the provided symbol. If an alias is 381 /// generated, the provided alias mapping and reverse mapping are updated. 382 /// Returns success if an alias was generated, failure otherwise. 383 template <typename T> 384 LogicalResult 385 generateAlias(T symbol, 386 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol); 387 388 /// The set of asm interfaces within the context. 389 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 390 391 /// Mapping between an alias and the set of symbols mapped to it. 392 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr; 393 llvm::MapVector<StringRef, std::vector<Type>> aliasToType; 394 395 /// An allocator used for alias names. 396 llvm::BumpPtrAllocator &aliasAllocator; 397 398 /// The set of visited attributes. 399 DenseSet<Attribute> visitedAttributes; 400 401 /// The set of attributes that have aliases *and* can be deferred. 402 DenseSet<Attribute> deferrableAttributes; 403 404 /// The set of visited types. 405 DenseSet<Type> visitedTypes; 406 407 /// Storage and stream used when generating an alias. 408 SmallString<32> aliasBuffer; 409 llvm::raw_svector_ostream aliasOS; 410 }; 411 412 /// This class implements a dummy OpAsmPrinter that doesn't print any output, 413 /// and merely collects the attributes and types that *would* be printed in a 414 /// normal print invocation so that we can generate proper aliases. This allows 415 /// for us to generate aliases only for the attributes and types that would be 416 /// in the output, and trims down unnecessary output. 417 class DummyAliasOperationPrinter : private OpAsmPrinter { 418 public: 419 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, 420 AliasInitializer &initializer) 421 : printerFlags(printerFlags), initializer(initializer) {} 422 423 /// Print the given operation. 424 void print(Operation *op) { 425 // Visit the operation location. 426 if (printerFlags.shouldPrintDebugInfo()) 427 initializer.visit(op->getLoc(), /*canBeDeferred=*/true); 428 429 // If requested, always print the generic form. 430 if (!printerFlags.shouldPrintGenericOpForm()) { 431 // Check to see if this is a known operation. If so, use the registered 432 // custom printer hook. 433 if (auto opInfo = op->getRegisteredInfo()) { 434 opInfo->printAssembly(op, *this, /*defaultDialect=*/""); 435 return; 436 } 437 } 438 439 // Otherwise print with the generic assembly form. 440 printGenericOp(op); 441 } 442 443 private: 444 /// Print the given operation in the generic form. 445 void printGenericOp(Operation *op, bool printOpName = true) override { 446 // Consider nested operations for aliases. 447 if (op->getNumRegions() != 0) { 448 for (Region ®ion : op->getRegions()) 449 printRegion(region, /*printEntryBlockArgs=*/true, 450 /*printBlockTerminators=*/true); 451 } 452 453 // Visit all the types used in the operation. 454 for (Type type : op->getOperandTypes()) 455 printType(type); 456 for (Type type : op->getResultTypes()) 457 printType(type); 458 459 // Consider the attributes of the operation for aliases. 460 for (const NamedAttribute &attr : op->getAttrs()) 461 printAttribute(attr.getValue()); 462 } 463 464 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 465 /// block are not printed. If 'printBlockTerminator' is false, the terminator 466 /// operation of the block is not printed. 467 void print(Block *block, bool printBlockArgs = true, 468 bool printBlockTerminator = true) { 469 // Consider the types of the block arguments for aliases if 'printBlockArgs' 470 // is set to true. 471 if (printBlockArgs) { 472 for (BlockArgument arg : block->getArguments()) { 473 printType(arg.getType()); 474 475 // Visit the argument location. 476 if (printerFlags.shouldPrintDebugInfo()) 477 // TODO: Allow deferring argument locations. 478 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 479 } 480 } 481 482 // Consider the operations within this block, ignoring the terminator if 483 // requested. 484 bool hasTerminator = 485 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 486 auto range = llvm::make_range( 487 block->begin(), 488 std::prev(block->end(), 489 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 490 for (Operation &op : range) 491 print(&op); 492 } 493 494 /// Print the given region. 495 void printRegion(Region ®ion, bool printEntryBlockArgs, 496 bool printBlockTerminators, 497 bool printEmptyBlock = false) override { 498 if (region.empty()) 499 return; 500 501 auto *entryBlock = ®ion.front(); 502 print(entryBlock, printEntryBlockArgs, printBlockTerminators); 503 for (Block &b : llvm::drop_begin(region, 1)) 504 print(&b); 505 } 506 507 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, 508 bool omitType) override { 509 printType(arg.getType()); 510 // Visit the argument location. 511 if (printerFlags.shouldPrintDebugInfo()) 512 // TODO: Allow deferring argument locations. 513 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 514 } 515 516 /// Consider the given type to be printed for an alias. 517 void printType(Type type) override { initializer.visit(type); } 518 519 /// Consider the given attribute to be printed for an alias. 520 void printAttribute(Attribute attr) override { initializer.visit(attr); } 521 void printAttributeWithoutType(Attribute attr) override { 522 printAttribute(attr); 523 } 524 LogicalResult printAlias(Attribute attr) override { 525 initializer.visit(attr); 526 return success(); 527 } 528 LogicalResult printAlias(Type type) override { 529 initializer.visit(type); 530 return success(); 531 } 532 533 /// Print the given set of attributes with names not included within 534 /// 'elidedAttrs'. 535 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 536 ArrayRef<StringRef> elidedAttrs = {}) override { 537 if (attrs.empty()) 538 return; 539 if (elidedAttrs.empty()) { 540 for (const NamedAttribute &attr : attrs) 541 printAttribute(attr.getValue()); 542 return; 543 } 544 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 545 elidedAttrs.end()); 546 for (const NamedAttribute &attr : attrs) 547 if (!elidedAttrsSet.contains(attr.getName().strref())) 548 printAttribute(attr.getValue()); 549 } 550 void printOptionalAttrDictWithKeyword( 551 ArrayRef<NamedAttribute> attrs, 552 ArrayRef<StringRef> elidedAttrs = {}) override { 553 printOptionalAttrDict(attrs, elidedAttrs); 554 } 555 556 /// Return a null stream as the output stream, this will ignore any data fed 557 /// to it. 558 raw_ostream &getStream() const override { return os; } 559 560 /// The following are hooks of `OpAsmPrinter` that are not necessary for 561 /// determining potential aliases. 562 void printFloat(const APFloat &value) override {} 563 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} 564 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} 565 void printNewline() override {} 566 void printOperand(Value) override {} 567 void printOperand(Value, raw_ostream &os) override { 568 // Users expect the output string to have at least the prefixed % to signal 569 // a value name. To maintain this invariant, emit a name even if it is 570 // guaranteed to go unused. 571 os << "%"; 572 } 573 void printKeywordOrString(StringRef) override {} 574 void printSymbolName(StringRef) override {} 575 void printSuccessor(Block *) override {} 576 void printSuccessorAndUseList(Block *, ValueRange) override {} 577 void shadowRegionArgs(Region &, ValueRange) override {} 578 579 /// The printer flags to use when determining potential aliases. 580 const OpPrintingFlags &printerFlags; 581 582 /// The initializer to use when identifying aliases. 583 AliasInitializer &initializer; 584 585 /// A dummy output stream. 586 mutable llvm::raw_null_ostream os; 587 }; 588 } // namespace 589 590 /// Sanitize the given name such that it can be used as a valid identifier. If 591 /// the string needs to be modified in any way, the provided buffer is used to 592 /// store the new copy, 593 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, 594 StringRef allowedPunctChars = "$._-", 595 bool allowTrailingDigit = true) { 596 assert(!name.empty() && "Shouldn't have an empty name here"); 597 598 auto copyNameToBuffer = [&] { 599 for (char ch : name) { 600 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) 601 buffer.push_back(ch); 602 else if (ch == ' ') 603 buffer.push_back('_'); 604 else 605 buffer.append(llvm::utohexstr((unsigned char)ch)); 606 } 607 }; 608 609 // Check to see if this name is valid. If it starts with a digit, then it 610 // could conflict with the autogenerated numeric ID's, so add an underscore 611 // prefix to avoid problems. 612 if (isdigit(name[0])) { 613 buffer.push_back('_'); 614 copyNameToBuffer(); 615 return buffer; 616 } 617 618 // If the name ends with a trailing digit, add a '_' to avoid potential 619 // conflicts with autogenerated ID's. 620 if (!allowTrailingDigit && isdigit(name.back())) { 621 copyNameToBuffer(); 622 buffer.push_back('_'); 623 return buffer; 624 } 625 626 // Check to see that the name consists of only valid identifier characters. 627 for (char ch : name) { 628 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { 629 copyNameToBuffer(); 630 return buffer; 631 } 632 } 633 634 // If there are no invalid characters, return the original name. 635 return name; 636 } 637 638 /// Given a collection of aliases and symbols, initialize a mapping from a 639 /// symbol to a given alias. 640 template <typename T> 641 static void 642 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol, 643 llvm::MapVector<T, SymbolAlias> &symbolToAlias, 644 DenseSet<T> *deferrableAliases = nullptr) { 645 std::vector<std::pair<StringRef, std::vector<T>>> aliases = 646 aliasToSymbol.takeVector(); 647 llvm::array_pod_sort(aliases.begin(), aliases.end(), 648 [](const auto *lhs, const auto *rhs) { 649 return lhs->first.compare(rhs->first); 650 }); 651 652 for (auto &it : aliases) { 653 // If there is only one instance for this alias, use the name directly. 654 if (it.second.size() == 1) { 655 T symbol = it.second.front(); 656 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 657 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)}); 658 continue; 659 } 660 // Otherwise, add the index to the name. 661 for (int i = 0, e = it.second.size(); i < e; ++i) { 662 T symbol = it.second[i]; 663 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 664 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)}); 665 } 666 } 667 } 668 669 void AliasInitializer::initialize( 670 Operation *op, const OpPrintingFlags &printerFlags, 671 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 672 llvm::MapVector<Type, SymbolAlias> &typeToAlias) { 673 // Use a dummy printer when walking the IR so that we can collect the 674 // attributes/types that will actually be used during printing when 675 // considering aliases. 676 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); 677 aliasPrinter.print(op); 678 679 // Initialize the aliases sorted by name. 680 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); 681 initializeAliases(aliasToType, typeToAlias); 682 } 683 684 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { 685 if (!visitedAttributes.insert(attr).second) { 686 // If this attribute already has an alias and this instance can't be 687 // deferred, make sure that the alias isn't deferred. 688 if (!canBeDeferred) 689 deferrableAttributes.erase(attr); 690 return; 691 } 692 693 // Try to generate an alias for this attribute. 694 if (succeeded(generateAlias(attr, aliasToAttr))) { 695 if (canBeDeferred) 696 deferrableAttributes.insert(attr); 697 return; 698 } 699 700 // Check for any sub elements. 701 if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) { 702 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 703 [&](Type type) { visit(type); }); 704 } 705 } 706 707 void AliasInitializer::visit(Type type) { 708 if (!visitedTypes.insert(type).second) 709 return; 710 711 // Try to generate an alias for this type. 712 if (succeeded(generateAlias(type, aliasToType))) 713 return; 714 715 // Check for any sub elements. 716 if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) { 717 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 718 [&](Type type) { visit(type); }); 719 } 720 } 721 722 template <typename T> 723 LogicalResult AliasInitializer::generateAlias( 724 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) { 725 SmallString<32> nameBuffer; 726 for (const auto &interface : interfaces) { 727 OpAsmDialectInterface::AliasResult result = 728 interface.getAlias(symbol, aliasOS); 729 if (result == OpAsmDialectInterface::AliasResult::NoAlias) 730 continue; 731 nameBuffer = std::move(aliasBuffer); 732 assert(!nameBuffer.empty() && "expected valid alias name"); 733 if (result == OpAsmDialectInterface::AliasResult::FinalAlias) 734 break; 735 } 736 737 if (nameBuffer.empty()) 738 return failure(); 739 740 SmallString<16> tempBuffer; 741 StringRef name = 742 sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", 743 /*allowTrailingDigit=*/false); 744 name = name.copy(aliasAllocator); 745 aliasToSymbol[name].push_back(symbol); 746 return success(); 747 } 748 749 //===----------------------------------------------------------------------===// 750 // AliasState 751 //===----------------------------------------------------------------------===// 752 753 namespace { 754 /// This class manages the state for type and attribute aliases. 755 class AliasState { 756 public: 757 // Initialize the internal aliases. 758 void 759 initialize(Operation *op, const OpPrintingFlags &printerFlags, 760 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 761 762 /// Get an alias for the given attribute if it has one and print it in `os`. 763 /// Returns success if an alias was printed, failure otherwise. 764 LogicalResult getAlias(Attribute attr, raw_ostream &os) const; 765 766 /// Get an alias for the given type if it has one and print it in `os`. 767 /// Returns success if an alias was printed, failure otherwise. 768 LogicalResult getAlias(Type ty, raw_ostream &os) const; 769 770 /// Print all of the referenced aliases that can not be resolved in a deferred 771 /// manner. 772 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 773 printAliases(os, newLine, /*isDeferred=*/false); 774 } 775 776 /// Print all of the referenced aliases that support deferred resolution. 777 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 778 printAliases(os, newLine, /*isDeferred=*/true); 779 } 780 781 private: 782 /// Print all of the referenced aliases that support the provided resolution 783 /// behavior. 784 void printAliases(raw_ostream &os, NewLineCounter &newLine, 785 bool isDeferred) const; 786 787 /// Mapping between attribute and alias. 788 llvm::MapVector<Attribute, SymbolAlias> attrToAlias; 789 /// Mapping between type and alias. 790 llvm::MapVector<Type, SymbolAlias> typeToAlias; 791 792 /// An allocator used for alias names. 793 llvm::BumpPtrAllocator aliasAllocator; 794 }; 795 } // namespace 796 797 void AliasState::initialize( 798 Operation *op, const OpPrintingFlags &printerFlags, 799 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 800 AliasInitializer initializer(interfaces, aliasAllocator); 801 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); 802 } 803 804 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { 805 auto it = attrToAlias.find(attr); 806 if (it == attrToAlias.end()) 807 return failure(); 808 it->second.print(os << '#'); 809 return success(); 810 } 811 812 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { 813 auto it = typeToAlias.find(ty); 814 if (it == typeToAlias.end()) 815 return failure(); 816 817 it->second.print(os << '!'); 818 return success(); 819 } 820 821 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine, 822 bool isDeferred) const { 823 auto filterFn = [=](const auto &aliasIt) { 824 return aliasIt.second.canBeDeferred() == isDeferred; 825 }; 826 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) { 827 it.second.print(os << '#'); 828 os << " = " << it.first << newLine; 829 } 830 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) { 831 it.second.print(os << '!'); 832 os << " = " << it.first << newLine; 833 } 834 } 835 836 //===----------------------------------------------------------------------===// 837 // SSANameState 838 //===----------------------------------------------------------------------===// 839 840 namespace { 841 /// Info about block printing: a number which is its position in the visitation 842 /// order, and a name that is used to print reference to it, e.g. ^bb42. 843 struct BlockInfo { 844 int ordering; 845 StringRef name; 846 }; 847 848 /// This class manages the state of SSA value names. 849 class SSANameState { 850 public: 851 /// A sentinel value used for values with names set. 852 enum : unsigned { NameSentinel = ~0U }; 853 854 SSANameState(Operation *op, const OpPrintingFlags &printerFlags); 855 856 /// Print the SSA identifier for the given value to 'stream'. If 857 /// 'printResultNo' is true, it also presents the result number ('#' number) 858 /// of this value. 859 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; 860 861 /// Print the operation identifier. 862 void printOperationID(Operation *op, raw_ostream &stream) const; 863 864 /// Return the result indices for each of the result groups registered by this 865 /// operation, or empty if none exist. 866 ArrayRef<int> getOpResultGroups(Operation *op); 867 868 /// Get the info for the given block. 869 BlockInfo getBlockInfo(Block *block); 870 871 /// Renumber the arguments for the specified region to the same names as the 872 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for 873 /// details. 874 void shadowRegionArgs(Region ®ion, ValueRange namesToUse); 875 876 private: 877 /// Number the SSA values within the given IR unit. 878 void numberValuesInRegion(Region ®ion); 879 void numberValuesInBlock(Block &block); 880 void numberValuesInOp(Operation &op); 881 882 /// Given a result of an operation 'result', find the result group head 883 /// 'lookupValue' and the result of 'result' within that group in 884 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group 885 /// has more than 1 result. 886 void getResultIDAndNumber(OpResult result, Value &lookupValue, 887 Optional<int> &lookupResultNo) const; 888 889 /// Set a special value name for the given value. 890 void setValueName(Value value, StringRef name); 891 892 /// Uniques the given value name within the printer. If the given name 893 /// conflicts, it is automatically renamed. 894 StringRef uniqueValueName(StringRef name); 895 896 /// This is the value ID for each SSA value. If this returns NameSentinel, 897 /// then the valueID has an entry in valueNames. 898 DenseMap<Value, unsigned> valueIDs; 899 DenseMap<Value, StringRef> valueNames; 900 901 /// When printing users of values, an operation without a result might 902 /// be the user. This map holds ids for such operations. 903 DenseMap<Operation *, unsigned> operationIDs; 904 905 /// This is a map of operations that contain multiple named result groups, 906 /// i.e. there may be multiple names for the results of the operation. The 907 /// value of this map are the result numbers that start a result group. 908 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; 909 910 /// This maps blocks to there visitation number in the current region as well 911 /// as the string representing their name. 912 DenseMap<Block *, BlockInfo> blockNames; 913 914 /// This keeps track of all of the non-numeric names that are in flight, 915 /// allowing us to check for duplicates. 916 /// Note: the value of the map is unused. 917 llvm::ScopedHashTable<StringRef, char> usedNames; 918 llvm::BumpPtrAllocator usedNameAllocator; 919 920 /// This is the next value ID to assign in numbering. 921 unsigned nextValueID = 0; 922 /// This is the next ID to assign to a region entry block argument. 923 unsigned nextArgumentID = 0; 924 /// This is the next ID to assign when a name conflict is detected. 925 unsigned nextConflictID = 0; 926 927 /// These are the printing flags. They control, eg., whether to print in 928 /// generic form. 929 OpPrintingFlags printerFlags; 930 }; 931 } // namespace 932 933 SSANameState::SSANameState( 934 Operation *op, const OpPrintingFlags &printerFlags) 935 : printerFlags(printerFlags) { 936 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); 937 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); 938 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); 939 940 // The naming context includes `nextValueID`, `nextArgumentID`, 941 // `nextConflictID` and `usedNames` scoped HashTable. This information is 942 // carried from the parent region. 943 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; 944 using NamingContext = 945 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; 946 947 // Allocator for UsedNamesScopeTy 948 llvm::BumpPtrAllocator allocator; 949 950 // Add a scope for the top level operation. 951 auto *topLevelNamesScope = 952 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); 953 954 SmallVector<NamingContext, 8> nameContext; 955 for (Region ®ion : op->getRegions()) 956 nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, 957 nextConflictID, topLevelNamesScope)); 958 959 numberValuesInOp(*op); 960 961 while (!nameContext.empty()) { 962 Region *region; 963 UsedNamesScopeTy *parentScope; 964 std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) = 965 nameContext.pop_back_val(); 966 967 // When we switch from one subtree to another, pop the scopes(needless) 968 // until the parent scope. 969 while (usedNames.getCurScope() != parentScope) { 970 usedNames.getCurScope()->~UsedNamesScopeTy(); 971 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && 972 "top level parentScope must be a nullptr"); 973 } 974 975 // Add a scope for the current region. 976 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) 977 UsedNamesScopeTy(usedNames); 978 979 numberValuesInRegion(*region); 980 981 for (Operation &op : region->getOps()) 982 for (Region ®ion : op.getRegions()) 983 nameContext.push_back(std::make_tuple(®ion, nextValueID, 984 nextArgumentID, nextConflictID, 985 curNamesScope)); 986 } 987 988 // Manually remove all the scopes. 989 while (usedNames.getCurScope() != nullptr) 990 usedNames.getCurScope()->~UsedNamesScopeTy(); 991 } 992 993 void SSANameState::printValueID(Value value, bool printResultNo, 994 raw_ostream &stream) const { 995 if (!value) { 996 stream << "<<NULL VALUE>>"; 997 return; 998 } 999 1000 Optional<int> resultNo; 1001 auto lookupValue = value; 1002 1003 // If this is an operation result, collect the head lookup value of the result 1004 // group and the result number of 'result' within that group. 1005 if (OpResult result = value.dyn_cast<OpResult>()) 1006 getResultIDAndNumber(result, lookupValue, resultNo); 1007 1008 auto it = valueIDs.find(lookupValue); 1009 if (it == valueIDs.end()) { 1010 stream << "<<UNKNOWN SSA VALUE>>"; 1011 return; 1012 } 1013 1014 stream << '%'; 1015 if (it->second != NameSentinel) { 1016 stream << it->second; 1017 } else { 1018 auto nameIt = valueNames.find(lookupValue); 1019 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 1020 stream << nameIt->second; 1021 } 1022 1023 if (resultNo && printResultNo) 1024 stream << '#' << resultNo; 1025 } 1026 1027 void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { 1028 auto it = operationIDs.find(op); 1029 if (it == operationIDs.end()) { 1030 stream << "<<UNKOWN OPERATION>>"; 1031 } else { 1032 stream << '%' << it->second; 1033 } 1034 } 1035 1036 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { 1037 auto it = opResultGroups.find(op); 1038 return it == opResultGroups.end() ? ArrayRef<int>() : it->second; 1039 } 1040 1041 BlockInfo SSANameState::getBlockInfo(Block *block) { 1042 auto it = blockNames.find(block); 1043 BlockInfo invalidBlock{-1, "INVALIDBLOCK"}; 1044 return it != blockNames.end() ? it->second : invalidBlock; 1045 } 1046 1047 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { 1048 assert(!region.empty() && "cannot shadow arguments of an empty region"); 1049 assert(region.getNumArguments() == namesToUse.size() && 1050 "incorrect number of names passed in"); 1051 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 1052 "only KnownIsolatedFromAbove ops can shadow names"); 1053 1054 SmallVector<char, 16> nameStr; 1055 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 1056 auto nameToUse = namesToUse[i]; 1057 if (nameToUse == nullptr) 1058 continue; 1059 auto nameToReplace = region.getArgument(i); 1060 1061 nameStr.clear(); 1062 llvm::raw_svector_ostream nameStream(nameStr); 1063 printValueID(nameToUse, /*printResultNo=*/true, nameStream); 1064 1065 // Entry block arguments should already have a pretty "arg" name. 1066 assert(valueIDs[nameToReplace] == NameSentinel); 1067 1068 // Use the name without the leading %. 1069 auto name = StringRef(nameStream.str()).drop_front(); 1070 1071 // Overwrite the name. 1072 valueNames[nameToReplace] = name.copy(usedNameAllocator); 1073 } 1074 } 1075 1076 void SSANameState::numberValuesInRegion(Region ®ion) { 1077 auto setBlockArgNameFn = [&](Value arg, StringRef name) { 1078 assert(!valueIDs.count(arg) && "arg numbered multiple times"); 1079 assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion && 1080 "arg not defined in current region"); 1081 setValueName(arg, name); 1082 }; 1083 1084 if (!printerFlags.shouldPrintGenericOpForm()) { 1085 if (Operation *op = region.getParentOp()) { 1086 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) 1087 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); 1088 } 1089 } 1090 1091 // Number the values within this region in a breadth-first order. 1092 unsigned nextBlockID = 0; 1093 for (auto &block : region) { 1094 // Each block gets a unique ID, and all of the operations within it get 1095 // numbered as well. 1096 auto blockInfoIt = blockNames.insert({&block, {-1, ""}}); 1097 if (blockInfoIt.second) { 1098 // This block hasn't been named through `getAsmBlockArgumentNames`, use 1099 // default `^bbNNN` format. 1100 std::string name; 1101 llvm::raw_string_ostream(name) << "^bb" << nextBlockID; 1102 blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator); 1103 } 1104 blockInfoIt.first->second.ordering = nextBlockID++; 1105 1106 numberValuesInBlock(block); 1107 } 1108 } 1109 1110 void SSANameState::numberValuesInBlock(Block &block) { 1111 // Number the block arguments. We give entry block arguments a special name 1112 // 'arg'. 1113 bool isEntryBlock = block.isEntryBlock(); 1114 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); 1115 llvm::raw_svector_ostream specialName(specialNameBuffer); 1116 for (auto arg : block.getArguments()) { 1117 if (valueIDs.count(arg)) 1118 continue; 1119 if (isEntryBlock) { 1120 specialNameBuffer.resize(strlen("arg")); 1121 specialName << nextArgumentID++; 1122 } 1123 setValueName(arg, specialName.str()); 1124 } 1125 1126 // Number the operations in this block. 1127 for (auto &op : block) 1128 numberValuesInOp(op); 1129 } 1130 1131 void SSANameState::numberValuesInOp(Operation &op) { 1132 // Function used to set the special result names for the operation. 1133 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); 1134 auto setResultNameFn = [&](Value result, StringRef name) { 1135 assert(!valueIDs.count(result) && "result numbered multiple times"); 1136 assert(result.getDefiningOp() == &op && "result not defined by 'op'"); 1137 setValueName(result, name); 1138 1139 // Record the result number for groups not anchored at 0. 1140 if (int resultNo = result.cast<OpResult>().getResultNumber()) 1141 resultGroups.push_back(resultNo); 1142 }; 1143 // Operations can customize the printing of block names in OpAsmOpInterface. 1144 auto setBlockNameFn = [&](Block *block, StringRef name) { 1145 assert(block->getParentOp() == &op && 1146 "getAsmBlockArgumentNames callback invoked on a block not directly " 1147 "nested under the current operation"); 1148 assert(!blockNames.count(block) && "block numbered multiple times"); 1149 SmallString<16> tmpBuffer{"^"}; 1150 name = sanitizeIdentifier(name, tmpBuffer); 1151 if (name.data() != tmpBuffer.data()) { 1152 tmpBuffer.append(name); 1153 name = tmpBuffer.str(); 1154 } 1155 name = name.copy(usedNameAllocator); 1156 blockNames[block] = {-1, name}; 1157 }; 1158 1159 if (!printerFlags.shouldPrintGenericOpForm()) { 1160 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { 1161 asmInterface.getAsmBlockNames(setBlockNameFn); 1162 asmInterface.getAsmResultNames(setResultNameFn); 1163 } 1164 } 1165 1166 unsigned numResults = op.getNumResults(); 1167 if (numResults == 0) { 1168 // If value users should be printed, operations with no result need an id. 1169 if (printerFlags.shouldPrintValueUsers()) { 1170 if (operationIDs.try_emplace(&op, nextValueID).second) 1171 ++nextValueID; 1172 } 1173 return; 1174 } 1175 Value resultBegin = op.getResult(0); 1176 1177 // If the first result wasn't numbered, give it a default number. 1178 if (valueIDs.try_emplace(resultBegin, nextValueID).second) 1179 ++nextValueID; 1180 1181 // If this operation has multiple result groups, mark it. 1182 if (resultGroups.size() != 1) { 1183 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); 1184 opResultGroups.try_emplace(&op, std::move(resultGroups)); 1185 } 1186 } 1187 1188 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, 1189 Optional<int> &lookupResultNo) const { 1190 Operation *owner = result.getOwner(); 1191 if (owner->getNumResults() == 1) 1192 return; 1193 int resultNo = result.getResultNumber(); 1194 1195 // If this operation has multiple result groups, we will need to find the 1196 // one corresponding to this result. 1197 auto resultGroupIt = opResultGroups.find(owner); 1198 if (resultGroupIt == opResultGroups.end()) { 1199 // If not, just use the first result. 1200 lookupResultNo = resultNo; 1201 lookupValue = owner->getResult(0); 1202 return; 1203 } 1204 1205 // Find the correct index using a binary search, as the groups are ordered. 1206 ArrayRef<int> resultGroups = resultGroupIt->second; 1207 const auto *it = llvm::upper_bound(resultGroups, resultNo); 1208 int groupResultNo = 0, groupSize = 0; 1209 1210 // If there are no smaller elements, the last result group is the lookup. 1211 if (it == resultGroups.end()) { 1212 groupResultNo = resultGroups.back(); 1213 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); 1214 } else { 1215 // Otherwise, the previous element is the lookup. 1216 groupResultNo = *std::prev(it); 1217 groupSize = *it - groupResultNo; 1218 } 1219 1220 // We only record the result number for a group of size greater than 1. 1221 if (groupSize != 1) 1222 lookupResultNo = resultNo - groupResultNo; 1223 lookupValue = owner->getResult(groupResultNo); 1224 } 1225 1226 void SSANameState::setValueName(Value value, StringRef name) { 1227 // If the name is empty, the value uses the default numbering. 1228 if (name.empty()) { 1229 valueIDs[value] = nextValueID++; 1230 return; 1231 } 1232 1233 valueIDs[value] = NameSentinel; 1234 valueNames[value] = uniqueValueName(name); 1235 } 1236 1237 StringRef SSANameState::uniqueValueName(StringRef name) { 1238 SmallString<16> tmpBuffer; 1239 name = sanitizeIdentifier(name, tmpBuffer); 1240 1241 // Check to see if this name is already unique. 1242 if (!usedNames.count(name)) { 1243 name = name.copy(usedNameAllocator); 1244 } else { 1245 // Otherwise, we had a conflict - probe until we find a unique name. This 1246 // is guaranteed to terminate (and usually in a single iteration) because it 1247 // generates new names by incrementing nextConflictID. 1248 SmallString<64> probeName(name); 1249 probeName.push_back('_'); 1250 while (true) { 1251 probeName += llvm::utostr(nextConflictID++); 1252 if (!usedNames.count(probeName)) { 1253 name = probeName.str().copy(usedNameAllocator); 1254 break; 1255 } 1256 probeName.resize(name.size() + 1); 1257 } 1258 } 1259 1260 usedNames.insert(name, char()); 1261 return name; 1262 } 1263 1264 //===----------------------------------------------------------------------===// 1265 // Resources 1266 //===----------------------------------------------------------------------===// 1267 1268 AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; 1269 AsmResourceBuilder::~AsmResourceBuilder() = default; 1270 AsmResourceParser::~AsmResourceParser() = default; 1271 AsmResourcePrinter::~AsmResourcePrinter() = default; 1272 1273 //===----------------------------------------------------------------------===// 1274 // AsmState 1275 //===----------------------------------------------------------------------===// 1276 1277 namespace mlir { 1278 namespace detail { 1279 class AsmStateImpl { 1280 public: 1281 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, 1282 AsmState::LocationMap *locationMap) 1283 : interfaces(op->getContext()), nameState(op, printerFlags), 1284 printerFlags(printerFlags), locationMap(locationMap) {} 1285 1286 /// Initialize the alias state to enable the printing of aliases. 1287 void initializeAliases(Operation *op) { 1288 aliasState.initialize(op, printerFlags, interfaces); 1289 } 1290 1291 /// Get the state used for aliases. 1292 AliasState &getAliasState() { return aliasState; } 1293 1294 /// Get the state used for SSA names. 1295 SSANameState &getSSANameState() { return nameState; } 1296 1297 /// Return the dialects within the context that implement 1298 /// OpAsmDialectInterface. 1299 DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() { 1300 return interfaces; 1301 } 1302 1303 /// Return the non-dialect resource printers. 1304 auto getResourcePrinters() { 1305 return llvm::make_pointee_range(externalResourcePrinters); 1306 } 1307 1308 /// Get the printer flags. 1309 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } 1310 1311 /// Register the location, line and column, within the buffer that the given 1312 /// operation was printed at. 1313 void registerOperationLocation(Operation *op, unsigned line, unsigned col) { 1314 if (locationMap) 1315 (*locationMap)[op] = std::make_pair(line, col); 1316 } 1317 1318 private: 1319 /// Collection of OpAsm interfaces implemented in the context. 1320 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 1321 1322 /// A collection of non-dialect resource printers. 1323 SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; 1324 1325 /// The state used for attribute and type aliases. 1326 AliasState aliasState; 1327 1328 /// The state used for SSA value names. 1329 SSANameState nameState; 1330 1331 /// Flags that control op output. 1332 OpPrintingFlags printerFlags; 1333 1334 /// An optional location map to be populated. 1335 AsmState::LocationMap *locationMap; 1336 1337 // Allow direct access to the impl fields. 1338 friend AsmState; 1339 }; 1340 } // namespace detail 1341 } // namespace mlir 1342 1343 /// Verifies the operation and switches to generic op printing if verification 1344 /// fails. We need to do this because custom print functions may fail for 1345 /// invalid ops. 1346 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, 1347 OpPrintingFlags printerFlags) { 1348 if (printerFlags.shouldPrintGenericOpForm() || 1349 printerFlags.shouldAssumeVerified()) 1350 return printerFlags; 1351 1352 LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: " 1353 << op->getName() << "\n"); 1354 1355 // Ignore errors emitted by the verifier. We check the thread id to avoid 1356 // consuming other threads' errors. 1357 auto parentThreadId = llvm::get_threadid(); 1358 ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { 1359 if (parentThreadId == llvm::get_threadid()) { 1360 LLVM_DEBUG({ 1361 diag.print(llvm::dbgs()); 1362 llvm::dbgs() << "\n"; 1363 }); 1364 return success(); 1365 } 1366 return failure(); 1367 }); 1368 if (failed(verify(op))) { 1369 LLVM_DEBUG(llvm::dbgs() 1370 << DEBUG_TYPE << ": '" << op->getName() 1371 << "' failed to verify and will be printed in generic form\n"); 1372 printerFlags.printGenericOpForm(); 1373 } 1374 1375 return printerFlags; 1376 } 1377 1378 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, 1379 LocationMap *locationMap) 1380 : impl(std::make_unique<AsmStateImpl>( 1381 op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {} 1382 AsmState::~AsmState() = default; 1383 1384 const OpPrintingFlags &AsmState::getPrinterFlags() const { 1385 return impl->getPrinterFlags(); 1386 } 1387 1388 void AsmState::attachResourcePrinter( 1389 std::unique_ptr<AsmResourcePrinter> printer) { 1390 impl->externalResourcePrinters.emplace_back(std::move(printer)); 1391 } 1392 1393 //===----------------------------------------------------------------------===// 1394 // AsmPrinter::Impl 1395 //===----------------------------------------------------------------------===// 1396 1397 namespace mlir { 1398 class AsmPrinter::Impl { 1399 public: 1400 Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None, 1401 AsmStateImpl *state = nullptr) 1402 : os(os), printerFlags(flags), state(state) {} 1403 explicit Impl(Impl &other) 1404 : Impl(other.os, other.printerFlags, other.state) {} 1405 1406 /// Returns the output stream of the printer. 1407 raw_ostream &getStream() { return os; } 1408 1409 template <typename Container, typename UnaryFunctor> 1410 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { 1411 llvm::interleaveComma(c, os, eachFn); 1412 } 1413 1414 /// This enum describes the different kinds of elision for the type of an 1415 /// attribute when printing it. 1416 enum class AttrTypeElision { 1417 /// The type must not be elided, 1418 Never, 1419 /// The type may be elided when it matches the default used in the parser 1420 /// (for example i64 is the default for integer attributes). 1421 May, 1422 /// The type must be elided. 1423 Must 1424 }; 1425 1426 /// Print the given attribute. 1427 void printAttribute(Attribute attr, 1428 AttrTypeElision typeElision = AttrTypeElision::Never); 1429 1430 /// Print the alias for the given attribute, return failure if no alias could 1431 /// be printed. 1432 LogicalResult printAlias(Attribute attr); 1433 1434 void printType(Type type); 1435 1436 /// Print the alias for the given type, return failure if no alias could 1437 /// be printed. 1438 LogicalResult printAlias(Type type); 1439 1440 /// Print the given location to the stream. If `allowAlias` is true, this 1441 /// allows for the internal location to use an attribute alias. 1442 void printLocation(LocationAttr loc, bool allowAlias = false); 1443 1444 /// Print a reference to the given resource that is owned by the given 1445 /// dialect. 1446 void printResourceHandle(const AsmDialectResourceHandle &resource) { 1447 auto *interface = cast<OpAsmDialectInterface>(resource.getDialect()); 1448 os << interface->getResourceKey(resource); 1449 dialectResources[resource.getDialect()].insert(resource); 1450 } 1451 1452 void printAffineMap(AffineMap map); 1453 void 1454 printAffineExpr(AffineExpr expr, 1455 function_ref<void(unsigned, bool)> printValueName = nullptr); 1456 void printAffineConstraint(AffineExpr expr, bool isEq); 1457 void printIntegerSet(IntegerSet set); 1458 1459 protected: 1460 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1461 ArrayRef<StringRef> elidedAttrs = {}, 1462 bool withKeyword = false); 1463 void printNamedAttribute(NamedAttribute attr); 1464 void printTrailingLocation(Location loc, bool allowAlias = true); 1465 void printLocationInternal(LocationAttr loc, bool pretty = false); 1466 1467 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1468 /// used instead of individual elements when the elements attr is large. 1469 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); 1470 1471 /// Print a dense string elements attribute. 1472 void printDenseStringElementsAttr(DenseStringElementsAttr attr); 1473 1474 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1475 /// used instead of individual elements when the elements attr is large. 1476 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1477 bool allowHex); 1478 1479 void printDialectAttribute(Attribute attr); 1480 void printDialectType(Type type); 1481 1482 /// Print an escaped string, wrapped with "". 1483 void printEscapedString(StringRef str); 1484 1485 /// Print a hex string, wrapped with "". 1486 void printHexString(StringRef str); 1487 void printHexString(ArrayRef<char> data); 1488 1489 /// This enum is used to represent the binding strength of the enclosing 1490 /// context that an AffineExprStorage is being printed in, so we can 1491 /// intelligently produce parens. 1492 enum class BindingStrength { 1493 Weak, // + and - 1494 Strong, // All other binary operators. 1495 }; 1496 void printAffineExprInternal( 1497 AffineExpr expr, BindingStrength enclosingTightness, 1498 function_ref<void(unsigned, bool)> printValueName = nullptr); 1499 1500 /// The output stream for the printer. 1501 raw_ostream &os; 1502 1503 /// A set of flags to control the printer's behavior. 1504 OpPrintingFlags printerFlags; 1505 1506 /// An optional printer state for the module. 1507 AsmStateImpl *state; 1508 1509 /// A tracker for the number of new lines emitted during printing. 1510 NewLineCounter newLine; 1511 1512 /// A set of dialect resources that were referenced during printing. 1513 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources; 1514 }; 1515 } // namespace mlir 1516 1517 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { 1518 // Check to see if we are printing debug information. 1519 if (!printerFlags.shouldPrintDebugInfo()) 1520 return; 1521 1522 os << " "; 1523 printLocation(loc, /*allowAlias=*/allowAlias); 1524 } 1525 1526 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) { 1527 TypeSwitch<LocationAttr>(loc) 1528 .Case<OpaqueLoc>([&](OpaqueLoc loc) { 1529 printLocationInternal(loc.getFallbackLocation(), pretty); 1530 }) 1531 .Case<UnknownLoc>([&](UnknownLoc loc) { 1532 if (pretty) 1533 os << "[unknown]"; 1534 else 1535 os << "unknown"; 1536 }) 1537 .Case<FileLineColLoc>([&](FileLineColLoc loc) { 1538 if (pretty) 1539 os << loc.getFilename().getValue(); 1540 else 1541 printEscapedString(loc.getFilename()); 1542 os << ':' << loc.getLine() << ':' << loc.getColumn(); 1543 }) 1544 .Case<NameLoc>([&](NameLoc loc) { 1545 printEscapedString(loc.getName()); 1546 1547 // Print the child if it isn't unknown. 1548 auto childLoc = loc.getChildLoc(); 1549 if (!childLoc.isa<UnknownLoc>()) { 1550 os << '('; 1551 printLocationInternal(childLoc, pretty); 1552 os << ')'; 1553 } 1554 }) 1555 .Case<CallSiteLoc>([&](CallSiteLoc loc) { 1556 Location caller = loc.getCaller(); 1557 Location callee = loc.getCallee(); 1558 if (!pretty) 1559 os << "callsite("; 1560 printLocationInternal(callee, pretty); 1561 if (pretty) { 1562 if (callee.isa<NameLoc>()) { 1563 if (caller.isa<FileLineColLoc>()) { 1564 os << " at "; 1565 } else { 1566 os << newLine << " at "; 1567 } 1568 } else { 1569 os << newLine << " at "; 1570 } 1571 } else { 1572 os << " at "; 1573 } 1574 printLocationInternal(caller, pretty); 1575 if (!pretty) 1576 os << ")"; 1577 }) 1578 .Case<FusedLoc>([&](FusedLoc loc) { 1579 if (!pretty) 1580 os << "fused"; 1581 if (Attribute metadata = loc.getMetadata()) 1582 os << '<' << metadata << '>'; 1583 os << '['; 1584 interleave( 1585 loc.getLocations(), 1586 [&](Location loc) { printLocationInternal(loc, pretty); }, 1587 [&]() { os << ", "; }); 1588 os << ']'; 1589 }); 1590 } 1591 1592 /// Print a floating point value in a way that the parser will be able to 1593 /// round-trip losslessly. 1594 static void printFloatValue(const APFloat &apValue, raw_ostream &os) { 1595 // We would like to output the FP constant value in exponential notation, 1596 // but we cannot do this if doing so will lose precision. Check here to 1597 // make sure that we only output it in exponential format if we can parse 1598 // the value back and get the same value. 1599 bool isInf = apValue.isInfinity(); 1600 bool isNaN = apValue.isNaN(); 1601 if (!isInf && !isNaN) { 1602 SmallString<128> strValue; 1603 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, 1604 /*TruncateZero=*/false); 1605 1606 // Check to make sure that the stringized number is not some string like 1607 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 1608 // that the string matches the "[-+]?[0-9]" regex. 1609 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 1610 ((strValue[0] == '-' || strValue[0] == '+') && 1611 (strValue[1] >= '0' && strValue[1] <= '9'))) && 1612 "[-+]?[0-9] regex does not match!"); 1613 1614 // Parse back the stringized version and check that the value is equal 1615 // (i.e., there is no precision loss). 1616 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 1617 os << strValue; 1618 return; 1619 } 1620 1621 // If it is not, use the default format of APFloat instead of the 1622 // exponential notation. 1623 strValue.clear(); 1624 apValue.toString(strValue); 1625 1626 // Make sure that we can parse the default form as a float. 1627 if (strValue.str().contains('.')) { 1628 os << strValue; 1629 return; 1630 } 1631 } 1632 1633 // Print special values in hexadecimal format. The sign bit should be included 1634 // in the literal. 1635 SmallVector<char, 16> str; 1636 APInt apInt = apValue.bitcastToAPInt(); 1637 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 1638 /*formatAsCLiteral=*/true); 1639 os << str; 1640 } 1641 1642 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { 1643 if (printerFlags.shouldPrintDebugInfoPrettyForm()) 1644 return printLocationInternal(loc, /*pretty=*/true); 1645 1646 os << "loc("; 1647 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) 1648 printLocationInternal(loc); 1649 os << ')'; 1650 } 1651 1652 /// Returns true if the given dialect symbol data is simple enough to print in 1653 /// the pretty form. This is essentially when the symbol takes the form: 1654 /// identifier (`<` body `>`)? 1655 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 1656 // The name must start with an identifier. 1657 if (symName.empty() || !isalpha(symName.front())) 1658 return false; 1659 1660 // Ignore all the characters that are valid in an identifier in the symbol 1661 // name. 1662 symName = symName.drop_while( 1663 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); 1664 if (symName.empty()) 1665 return true; 1666 1667 // If we got to an unexpected character, then it must be a <>. Check that the 1668 // rest of the symbol is wrapped within <>. 1669 return symName.front() == '<' && symName.back() == '>'; 1670 } 1671 1672 /// Print the given dialect symbol to the stream. 1673 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 1674 StringRef dialectName, StringRef symString) { 1675 os << symPrefix << dialectName; 1676 1677 // If this symbol name is simple enough, print it directly in pretty form, 1678 // otherwise, we print it as an escaped string. 1679 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 1680 os << '.' << symString; 1681 return; 1682 } 1683 1684 os << '<' << symString << '>'; 1685 } 1686 1687 /// Returns true if the given string can be represented as a bare identifier. 1688 static bool isBareIdentifier(StringRef name) { 1689 // By making this unsigned, the value passed in to isalnum will always be 1690 // in the range 0-255. This is important when building with MSVC because 1691 // its implementation will assert. This situation can arise when dealing 1692 // with UTF-8 multibyte characters. 1693 if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) 1694 return false; 1695 return llvm::all_of(name.drop_front(), [](unsigned char c) { 1696 return isalnum(c) || c == '_' || c == '$' || c == '.'; 1697 }); 1698 } 1699 1700 /// Print the given string as a keyword, or a quoted and escaped string if it 1701 /// has any special or non-printable characters in it. 1702 static void printKeywordOrString(StringRef keyword, raw_ostream &os) { 1703 // If it can be represented as a bare identifier, write it directly. 1704 if (isBareIdentifier(keyword)) { 1705 os << keyword; 1706 return; 1707 } 1708 1709 // Otherwise, output the keyword wrapped in quotes with proper escaping. 1710 os << "\""; 1711 printEscapedString(keyword, os); 1712 os << '"'; 1713 } 1714 1715 /// Print the given string as a symbol reference. A symbol reference is 1716 /// represented as a string prefixed with '@'. The reference is surrounded with 1717 /// ""'s and escaped if it has any special or non-printable characters in it. 1718 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { 1719 assert(!symbolRef.empty() && "expected valid symbol reference"); 1720 os << '@'; 1721 printKeywordOrString(symbolRef, os); 1722 } 1723 1724 // Print out a valid ElementsAttr that is succinct and can represent any 1725 // potential shape/type, for use when eliding a large ElementsAttr. 1726 // 1727 // We choose to use an opaque ElementsAttr literal with conspicuous content to 1728 // hopefully alert readers to the fact that this has been elided. 1729 // 1730 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will 1731 // accept the string "elided". The first string must be a registered dialect 1732 // name and the latter must be a hex constant. 1733 static void printElidedElementsAttr(raw_ostream &os) { 1734 os << R"(opaque<"elided_large_const", "0xDEADBEEF">)"; 1735 } 1736 1737 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { 1738 return success(state && succeeded(state->getAliasState().getAlias(attr, os))); 1739 } 1740 1741 LogicalResult AsmPrinter::Impl::printAlias(Type type) { 1742 return success(state && succeeded(state->getAliasState().getAlias(type, os))); 1743 } 1744 1745 void AsmPrinter::Impl::printAttribute(Attribute attr, 1746 AttrTypeElision typeElision) { 1747 if (!attr) { 1748 os << "<<NULL ATTRIBUTE>>"; 1749 return; 1750 } 1751 1752 // Try to print an alias for this attribute. 1753 if (succeeded(printAlias(attr))) 1754 return; 1755 1756 auto attrType = attr.getType(); 1757 if (!isa<BuiltinDialect>(attr.getDialect())) { 1758 printDialectAttribute(attr); 1759 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) { 1760 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 1761 opaqueAttr.getAttrData()); 1762 } else if (attr.isa<UnitAttr>()) { 1763 os << "unit"; 1764 return; 1765 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) { 1766 os << '{'; 1767 interleaveComma(dictAttr.getValue(), 1768 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1769 os << '}'; 1770 1771 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) { 1772 if (attrType.isSignlessInteger(1)) { 1773 os << (intAttr.getValue().getBoolValue() ? "true" : "false"); 1774 1775 // Boolean integer attributes always elides the type. 1776 return; 1777 } 1778 1779 // Only print attributes as unsigned if they are explicitly unsigned or are 1780 // signless 1-bit values. Indexes, signed values, and multi-bit signless 1781 // values print as signed. 1782 bool isUnsigned = 1783 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); 1784 intAttr.getValue().print(os, !isUnsigned); 1785 1786 // IntegerAttr elides the type if I64. 1787 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) 1788 return; 1789 1790 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) { 1791 printFloatValue(floatAttr.getValue(), os); 1792 1793 // FloatAttr elides the type if F64. 1794 if (typeElision == AttrTypeElision::May && attrType.isF64()) 1795 return; 1796 1797 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) { 1798 printEscapedString(strAttr.getValue()); 1799 1800 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 1801 os << '['; 1802 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { 1803 printAttribute(attr, AttrTypeElision::May); 1804 }); 1805 os << ']'; 1806 1807 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) { 1808 os << "affine_map<"; 1809 affineMapAttr.getValue().print(os); 1810 os << '>'; 1811 1812 // AffineMap always elides the type. 1813 return; 1814 1815 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) { 1816 os << "affine_set<"; 1817 integerSetAttr.getValue().print(os); 1818 os << '>'; 1819 1820 // IntegerSet always elides the type. 1821 return; 1822 1823 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 1824 printType(typeAttr.getValue()); 1825 1826 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) { 1827 printSymbolReference(refAttr.getRootReference().getValue(), os); 1828 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { 1829 os << "::"; 1830 printSymbolReference(nestedRef.getValue(), os); 1831 } 1832 1833 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) { 1834 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { 1835 printElidedElementsAttr(os); 1836 } else { 1837 os << "opaque<" << opaqueAttr.getDialect() << ", "; 1838 printHexString(opaqueAttr.getValue()); 1839 os << ">"; 1840 } 1841 1842 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) { 1843 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { 1844 printElidedElementsAttr(os); 1845 } else { 1846 os << "dense<"; 1847 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); 1848 os << '>'; 1849 } 1850 1851 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) { 1852 if (printerFlags.shouldElideElementsAttr(strEltAttr)) { 1853 printElidedElementsAttr(os); 1854 } else { 1855 os << "dense<"; 1856 printDenseStringElementsAttr(strEltAttr); 1857 os << '>'; 1858 } 1859 1860 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) { 1861 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || 1862 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { 1863 printElidedElementsAttr(os); 1864 } else { 1865 os << "sparse<"; 1866 DenseIntElementsAttr indices = sparseEltAttr.getIndices(); 1867 if (indices.getNumElements() != 0) { 1868 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); 1869 os << ", "; 1870 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); 1871 } 1872 os << '>'; 1873 } 1874 } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) { 1875 typeElision = AttrTypeElision::Must; 1876 switch (denseArrayAttr.getElementType()) { 1877 case DenseArrayBaseAttr::EltType::I8: 1878 os << "[:i8"; 1879 break; 1880 case DenseArrayBaseAttr::EltType::I16: 1881 os << "[:i16"; 1882 break; 1883 case DenseArrayBaseAttr::EltType::I32: 1884 os << "[:i32"; 1885 break; 1886 case DenseArrayBaseAttr::EltType::I64: 1887 os << "[:i64"; 1888 break; 1889 case DenseArrayBaseAttr::EltType::F32: 1890 os << "[:f32"; 1891 break; 1892 case DenseArrayBaseAttr::EltType::F64: 1893 os << "[:f64"; 1894 break; 1895 } 1896 if (denseArrayAttr.getType().cast<ShapedType>().getRank()) 1897 os << " "; 1898 denseArrayAttr.printWithoutBraces(os); 1899 os << "]"; 1900 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) { 1901 printLocation(locAttr); 1902 } else { 1903 llvm::report_fatal_error("Unknown builtin attribute"); 1904 } 1905 // Don't print the type if we must elide it, or if it is a None type. 1906 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { 1907 os << " : "; 1908 printType(attrType); 1909 } 1910 } 1911 1912 /// Print the integer element of a DenseElementsAttr. 1913 static void printDenseIntElement(const APInt &value, raw_ostream &os, 1914 bool isSigned) { 1915 if (value.getBitWidth() == 1) 1916 os << (value.getBoolValue() ? "true" : "false"); 1917 else 1918 value.print(os, isSigned); 1919 } 1920 1921 static void 1922 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, 1923 function_ref<void(unsigned)> printEltFn) { 1924 // Special case for 0-d and splat tensors. 1925 if (isSplat) 1926 return printEltFn(0); 1927 1928 // Special case for degenerate tensors. 1929 auto numElements = type.getNumElements(); 1930 if (numElements == 0) 1931 return; 1932 1933 // We use a mixed-radix counter to iterate through the shape. When we bump a 1934 // non-least-significant digit, we emit a close bracket. When we next emit an 1935 // element we re-open all closed brackets. 1936 1937 // The mixed-radix counter, with radices in 'shape'. 1938 int64_t rank = type.getRank(); 1939 SmallVector<unsigned, 4> counter(rank, 0); 1940 // The number of brackets that have been opened and not closed. 1941 unsigned openBrackets = 0; 1942 1943 auto shape = type.getShape(); 1944 auto bumpCounter = [&] { 1945 // Bump the least significant digit. 1946 ++counter[rank - 1]; 1947 // Iterate backwards bubbling back the increment. 1948 for (unsigned i = rank - 1; i > 0; --i) 1949 if (counter[i] >= shape[i]) { 1950 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 1951 counter[i] = 0; 1952 ++counter[i - 1]; 1953 --openBrackets; 1954 os << ']'; 1955 } 1956 }; 1957 1958 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 1959 if (idx != 0) 1960 os << ", "; 1961 while (openBrackets++ < rank) 1962 os << '['; 1963 openBrackets = rank; 1964 printEltFn(idx); 1965 bumpCounter(); 1966 } 1967 while (openBrackets-- > 0) 1968 os << ']'; 1969 } 1970 1971 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, 1972 bool allowHex) { 1973 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) 1974 return printDenseStringElementsAttr(stringAttr); 1975 1976 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(), 1977 allowHex); 1978 } 1979 1980 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( 1981 DenseIntOrFPElementsAttr attr, bool allowHex) { 1982 auto type = attr.getType(); 1983 auto elementType = type.getElementType(); 1984 1985 // Check to see if we should format this attribute as a hex string. 1986 auto numElements = type.getNumElements(); 1987 if (!attr.isSplat() && allowHex && 1988 shouldPrintElementsAttrWithHex(numElements)) { 1989 ArrayRef<char> rawData = attr.getRawData(); 1990 if (llvm::support::endian::system_endianness() == 1991 llvm::support::endianness::big) { 1992 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE 1993 // machines. It is converted here to print in LE format. 1994 SmallVector<char, 64> outDataVec(rawData.size()); 1995 MutableArrayRef<char> convRawData(outDataVec); 1996 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1997 rawData, convRawData, type); 1998 printHexString(convRawData); 1999 } else { 2000 printHexString(rawData); 2001 } 2002 2003 return; 2004 } 2005 2006 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 2007 Type complexElementType = complexTy.getElementType(); 2008 // Note: The if and else below had a common lambda function which invoked 2009 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 2010 // and hence was replaced. 2011 if (complexElementType.isa<IntegerType>()) { 2012 bool isSigned = !complexElementType.isUnsignedInteger(); 2013 auto valueIt = attr.value_begin<std::complex<APInt>>(); 2014 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2015 auto complexValue = *(valueIt + index); 2016 os << "("; 2017 printDenseIntElement(complexValue.real(), os, isSigned); 2018 os << ","; 2019 printDenseIntElement(complexValue.imag(), os, isSigned); 2020 os << ")"; 2021 }); 2022 } else { 2023 auto valueIt = attr.value_begin<std::complex<APFloat>>(); 2024 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2025 auto complexValue = *(valueIt + index); 2026 os << "("; 2027 printFloatValue(complexValue.real(), os); 2028 os << ","; 2029 printFloatValue(complexValue.imag(), os); 2030 os << ")"; 2031 }); 2032 } 2033 } else if (elementType.isIntOrIndex()) { 2034 bool isSigned = !elementType.isUnsignedInteger(); 2035 auto valueIt = attr.value_begin<APInt>(); 2036 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2037 printDenseIntElement(*(valueIt + index), os, isSigned); 2038 }); 2039 } else { 2040 assert(elementType.isa<FloatType>() && "unexpected element type"); 2041 auto valueIt = attr.value_begin<APFloat>(); 2042 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 2043 printFloatValue(*(valueIt + index), os); 2044 }); 2045 } 2046 } 2047 2048 void AsmPrinter::Impl::printDenseStringElementsAttr( 2049 DenseStringElementsAttr attr) { 2050 ArrayRef<StringRef> data = attr.getRawStringData(); 2051 auto printFn = [&](unsigned index) { printEscapedString(data[index]); }; 2052 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); 2053 } 2054 2055 void AsmPrinter::Impl::printType(Type type) { 2056 if (!type) { 2057 os << "<<NULL TYPE>>"; 2058 return; 2059 } 2060 2061 // Try to print an alias for this type. 2062 if (state && succeeded(state->getAliasState().getAlias(type, os))) 2063 return; 2064 2065 TypeSwitch<Type>(type) 2066 .Case<OpaqueType>([&](OpaqueType opaqueTy) { 2067 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 2068 opaqueTy.getTypeData()); 2069 }) 2070 .Case<IndexType>([&](Type) { os << "index"; }) 2071 .Case<BFloat16Type>([&](Type) { os << "bf16"; }) 2072 .Case<Float16Type>([&](Type) { os << "f16"; }) 2073 .Case<Float32Type>([&](Type) { os << "f32"; }) 2074 .Case<Float64Type>([&](Type) { os << "f64"; }) 2075 .Case<Float80Type>([&](Type) { os << "f80"; }) 2076 .Case<Float128Type>([&](Type) { os << "f128"; }) 2077 .Case<IntegerType>([&](IntegerType integerTy) { 2078 if (integerTy.isSigned()) 2079 os << 's'; 2080 else if (integerTy.isUnsigned()) 2081 os << 'u'; 2082 os << 'i' << integerTy.getWidth(); 2083 }) 2084 .Case<FunctionType>([&](FunctionType funcTy) { 2085 os << '('; 2086 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); 2087 os << ") -> "; 2088 ArrayRef<Type> results = funcTy.getResults(); 2089 if (results.size() == 1 && !results[0].isa<FunctionType>()) { 2090 printType(results[0]); 2091 } else { 2092 os << '('; 2093 interleaveComma(results, [&](Type ty) { printType(ty); }); 2094 os << ')'; 2095 } 2096 }) 2097 .Case<VectorType>([&](VectorType vectorTy) { 2098 os << "vector<"; 2099 auto vShape = vectorTy.getShape(); 2100 unsigned lastDim = vShape.size(); 2101 unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims(); 2102 unsigned dimIdx = 0; 2103 for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++) 2104 os << vShape[dimIdx] << 'x'; 2105 if (vectorTy.isScalable()) { 2106 os << '['; 2107 unsigned secondToLastDim = lastDim - 1; 2108 for (; dimIdx < secondToLastDim; dimIdx++) 2109 os << vShape[dimIdx] << 'x'; 2110 os << vShape[dimIdx] << "]x"; 2111 } 2112 printType(vectorTy.getElementType()); 2113 os << '>'; 2114 }) 2115 .Case<RankedTensorType>([&](RankedTensorType tensorTy) { 2116 os << "tensor<"; 2117 for (int64_t dim : tensorTy.getShape()) { 2118 if (ShapedType::isDynamic(dim)) 2119 os << '?'; 2120 else 2121 os << dim; 2122 os << 'x'; 2123 } 2124 printType(tensorTy.getElementType()); 2125 // Only print the encoding attribute value if set. 2126 if (tensorTy.getEncoding()) { 2127 os << ", "; 2128 printAttribute(tensorTy.getEncoding()); 2129 } 2130 os << '>'; 2131 }) 2132 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { 2133 os << "tensor<*x"; 2134 printType(tensorTy.getElementType()); 2135 os << '>'; 2136 }) 2137 .Case<MemRefType>([&](MemRefType memrefTy) { 2138 os << "memref<"; 2139 for (int64_t dim : memrefTy.getShape()) { 2140 if (ShapedType::isDynamic(dim)) 2141 os << '?'; 2142 else 2143 os << dim; 2144 os << 'x'; 2145 } 2146 printType(memrefTy.getElementType()); 2147 if (!memrefTy.getLayout().isIdentity()) { 2148 os << ", "; 2149 printAttribute(memrefTy.getLayout(), AttrTypeElision::May); 2150 } 2151 // Only print the memory space if it is the non-default one. 2152 if (memrefTy.getMemorySpace()) { 2153 os << ", "; 2154 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2155 } 2156 os << '>'; 2157 }) 2158 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { 2159 os << "memref<*x"; 2160 printType(memrefTy.getElementType()); 2161 // Only print the memory space if it is the non-default one. 2162 if (memrefTy.getMemorySpace()) { 2163 os << ", "; 2164 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2165 } 2166 os << '>'; 2167 }) 2168 .Case<ComplexType>([&](ComplexType complexTy) { 2169 os << "complex<"; 2170 printType(complexTy.getElementType()); 2171 os << '>'; 2172 }) 2173 .Case<TupleType>([&](TupleType tupleTy) { 2174 os << "tuple<"; 2175 interleaveComma(tupleTy.getTypes(), 2176 [&](Type type) { printType(type); }); 2177 os << '>'; 2178 }) 2179 .Case<NoneType>([&](Type) { os << "none"; }) 2180 .Default([&](Type type) { return printDialectType(type); }); 2181 } 2182 2183 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2184 ArrayRef<StringRef> elidedAttrs, 2185 bool withKeyword) { 2186 // If there are no attributes, then there is nothing to be done. 2187 if (attrs.empty()) 2188 return; 2189 2190 // Functor used to print a filtered attribute list. 2191 auto printFilteredAttributesFn = [&](auto filteredAttrs) { 2192 // Print the 'attributes' keyword if necessary. 2193 if (withKeyword) 2194 os << " attributes"; 2195 2196 // Otherwise, print them all out in braces. 2197 os << " {"; 2198 interleaveComma(filteredAttrs, 2199 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 2200 os << '}'; 2201 }; 2202 2203 // If no attributes are elided, we can directly print with no filtering. 2204 if (elidedAttrs.empty()) 2205 return printFilteredAttributesFn(attrs); 2206 2207 // Otherwise, filter out any attributes that shouldn't be included. 2208 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 2209 elidedAttrs.end()); 2210 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 2211 return !elidedAttrsSet.contains(attr.getName().strref()); 2212 }); 2213 if (!filteredAttrs.empty()) 2214 printFilteredAttributesFn(filteredAttrs); 2215 } 2216 2217 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { 2218 // Print the name without quotes if possible. 2219 ::printKeywordOrString(attr.getName().strref(), os); 2220 2221 // Pretty printing elides the attribute value for unit attributes. 2222 if (attr.getValue().isa<UnitAttr>()) 2223 return; 2224 2225 os << " = "; 2226 printAttribute(attr.getValue()); 2227 } 2228 2229 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { 2230 auto &dialect = attr.getDialect(); 2231 2232 // Ask the dialect to serialize the attribute to a string. 2233 std::string attrName; 2234 { 2235 llvm::raw_string_ostream attrNameStr(attrName); 2236 Impl subPrinter(attrNameStr, printerFlags, state); 2237 DialectAsmPrinter printer(subPrinter); 2238 dialect.printAttribute(attr, printer); 2239 2240 // FIXME: Delete this when we no longer require a nested printer. 2241 for (auto &it : subPrinter.dialectResources) 2242 for (const auto &resource : it.second) 2243 dialectResources[it.first].insert(resource); 2244 } 2245 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 2246 } 2247 2248 void AsmPrinter::Impl::printDialectType(Type type) { 2249 auto &dialect = type.getDialect(); 2250 2251 // Ask the dialect to serialize the type to a string. 2252 std::string typeName; 2253 { 2254 llvm::raw_string_ostream typeNameStr(typeName); 2255 Impl subPrinter(typeNameStr, printerFlags, state); 2256 DialectAsmPrinter printer(subPrinter); 2257 dialect.printType(type, printer); 2258 2259 // FIXME: Delete this when we no longer require a nested printer. 2260 for (auto &it : subPrinter.dialectResources) 2261 for (const auto &resource : it.second) 2262 dialectResources[it.first].insert(resource); 2263 } 2264 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 2265 } 2266 2267 void AsmPrinter::Impl::printEscapedString(StringRef str) { 2268 os << "\""; 2269 llvm::printEscapedString(str, os); 2270 os << "\""; 2271 } 2272 2273 void AsmPrinter::Impl::printHexString(StringRef str) { 2274 os << "\"0x" << llvm::toHex(str) << "\""; 2275 } 2276 void AsmPrinter::Impl::printHexString(ArrayRef<char> data) { 2277 printHexString(StringRef(data.data(), data.size())); 2278 } 2279 2280 //===--------------------------------------------------------------------===// 2281 // AsmPrinter 2282 //===--------------------------------------------------------------------===// 2283 2284 AsmPrinter::~AsmPrinter() = default; 2285 2286 raw_ostream &AsmPrinter::getStream() const { 2287 assert(impl && "expected AsmPrinter::getStream to be overriden"); 2288 return impl->getStream(); 2289 } 2290 2291 /// Print the given floating point value in a stablized form. 2292 void AsmPrinter::printFloat(const APFloat &value) { 2293 assert(impl && "expected AsmPrinter::printFloat to be overriden"); 2294 printFloatValue(value, impl->getStream()); 2295 } 2296 2297 void AsmPrinter::printType(Type type) { 2298 assert(impl && "expected AsmPrinter::printType to be overriden"); 2299 impl->printType(type); 2300 } 2301 2302 void AsmPrinter::printAttribute(Attribute attr) { 2303 assert(impl && "expected AsmPrinter::printAttribute to be overriden"); 2304 impl->printAttribute(attr); 2305 } 2306 2307 LogicalResult AsmPrinter::printAlias(Attribute attr) { 2308 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2309 return impl->printAlias(attr); 2310 } 2311 2312 LogicalResult AsmPrinter::printAlias(Type type) { 2313 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2314 return impl->printAlias(type); 2315 } 2316 2317 void AsmPrinter::printAttributeWithoutType(Attribute attr) { 2318 assert(impl && 2319 "expected AsmPrinter::printAttributeWithoutType to be overriden"); 2320 impl->printAttribute(attr, Impl::AttrTypeElision::Must); 2321 } 2322 2323 void AsmPrinter::printKeywordOrString(StringRef keyword) { 2324 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); 2325 ::printKeywordOrString(keyword, impl->getStream()); 2326 } 2327 2328 void AsmPrinter::printSymbolName(StringRef symbolRef) { 2329 assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); 2330 ::printSymbolReference(symbolRef, impl->getStream()); 2331 } 2332 2333 void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { 2334 assert(impl && "expected AsmPrinter::printResourceHandle to be overriden"); 2335 impl->printResourceHandle(resource); 2336 } 2337 2338 //===----------------------------------------------------------------------===// 2339 // Affine expressions and maps 2340 //===----------------------------------------------------------------------===// 2341 2342 void AsmPrinter::Impl::printAffineExpr( 2343 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { 2344 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 2345 } 2346 2347 void AsmPrinter::Impl::printAffineExprInternal( 2348 AffineExpr expr, BindingStrength enclosingTightness, 2349 function_ref<void(unsigned, bool)> printValueName) { 2350 const char *binopSpelling = nullptr; 2351 switch (expr.getKind()) { 2352 case AffineExprKind::SymbolId: { 2353 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); 2354 if (printValueName) 2355 printValueName(pos, /*isSymbol=*/true); 2356 else 2357 os << 's' << pos; 2358 return; 2359 } 2360 case AffineExprKind::DimId: { 2361 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 2362 if (printValueName) 2363 printValueName(pos, /*isSymbol=*/false); 2364 else 2365 os << 'd' << pos; 2366 return; 2367 } 2368 case AffineExprKind::Constant: 2369 os << expr.cast<AffineConstantExpr>().getValue(); 2370 return; 2371 case AffineExprKind::Add: 2372 binopSpelling = " + "; 2373 break; 2374 case AffineExprKind::Mul: 2375 binopSpelling = " * "; 2376 break; 2377 case AffineExprKind::FloorDiv: 2378 binopSpelling = " floordiv "; 2379 break; 2380 case AffineExprKind::CeilDiv: 2381 binopSpelling = " ceildiv "; 2382 break; 2383 case AffineExprKind::Mod: 2384 binopSpelling = " mod "; 2385 break; 2386 } 2387 2388 auto binOp = expr.cast<AffineBinaryOpExpr>(); 2389 AffineExpr lhsExpr = binOp.getLHS(); 2390 AffineExpr rhsExpr = binOp.getRHS(); 2391 2392 // Handle tightly binding binary operators. 2393 if (binOp.getKind() != AffineExprKind::Add) { 2394 if (enclosingTightness == BindingStrength::Strong) 2395 os << '('; 2396 2397 // Pretty print multiplication with -1. 2398 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); 2399 if (rhsConst && binOp.getKind() == AffineExprKind::Mul && 2400 rhsConst.getValue() == -1) { 2401 os << "-"; 2402 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2403 if (enclosingTightness == BindingStrength::Strong) 2404 os << ')'; 2405 return; 2406 } 2407 2408 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2409 2410 os << binopSpelling; 2411 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 2412 2413 if (enclosingTightness == BindingStrength::Strong) 2414 os << ')'; 2415 return; 2416 } 2417 2418 // Print out special "pretty" forms for add. 2419 if (enclosingTightness == BindingStrength::Strong) 2420 os << '('; 2421 2422 // Pretty print addition to a product that has a negative operand as a 2423 // subtraction. 2424 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { 2425 if (rhs.getKind() == AffineExprKind::Mul) { 2426 AffineExpr rrhsExpr = rhs.getRHS(); 2427 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { 2428 if (rrhs.getValue() == -1) { 2429 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2430 printValueName); 2431 os << " - "; 2432 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 2433 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2434 printValueName); 2435 } else { 2436 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 2437 printValueName); 2438 } 2439 2440 if (enclosingTightness == BindingStrength::Strong) 2441 os << ')'; 2442 return; 2443 } 2444 2445 if (rrhs.getValue() < -1) { 2446 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2447 printValueName); 2448 os << " - "; 2449 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2450 printValueName); 2451 os << " * " << -rrhs.getValue(); 2452 if (enclosingTightness == BindingStrength::Strong) 2453 os << ')'; 2454 return; 2455 } 2456 } 2457 } 2458 } 2459 2460 // Pretty print addition to a negative number as a subtraction. 2461 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { 2462 if (rhsConst.getValue() < 0) { 2463 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2464 os << " - " << -rhsConst.getValue(); 2465 if (enclosingTightness == BindingStrength::Strong) 2466 os << ')'; 2467 return; 2468 } 2469 } 2470 2471 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2472 2473 os << " + "; 2474 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 2475 2476 if (enclosingTightness == BindingStrength::Strong) 2477 os << ')'; 2478 } 2479 2480 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { 2481 printAffineExprInternal(expr, BindingStrength::Weak); 2482 isEq ? os << " == 0" : os << " >= 0"; 2483 } 2484 2485 void AsmPrinter::Impl::printAffineMap(AffineMap map) { 2486 // Dimension identifiers. 2487 os << '('; 2488 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 2489 os << 'd' << i << ", "; 2490 if (map.getNumDims() >= 1) 2491 os << 'd' << map.getNumDims() - 1; 2492 os << ')'; 2493 2494 // Symbolic identifiers. 2495 if (map.getNumSymbols() != 0) { 2496 os << '['; 2497 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 2498 os << 's' << i << ", "; 2499 if (map.getNumSymbols() >= 1) 2500 os << 's' << map.getNumSymbols() - 1; 2501 os << ']'; 2502 } 2503 2504 // Result affine expressions. 2505 os << " -> ("; 2506 interleaveComma(map.getResults(), 2507 [&](AffineExpr expr) { printAffineExpr(expr); }); 2508 os << ')'; 2509 } 2510 2511 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { 2512 // Dimension identifiers. 2513 os << '('; 2514 for (unsigned i = 1; i < set.getNumDims(); ++i) 2515 os << 'd' << i - 1 << ", "; 2516 if (set.getNumDims() >= 1) 2517 os << 'd' << set.getNumDims() - 1; 2518 os << ')'; 2519 2520 // Symbolic identifiers. 2521 if (set.getNumSymbols() != 0) { 2522 os << '['; 2523 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 2524 os << 's' << i << ", "; 2525 if (set.getNumSymbols() >= 1) 2526 os << 's' << set.getNumSymbols() - 1; 2527 os << ']'; 2528 } 2529 2530 // Print constraints. 2531 os << " : ("; 2532 int numConstraints = set.getNumConstraints(); 2533 for (int i = 1; i < numConstraints; ++i) { 2534 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 2535 os << ", "; 2536 } 2537 if (numConstraints >= 1) 2538 printAffineConstraint(set.getConstraint(numConstraints - 1), 2539 set.isEq(numConstraints - 1)); 2540 os << ')'; 2541 } 2542 2543 //===----------------------------------------------------------------------===// 2544 // OperationPrinter 2545 //===----------------------------------------------------------------------===// 2546 2547 namespace { 2548 /// This class contains the logic for printing operations, regions, and blocks. 2549 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { 2550 public: 2551 using Impl = AsmPrinter::Impl; 2552 using Impl::printType; 2553 2554 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) 2555 : Impl(os, state.getPrinterFlags(), &state), 2556 OpAsmPrinter(static_cast<Impl &>(*this)) {} 2557 2558 /// Print the given top-level operation. 2559 void printTopLevelOperation(Operation *op); 2560 2561 /// Print the given operation with its indent and location. 2562 void print(Operation *op); 2563 /// Print the bare location, not including indentation/location/etc. 2564 void printOperation(Operation *op); 2565 /// Print the given operation in the generic form. 2566 void printGenericOp(Operation *op, bool printOpName) override; 2567 2568 /// Print the name of the given block. 2569 void printBlockName(Block *block); 2570 2571 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 2572 /// block are not printed. If 'printBlockTerminator' is false, the terminator 2573 /// operation of the block is not printed. 2574 void print(Block *block, bool printBlockArgs = true, 2575 bool printBlockTerminator = true); 2576 2577 /// Print the ID of the given value, optionally with its result number. 2578 void printValueID(Value value, bool printResultNo = true, 2579 raw_ostream *streamOverride = nullptr) const; 2580 2581 /// Print the ID of the given operation. 2582 void printOperationID(Operation *op, 2583 raw_ostream *streamOverride = nullptr) const; 2584 2585 //===--------------------------------------------------------------------===// 2586 // OpAsmPrinter methods 2587 //===--------------------------------------------------------------------===// 2588 2589 /// Print a newline and indent the printer to the start of the current 2590 /// operation. 2591 void printNewline() override { 2592 os << newLine; 2593 os.indent(currentIndent); 2594 } 2595 2596 /// Print a block argument in the usual format of: 2597 /// %ssaName : type {attr1=42} loc("here") 2598 /// where location printing is controlled by the standard internal option. 2599 /// You may pass omitType=true to not print a type, and pass an empty 2600 /// attribute list if you don't care for attributes. 2601 void printRegionArgument(BlockArgument arg, 2602 ArrayRef<NamedAttribute> argAttrs = {}, 2603 bool omitType = false) override; 2604 2605 /// Print the ID for the given value. 2606 void printOperand(Value value) override { printValueID(value); } 2607 void printOperand(Value value, raw_ostream &os) override { 2608 printValueID(value, /*printResultNo=*/true, &os); 2609 } 2610 2611 /// Print an optional attribute dictionary with a given set of elided values. 2612 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2613 ArrayRef<StringRef> elidedAttrs = {}) override { 2614 Impl::printOptionalAttrDict(attrs, elidedAttrs); 2615 } 2616 void printOptionalAttrDictWithKeyword( 2617 ArrayRef<NamedAttribute> attrs, 2618 ArrayRef<StringRef> elidedAttrs = {}) override { 2619 Impl::printOptionalAttrDict(attrs, elidedAttrs, 2620 /*withKeyword=*/true); 2621 } 2622 2623 /// Print the given successor. 2624 void printSuccessor(Block *successor) override; 2625 2626 /// Print an operation successor with the operands used for the block 2627 /// arguments. 2628 void printSuccessorAndUseList(Block *successor, 2629 ValueRange succOperands) override; 2630 2631 /// Print the given region. 2632 void printRegion(Region ®ion, bool printEntryBlockArgs, 2633 bool printBlockTerminators, bool printEmptyBlock) override; 2634 2635 /// Renumber the arguments for the specified region to the same names as the 2636 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 2637 /// operations. If any entry in namesToUse is null, the corresponding 2638 /// argument name is left alone. 2639 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { 2640 state->getSSANameState().shadowRegionArgs(region, namesToUse); 2641 } 2642 2643 /// Print the given affine map with the symbol and dimension operands printed 2644 /// inline with the map. 2645 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2646 ValueRange operands) override; 2647 2648 /// Print the given affine expression with the symbol and dimension operands 2649 /// printed inline with the expression. 2650 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, 2651 ValueRange symOperands) override; 2652 2653 /// Print users of this operation or id of this operation if it has no result. 2654 void printUsersComment(Operation *op); 2655 2656 /// Print users of this block arg. 2657 void printUsersComment(BlockArgument arg); 2658 2659 /// Print the users of a value. 2660 void printValueUsers(Value value); 2661 2662 /// Print either the ids of the result values or the id of the operation if 2663 /// the operation has no results. 2664 void printUserIDs(Operation *user, bool prefixComma = false); 2665 2666 private: 2667 /// This class represents a resource builder implementation for the MLIR 2668 /// textual assembly format. 2669 class ResourceBuilder : public AsmResourceBuilder { 2670 public: 2671 using ValueFn = function_ref<void(raw_ostream &)>; 2672 using PrintFn = function_ref<void(StringRef, ValueFn)>; 2673 2674 ResourceBuilder(OperationPrinter &p, PrintFn printFn) 2675 : p(p), printFn(printFn) {} 2676 ~ResourceBuilder() override = default; 2677 2678 void buildBool(StringRef key, bool data) final { 2679 printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); }); 2680 } 2681 2682 void buildString(StringRef key, StringRef data) final { 2683 printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); }); 2684 } 2685 2686 void buildBlob(StringRef key, ArrayRef<char> data, 2687 uint32_t dataAlignment) final { 2688 printFn(key, [&](raw_ostream &os) { 2689 // Store the blob in a hex string containing the alignment and the data. 2690 llvm::support::ulittle32_t dataAlignmentLE(dataAlignment); 2691 os << "\"0x" 2692 << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE), 2693 sizeof(dataAlignment))) 2694 << llvm::toHex(StringRef(data.data(), data.size())) << "\""; 2695 }); 2696 } 2697 2698 private: 2699 OperationPrinter &p; 2700 PrintFn printFn; 2701 }; 2702 2703 /// Print the metadata dictionary for the file, eliding it if it is empty. 2704 void printFileMetadataDictionary(Operation *op); 2705 2706 /// Print the resource sections for the file metadata dictionary. 2707 /// `checkAddMetadataDict` is used to indicate that metadata is going to be 2708 /// added, and the file metadata dictionary should be started if it hasn't 2709 /// yet. 2710 void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict, 2711 Operation *op); 2712 2713 // Contains the stack of default dialects to use when printing regions. 2714 // A new dialect is pushed to the stack before parsing regions nested under an 2715 // operation implementing `OpAsmOpInterface`, and popped when done. At the 2716 // top-level we start with "builtin" as the default, so that the top-level 2717 // `module` operation prints as-is. 2718 SmallVector<StringRef> defaultDialectStack{"builtin"}; 2719 2720 /// The number of spaces used for indenting nested operations. 2721 const static unsigned indentWidth = 2; 2722 2723 // This is the current indentation level for nested structures. 2724 unsigned currentIndent = 0; 2725 }; 2726 } // namespace 2727 2728 void OperationPrinter::printTopLevelOperation(Operation *op) { 2729 // Output the aliases at the top level that can't be deferred. 2730 state->getAliasState().printNonDeferredAliases(os, newLine); 2731 2732 // Print the module. 2733 print(op); 2734 os << newLine; 2735 2736 // Output the aliases at the top level that can be deferred. 2737 state->getAliasState().printDeferredAliases(os, newLine); 2738 2739 // Output any file level metadata. 2740 printFileMetadataDictionary(op); 2741 } 2742 2743 void OperationPrinter::printFileMetadataDictionary(Operation *op) { 2744 bool sawMetadataEntry = false; 2745 auto checkAddMetadataDict = [&] { 2746 if (!std::exchange(sawMetadataEntry, true)) 2747 os << newLine << "{-#" << newLine; 2748 }; 2749 2750 // Add the various types of metadata. 2751 printResourceFileMetadata(checkAddMetadataDict, op); 2752 2753 // If the file dictionary exists, close it. 2754 if (sawMetadataEntry) 2755 os << newLine << "#-}" << newLine; 2756 } 2757 2758 void OperationPrinter::printResourceFileMetadata( 2759 function_ref<void()> checkAddMetadataDict, Operation *op) { 2760 // Functor used to add data entries to the file metadata dictionary. 2761 bool hadResource = false; 2762 auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, 2763 auto &&...providerArgs) { 2764 bool hadEntry = false; 2765 auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { 2766 checkAddMetadataDict(); 2767 2768 // Emit the top-level resource entry if we haven't yet. 2769 if (!std::exchange(hadResource, true)) 2770 os << " " << dictName << "_resources: {" << newLine; 2771 // Emit the parent resource entry if we haven't yet. 2772 if (!std::exchange(hadEntry, true)) 2773 os << " " << name << ": {" << newLine; 2774 else 2775 os << "," << newLine; 2776 2777 os << " " << key << ": "; 2778 valueFn(os); 2779 }; 2780 ResourceBuilder entryBuilder(*this, printFn); 2781 provider.buildResources(op, providerArgs..., entryBuilder); 2782 2783 if (hadEntry) 2784 os << newLine << " }"; 2785 }; 2786 2787 // Print the `dialect_resources` section if we have any dialects with 2788 // resources. 2789 for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) { 2790 StringRef name = interface.getDialect()->getNamespace(); 2791 auto it = dialectResources.find(interface.getDialect()); 2792 if (it != dialectResources.end()) 2793 processProvider("dialect", name, interface, it->second); 2794 else 2795 processProvider("dialect", name, interface, 2796 SetVector<AsmDialectResourceHandle>()); 2797 } 2798 if (hadResource) 2799 os << newLine << " }"; 2800 2801 // Print the `external_resources` section if we have any external clients with 2802 // resources. 2803 hadResource = false; 2804 for (const auto &printer : state->getResourcePrinters()) 2805 processProvider("external", printer.getName(), printer); 2806 if (hadResource) 2807 os << newLine << " }"; 2808 } 2809 2810 /// Print a block argument in the usual format of: 2811 /// %ssaName : type {attr1=42} loc("here") 2812 /// where location printing is controlled by the standard internal option. 2813 /// You may pass omitType=true to not print a type, and pass an empty 2814 /// attribute list if you don't care for attributes. 2815 void OperationPrinter::printRegionArgument(BlockArgument arg, 2816 ArrayRef<NamedAttribute> argAttrs, 2817 bool omitType) { 2818 printOperand(arg); 2819 if (!omitType) { 2820 os << ": "; 2821 printType(arg.getType()); 2822 } 2823 printOptionalAttrDict(argAttrs); 2824 // TODO: We should allow location aliases on block arguments. 2825 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2826 } 2827 2828 void OperationPrinter::print(Operation *op) { 2829 // Track the location of this operation. 2830 state->registerOperationLocation(op, newLine.curLine, currentIndent); 2831 2832 os.indent(currentIndent); 2833 printOperation(op); 2834 printTrailingLocation(op->getLoc()); 2835 if (printerFlags.shouldPrintValueUsers()) 2836 printUsersComment(op); 2837 } 2838 2839 void OperationPrinter::printOperation(Operation *op) { 2840 if (size_t numResults = op->getNumResults()) { 2841 auto printResultGroup = [&](size_t resultNo, size_t resultCount) { 2842 printValueID(op->getResult(resultNo), /*printResultNo=*/false); 2843 if (resultCount > 1) 2844 os << ':' << resultCount; 2845 }; 2846 2847 // Check to see if this operation has multiple result groups. 2848 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); 2849 if (!resultGroups.empty()) { 2850 // Interleave the groups excluding the last one, this one will be handled 2851 // separately. 2852 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { 2853 printResultGroup(resultGroups[i], 2854 resultGroups[i + 1] - resultGroups[i]); 2855 }); 2856 os << ", "; 2857 printResultGroup(resultGroups.back(), numResults - resultGroups.back()); 2858 2859 } else { 2860 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); 2861 } 2862 2863 os << " = "; 2864 } 2865 2866 // If requested, always print the generic form. 2867 if (!printerFlags.shouldPrintGenericOpForm()) { 2868 // Check to see if this is a known operation. If so, use the registered 2869 // custom printer hook. 2870 if (auto opInfo = op->getRegisteredInfo()) { 2871 opInfo->printAssembly(op, *this, defaultDialectStack.back()); 2872 return; 2873 } 2874 // Otherwise try to dispatch to the dialect, if available. 2875 if (Dialect *dialect = op->getDialect()) { 2876 if (auto opPrinter = dialect->getOperationPrinter(op)) { 2877 // Print the op name first. 2878 StringRef name = op->getName().getStringRef(); 2879 // Only drop the default dialect prefix when it cannot lead to 2880 // ambiguities. 2881 if (name.count('.') == 1) 2882 name.consume_front((defaultDialectStack.back() + ".").str()); 2883 os << name; 2884 2885 // Print the rest of the op now. 2886 opPrinter(op, *this); 2887 return; 2888 } 2889 } 2890 } 2891 2892 // Otherwise print with the generic assembly form. 2893 printGenericOp(op, /*printOpName=*/true); 2894 } 2895 2896 void OperationPrinter::printUsersComment(Operation *op) { 2897 unsigned numResults = op->getNumResults(); 2898 if (!numResults && op->getNumOperands()) { 2899 os << " // id: "; 2900 printOperationID(op); 2901 } else if (numResults && op->use_empty()) { 2902 os << " // unused"; 2903 } else if (numResults && !op->use_empty()) { 2904 // Print "user" if the operation has one result used to compute one other 2905 // result, or is used in one operation with no result. 2906 unsigned usedInNResults = 0; 2907 unsigned usedInNOperations = 0; 2908 SmallPtrSet<Operation *, 1> userSet; 2909 for (Operation *user : op->getUsers()) { 2910 if (userSet.insert(user).second) { 2911 ++usedInNOperations; 2912 usedInNResults += user->getNumResults(); 2913 } 2914 } 2915 2916 // We already know that users is not empty. 2917 bool exactlyOneUniqueUse = 2918 usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1; 2919 os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": "; 2920 bool shouldPrintBrackets = numResults > 1; 2921 auto printOpResult = [&](OpResult opResult) { 2922 if (shouldPrintBrackets) 2923 os << "("; 2924 printValueUsers(opResult); 2925 if (shouldPrintBrackets) 2926 os << ")"; 2927 }; 2928 2929 interleaveComma(op->getResults(), printOpResult); 2930 } 2931 } 2932 2933 void OperationPrinter::printUsersComment(BlockArgument arg) { 2934 os << "// "; 2935 printValueID(arg); 2936 if (arg.use_empty()) { 2937 os << " is unused"; 2938 } else { 2939 os << " is used by "; 2940 printValueUsers(arg); 2941 } 2942 os << newLine; 2943 } 2944 2945 void OperationPrinter::printValueUsers(Value value) { 2946 if (value.use_empty()) 2947 os << "unused"; 2948 2949 // One value might be used as the operand of an operation more than once. 2950 // Only print the operations results once in that case. 2951 SmallPtrSet<Operation *, 1> userSet; 2952 for (auto &indexedUser : enumerate(value.getUsers())) { 2953 if (userSet.insert(indexedUser.value()).second) 2954 printUserIDs(indexedUser.value(), indexedUser.index()); 2955 } 2956 } 2957 2958 void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { 2959 if (prefixComma) 2960 os << ", "; 2961 2962 if (!user->getNumResults()) { 2963 printOperationID(user); 2964 } else { 2965 interleaveComma(user->getResults(), 2966 [this](Value result) { printValueID(result); }); 2967 } 2968 } 2969 2970 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { 2971 if (printOpName) 2972 printEscapedString(op->getName().getStringRef()); 2973 os << '('; 2974 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); 2975 os << ')'; 2976 2977 // For terminators, print the list of successors and their operands. 2978 if (op->getNumSuccessors() != 0) { 2979 os << '['; 2980 interleaveComma(op->getSuccessors(), 2981 [&](Block *successor) { printBlockName(successor); }); 2982 os << ']'; 2983 } 2984 2985 // Print regions. 2986 if (op->getNumRegions() != 0) { 2987 os << " ("; 2988 interleaveComma(op->getRegions(), [&](Region ®ion) { 2989 printRegion(region, /*printEntryBlockArgs=*/true, 2990 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); 2991 }); 2992 os << ')'; 2993 } 2994 2995 auto attrs = op->getAttrs(); 2996 printOptionalAttrDict(attrs); 2997 2998 // Print the type signature of the operation. 2999 os << " : "; 3000 printFunctionalType(op); 3001 } 3002 3003 void OperationPrinter::printBlockName(Block *block) { 3004 os << state->getSSANameState().getBlockInfo(block).name; 3005 } 3006 3007 void OperationPrinter::print(Block *block, bool printBlockArgs, 3008 bool printBlockTerminator) { 3009 // Print the block label and argument list if requested. 3010 if (printBlockArgs) { 3011 os.indent(currentIndent); 3012 printBlockName(block); 3013 3014 // Print the argument list if non-empty. 3015 if (!block->args_empty()) { 3016 os << '('; 3017 interleaveComma(block->getArguments(), [&](BlockArgument arg) { 3018 printValueID(arg); 3019 os << ": "; 3020 printType(arg.getType()); 3021 // TODO: We should allow location aliases on block arguments. 3022 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 3023 }); 3024 os << ')'; 3025 } 3026 os << ':'; 3027 3028 // Print out some context information about the predecessors of this block. 3029 if (!block->getParent()) { 3030 os << " // block is not in a region!"; 3031 } else if (block->hasNoPredecessors()) { 3032 if (!block->isEntryBlock()) 3033 os << " // no predecessors"; 3034 } else if (auto *pred = block->getSinglePredecessor()) { 3035 os << " // pred: "; 3036 printBlockName(pred); 3037 } else { 3038 // We want to print the predecessors in a stable order, not in 3039 // whatever order the use-list is in, so gather and sort them. 3040 SmallVector<BlockInfo, 4> predIDs; 3041 for (auto *pred : block->getPredecessors()) 3042 predIDs.push_back(state->getSSANameState().getBlockInfo(pred)); 3043 llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { 3044 return lhs.ordering < rhs.ordering; 3045 }); 3046 3047 os << " // " << predIDs.size() << " preds: "; 3048 3049 interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; }); 3050 } 3051 os << newLine; 3052 } 3053 3054 currentIndent += indentWidth; 3055 3056 if (printerFlags.shouldPrintValueUsers()) { 3057 for (BlockArgument arg : block->getArguments()) { 3058 os.indent(currentIndent); 3059 printUsersComment(arg); 3060 } 3061 } 3062 3063 bool hasTerminator = 3064 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 3065 auto range = llvm::make_range( 3066 block->begin(), 3067 std::prev(block->end(), 3068 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 3069 for (auto &op : range) { 3070 print(&op); 3071 os << newLine; 3072 } 3073 currentIndent -= indentWidth; 3074 } 3075 3076 void OperationPrinter::printValueID(Value value, bool printResultNo, 3077 raw_ostream *streamOverride) const { 3078 state->getSSANameState().printValueID(value, printResultNo, 3079 streamOverride ? *streamOverride : os); 3080 } 3081 3082 void OperationPrinter::printOperationID(Operation *op, 3083 raw_ostream *streamOverride) const { 3084 state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride 3085 : os); 3086 } 3087 3088 void OperationPrinter::printSuccessor(Block *successor) { 3089 printBlockName(successor); 3090 } 3091 3092 void OperationPrinter::printSuccessorAndUseList(Block *successor, 3093 ValueRange succOperands) { 3094 printBlockName(successor); 3095 if (succOperands.empty()) 3096 return; 3097 3098 os << '('; 3099 interleaveComma(succOperands, 3100 [this](Value operand) { printValueID(operand); }); 3101 os << " : "; 3102 interleaveComma(succOperands, 3103 [this](Value operand) { printType(operand.getType()); }); 3104 os << ')'; 3105 } 3106 3107 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, 3108 bool printBlockTerminators, 3109 bool printEmptyBlock) { 3110 os << "{" << newLine; 3111 if (!region.empty()) { 3112 auto restoreDefaultDialect = 3113 llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); }); 3114 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) 3115 defaultDialectStack.push_back(iface.getDefaultDialect()); 3116 else 3117 defaultDialectStack.push_back(""); 3118 3119 auto *entryBlock = ®ion.front(); 3120 // Force printing the block header if printEmptyBlock is set and the block 3121 // is empty or if printEntryBlockArgs is set and there are arguments to 3122 // print. 3123 bool shouldAlwaysPrintBlockHeader = 3124 (printEmptyBlock && entryBlock->empty()) || 3125 (printEntryBlockArgs && entryBlock->getNumArguments() != 0); 3126 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); 3127 for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) 3128 print(&b); 3129 } 3130 os.indent(currentIndent) << "}"; 3131 } 3132 3133 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, 3134 ValueRange operands) { 3135 AffineMap map = mapAttr.getValue(); 3136 unsigned numDims = map.getNumDims(); 3137 auto printValueName = [&](unsigned pos, bool isSymbol) { 3138 unsigned index = isSymbol ? numDims + pos : pos; 3139 assert(index < operands.size()); 3140 if (isSymbol) 3141 os << "symbol("; 3142 printValueID(operands[index]); 3143 if (isSymbol) 3144 os << ')'; 3145 }; 3146 3147 interleaveComma(map.getResults(), [&](AffineExpr expr) { 3148 printAffineExpr(expr, printValueName); 3149 }); 3150 } 3151 3152 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, 3153 ValueRange dimOperands, 3154 ValueRange symOperands) { 3155 auto printValueName = [&](unsigned pos, bool isSymbol) { 3156 if (!isSymbol) 3157 return printValueID(dimOperands[pos]); 3158 os << "symbol("; 3159 printValueID(symOperands[pos]); 3160 os << ')'; 3161 }; 3162 printAffineExpr(expr, printValueName); 3163 } 3164 3165 //===----------------------------------------------------------------------===// 3166 // print and dump methods 3167 //===----------------------------------------------------------------------===// 3168 3169 void Attribute::print(raw_ostream &os) const { 3170 AsmPrinter::Impl(os).printAttribute(*this); 3171 } 3172 3173 void Attribute::dump() const { 3174 print(llvm::errs()); 3175 llvm::errs() << "\n"; 3176 } 3177 3178 void Type::print(raw_ostream &os) const { 3179 AsmPrinter::Impl(os).printType(*this); 3180 } 3181 3182 void Type::dump() const { print(llvm::errs()); } 3183 3184 void AffineMap::dump() const { 3185 print(llvm::errs()); 3186 llvm::errs() << "\n"; 3187 } 3188 3189 void IntegerSet::dump() const { 3190 print(llvm::errs()); 3191 llvm::errs() << "\n"; 3192 } 3193 3194 void AffineExpr::print(raw_ostream &os) const { 3195 if (!expr) { 3196 os << "<<NULL AFFINE EXPR>>"; 3197 return; 3198 } 3199 AsmPrinter::Impl(os).printAffineExpr(*this); 3200 } 3201 3202 void AffineExpr::dump() const { 3203 print(llvm::errs()); 3204 llvm::errs() << "\n"; 3205 } 3206 3207 void AffineMap::print(raw_ostream &os) const { 3208 if (!map) { 3209 os << "<<NULL AFFINE MAP>>"; 3210 return; 3211 } 3212 AsmPrinter::Impl(os).printAffineMap(*this); 3213 } 3214 3215 void IntegerSet::print(raw_ostream &os) const { 3216 AsmPrinter::Impl(os).printIntegerSet(*this); 3217 } 3218 3219 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); } 3220 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) { 3221 if (!impl) { 3222 os << "<<NULL VALUE>>"; 3223 return; 3224 } 3225 3226 if (auto *op = getDefiningOp()) 3227 return op->print(os, flags); 3228 // TODO: Improve BlockArgument print'ing. 3229 BlockArgument arg = this->cast<BlockArgument>(); 3230 os << "<block argument> of type '" << arg.getType() 3231 << "' at index: " << arg.getArgNumber(); 3232 } 3233 void Value::print(raw_ostream &os, AsmState &state) { 3234 if (!impl) { 3235 os << "<<NULL VALUE>>"; 3236 return; 3237 } 3238 3239 if (auto *op = getDefiningOp()) 3240 return op->print(os, state); 3241 3242 // TODO: Improve BlockArgument print'ing. 3243 BlockArgument arg = this->cast<BlockArgument>(); 3244 os << "<block argument> of type '" << arg.getType() 3245 << "' at index: " << arg.getArgNumber(); 3246 } 3247 3248 void Value::dump() { 3249 print(llvm::errs()); 3250 llvm::errs() << "\n"; 3251 } 3252 3253 void Value::printAsOperand(raw_ostream &os, AsmState &state) { 3254 // TODO: This doesn't necessarily capture all potential cases. 3255 // Currently, region arguments can be shadowed when printing the main 3256 // operation. If the IR hasn't been printed, this will produce the old SSA 3257 // name and not the shadowed name. 3258 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, 3259 os); 3260 } 3261 3262 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { 3263 // If this is a top level operation, we also print aliases. 3264 if (!getParent() && !printerFlags.shouldUseLocalScope()) { 3265 AsmState state(this, printerFlags); 3266 state.getImpl().initializeAliases(this); 3267 print(os, state); 3268 return; 3269 } 3270 3271 // Find the operation to number from based upon the provided flags. 3272 Operation *op = this; 3273 bool shouldUseLocalScope = printerFlags.shouldUseLocalScope(); 3274 do { 3275 // If we are printing local scope, stop at the first operation that is 3276 // isolated from above. 3277 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 3278 break; 3279 3280 // Otherwise, traverse up to the next parent. 3281 Operation *parentOp = op->getParentOp(); 3282 if (!parentOp) 3283 break; 3284 op = parentOp; 3285 } while (true); 3286 3287 AsmState state(op, printerFlags); 3288 print(os, state); 3289 } 3290 void Operation::print(raw_ostream &os, AsmState &state) { 3291 OperationPrinter printer(os, state.getImpl()); 3292 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) 3293 printer.printTopLevelOperation(this); 3294 else 3295 printer.print(this); 3296 } 3297 3298 void Operation::dump() { 3299 print(llvm::errs(), OpPrintingFlags().useLocalScope()); 3300 llvm::errs() << "\n"; 3301 } 3302 3303 void Block::print(raw_ostream &os) { 3304 Operation *parentOp = getParentOp(); 3305 if (!parentOp) { 3306 os << "<<UNLINKED BLOCK>>\n"; 3307 return; 3308 } 3309 // Get the top-level op. 3310 while (auto *nextOp = parentOp->getParentOp()) 3311 parentOp = nextOp; 3312 3313 AsmState state(parentOp); 3314 print(os, state); 3315 } 3316 void Block::print(raw_ostream &os, AsmState &state) { 3317 OperationPrinter(os, state.getImpl()).print(this); 3318 } 3319 3320 void Block::dump() { print(llvm::errs()); } 3321 3322 /// Print out the name of the block without printing its body. 3323 void Block::printAsOperand(raw_ostream &os, bool printType) { 3324 Operation *parentOp = getParentOp(); 3325 if (!parentOp) { 3326 os << "<<UNLINKED BLOCK>>\n"; 3327 return; 3328 } 3329 AsmState state(parentOp); 3330 printAsOperand(os, state); 3331 } 3332 void Block::printAsOperand(raw_ostream &os, AsmState &state) { 3333 OperationPrinter printer(os, state.getImpl()); 3334 printer.printBlockName(this); 3335 } 3336