1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the MLIR AsmPrinter class, which is used to implement 10 // the various print() methods on the core IR objects. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/AffineExpr.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/AsmState.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/Operation.h" 25 #include "mlir/IR/SubElementInterfaces.h" 26 #include "llvm/ADT/APFloat.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/MapVector.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/ScopedHashTable.h" 31 #include "llvm/ADT/SetVector.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/StringExtras.h" 34 #include "llvm/ADT/StringSet.h" 35 #include "llvm/ADT/TypeSwitch.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/Endian.h" 38 #include "llvm/Support/Regex.h" 39 #include "llvm/Support/SaveAndRestore.h" 40 41 #include <tuple> 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 void Identifier::print(raw_ostream &os) const { os << str(); } 47 48 void Identifier::dump() const { print(llvm::errs()); } 49 50 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 51 52 void OperationName::dump() const { print(llvm::errs()); } 53 54 DialectAsmPrinter::~DialectAsmPrinter() {} 55 56 //===--------------------------------------------------------------------===// 57 // OpAsmPrinter 58 //===--------------------------------------------------------------------===// 59 60 OpAsmPrinter::~OpAsmPrinter() {} 61 62 void OpAsmPrinter::printFunctionalType(Operation *op) { 63 auto &os = getStream(); 64 os << '('; 65 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { 66 // Print the types of null values as <<NULL TYPE>>. 67 *this << (operand ? operand.getType() : Type()); 68 }); 69 os << ") -> "; 70 71 // Print the result list. We don't parenthesize single result types unless 72 // it is a function (avoiding a grammar ambiguity). 73 bool wrapped = op->getNumResults() != 1; 74 if (!wrapped && op->getResult(0).getType() && 75 op->getResult(0).getType().isa<FunctionType>()) 76 wrapped = true; 77 78 if (wrapped) 79 os << '('; 80 81 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { 82 // Print the types of null values as <<NULL TYPE>>. 83 *this << (result ? result.getType() : Type()); 84 }); 85 86 if (wrapped) 87 os << ')'; 88 } 89 90 //===--------------------------------------------------------------------===// 91 // Operation OpAsm interface. 92 //===--------------------------------------------------------------------===// 93 94 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 95 #include "mlir/IR/OpAsmInterface.cpp.inc" 96 97 //===----------------------------------------------------------------------===// 98 // OpPrintingFlags 99 //===----------------------------------------------------------------------===// 100 101 namespace { 102 /// This struct contains command line options that can be used to initialize 103 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need 104 /// for global command line options. 105 struct AsmPrinterOptions { 106 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ 107 "mlir-print-elementsattrs-with-hex-if-larger", 108 llvm::cl::desc( 109 "Print DenseElementsAttrs with a hex string that have " 110 "more elements than the given upper limit (use -1 to disable)")}; 111 112 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ 113 "mlir-elide-elementsattrs-if-larger", 114 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " 115 "more elements than the given upper limit")}; 116 117 llvm::cl::opt<bool> printDebugInfoOpt{ 118 "mlir-print-debuginfo", llvm::cl::init(false), 119 llvm::cl::desc("Print debug info in MLIR output")}; 120 121 llvm::cl::opt<bool> printPrettyDebugInfoOpt{ 122 "mlir-pretty-debuginfo", llvm::cl::init(false), 123 llvm::cl::desc("Print pretty debug info in MLIR output")}; 124 125 // Use the generic op output form in the operation printer even if the custom 126 // form is defined. 127 llvm::cl::opt<bool> printGenericOpFormOpt{ 128 "mlir-print-op-generic", llvm::cl::init(false), 129 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; 130 131 llvm::cl::opt<bool> printLocalScopeOpt{ 132 "mlir-print-local-scope", llvm::cl::init(false), 133 llvm::cl::desc("Print assuming in local scope by default"), 134 llvm::cl::Hidden}; 135 }; 136 } // end anonymous namespace 137 138 static llvm::ManagedStatic<AsmPrinterOptions> clOptions; 139 140 /// Register a set of useful command-line options that can be used to configure 141 /// various flags within the AsmPrinter. 142 void mlir::registerAsmPrinterCLOptions() { 143 // Make sure that the options struct has been initialized. 144 *clOptions; 145 } 146 147 /// Initialize the printing flags with default supplied by the cl::opts above. 148 OpPrintingFlags::OpPrintingFlags() 149 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), 150 printGenericOpFormFlag(false), printLocalScope(false) { 151 // Initialize based upon command line options, if they are available. 152 if (!clOptions.isConstructed()) 153 return; 154 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) 155 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; 156 printDebugInfoFlag = clOptions->printDebugInfoOpt; 157 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; 158 printGenericOpFormFlag = clOptions->printGenericOpFormOpt; 159 printLocalScope = clOptions->printLocalScopeOpt; 160 } 161 162 /// Enable the elision of large elements attributes, by printing a '...' 163 /// instead of the element data, when the number of elements is greater than 164 /// `largeElementLimit`. Note: The IR generated with this option is not 165 /// parsable. 166 OpPrintingFlags & 167 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { 168 elementsAttrElementLimit = largeElementLimit; 169 return *this; 170 } 171 172 /// Enable printing of debug information. If 'prettyForm' is set to true, 173 /// debug information is printed in a more readable 'pretty' form. 174 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { 175 printDebugInfoFlag = true; 176 printDebugInfoPrettyFormFlag = prettyForm; 177 return *this; 178 } 179 180 /// Always print operations in the generic form. 181 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { 182 printGenericOpFormFlag = true; 183 return *this; 184 } 185 186 /// Use local scope when printing the operation. This allows for using the 187 /// printer in a more localized and thread-safe setting, but may not necessarily 188 /// be identical of what the IR will look like when dumping the full module. 189 OpPrintingFlags &OpPrintingFlags::useLocalScope() { 190 printLocalScope = true; 191 return *this; 192 } 193 194 /// Return if the given ElementsAttr should be elided. 195 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { 196 return elementsAttrElementLimit.hasValue() && 197 *elementsAttrElementLimit < int64_t(attr.getNumElements()) && 198 !attr.isa<SplatElementsAttr>(); 199 } 200 201 /// Return the size limit for printing large ElementsAttr. 202 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { 203 return elementsAttrElementLimit; 204 } 205 206 /// Return if debug information should be printed. 207 bool OpPrintingFlags::shouldPrintDebugInfo() const { 208 return printDebugInfoFlag; 209 } 210 211 /// Return if debug information should be printed in the pretty form. 212 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { 213 return printDebugInfoPrettyFormFlag; 214 } 215 216 /// Return if operations should be printed in the generic form. 217 bool OpPrintingFlags::shouldPrintGenericOpForm() const { 218 return printGenericOpFormFlag; 219 } 220 221 /// Return if the printer should use local scope when dumping the IR. 222 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } 223 224 /// Returns true if an ElementsAttr with the given number of elements should be 225 /// printed with hex. 226 static bool shouldPrintElementsAttrWithHex(int64_t numElements) { 227 // Check to see if a command line option was provided for the limit. 228 if (clOptions.isConstructed()) { 229 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) { 230 // -1 is used to disable hex printing. 231 if (clOptions->printElementsAttrWithHexIfLarger == -1) 232 return false; 233 return numElements > clOptions->printElementsAttrWithHexIfLarger; 234 } 235 } 236 237 // Otherwise, default to printing with hex if the number of elements is >100. 238 return numElements > 100; 239 } 240 241 //===----------------------------------------------------------------------===// 242 // NewLineCounter 243 //===----------------------------------------------------------------------===// 244 245 namespace { 246 /// This class is a simple formatter that emits a new line when inputted into a 247 /// stream, that enables counting the number of newlines emitted. This class 248 /// should be used whenever emitting newlines in the printer. 249 struct NewLineCounter { 250 unsigned curLine = 1; 251 }; 252 } // end anonymous namespace 253 254 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { 255 ++newLine.curLine; 256 return os << '\n'; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // AliasInitializer 261 //===----------------------------------------------------------------------===// 262 263 namespace { 264 /// This class represents a specific instance of a symbol Alias. 265 class SymbolAlias { 266 public: 267 SymbolAlias(StringRef name, bool isDeferrable) 268 : name(name), suffixIndex(0), hasSuffixIndex(false), 269 isDeferrable(isDeferrable) {} 270 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) 271 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true), 272 isDeferrable(isDeferrable) {} 273 274 /// Print this alias to the given stream. 275 void print(raw_ostream &os) const { 276 os << name; 277 if (hasSuffixIndex) 278 os << suffixIndex; 279 } 280 281 /// Returns true if this alias supports deferred resolution when parsing. 282 bool canBeDeferred() const { return isDeferrable; } 283 284 private: 285 /// The main name of the alias. 286 StringRef name; 287 /// The optional suffix index of the alias, if multiple aliases had the same 288 /// name. 289 uint32_t suffixIndex : 30; 290 /// A flag indicating whether this alias has a suffix or not. 291 bool hasSuffixIndex : 1; 292 /// A flag indicating whether this alias may be deferred or not. 293 bool isDeferrable : 1; 294 }; 295 296 /// This class represents a utility that initializes the set of attribute and 297 /// type aliases, without the need to store the extra information within the 298 /// main AliasState class or pass it around via function arguments. 299 class AliasInitializer { 300 public: 301 AliasInitializer( 302 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, 303 llvm::BumpPtrAllocator &aliasAllocator) 304 : interfaces(interfaces), aliasAllocator(aliasAllocator), 305 aliasOS(aliasBuffer) {} 306 307 void initialize(Operation *op, const OpPrintingFlags &printerFlags, 308 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 309 llvm::MapVector<Type, SymbolAlias> &typeToAlias); 310 311 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is 312 /// set to true if the originator of this attribute can resolve the alias 313 /// after parsing has completed (e.g. in the case of operation locations). 314 void visit(Attribute attr, bool canBeDeferred = false); 315 316 /// Visit the given type to see if it has an alias. 317 void visit(Type type); 318 319 private: 320 /// Try to generate an alias for the provided symbol. If an alias is 321 /// generated, the provided alias mapping and reverse mapping are updated. 322 /// Returns success if an alias was generated, failure otherwise. 323 template <typename T> 324 LogicalResult 325 generateAlias(T symbol, 326 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol); 327 328 /// The set of asm interfaces within the context. 329 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 330 331 /// Mapping between an alias and the set of symbols mapped to it. 332 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr; 333 llvm::MapVector<StringRef, std::vector<Type>> aliasToType; 334 335 /// An allocator used for alias names. 336 llvm::BumpPtrAllocator &aliasAllocator; 337 338 /// The set of visited attributes. 339 DenseSet<Attribute> visitedAttributes; 340 341 /// The set of attributes that have aliases *and* can be deferred. 342 DenseSet<Attribute> deferrableAttributes; 343 344 /// The set of visited types. 345 DenseSet<Type> visitedTypes; 346 347 /// Storage and stream used when generating an alias. 348 SmallString<32> aliasBuffer; 349 llvm::raw_svector_ostream aliasOS; 350 }; 351 352 /// This class implements a dummy OpAsmPrinter that doesn't print any output, 353 /// and merely collects the attributes and types that *would* be printed in a 354 /// normal print invocation so that we can generate proper aliases. This allows 355 /// for us to generate aliases only for the attributes and types that would be 356 /// in the output, and trims down unnecessary output. 357 class DummyAliasOperationPrinter : private OpAsmPrinter { 358 public: 359 explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags, 360 AliasInitializer &initializer) 361 : printerFlags(flags), initializer(initializer) {} 362 363 /// Print the given operation. 364 void print(Operation *op) { 365 // Visit the operation location. 366 if (printerFlags.shouldPrintDebugInfo()) 367 initializer.visit(op->getLoc(), /*canBeDeferred=*/true); 368 369 // If requested, always print the generic form. 370 if (!printerFlags.shouldPrintGenericOpForm()) { 371 // Check to see if this is a known operation. If so, use the registered 372 // custom printer hook. 373 if (auto *opInfo = op->getAbstractOperation()) { 374 opInfo->printAssembly(op, *this); 375 return; 376 } 377 } 378 379 // Otherwise print with the generic assembly form. 380 printGenericOp(op); 381 } 382 383 private: 384 /// Print the given operation in the generic form. 385 void printGenericOp(Operation *op) override { 386 // Consider nested operations for aliases. 387 if (op->getNumRegions() != 0) { 388 for (Region ®ion : op->getRegions()) 389 printRegion(region, /*printEntryBlockArgs=*/true, 390 /*printBlockTerminators=*/true); 391 } 392 393 // Visit all the types used in the operation. 394 for (Type type : op->getOperandTypes()) 395 printType(type); 396 for (Type type : op->getResultTypes()) 397 printType(type); 398 399 // Consider the attributes of the operation for aliases. 400 for (const NamedAttribute &attr : op->getAttrs()) 401 printAttribute(attr.second); 402 } 403 404 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 405 /// block are not printed. If 'printBlockTerminator' is false, the terminator 406 /// operation of the block is not printed. 407 void print(Block *block, bool printBlockArgs = true, 408 bool printBlockTerminator = true) { 409 // Consider the types of the block arguments for aliases if 'printBlockArgs' 410 // is set to true. 411 if (printBlockArgs) { 412 for (BlockArgument arg : block->getArguments()) { 413 printType(arg.getType()); 414 415 // Visit the argument location. 416 if (printerFlags.shouldPrintDebugInfo()) 417 // TODO: Allow deferring argument locations. 418 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 419 } 420 } 421 422 // Consider the operations within this block, ignoring the terminator if 423 // requested. 424 bool hasTerminator = 425 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 426 auto range = llvm::make_range( 427 block->begin(), 428 std::prev(block->end(), 429 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 430 for (Operation &op : range) 431 print(&op); 432 } 433 434 /// Print the given region. 435 void printRegion(Region ®ion, bool printEntryBlockArgs, 436 bool printBlockTerminators, 437 bool printEmptyBlock = false) override { 438 if (region.empty()) 439 return; 440 441 auto *entryBlock = ®ion.front(); 442 print(entryBlock, printEntryBlockArgs, printBlockTerminators); 443 for (Block &b : llvm::drop_begin(region, 1)) 444 print(&b); 445 } 446 447 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, 448 bool omitType) override { 449 printType(arg.getType()); 450 // Visit the argument location. 451 if (printerFlags.shouldPrintDebugInfo()) 452 // TODO: Allow deferring argument locations. 453 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 454 } 455 456 /// Consider the given type to be printed for an alias. 457 void printType(Type type) override { initializer.visit(type); } 458 459 /// Consider the given attribute to be printed for an alias. 460 void printAttribute(Attribute attr) override { initializer.visit(attr); } 461 void printAttributeWithoutType(Attribute attr) override { 462 printAttribute(attr); 463 } 464 465 /// Print the given set of attributes with names not included within 466 /// 'elidedAttrs'. 467 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 468 ArrayRef<StringRef> elidedAttrs = {}) override { 469 if (attrs.empty()) 470 return; 471 if (elidedAttrs.empty()) { 472 for (const NamedAttribute &attr : attrs) 473 printAttribute(attr.second); 474 return; 475 } 476 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 477 elidedAttrs.end()); 478 for (const NamedAttribute &attr : attrs) 479 if (!elidedAttrsSet.contains(attr.first.strref())) 480 printAttribute(attr.second); 481 } 482 void printOptionalAttrDictWithKeyword( 483 ArrayRef<NamedAttribute> attrs, 484 ArrayRef<StringRef> elidedAttrs = {}) override { 485 printOptionalAttrDict(attrs, elidedAttrs); 486 } 487 488 /// Return a null stream as the output stream, this will ignore any data fed 489 /// to it. 490 raw_ostream &getStream() const override { return os; } 491 492 /// The following are hooks of `OpAsmPrinter` that are not necessary for 493 /// determining potential aliases. 494 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} 495 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} 496 void printNewline() override {} 497 void printOperand(Value) override {} 498 void printOperand(Value, raw_ostream &os) override { 499 // Users expect the output string to have at least the prefixed % to signal 500 // a value name. To maintain this invariant, emit a name even if it is 501 // guaranteed to go unused. 502 os << "%"; 503 } 504 void printSymbolName(StringRef) override {} 505 void printSuccessor(Block *) override {} 506 void printSuccessorAndUseList(Block *, ValueRange) override {} 507 void shadowRegionArgs(Region &, ValueRange) override {} 508 509 /// The printer flags to use when determining potential aliases. 510 const OpPrintingFlags &printerFlags; 511 512 /// The initializer to use when identifying aliases. 513 AliasInitializer &initializer; 514 515 /// A dummy output stream. 516 mutable llvm::raw_null_ostream os; 517 }; 518 } // end anonymous namespace 519 520 /// Sanitize the given name such that it can be used as a valid identifier. If 521 /// the string needs to be modified in any way, the provided buffer is used to 522 /// store the new copy, 523 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, 524 StringRef allowedPunctChars = "$._-", 525 bool allowTrailingDigit = true) { 526 assert(!name.empty() && "Shouldn't have an empty name here"); 527 528 auto copyNameToBuffer = [&] { 529 for (char ch : name) { 530 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) 531 buffer.push_back(ch); 532 else if (ch == ' ') 533 buffer.push_back('_'); 534 else 535 buffer.append(llvm::utohexstr((unsigned char)ch)); 536 } 537 }; 538 539 // Check to see if this name is valid. If it starts with a digit, then it 540 // could conflict with the autogenerated numeric ID's, so add an underscore 541 // prefix to avoid problems. 542 if (isdigit(name[0])) { 543 buffer.push_back('_'); 544 copyNameToBuffer(); 545 return buffer; 546 } 547 548 // If the name ends with a trailing digit, add a '_' to avoid potential 549 // conflicts with autogenerated ID's. 550 if (!allowTrailingDigit && isdigit(name.back())) { 551 copyNameToBuffer(); 552 buffer.push_back('_'); 553 return buffer; 554 } 555 556 // Check to see that the name consists of only valid identifier characters. 557 for (char ch : name) { 558 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { 559 copyNameToBuffer(); 560 return buffer; 561 } 562 } 563 564 // If there are no invalid characters, return the original name. 565 return name; 566 } 567 568 /// Given a collection of aliases and symbols, initialize a mapping from a 569 /// symbol to a given alias. 570 template <typename T> 571 static void 572 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol, 573 llvm::MapVector<T, SymbolAlias> &symbolToAlias, 574 DenseSet<T> *deferrableAliases = nullptr) { 575 std::vector<std::pair<StringRef, std::vector<T>>> aliases = 576 aliasToSymbol.takeVector(); 577 llvm::array_pod_sort(aliases.begin(), aliases.end(), 578 [](const auto *lhs, const auto *rhs) { 579 return lhs->first.compare(rhs->first); 580 }); 581 582 for (auto &it : aliases) { 583 // If there is only one instance for this alias, use the name directly. 584 if (it.second.size() == 1) { 585 T symbol = it.second.front(); 586 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 587 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)}); 588 continue; 589 } 590 // Otherwise, add the index to the name. 591 for (int i = 0, e = it.second.size(); i < e; ++i) { 592 T symbol = it.second[i]; 593 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 594 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)}); 595 } 596 } 597 } 598 599 void AliasInitializer::initialize( 600 Operation *op, const OpPrintingFlags &printerFlags, 601 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 602 llvm::MapVector<Type, SymbolAlias> &typeToAlias) { 603 // Use a dummy printer when walking the IR so that we can collect the 604 // attributes/types that will actually be used during printing when 605 // considering aliases. 606 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); 607 aliasPrinter.print(op); 608 609 // Initialize the aliases sorted by name. 610 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); 611 initializeAliases(aliasToType, typeToAlias); 612 } 613 614 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { 615 if (!visitedAttributes.insert(attr).second) { 616 // If this attribute already has an alias and this instance can't be 617 // deferred, make sure that the alias isn't deferred. 618 if (!canBeDeferred) 619 deferrableAttributes.erase(attr); 620 return; 621 } 622 623 // Try to generate an alias for this attribute. 624 if (succeeded(generateAlias(attr, aliasToAttr))) { 625 if (canBeDeferred) 626 deferrableAttributes.insert(attr); 627 return; 628 } 629 630 // Check for any sub elements. 631 if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) { 632 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 633 [&](Type type) { visit(type); }); 634 } 635 } 636 637 void AliasInitializer::visit(Type type) { 638 if (!visitedTypes.insert(type).second) 639 return; 640 641 // Try to generate an alias for this type. 642 if (succeeded(generateAlias(type, aliasToType))) 643 return; 644 645 // Check for any sub elements. 646 if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) { 647 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 648 [&](Type type) { visit(type); }); 649 } 650 } 651 652 template <typename T> 653 LogicalResult AliasInitializer::generateAlias( 654 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) { 655 SmallString<16> tempBuffer; 656 for (const auto &interface : interfaces) { 657 if (failed(interface.getAlias(symbol, aliasOS))) 658 continue; 659 StringRef name = aliasOS.str(); 660 assert(!name.empty() && "expected valid alias name"); 661 name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-", 662 /*allowTrailingDigit=*/false); 663 name = name.copy(aliasAllocator); 664 665 aliasToSymbol[name].push_back(symbol); 666 aliasBuffer.clear(); 667 return success(); 668 } 669 return failure(); 670 } 671 672 //===----------------------------------------------------------------------===// 673 // AliasState 674 //===----------------------------------------------------------------------===// 675 676 namespace { 677 /// This class manages the state for type and attribute aliases. 678 class AliasState { 679 public: 680 // Initialize the internal aliases. 681 void 682 initialize(Operation *op, const OpPrintingFlags &printerFlags, 683 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 684 685 /// Get an alias for the given attribute if it has one and print it in `os`. 686 /// Returns success if an alias was printed, failure otherwise. 687 LogicalResult getAlias(Attribute attr, raw_ostream &os) const; 688 689 /// Get an alias for the given type if it has one and print it in `os`. 690 /// Returns success if an alias was printed, failure otherwise. 691 LogicalResult getAlias(Type ty, raw_ostream &os) const; 692 693 /// Print all of the referenced aliases that can not be resolved in a deferred 694 /// manner. 695 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 696 printAliases(os, newLine, /*isDeferred=*/false); 697 } 698 699 /// Print all of the referenced aliases that support deferred resolution. 700 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 701 printAliases(os, newLine, /*isDeferred=*/true); 702 } 703 704 private: 705 /// Print all of the referenced aliases that support the provided resolution 706 /// behavior. 707 void printAliases(raw_ostream &os, NewLineCounter &newLine, 708 bool isDeferred) const; 709 710 /// Mapping between attribute and alias. 711 llvm::MapVector<Attribute, SymbolAlias> attrToAlias; 712 /// Mapping between type and alias. 713 llvm::MapVector<Type, SymbolAlias> typeToAlias; 714 715 /// An allocator used for alias names. 716 llvm::BumpPtrAllocator aliasAllocator; 717 }; 718 } // end anonymous namespace 719 720 void AliasState::initialize( 721 Operation *op, const OpPrintingFlags &printerFlags, 722 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 723 AliasInitializer initializer(interfaces, aliasAllocator); 724 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); 725 } 726 727 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { 728 auto it = attrToAlias.find(attr); 729 if (it == attrToAlias.end()) 730 return failure(); 731 it->second.print(os << '#'); 732 return success(); 733 } 734 735 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { 736 auto it = typeToAlias.find(ty); 737 if (it == typeToAlias.end()) 738 return failure(); 739 740 it->second.print(os << '!'); 741 return success(); 742 } 743 744 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine, 745 bool isDeferred) const { 746 auto filterFn = [=](const auto &aliasIt) { 747 return aliasIt.second.canBeDeferred() == isDeferred; 748 }; 749 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) { 750 it.second.print(os << '#'); 751 os << " = " << it.first << newLine; 752 } 753 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) { 754 it.second.print(os << '!'); 755 os << " = type " << it.first << newLine; 756 } 757 } 758 759 //===----------------------------------------------------------------------===// 760 // SSANameState 761 //===----------------------------------------------------------------------===// 762 763 namespace { 764 /// This class manages the state of SSA value names. 765 class SSANameState { 766 public: 767 /// A sentinel value used for values with names set. 768 enum : unsigned { NameSentinel = ~0U }; 769 770 SSANameState(Operation *op, 771 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 772 773 /// Print the SSA identifier for the given value to 'stream'. If 774 /// 'printResultNo' is true, it also presents the result number ('#' number) 775 /// of this value. 776 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; 777 778 /// Return the result indices for each of the result groups registered by this 779 /// operation, or empty if none exist. 780 ArrayRef<int> getOpResultGroups(Operation *op); 781 782 /// Get the ID for the given block. 783 unsigned getBlockID(Block *block); 784 785 /// Renumber the arguments for the specified region to the same names as the 786 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for 787 /// details. 788 void shadowRegionArgs(Region ®ion, ValueRange namesToUse); 789 790 private: 791 /// Number the SSA values within the given IR unit. 792 void numberValuesInRegion( 793 Region ®ion, 794 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 795 void numberValuesInBlock( 796 Block &block, 797 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 798 void numberValuesInOp( 799 Operation &op, 800 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 801 802 /// Given a result of an operation 'result', find the result group head 803 /// 'lookupValue' and the result of 'result' within that group in 804 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group 805 /// has more than 1 result. 806 void getResultIDAndNumber(OpResult result, Value &lookupValue, 807 Optional<int> &lookupResultNo) const; 808 809 /// Set a special value name for the given value. 810 void setValueName(Value value, StringRef name); 811 812 /// Uniques the given value name within the printer. If the given name 813 /// conflicts, it is automatically renamed. 814 StringRef uniqueValueName(StringRef name); 815 816 /// This is the value ID for each SSA value. If this returns NameSentinel, 817 /// then the valueID has an entry in valueNames. 818 DenseMap<Value, unsigned> valueIDs; 819 DenseMap<Value, StringRef> valueNames; 820 821 /// This is a map of operations that contain multiple named result groups, 822 /// i.e. there may be multiple names for the results of the operation. The 823 /// value of this map are the result numbers that start a result group. 824 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; 825 826 /// This is the block ID for each block in the current. 827 DenseMap<Block *, unsigned> blockIDs; 828 829 /// This keeps track of all of the non-numeric names that are in flight, 830 /// allowing us to check for duplicates. 831 /// Note: the value of the map is unused. 832 llvm::ScopedHashTable<StringRef, char> usedNames; 833 llvm::BumpPtrAllocator usedNameAllocator; 834 835 /// This is the next value ID to assign in numbering. 836 unsigned nextValueID = 0; 837 /// This is the next ID to assign to a region entry block argument. 838 unsigned nextArgumentID = 0; 839 /// This is the next ID to assign when a name conflict is detected. 840 unsigned nextConflictID = 0; 841 }; 842 } // end anonymous namespace 843 844 SSANameState::SSANameState( 845 Operation *op, 846 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 847 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); 848 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); 849 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); 850 851 // The naming context includes `nextValueID`, `nextArgumentID`, 852 // `nextConflictID` and `usedNames` scoped HashTable. This information is 853 // carried from the parent region. 854 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; 855 using NamingContext = 856 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; 857 858 // Allocator for UsedNamesScopeTy 859 llvm::BumpPtrAllocator allocator; 860 861 // Add a scope for the top level operation. 862 auto *topLevelNamesScope = 863 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); 864 865 SmallVector<NamingContext, 8> nameContext; 866 for (Region ®ion : op->getRegions()) 867 nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, 868 nextConflictID, topLevelNamesScope)); 869 870 numberValuesInOp(*op, interfaces); 871 872 while (!nameContext.empty()) { 873 Region *region; 874 UsedNamesScopeTy *parentScope; 875 std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) = 876 nameContext.pop_back_val(); 877 878 // When we switch from one subtree to another, pop the scopes(needless) 879 // until the parent scope. 880 while (usedNames.getCurScope() != parentScope) { 881 usedNames.getCurScope()->~UsedNamesScopeTy(); 882 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && 883 "top level parentScope must be a nullptr"); 884 } 885 886 // Add a scope for the current region. 887 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) 888 UsedNamesScopeTy(usedNames); 889 890 numberValuesInRegion(*region, interfaces); 891 892 for (Operation &op : region->getOps()) 893 for (Region ®ion : op.getRegions()) 894 nameContext.push_back(std::make_tuple(®ion, nextValueID, 895 nextArgumentID, nextConflictID, 896 curNamesScope)); 897 } 898 899 // Manually remove all the scopes. 900 while (usedNames.getCurScope() != nullptr) 901 usedNames.getCurScope()->~UsedNamesScopeTy(); 902 } 903 904 void SSANameState::printValueID(Value value, bool printResultNo, 905 raw_ostream &stream) const { 906 if (!value) { 907 stream << "<<NULL>>"; 908 return; 909 } 910 911 Optional<int> resultNo; 912 auto lookupValue = value; 913 914 // If this is an operation result, collect the head lookup value of the result 915 // group and the result number of 'result' within that group. 916 if (OpResult result = value.dyn_cast<OpResult>()) 917 getResultIDAndNumber(result, lookupValue, resultNo); 918 919 auto it = valueIDs.find(lookupValue); 920 if (it == valueIDs.end()) { 921 stream << "<<UNKNOWN SSA VALUE>>"; 922 return; 923 } 924 925 stream << '%'; 926 if (it->second != NameSentinel) { 927 stream << it->second; 928 } else { 929 auto nameIt = valueNames.find(lookupValue); 930 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 931 stream << nameIt->second; 932 } 933 934 if (resultNo.hasValue() && printResultNo) 935 stream << '#' << resultNo; 936 } 937 938 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { 939 auto it = opResultGroups.find(op); 940 return it == opResultGroups.end() ? ArrayRef<int>() : it->second; 941 } 942 943 unsigned SSANameState::getBlockID(Block *block) { 944 auto it = blockIDs.find(block); 945 return it != blockIDs.end() ? it->second : NameSentinel; 946 } 947 948 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { 949 assert(!region.empty() && "cannot shadow arguments of an empty region"); 950 assert(region.getNumArguments() == namesToUse.size() && 951 "incorrect number of names passed in"); 952 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 953 "only KnownIsolatedFromAbove ops can shadow names"); 954 955 SmallVector<char, 16> nameStr; 956 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 957 auto nameToUse = namesToUse[i]; 958 if (nameToUse == nullptr) 959 continue; 960 auto nameToReplace = region.getArgument(i); 961 962 nameStr.clear(); 963 llvm::raw_svector_ostream nameStream(nameStr); 964 printValueID(nameToUse, /*printResultNo=*/true, nameStream); 965 966 // Entry block arguments should already have a pretty "arg" name. 967 assert(valueIDs[nameToReplace] == NameSentinel); 968 969 // Use the name without the leading %. 970 auto name = StringRef(nameStream.str()).drop_front(); 971 972 // Overwrite the name. 973 valueNames[nameToReplace] = name.copy(usedNameAllocator); 974 } 975 } 976 977 void SSANameState::numberValuesInRegion( 978 Region ®ion, 979 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 980 // Number the values within this region in a breadth-first order. 981 unsigned nextBlockID = 0; 982 for (auto &block : region) { 983 // Each block gets a unique ID, and all of the operations within it get 984 // numbered as well. 985 blockIDs[&block] = nextBlockID++; 986 numberValuesInBlock(block, interfaces); 987 } 988 } 989 990 void SSANameState::numberValuesInBlock( 991 Block &block, 992 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 993 auto setArgNameFn = [&](Value arg, StringRef name) { 994 assert(!valueIDs.count(arg) && "arg numbered multiple times"); 995 assert(arg.cast<BlockArgument>().getOwner() == &block && 996 "arg not defined in 'block'"); 997 setValueName(arg, name); 998 }; 999 1000 bool isEntryBlock = block.isEntryBlock(); 1001 if (isEntryBlock) { 1002 if (auto *op = block.getParentOp()) { 1003 if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) 1004 asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); 1005 } 1006 } 1007 1008 // Number the block arguments. We give entry block arguments a special name 1009 // 'arg'. 1010 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); 1011 llvm::raw_svector_ostream specialName(specialNameBuffer); 1012 for (auto arg : block.getArguments()) { 1013 if (valueIDs.count(arg)) 1014 continue; 1015 if (isEntryBlock) { 1016 specialNameBuffer.resize(strlen("arg")); 1017 specialName << nextArgumentID++; 1018 } 1019 setValueName(arg, specialName.str()); 1020 } 1021 1022 // Number the operations in this block. 1023 for (auto &op : block) 1024 numberValuesInOp(op, interfaces); 1025 } 1026 1027 void SSANameState::numberValuesInOp( 1028 Operation &op, 1029 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 1030 unsigned numResults = op.getNumResults(); 1031 if (numResults == 0) 1032 return; 1033 Value resultBegin = op.getResult(0); 1034 1035 // Function used to set the special result names for the operation. 1036 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); 1037 auto setResultNameFn = [&](Value result, StringRef name) { 1038 assert(!valueIDs.count(result) && "result numbered multiple times"); 1039 assert(result.getDefiningOp() == &op && "result not defined by 'op'"); 1040 setValueName(result, name); 1041 1042 // Record the result number for groups not anchored at 0. 1043 if (int resultNo = result.cast<OpResult>().getResultNumber()) 1044 resultGroups.push_back(resultNo); 1045 }; 1046 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) 1047 asmInterface.getAsmResultNames(setResultNameFn); 1048 else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) 1049 asmInterface->getAsmResultNames(&op, setResultNameFn); 1050 1051 // If the first result wasn't numbered, give it a default number. 1052 if (valueIDs.try_emplace(resultBegin, nextValueID).second) 1053 ++nextValueID; 1054 1055 // If this operation has multiple result groups, mark it. 1056 if (resultGroups.size() != 1) { 1057 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); 1058 opResultGroups.try_emplace(&op, std::move(resultGroups)); 1059 } 1060 } 1061 1062 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, 1063 Optional<int> &lookupResultNo) const { 1064 Operation *owner = result.getOwner(); 1065 if (owner->getNumResults() == 1) 1066 return; 1067 int resultNo = result.getResultNumber(); 1068 1069 // If this operation has multiple result groups, we will need to find the 1070 // one corresponding to this result. 1071 auto resultGroupIt = opResultGroups.find(owner); 1072 if (resultGroupIt == opResultGroups.end()) { 1073 // If not, just use the first result. 1074 lookupResultNo = resultNo; 1075 lookupValue = owner->getResult(0); 1076 return; 1077 } 1078 1079 // Find the correct index using a binary search, as the groups are ordered. 1080 ArrayRef<int> resultGroups = resultGroupIt->second; 1081 auto it = llvm::upper_bound(resultGroups, resultNo); 1082 int groupResultNo = 0, groupSize = 0; 1083 1084 // If there are no smaller elements, the last result group is the lookup. 1085 if (it == resultGroups.end()) { 1086 groupResultNo = resultGroups.back(); 1087 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); 1088 } else { 1089 // Otherwise, the previous element is the lookup. 1090 groupResultNo = *std::prev(it); 1091 groupSize = *it - groupResultNo; 1092 } 1093 1094 // We only record the result number for a group of size greater than 1. 1095 if (groupSize != 1) 1096 lookupResultNo = resultNo - groupResultNo; 1097 lookupValue = owner->getResult(groupResultNo); 1098 } 1099 1100 void SSANameState::setValueName(Value value, StringRef name) { 1101 // If the name is empty, the value uses the default numbering. 1102 if (name.empty()) { 1103 valueIDs[value] = nextValueID++; 1104 return; 1105 } 1106 1107 valueIDs[value] = NameSentinel; 1108 valueNames[value] = uniqueValueName(name); 1109 } 1110 1111 StringRef SSANameState::uniqueValueName(StringRef name) { 1112 SmallString<16> tmpBuffer; 1113 name = sanitizeIdentifier(name, tmpBuffer); 1114 1115 // Check to see if this name is already unique. 1116 if (!usedNames.count(name)) { 1117 name = name.copy(usedNameAllocator); 1118 } else { 1119 // Otherwise, we had a conflict - probe until we find a unique name. This 1120 // is guaranteed to terminate (and usually in a single iteration) because it 1121 // generates new names by incrementing nextConflictID. 1122 SmallString<64> probeName(name); 1123 probeName.push_back('_'); 1124 while (true) { 1125 probeName += llvm::utostr(nextConflictID++); 1126 if (!usedNames.count(probeName)) { 1127 name = StringRef(probeName).copy(usedNameAllocator); 1128 break; 1129 } 1130 probeName.resize(name.size() + 1); 1131 } 1132 } 1133 1134 usedNames.insert(name, char()); 1135 return name; 1136 } 1137 1138 //===----------------------------------------------------------------------===// 1139 // AsmState 1140 //===----------------------------------------------------------------------===// 1141 1142 namespace mlir { 1143 namespace detail { 1144 class AsmStateImpl { 1145 public: 1146 explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap) 1147 : interfaces(op->getContext()), nameState(op, interfaces), 1148 locationMap(locationMap) {} 1149 1150 /// Initialize the alias state to enable the printing of aliases. 1151 void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) { 1152 aliasState.initialize(op, printerFlags, interfaces); 1153 } 1154 1155 /// Get an instance of the OpAsmDialectInterface for the given dialect, or 1156 /// null if one wasn't registered. 1157 const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { 1158 return interfaces.getInterfaceFor(dialect); 1159 } 1160 1161 /// Get the state used for aliases. 1162 AliasState &getAliasState() { return aliasState; } 1163 1164 /// Get the state used for SSA names. 1165 SSANameState &getSSANameState() { return nameState; } 1166 1167 /// Register the location, line and column, within the buffer that the given 1168 /// operation was printed at. 1169 void registerOperationLocation(Operation *op, unsigned line, unsigned col) { 1170 if (locationMap) 1171 (*locationMap)[op] = std::make_pair(line, col); 1172 } 1173 1174 private: 1175 /// Collection of OpAsm interfaces implemented in the context. 1176 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 1177 1178 /// The state used for attribute and type aliases. 1179 AliasState aliasState; 1180 1181 /// The state used for SSA value names. 1182 SSANameState nameState; 1183 1184 /// An optional location map to be populated. 1185 AsmState::LocationMap *locationMap; 1186 }; 1187 } // end namespace detail 1188 } // end namespace mlir 1189 1190 AsmState::AsmState(Operation *op, LocationMap *locationMap) 1191 : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {} 1192 AsmState::~AsmState() {} 1193 1194 //===----------------------------------------------------------------------===// 1195 // ModulePrinter 1196 //===----------------------------------------------------------------------===// 1197 1198 namespace { 1199 class ModulePrinter { 1200 public: 1201 ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, 1202 AsmStateImpl *state = nullptr) 1203 : os(os), printerFlags(flags), state(state) {} 1204 explicit ModulePrinter(ModulePrinter &printer) 1205 : os(printer.os), printerFlags(printer.printerFlags), 1206 state(printer.state) {} 1207 1208 /// Returns the output stream of the printer. 1209 raw_ostream &getStream() { return os; } 1210 1211 template <typename Container, typename UnaryFunctor> 1212 inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { 1213 llvm::interleaveComma(c, os, each_fn); 1214 } 1215 1216 /// This enum describes the different kinds of elision for the type of an 1217 /// attribute when printing it. 1218 enum class AttrTypeElision { 1219 /// The type must not be elided, 1220 Never, 1221 /// The type may be elided when it matches the default used in the parser 1222 /// (for example i64 is the default for integer attributes). 1223 May, 1224 /// The type must be elided. 1225 Must 1226 }; 1227 1228 /// Print the given attribute. 1229 void printAttribute(Attribute attr, 1230 AttrTypeElision typeElision = AttrTypeElision::Never); 1231 1232 void printType(Type type); 1233 1234 /// Print the given location to the stream. If `allowAlias` is true, this 1235 /// allows for the internal location to use an attribute alias. 1236 void printLocation(LocationAttr loc, bool allowAlias = false); 1237 1238 void printAffineMap(AffineMap map); 1239 void 1240 printAffineExpr(AffineExpr expr, 1241 function_ref<void(unsigned, bool)> printValueName = nullptr); 1242 void printAffineConstraint(AffineExpr expr, bool isEq); 1243 void printIntegerSet(IntegerSet set); 1244 1245 protected: 1246 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1247 ArrayRef<StringRef> elidedAttrs = {}, 1248 bool withKeyword = false); 1249 void printNamedAttribute(NamedAttribute attr); 1250 void printTrailingLocation(Location loc, bool allowAlias = true); 1251 void printLocationInternal(LocationAttr loc, bool pretty = false); 1252 1253 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1254 /// used instead of individual elements when the elements attr is large. 1255 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); 1256 1257 /// Print a dense string elements attribute. 1258 void printDenseStringElementsAttr(DenseStringElementsAttr attr); 1259 1260 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1261 /// used instead of individual elements when the elements attr is large. 1262 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1263 bool allowHex); 1264 1265 void printDialectAttribute(Attribute attr); 1266 void printDialectType(Type type); 1267 1268 /// This enum is used to represent the binding strength of the enclosing 1269 /// context that an AffineExprStorage is being printed in, so we can 1270 /// intelligently produce parens. 1271 enum class BindingStrength { 1272 Weak, // + and - 1273 Strong, // All other binary operators. 1274 }; 1275 void printAffineExprInternal( 1276 AffineExpr expr, BindingStrength enclosingTightness, 1277 function_ref<void(unsigned, bool)> printValueName = nullptr); 1278 1279 /// The output stream for the printer. 1280 raw_ostream &os; 1281 1282 /// A set of flags to control the printer's behavior. 1283 OpPrintingFlags printerFlags; 1284 1285 /// An optional printer state for the module. 1286 AsmStateImpl *state; 1287 1288 /// A tracker for the number of new lines emitted during printing. 1289 NewLineCounter newLine; 1290 }; 1291 } // end anonymous namespace 1292 1293 void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) { 1294 // Check to see if we are printing debug information. 1295 if (!printerFlags.shouldPrintDebugInfo()) 1296 return; 1297 1298 os << " "; 1299 printLocation(loc, /*allowAlias=*/allowAlias); 1300 } 1301 1302 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { 1303 TypeSwitch<LocationAttr>(loc) 1304 .Case<OpaqueLoc>([&](OpaqueLoc loc) { 1305 printLocationInternal(loc.getFallbackLocation(), pretty); 1306 }) 1307 .Case<UnknownLoc>([&](UnknownLoc loc) { 1308 if (pretty) 1309 os << "[unknown]"; 1310 else 1311 os << "unknown"; 1312 }) 1313 .Case<FileLineColLoc>([&](FileLineColLoc loc) { 1314 if (pretty) { 1315 os << loc.getFilename(); 1316 } else { 1317 os << "\""; 1318 printEscapedString(loc.getFilename(), os); 1319 os << "\""; 1320 } 1321 os << ':' << loc.getLine() << ':' << loc.getColumn(); 1322 }) 1323 .Case<NameLoc>([&](NameLoc loc) { 1324 os << '\"'; 1325 printEscapedString(loc.getName(), os); 1326 os << '\"'; 1327 1328 // Print the child if it isn't unknown. 1329 auto childLoc = loc.getChildLoc(); 1330 if (!childLoc.isa<UnknownLoc>()) { 1331 os << '('; 1332 printLocationInternal(childLoc, pretty); 1333 os << ')'; 1334 } 1335 }) 1336 .Case<CallSiteLoc>([&](CallSiteLoc loc) { 1337 Location caller = loc.getCaller(); 1338 Location callee = loc.getCallee(); 1339 if (!pretty) 1340 os << "callsite("; 1341 printLocationInternal(callee, pretty); 1342 if (pretty) { 1343 if (callee.isa<NameLoc>()) { 1344 if (caller.isa<FileLineColLoc>()) { 1345 os << " at "; 1346 } else { 1347 os << newLine << " at "; 1348 } 1349 } else { 1350 os << newLine << " at "; 1351 } 1352 } else { 1353 os << " at "; 1354 } 1355 printLocationInternal(caller, pretty); 1356 if (!pretty) 1357 os << ")"; 1358 }) 1359 .Case<FusedLoc>([&](FusedLoc loc) { 1360 if (!pretty) 1361 os << "fused"; 1362 if (Attribute metadata = loc.getMetadata()) 1363 os << '<' << metadata << '>'; 1364 os << '['; 1365 interleave( 1366 loc.getLocations(), 1367 [&](Location loc) { printLocationInternal(loc, pretty); }, 1368 [&]() { os << ", "; }); 1369 os << ']'; 1370 }); 1371 } 1372 1373 /// Print a floating point value in a way that the parser will be able to 1374 /// round-trip losslessly. 1375 static void printFloatValue(const APFloat &apValue, raw_ostream &os) { 1376 // We would like to output the FP constant value in exponential notation, 1377 // but we cannot do this if doing so will lose precision. Check here to 1378 // make sure that we only output it in exponential format if we can parse 1379 // the value back and get the same value. 1380 bool isInf = apValue.isInfinity(); 1381 bool isNaN = apValue.isNaN(); 1382 if (!isInf && !isNaN) { 1383 SmallString<128> strValue; 1384 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, 1385 /*TruncateZero=*/false); 1386 1387 // Check to make sure that the stringized number is not some string like 1388 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 1389 // that the string matches the "[-+]?[0-9]" regex. 1390 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 1391 ((strValue[0] == '-' || strValue[0] == '+') && 1392 (strValue[1] >= '0' && strValue[1] <= '9'))) && 1393 "[-+]?[0-9] regex does not match!"); 1394 1395 // Parse back the stringized version and check that the value is equal 1396 // (i.e., there is no precision loss). 1397 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 1398 os << strValue; 1399 return; 1400 } 1401 1402 // If it is not, use the default format of APFloat instead of the 1403 // exponential notation. 1404 strValue.clear(); 1405 apValue.toString(strValue); 1406 1407 // Make sure that we can parse the default form as a float. 1408 if (StringRef(strValue).contains('.')) { 1409 os << strValue; 1410 return; 1411 } 1412 } 1413 1414 // Print special values in hexadecimal format. The sign bit should be included 1415 // in the literal. 1416 SmallVector<char, 16> str; 1417 APInt apInt = apValue.bitcastToAPInt(); 1418 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 1419 /*formatAsCLiteral=*/true); 1420 os << str; 1421 } 1422 1423 void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) { 1424 if (printerFlags.shouldPrintDebugInfoPrettyForm()) 1425 return printLocationInternal(loc, /*pretty=*/true); 1426 1427 os << "loc("; 1428 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) 1429 printLocationInternal(loc); 1430 os << ')'; 1431 } 1432 1433 /// Returns true if the given dialect symbol data is simple enough to print in 1434 /// the pretty form, i.e. without the enclosing "". 1435 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 1436 // The name must start with an identifier. 1437 if (symName.empty() || !isalpha(symName.front())) 1438 return false; 1439 1440 // Ignore all the characters that are valid in an identifier in the symbol 1441 // name. 1442 symName = symName.drop_while( 1443 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); 1444 if (symName.empty()) 1445 return true; 1446 1447 // If we got to an unexpected character, then it must be a <>. Check those 1448 // recursively. 1449 if (symName.front() != '<' || symName.back() != '>') 1450 return false; 1451 1452 SmallVector<char, 8> nestedPunctuation; 1453 do { 1454 // If we ran out of characters, then we had a punctuation mismatch. 1455 if (symName.empty()) 1456 return false; 1457 1458 auto c = symName.front(); 1459 symName = symName.drop_front(); 1460 1461 switch (c) { 1462 // We never allow null characters. This is an EOF indicator for the lexer 1463 // which we could handle, but isn't important for any known dialect. 1464 case '\0': 1465 return false; 1466 case '<': 1467 case '[': 1468 case '(': 1469 case '{': 1470 nestedPunctuation.push_back(c); 1471 continue; 1472 case '-': 1473 // Treat `->` as a special token. 1474 if (!symName.empty() && symName.front() == '>') { 1475 symName = symName.drop_front(); 1476 continue; 1477 } 1478 break; 1479 // Reject types with mismatched brackets. 1480 case '>': 1481 if (nestedPunctuation.pop_back_val() != '<') 1482 return false; 1483 break; 1484 case ']': 1485 if (nestedPunctuation.pop_back_val() != '[') 1486 return false; 1487 break; 1488 case ')': 1489 if (nestedPunctuation.pop_back_val() != '(') 1490 return false; 1491 break; 1492 case '}': 1493 if (nestedPunctuation.pop_back_val() != '{') 1494 return false; 1495 break; 1496 default: 1497 continue; 1498 } 1499 1500 // We're done when the punctuation is fully matched. 1501 } while (!nestedPunctuation.empty()); 1502 1503 // If there were extra characters, then we failed. 1504 return symName.empty(); 1505 } 1506 1507 /// Print the given dialect symbol to the stream. 1508 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 1509 StringRef dialectName, StringRef symString) { 1510 os << symPrefix << dialectName; 1511 1512 // If this symbol name is simple enough, print it directly in pretty form, 1513 // otherwise, we print it as an escaped string. 1514 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 1515 os << '.' << symString; 1516 return; 1517 } 1518 1519 // TODO: escape the symbol name, it could contain " characters. 1520 os << "<\"" << symString << "\">"; 1521 } 1522 1523 /// Returns true if the given string can be represented as a bare identifier. 1524 static bool isBareIdentifier(StringRef name) { 1525 assert(!name.empty() && "invalid name"); 1526 1527 // By making this unsigned, the value passed in to isalnum will always be 1528 // in the range 0-255. This is important when building with MSVC because 1529 // its implementation will assert. This situation can arise when dealing 1530 // with UTF-8 multibyte characters. 1531 unsigned char firstChar = static_cast<unsigned char>(name[0]); 1532 if (!isalpha(firstChar) && firstChar != '_') 1533 return false; 1534 return llvm::all_of(name.drop_front(), [](unsigned char c) { 1535 return isalnum(c) || c == '_' || c == '$' || c == '.'; 1536 }); 1537 } 1538 1539 /// Print the given string as a symbol reference. A symbol reference is 1540 /// represented as a string prefixed with '@'. The reference is surrounded with 1541 /// ""'s and escaped if it has any special or non-printable characters in it. 1542 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { 1543 assert(!symbolRef.empty() && "expected valid symbol reference"); 1544 1545 // If the symbol can be represented as a bare identifier, write it directly. 1546 if (isBareIdentifier(symbolRef)) { 1547 os << '@' << symbolRef; 1548 return; 1549 } 1550 1551 // Otherwise, output the reference wrapped in quotes with proper escaping. 1552 os << "@\""; 1553 printEscapedString(symbolRef, os); 1554 os << '"'; 1555 } 1556 1557 // Print out a valid ElementsAttr that is succinct and can represent any 1558 // potential shape/type, for use when eliding a large ElementsAttr. 1559 // 1560 // We choose to use an opaque ElementsAttr literal with conspicuous content to 1561 // hopefully alert readers to the fact that this has been elided. 1562 // 1563 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will 1564 // accept the string "elided". The first string must be a registered dialect 1565 // name and the latter must be a hex constant. 1566 static void printElidedElementsAttr(raw_ostream &os) { 1567 os << R"(opaque<"_", "0xDEADBEEF">)"; 1568 } 1569 1570 void ModulePrinter::printAttribute(Attribute attr, 1571 AttrTypeElision typeElision) { 1572 if (!attr) { 1573 os << "<<NULL ATTRIBUTE>>"; 1574 return; 1575 } 1576 1577 // Try to print an alias for this attribute. 1578 if (state && succeeded(state->getAliasState().getAlias(attr, os))) 1579 return; 1580 1581 auto attrType = attr.getType(); 1582 if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) { 1583 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 1584 opaqueAttr.getAttrData()); 1585 } else if (attr.isa<UnitAttr>()) { 1586 os << "unit"; 1587 return; 1588 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) { 1589 os << '{'; 1590 interleaveComma(dictAttr.getValue(), 1591 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1592 os << '}'; 1593 1594 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) { 1595 if (attrType.isSignlessInteger(1)) { 1596 os << (intAttr.getValue().getBoolValue() ? "true" : "false"); 1597 1598 // Boolean integer attributes always elides the type. 1599 return; 1600 } 1601 1602 // Only print attributes as unsigned if they are explicitly unsigned or are 1603 // signless 1-bit values. Indexes, signed values, and multi-bit signless 1604 // values print as signed. 1605 bool isUnsigned = 1606 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); 1607 intAttr.getValue().print(os, !isUnsigned); 1608 1609 // IntegerAttr elides the type if I64. 1610 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) 1611 return; 1612 1613 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) { 1614 printFloatValue(floatAttr.getValue(), os); 1615 1616 // FloatAttr elides the type if F64. 1617 if (typeElision == AttrTypeElision::May && attrType.isF64()) 1618 return; 1619 1620 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) { 1621 os << '"'; 1622 printEscapedString(strAttr.getValue(), os); 1623 os << '"'; 1624 1625 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 1626 os << '['; 1627 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { 1628 printAttribute(attr, AttrTypeElision::May); 1629 }); 1630 os << ']'; 1631 1632 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) { 1633 os << "affine_map<"; 1634 affineMapAttr.getValue().print(os); 1635 os << '>'; 1636 1637 // AffineMap always elides the type. 1638 return; 1639 1640 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) { 1641 os << "affine_set<"; 1642 integerSetAttr.getValue().print(os); 1643 os << '>'; 1644 1645 // IntegerSet always elides the type. 1646 return; 1647 1648 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 1649 printType(typeAttr.getValue()); 1650 1651 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) { 1652 printSymbolReference(refAttr.getRootReference(), os); 1653 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { 1654 os << "::"; 1655 printSymbolReference(nestedRef.getValue(), os); 1656 } 1657 1658 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) { 1659 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { 1660 printElidedElementsAttr(os); 1661 } else { 1662 os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x" 1663 << llvm::toHex(opaqueAttr.getValue()) << "\">"; 1664 } 1665 1666 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) { 1667 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { 1668 printElidedElementsAttr(os); 1669 } else { 1670 os << "dense<"; 1671 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); 1672 os << '>'; 1673 } 1674 1675 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) { 1676 if (printerFlags.shouldElideElementsAttr(strEltAttr)) { 1677 printElidedElementsAttr(os); 1678 } else { 1679 os << "dense<"; 1680 printDenseStringElementsAttr(strEltAttr); 1681 os << '>'; 1682 } 1683 1684 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) { 1685 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || 1686 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { 1687 printElidedElementsAttr(os); 1688 } else { 1689 os << "sparse<"; 1690 DenseIntElementsAttr indices = sparseEltAttr.getIndices(); 1691 if (indices.getNumElements() != 0) { 1692 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); 1693 os << ", "; 1694 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); 1695 } 1696 os << '>'; 1697 } 1698 1699 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) { 1700 printLocation(locAttr); 1701 1702 } else { 1703 return printDialectAttribute(attr); 1704 } 1705 1706 // Don't print the type if we must elide it, or if it is a None type. 1707 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { 1708 os << " : "; 1709 printType(attrType); 1710 } 1711 } 1712 1713 /// Print the integer element of a DenseElementsAttr. 1714 static void printDenseIntElement(const APInt &value, raw_ostream &os, 1715 bool isSigned) { 1716 if (value.getBitWidth() == 1) 1717 os << (value.getBoolValue() ? "true" : "false"); 1718 else 1719 value.print(os, isSigned); 1720 } 1721 1722 static void 1723 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, 1724 function_ref<void(unsigned)> printEltFn) { 1725 // Special case for 0-d and splat tensors. 1726 if (isSplat) 1727 return printEltFn(0); 1728 1729 // Special case for degenerate tensors. 1730 auto numElements = type.getNumElements(); 1731 if (numElements == 0) 1732 return; 1733 1734 // We use a mixed-radix counter to iterate through the shape. When we bump a 1735 // non-least-significant digit, we emit a close bracket. When we next emit an 1736 // element we re-open all closed brackets. 1737 1738 // The mixed-radix counter, with radices in 'shape'. 1739 int64_t rank = type.getRank(); 1740 SmallVector<unsigned, 4> counter(rank, 0); 1741 // The number of brackets that have been opened and not closed. 1742 unsigned openBrackets = 0; 1743 1744 auto shape = type.getShape(); 1745 auto bumpCounter = [&] { 1746 // Bump the least significant digit. 1747 ++counter[rank - 1]; 1748 // Iterate backwards bubbling back the increment. 1749 for (unsigned i = rank - 1; i > 0; --i) 1750 if (counter[i] >= shape[i]) { 1751 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 1752 counter[i] = 0; 1753 ++counter[i - 1]; 1754 --openBrackets; 1755 os << ']'; 1756 } 1757 }; 1758 1759 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 1760 if (idx != 0) 1761 os << ", "; 1762 while (openBrackets++ < rank) 1763 os << '['; 1764 openBrackets = rank; 1765 printEltFn(idx); 1766 bumpCounter(); 1767 } 1768 while (openBrackets-- > 0) 1769 os << ']'; 1770 } 1771 1772 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, 1773 bool allowHex) { 1774 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) 1775 return printDenseStringElementsAttr(stringAttr); 1776 1777 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(), 1778 allowHex); 1779 } 1780 1781 void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1782 bool allowHex) { 1783 auto type = attr.getType(); 1784 auto elementType = type.getElementType(); 1785 1786 // Check to see if we should format this attribute as a hex string. 1787 auto numElements = type.getNumElements(); 1788 if (!attr.isSplat() && allowHex && 1789 shouldPrintElementsAttrWithHex(numElements)) { 1790 ArrayRef<char> rawData = attr.getRawData(); 1791 if (llvm::support::endian::system_endianness() == 1792 llvm::support::endianness::big) { 1793 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE 1794 // machines. It is converted here to print in LE format. 1795 SmallVector<char, 64> outDataVec(rawData.size()); 1796 MutableArrayRef<char> convRawData(outDataVec); 1797 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1798 rawData, convRawData, type); 1799 os << '"' << "0x" 1800 << llvm::toHex(StringRef(convRawData.data(), convRawData.size())) 1801 << "\""; 1802 } else { 1803 os << '"' << "0x" 1804 << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\""; 1805 } 1806 1807 return; 1808 } 1809 1810 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1811 Type complexElementType = complexTy.getElementType(); 1812 // Note: The if and else below had a common lambda function which invoked 1813 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 1814 // and hence was replaced. 1815 if (complexElementType.isa<IntegerType>()) { 1816 bool isSigned = !complexElementType.isUnsignedInteger(); 1817 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1818 auto complexValue = *(attr.getComplexIntValues().begin() + index); 1819 os << "("; 1820 printDenseIntElement(complexValue.real(), os, isSigned); 1821 os << ","; 1822 printDenseIntElement(complexValue.imag(), os, isSigned); 1823 os << ")"; 1824 }); 1825 } else { 1826 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1827 auto complexValue = *(attr.getComplexFloatValues().begin() + index); 1828 os << "("; 1829 printFloatValue(complexValue.real(), os); 1830 os << ","; 1831 printFloatValue(complexValue.imag(), os); 1832 os << ")"; 1833 }); 1834 } 1835 } else if (elementType.isIntOrIndex()) { 1836 bool isSigned = !elementType.isUnsignedInteger(); 1837 auto intValues = attr.getIntValues(); 1838 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1839 printDenseIntElement(*(intValues.begin() + index), os, isSigned); 1840 }); 1841 } else { 1842 assert(elementType.isa<FloatType>() && "unexpected element type"); 1843 auto floatValues = attr.getFloatValues(); 1844 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1845 printFloatValue(*(floatValues.begin() + index), os); 1846 }); 1847 } 1848 } 1849 1850 void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { 1851 ArrayRef<StringRef> data = attr.getRawStringData(); 1852 auto printFn = [&](unsigned index) { 1853 os << "\""; 1854 printEscapedString(data[index], os); 1855 os << "\""; 1856 }; 1857 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); 1858 } 1859 1860 void ModulePrinter::printType(Type type) { 1861 if (!type) { 1862 os << "<<NULL TYPE>>"; 1863 return; 1864 } 1865 1866 // Try to print an alias for this type. 1867 if (state && succeeded(state->getAliasState().getAlias(type, os))) 1868 return; 1869 1870 TypeSwitch<Type>(type) 1871 .Case<OpaqueType>([&](OpaqueType opaqueTy) { 1872 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 1873 opaqueTy.getTypeData()); 1874 }) 1875 .Case<IndexType>([&](Type) { os << "index"; }) 1876 .Case<BFloat16Type>([&](Type) { os << "bf16"; }) 1877 .Case<Float16Type>([&](Type) { os << "f16"; }) 1878 .Case<Float32Type>([&](Type) { os << "f32"; }) 1879 .Case<Float64Type>([&](Type) { os << "f64"; }) 1880 .Case<Float80Type>([&](Type) { os << "f80"; }) 1881 .Case<Float128Type>([&](Type) { os << "f128"; }) 1882 .Case<IntegerType>([&](IntegerType integerTy) { 1883 if (integerTy.isSigned()) 1884 os << 's'; 1885 else if (integerTy.isUnsigned()) 1886 os << 'u'; 1887 os << 'i' << integerTy.getWidth(); 1888 }) 1889 .Case<FunctionType>([&](FunctionType funcTy) { 1890 os << '('; 1891 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); 1892 os << ") -> "; 1893 ArrayRef<Type> results = funcTy.getResults(); 1894 if (results.size() == 1 && !results[0].isa<FunctionType>()) { 1895 os << results[0]; 1896 } else { 1897 os << '('; 1898 interleaveComma(results, [&](Type ty) { printType(ty); }); 1899 os << ')'; 1900 } 1901 }) 1902 .Case<VectorType>([&](VectorType vectorTy) { 1903 os << "vector<"; 1904 for (int64_t dim : vectorTy.getShape()) 1905 os << dim << 'x'; 1906 os << vectorTy.getElementType() << '>'; 1907 }) 1908 .Case<RankedTensorType>([&](RankedTensorType tensorTy) { 1909 os << "tensor<"; 1910 for (int64_t dim : tensorTy.getShape()) { 1911 if (ShapedType::isDynamic(dim)) 1912 os << '?'; 1913 else 1914 os << dim; 1915 os << 'x'; 1916 } 1917 os << tensorTy.getElementType(); 1918 // Only print the encoding attribute value if set. 1919 if (tensorTy.getEncoding()) { 1920 os << ", "; 1921 printAttribute(tensorTy.getEncoding()); 1922 } 1923 os << '>'; 1924 }) 1925 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { 1926 os << "tensor<*x"; 1927 printType(tensorTy.getElementType()); 1928 os << '>'; 1929 }) 1930 .Case<MemRefType>([&](MemRefType memrefTy) { 1931 os << "memref<"; 1932 for (int64_t dim : memrefTy.getShape()) { 1933 if (ShapedType::isDynamic(dim)) 1934 os << '?'; 1935 else 1936 os << dim; 1937 os << 'x'; 1938 } 1939 printType(memrefTy.getElementType()); 1940 for (auto map : memrefTy.getAffineMaps()) { 1941 os << ", "; 1942 printAttribute(AffineMapAttr::get(map)); 1943 } 1944 // Only print the memory space if it is the non-default one. 1945 if (memrefTy.getMemorySpace()) { 1946 os << ", "; 1947 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 1948 } 1949 os << '>'; 1950 }) 1951 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { 1952 os << "memref<*x"; 1953 printType(memrefTy.getElementType()); 1954 // Only print the memory space if it is the non-default one. 1955 if (memrefTy.getMemorySpace()) { 1956 os << ", "; 1957 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 1958 } 1959 os << '>'; 1960 }) 1961 .Case<ComplexType>([&](ComplexType complexTy) { 1962 os << "complex<"; 1963 printType(complexTy.getElementType()); 1964 os << '>'; 1965 }) 1966 .Case<TupleType>([&](TupleType tupleTy) { 1967 os << "tuple<"; 1968 interleaveComma(tupleTy.getTypes(), 1969 [&](Type type) { printType(type); }); 1970 os << '>'; 1971 }) 1972 .Case<NoneType>([&](Type) { os << "none"; }) 1973 .Default([&](Type type) { return printDialectType(type); }); 1974 } 1975 1976 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1977 ArrayRef<StringRef> elidedAttrs, 1978 bool withKeyword) { 1979 // If there are no attributes, then there is nothing to be done. 1980 if (attrs.empty()) 1981 return; 1982 1983 // Functor used to print a filtered attribute list. 1984 auto printFilteredAttributesFn = [&](auto filteredAttrs) { 1985 // Print the 'attributes' keyword if necessary. 1986 if (withKeyword) 1987 os << " attributes"; 1988 1989 // Otherwise, print them all out in braces. 1990 os << " {"; 1991 interleaveComma(filteredAttrs, 1992 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1993 os << '}'; 1994 }; 1995 1996 // If no attributes are elided, we can directly print with no filtering. 1997 if (elidedAttrs.empty()) 1998 return printFilteredAttributesFn(attrs); 1999 2000 // Otherwise, filter out any attributes that shouldn't be included. 2001 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 2002 elidedAttrs.end()); 2003 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 2004 return !elidedAttrsSet.contains(attr.first.strref()); 2005 }); 2006 if (!filteredAttrs.empty()) 2007 printFilteredAttributesFn(filteredAttrs); 2008 } 2009 2010 void ModulePrinter::printNamedAttribute(NamedAttribute attr) { 2011 if (isBareIdentifier(attr.first)) { 2012 os << attr.first; 2013 } else { 2014 os << '"'; 2015 printEscapedString(attr.first.strref(), os); 2016 os << '"'; 2017 } 2018 2019 // Pretty printing elides the attribute value for unit attributes. 2020 if (attr.second.isa<UnitAttr>()) 2021 return; 2022 2023 os << " = "; 2024 printAttribute(attr.second); 2025 } 2026 2027 //===----------------------------------------------------------------------===// 2028 // CustomDialectAsmPrinter 2029 //===----------------------------------------------------------------------===// 2030 2031 namespace { 2032 /// This class provides the main specialization of the DialectAsmPrinter that is 2033 /// used to provide support for print attributes and types. This hooks allows 2034 /// for dialects to hook into the main ModulePrinter. 2035 struct CustomDialectAsmPrinter : public DialectAsmPrinter { 2036 public: 2037 CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} 2038 ~CustomDialectAsmPrinter() override {} 2039 2040 raw_ostream &getStream() const override { return printer.getStream(); } 2041 2042 /// Print the given attribute to the stream. 2043 void printAttribute(Attribute attr) override { printer.printAttribute(attr); } 2044 2045 /// Print the given floating point value in a stablized form. 2046 void printFloat(const APFloat &value) override { 2047 printFloatValue(value, getStream()); 2048 } 2049 2050 /// Print the given type to the stream. 2051 void printType(Type type) override { printer.printType(type); } 2052 2053 /// The main module printer. 2054 ModulePrinter &printer; 2055 }; 2056 } // end anonymous namespace 2057 2058 void ModulePrinter::printDialectAttribute(Attribute attr) { 2059 auto &dialect = attr.getDialect(); 2060 2061 // Ask the dialect to serialize the attribute to a string. 2062 std::string attrName; 2063 { 2064 llvm::raw_string_ostream attrNameStr(attrName); 2065 ModulePrinter subPrinter(attrNameStr, printerFlags, state); 2066 CustomDialectAsmPrinter printer(subPrinter); 2067 dialect.printAttribute(attr, printer); 2068 } 2069 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 2070 } 2071 2072 void ModulePrinter::printDialectType(Type type) { 2073 auto &dialect = type.getDialect(); 2074 2075 // Ask the dialect to serialize the type to a string. 2076 std::string typeName; 2077 { 2078 llvm::raw_string_ostream typeNameStr(typeName); 2079 ModulePrinter subPrinter(typeNameStr, printerFlags, state); 2080 CustomDialectAsmPrinter printer(subPrinter); 2081 dialect.printType(type, printer); 2082 } 2083 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 2084 } 2085 2086 //===----------------------------------------------------------------------===// 2087 // Affine expressions and maps 2088 //===----------------------------------------------------------------------===// 2089 2090 void ModulePrinter::printAffineExpr( 2091 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { 2092 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 2093 } 2094 2095 void ModulePrinter::printAffineExprInternal( 2096 AffineExpr expr, BindingStrength enclosingTightness, 2097 function_ref<void(unsigned, bool)> printValueName) { 2098 const char *binopSpelling = nullptr; 2099 switch (expr.getKind()) { 2100 case AffineExprKind::SymbolId: { 2101 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); 2102 if (printValueName) 2103 printValueName(pos, /*isSymbol=*/true); 2104 else 2105 os << 's' << pos; 2106 return; 2107 } 2108 case AffineExprKind::DimId: { 2109 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 2110 if (printValueName) 2111 printValueName(pos, /*isSymbol=*/false); 2112 else 2113 os << 'd' << pos; 2114 return; 2115 } 2116 case AffineExprKind::Constant: 2117 os << expr.cast<AffineConstantExpr>().getValue(); 2118 return; 2119 case AffineExprKind::Add: 2120 binopSpelling = " + "; 2121 break; 2122 case AffineExprKind::Mul: 2123 binopSpelling = " * "; 2124 break; 2125 case AffineExprKind::FloorDiv: 2126 binopSpelling = " floordiv "; 2127 break; 2128 case AffineExprKind::CeilDiv: 2129 binopSpelling = " ceildiv "; 2130 break; 2131 case AffineExprKind::Mod: 2132 binopSpelling = " mod "; 2133 break; 2134 } 2135 2136 auto binOp = expr.cast<AffineBinaryOpExpr>(); 2137 AffineExpr lhsExpr = binOp.getLHS(); 2138 AffineExpr rhsExpr = binOp.getRHS(); 2139 2140 // Handle tightly binding binary operators. 2141 if (binOp.getKind() != AffineExprKind::Add) { 2142 if (enclosingTightness == BindingStrength::Strong) 2143 os << '('; 2144 2145 // Pretty print multiplication with -1. 2146 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); 2147 if (rhsConst && binOp.getKind() == AffineExprKind::Mul && 2148 rhsConst.getValue() == -1) { 2149 os << "-"; 2150 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2151 if (enclosingTightness == BindingStrength::Strong) 2152 os << ')'; 2153 return; 2154 } 2155 2156 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2157 2158 os << binopSpelling; 2159 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 2160 2161 if (enclosingTightness == BindingStrength::Strong) 2162 os << ')'; 2163 return; 2164 } 2165 2166 // Print out special "pretty" forms for add. 2167 if (enclosingTightness == BindingStrength::Strong) 2168 os << '('; 2169 2170 // Pretty print addition to a product that has a negative operand as a 2171 // subtraction. 2172 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { 2173 if (rhs.getKind() == AffineExprKind::Mul) { 2174 AffineExpr rrhsExpr = rhs.getRHS(); 2175 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { 2176 if (rrhs.getValue() == -1) { 2177 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2178 printValueName); 2179 os << " - "; 2180 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 2181 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2182 printValueName); 2183 } else { 2184 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 2185 printValueName); 2186 } 2187 2188 if (enclosingTightness == BindingStrength::Strong) 2189 os << ')'; 2190 return; 2191 } 2192 2193 if (rrhs.getValue() < -1) { 2194 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2195 printValueName); 2196 os << " - "; 2197 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2198 printValueName); 2199 os << " * " << -rrhs.getValue(); 2200 if (enclosingTightness == BindingStrength::Strong) 2201 os << ')'; 2202 return; 2203 } 2204 } 2205 } 2206 } 2207 2208 // Pretty print addition to a negative number as a subtraction. 2209 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { 2210 if (rhsConst.getValue() < 0) { 2211 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2212 os << " - " << -rhsConst.getValue(); 2213 if (enclosingTightness == BindingStrength::Strong) 2214 os << ')'; 2215 return; 2216 } 2217 } 2218 2219 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2220 2221 os << " + "; 2222 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 2223 2224 if (enclosingTightness == BindingStrength::Strong) 2225 os << ')'; 2226 } 2227 2228 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { 2229 printAffineExprInternal(expr, BindingStrength::Weak); 2230 isEq ? os << " == 0" : os << " >= 0"; 2231 } 2232 2233 void ModulePrinter::printAffineMap(AffineMap map) { 2234 // Dimension identifiers. 2235 os << '('; 2236 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 2237 os << 'd' << i << ", "; 2238 if (map.getNumDims() >= 1) 2239 os << 'd' << map.getNumDims() - 1; 2240 os << ')'; 2241 2242 // Symbolic identifiers. 2243 if (map.getNumSymbols() != 0) { 2244 os << '['; 2245 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 2246 os << 's' << i << ", "; 2247 if (map.getNumSymbols() >= 1) 2248 os << 's' << map.getNumSymbols() - 1; 2249 os << ']'; 2250 } 2251 2252 // Result affine expressions. 2253 os << " -> ("; 2254 interleaveComma(map.getResults(), 2255 [&](AffineExpr expr) { printAffineExpr(expr); }); 2256 os << ')'; 2257 } 2258 2259 void ModulePrinter::printIntegerSet(IntegerSet set) { 2260 // Dimension identifiers. 2261 os << '('; 2262 for (unsigned i = 1; i < set.getNumDims(); ++i) 2263 os << 'd' << i - 1 << ", "; 2264 if (set.getNumDims() >= 1) 2265 os << 'd' << set.getNumDims() - 1; 2266 os << ')'; 2267 2268 // Symbolic identifiers. 2269 if (set.getNumSymbols() != 0) { 2270 os << '['; 2271 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 2272 os << 's' << i << ", "; 2273 if (set.getNumSymbols() >= 1) 2274 os << 's' << set.getNumSymbols() - 1; 2275 os << ']'; 2276 } 2277 2278 // Print constraints. 2279 os << " : ("; 2280 int numConstraints = set.getNumConstraints(); 2281 for (int i = 1; i < numConstraints; ++i) { 2282 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 2283 os << ", "; 2284 } 2285 if (numConstraints >= 1) 2286 printAffineConstraint(set.getConstraint(numConstraints - 1), 2287 set.isEq(numConstraints - 1)); 2288 os << ')'; 2289 } 2290 2291 //===----------------------------------------------------------------------===// 2292 // OperationPrinter 2293 //===----------------------------------------------------------------------===// 2294 2295 namespace { 2296 /// This class contains the logic for printing operations, regions, and blocks. 2297 class OperationPrinter : public ModulePrinter, private OpAsmPrinter { 2298 public: 2299 explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, 2300 AsmStateImpl &state) 2301 : ModulePrinter(os, flags, &state) {} 2302 2303 /// Print the given top-level operation. 2304 void printTopLevelOperation(Operation *op); 2305 2306 /// Print the given operation with its indent and location. 2307 void print(Operation *op); 2308 /// Print the bare location, not including indentation/location/etc. 2309 void printOperation(Operation *op); 2310 /// Print the given operation in the generic form. 2311 void printGenericOp(Operation *op) override; 2312 2313 /// Print the name of the given block. 2314 void printBlockName(Block *block); 2315 2316 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 2317 /// block are not printed. If 'printBlockTerminator' is false, the terminator 2318 /// operation of the block is not printed. 2319 void print(Block *block, bool printBlockArgs = true, 2320 bool printBlockTerminator = true); 2321 2322 /// Print the ID of the given value, optionally with its result number. 2323 void printValueID(Value value, bool printResultNo = true, 2324 raw_ostream *streamOverride = nullptr) const; 2325 2326 //===--------------------------------------------------------------------===// 2327 // OpAsmPrinter methods 2328 //===--------------------------------------------------------------------===// 2329 2330 /// Return the current stream of the printer. 2331 raw_ostream &getStream() const override { return os; } 2332 2333 /// Print a newline and indent the printer to the start of the current 2334 /// operation. 2335 void printNewline() override { 2336 os << newLine; 2337 os.indent(currentIndent); 2338 } 2339 2340 /// Print the given type. 2341 void printType(Type type) override { ModulePrinter::printType(type); } 2342 2343 /// Print the given attribute. 2344 void printAttribute(Attribute attr) override { 2345 ModulePrinter::printAttribute(attr); 2346 } 2347 2348 /// Print the given attribute without its type. The corresponding parser must 2349 /// provide a valid type for the attribute. 2350 void printAttributeWithoutType(Attribute attr) override { 2351 ModulePrinter::printAttribute(attr, AttrTypeElision::Must); 2352 } 2353 2354 /// Print a block argument in the usual format of: 2355 /// %ssaName : type {attr1=42} loc("here") 2356 /// where location printing is controlled by the standard internal option. 2357 /// You may pass omitType=true to not print a type, and pass an empty 2358 /// attribute list if you don't care for attributes. 2359 void printRegionArgument(BlockArgument arg, 2360 ArrayRef<NamedAttribute> argAttrs = {}, 2361 bool omitType = false) override; 2362 2363 /// Print the ID for the given value. 2364 void printOperand(Value value) override { printValueID(value); } 2365 void printOperand(Value value, raw_ostream &os) override { 2366 printValueID(value, /*printResultNo=*/true, &os); 2367 } 2368 2369 /// Print an optional attribute dictionary with a given set of elided values. 2370 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2371 ArrayRef<StringRef> elidedAttrs = {}) override { 2372 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); 2373 } 2374 void printOptionalAttrDictWithKeyword( 2375 ArrayRef<NamedAttribute> attrs, 2376 ArrayRef<StringRef> elidedAttrs = {}) override { 2377 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, 2378 /*withKeyword=*/true); 2379 } 2380 2381 /// Print the given successor. 2382 void printSuccessor(Block *successor) override; 2383 2384 /// Print an operation successor with the operands used for the block 2385 /// arguments. 2386 void printSuccessorAndUseList(Block *successor, 2387 ValueRange succOperands) override; 2388 2389 /// Print the given region. 2390 void printRegion(Region ®ion, bool printEntryBlockArgs, 2391 bool printBlockTerminators, bool printEmptyBlock) override; 2392 2393 /// Renumber the arguments for the specified region to the same names as the 2394 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 2395 /// operations. If any entry in namesToUse is null, the corresponding 2396 /// argument name is left alone. 2397 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { 2398 state->getSSANameState().shadowRegionArgs(region, namesToUse); 2399 } 2400 2401 /// Print the given affine map with the symbol and dimension operands printed 2402 /// inline with the map. 2403 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2404 ValueRange operands) override; 2405 2406 /// Print the given affine expression with the symbol and dimension operands 2407 /// printed inline with the expression. 2408 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, 2409 ValueRange symOperands) override; 2410 2411 /// Print the given string as a symbol reference. 2412 void printSymbolName(StringRef symbolRef) override { 2413 ::printSymbolReference(symbolRef, os); 2414 } 2415 2416 private: 2417 /// The number of spaces used for indenting nested operations. 2418 const static unsigned indentWidth = 2; 2419 2420 // This is the current indentation level for nested structures. 2421 unsigned currentIndent = 0; 2422 }; 2423 } // end anonymous namespace 2424 2425 void OperationPrinter::printTopLevelOperation(Operation *op) { 2426 // Output the aliases at the top level that can't be deferred. 2427 state->getAliasState().printNonDeferredAliases(os, newLine); 2428 2429 // Print the module. 2430 print(op); 2431 os << newLine; 2432 2433 // Output the aliases at the top level that can be deferred. 2434 state->getAliasState().printDeferredAliases(os, newLine); 2435 } 2436 2437 /// Print a block argument in the usual format of: 2438 /// %ssaName : type {attr1=42} loc("here") 2439 /// where location printing is controlled by the standard internal option. 2440 /// You may pass omitType=true to not print a type, and pass an empty 2441 /// attribute list if you don't care for attributes. 2442 void OperationPrinter::printRegionArgument(BlockArgument arg, 2443 ArrayRef<NamedAttribute> argAttrs, 2444 bool omitType) { 2445 printOperand(arg); 2446 if (!omitType) { 2447 os << ": "; 2448 printType(arg.getType()); 2449 } 2450 printOptionalAttrDict(argAttrs); 2451 // TODO: We should allow location aliases on block arguments. 2452 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2453 } 2454 2455 void OperationPrinter::print(Operation *op) { 2456 // Track the location of this operation. 2457 state->registerOperationLocation(op, newLine.curLine, currentIndent); 2458 2459 os.indent(currentIndent); 2460 printOperation(op); 2461 printTrailingLocation(op->getLoc()); 2462 } 2463 2464 void OperationPrinter::printOperation(Operation *op) { 2465 if (size_t numResults = op->getNumResults()) { 2466 auto printResultGroup = [&](size_t resultNo, size_t resultCount) { 2467 printValueID(op->getResult(resultNo), /*printResultNo=*/false); 2468 if (resultCount > 1) 2469 os << ':' << resultCount; 2470 }; 2471 2472 // Check to see if this operation has multiple result groups. 2473 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); 2474 if (!resultGroups.empty()) { 2475 // Interleave the groups excluding the last one, this one will be handled 2476 // separately. 2477 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { 2478 printResultGroup(resultGroups[i], 2479 resultGroups[i + 1] - resultGroups[i]); 2480 }); 2481 os << ", "; 2482 printResultGroup(resultGroups.back(), numResults - resultGroups.back()); 2483 2484 } else { 2485 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); 2486 } 2487 2488 os << " = "; 2489 } 2490 2491 // If requested, always print the generic form. 2492 if (!printerFlags.shouldPrintGenericOpForm()) { 2493 // Check to see if this is a known operation. If so, use the registered 2494 // custom printer hook. 2495 if (auto *opInfo = op->getAbstractOperation()) { 2496 opInfo->printAssembly(op, *this); 2497 return; 2498 } 2499 // Otherwise try to dispatch to the dialect, if available. 2500 if (Dialect *dialect = op->getDialect()) { 2501 if (succeeded(dialect->printOperation(op, *this))) 2502 return; 2503 } 2504 } 2505 2506 // Otherwise print with the generic assembly form. 2507 printGenericOp(op); 2508 } 2509 2510 void OperationPrinter::printGenericOp(Operation *op) { 2511 os << '"'; 2512 printEscapedString(op->getName().getStringRef(), os); 2513 os << "\"("; 2514 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); 2515 os << ')'; 2516 2517 // For terminators, print the list of successors and their operands. 2518 if (op->getNumSuccessors() != 0) { 2519 os << '['; 2520 interleaveComma(op->getSuccessors(), 2521 [&](Block *successor) { printBlockName(successor); }); 2522 os << ']'; 2523 } 2524 2525 // Print regions. 2526 if (op->getNumRegions() != 0) { 2527 os << " ("; 2528 interleaveComma(op->getRegions(), [&](Region ®ion) { 2529 printRegion(region, /*printEntryBlockArgs=*/true, 2530 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); 2531 }); 2532 os << ')'; 2533 } 2534 2535 auto attrs = op->getAttrs(); 2536 printOptionalAttrDict(attrs); 2537 2538 // Print the type signature of the operation. 2539 os << " : "; 2540 printFunctionalType(op); 2541 } 2542 2543 void OperationPrinter::printBlockName(Block *block) { 2544 auto id = state->getSSANameState().getBlockID(block); 2545 if (id != SSANameState::NameSentinel) 2546 os << "^bb" << id; 2547 else 2548 os << "^INVALIDBLOCK"; 2549 } 2550 2551 void OperationPrinter::print(Block *block, bool printBlockArgs, 2552 bool printBlockTerminator) { 2553 // Print the block label and argument list if requested. 2554 if (printBlockArgs) { 2555 os.indent(currentIndent); 2556 printBlockName(block); 2557 2558 // Print the argument list if non-empty. 2559 if (!block->args_empty()) { 2560 os << '('; 2561 interleaveComma(block->getArguments(), [&](BlockArgument arg) { 2562 printValueID(arg); 2563 os << ": "; 2564 printType(arg.getType()); 2565 // TODO: We should allow location aliases on block arguments. 2566 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2567 }); 2568 os << ')'; 2569 } 2570 os << ':'; 2571 2572 // Print out some context information about the predecessors of this block. 2573 if (!block->getParent()) { 2574 os << " // block is not in a region!"; 2575 } else if (block->hasNoPredecessors()) { 2576 os << " // no predecessors"; 2577 } else if (auto *pred = block->getSinglePredecessor()) { 2578 os << " // pred: "; 2579 printBlockName(pred); 2580 } else { 2581 // We want to print the predecessors in increasing numeric order, not in 2582 // whatever order the use-list is in, so gather and sort them. 2583 SmallVector<std::pair<unsigned, Block *>, 4> predIDs; 2584 for (auto *pred : block->getPredecessors()) 2585 predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); 2586 llvm::array_pod_sort(predIDs.begin(), predIDs.end()); 2587 2588 os << " // " << predIDs.size() << " preds: "; 2589 2590 interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) { 2591 printBlockName(pred.second); 2592 }); 2593 } 2594 os << newLine; 2595 } 2596 2597 currentIndent += indentWidth; 2598 bool hasTerminator = 2599 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 2600 auto range = llvm::make_range( 2601 block->begin(), 2602 std::prev(block->end(), 2603 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 2604 for (auto &op : range) { 2605 print(&op); 2606 os << newLine; 2607 } 2608 currentIndent -= indentWidth; 2609 } 2610 2611 void OperationPrinter::printValueID(Value value, bool printResultNo, 2612 raw_ostream *streamOverride) const { 2613 state->getSSANameState().printValueID(value, printResultNo, 2614 streamOverride ? *streamOverride : os); 2615 } 2616 2617 void OperationPrinter::printSuccessor(Block *successor) { 2618 printBlockName(successor); 2619 } 2620 2621 void OperationPrinter::printSuccessorAndUseList(Block *successor, 2622 ValueRange succOperands) { 2623 printBlockName(successor); 2624 if (succOperands.empty()) 2625 return; 2626 2627 os << '('; 2628 interleaveComma(succOperands, 2629 [this](Value operand) { printValueID(operand); }); 2630 os << " : "; 2631 interleaveComma(succOperands, 2632 [this](Value operand) { printType(operand.getType()); }); 2633 os << ')'; 2634 } 2635 2636 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, 2637 bool printBlockTerminators, 2638 bool printEmptyBlock) { 2639 os << " {" << newLine; 2640 if (!region.empty()) { 2641 auto *entryBlock = ®ion.front(); 2642 // Force printing the block header if printEmptyBlock is set and the block 2643 // is empty or if printEntryBlockArgs is set and there are arguments to 2644 // print. 2645 bool shouldAlwaysPrintBlockHeader = 2646 (printEmptyBlock && entryBlock->empty()) || 2647 (printEntryBlockArgs && entryBlock->getNumArguments() != 0); 2648 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); 2649 for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) 2650 print(&b); 2651 } 2652 os.indent(currentIndent) << "}"; 2653 } 2654 2655 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2656 ValueRange operands) { 2657 AffineMap map = mapAttr.getValue(); 2658 unsigned numDims = map.getNumDims(); 2659 auto printValueName = [&](unsigned pos, bool isSymbol) { 2660 unsigned index = isSymbol ? numDims + pos : pos; 2661 assert(index < operands.size()); 2662 if (isSymbol) 2663 os << "symbol("; 2664 printValueID(operands[index]); 2665 if (isSymbol) 2666 os << ')'; 2667 }; 2668 2669 interleaveComma(map.getResults(), [&](AffineExpr expr) { 2670 printAffineExpr(expr, printValueName); 2671 }); 2672 } 2673 2674 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, 2675 ValueRange dimOperands, 2676 ValueRange symOperands) { 2677 auto printValueName = [&](unsigned pos, bool isSymbol) { 2678 if (!isSymbol) 2679 return printValueID(dimOperands[pos]); 2680 os << "symbol("; 2681 printValueID(symOperands[pos]); 2682 os << ')'; 2683 }; 2684 printAffineExpr(expr, printValueName); 2685 } 2686 2687 //===----------------------------------------------------------------------===// 2688 // print and dump methods 2689 //===----------------------------------------------------------------------===// 2690 2691 void Attribute::print(raw_ostream &os) const { 2692 ModulePrinter(os).printAttribute(*this); 2693 } 2694 2695 void Attribute::dump() const { 2696 print(llvm::errs()); 2697 llvm::errs() << "\n"; 2698 } 2699 2700 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); } 2701 2702 void Type::dump() { print(llvm::errs()); } 2703 2704 void AffineMap::dump() const { 2705 print(llvm::errs()); 2706 llvm::errs() << "\n"; 2707 } 2708 2709 void IntegerSet::dump() const { 2710 print(llvm::errs()); 2711 llvm::errs() << "\n"; 2712 } 2713 2714 void AffineExpr::print(raw_ostream &os) const { 2715 if (!expr) { 2716 os << "<<NULL AFFINE EXPR>>"; 2717 return; 2718 } 2719 ModulePrinter(os).printAffineExpr(*this); 2720 } 2721 2722 void AffineExpr::dump() const { 2723 print(llvm::errs()); 2724 llvm::errs() << "\n"; 2725 } 2726 2727 void AffineMap::print(raw_ostream &os) const { 2728 if (!map) { 2729 os << "<<NULL AFFINE MAP>>"; 2730 return; 2731 } 2732 ModulePrinter(os).printAffineMap(*this); 2733 } 2734 2735 void IntegerSet::print(raw_ostream &os) const { 2736 ModulePrinter(os).printIntegerSet(*this); 2737 } 2738 2739 void Value::print(raw_ostream &os) { 2740 if (auto *op = getDefiningOp()) 2741 return op->print(os); 2742 // TODO: Improve BlockArgument print'ing. 2743 BlockArgument arg = this->cast<BlockArgument>(); 2744 os << "<block argument> of type '" << arg.getType() 2745 << "' at index: " << arg.getArgNumber() << '\n'; 2746 } 2747 void Value::print(raw_ostream &os, AsmState &state) { 2748 if (auto *op = getDefiningOp()) 2749 return op->print(os, state); 2750 2751 // TODO: Improve BlockArgument print'ing. 2752 BlockArgument arg = this->cast<BlockArgument>(); 2753 os << "<block argument> of type '" << arg.getType() 2754 << "' at index: " << arg.getArgNumber() << '\n'; 2755 } 2756 2757 void Value::dump() { 2758 print(llvm::errs()); 2759 llvm::errs() << "\n"; 2760 } 2761 2762 void Value::printAsOperand(raw_ostream &os, AsmState &state) { 2763 // TODO: This doesn't necessarily capture all potential cases. 2764 // Currently, region arguments can be shadowed when printing the main 2765 // operation. If the IR hasn't been printed, this will produce the old SSA 2766 // name and not the shadowed name. 2767 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, 2768 os); 2769 } 2770 2771 void Operation::print(raw_ostream &os, OpPrintingFlags flags) { 2772 // If this is a top level operation, we also print aliases. 2773 if (!getParent() && !flags.shouldUseLocalScope()) { 2774 AsmState state(this); 2775 state.getImpl().initializeAliases(this, flags); 2776 print(os, state, flags); 2777 return; 2778 } 2779 2780 // Find the operation to number from based upon the provided flags. 2781 Operation *op = this; 2782 bool shouldUseLocalScope = flags.shouldUseLocalScope(); 2783 do { 2784 // If we are printing local scope, stop at the first operation that is 2785 // isolated from above. 2786 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 2787 break; 2788 2789 // Otherwise, traverse up to the next parent. 2790 Operation *parentOp = op->getParentOp(); 2791 if (!parentOp) 2792 break; 2793 op = parentOp; 2794 } while (true); 2795 2796 AsmState state(op); 2797 print(os, state, flags); 2798 } 2799 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { 2800 OperationPrinter printer(os, flags, state.getImpl()); 2801 if (!getParent() && !flags.shouldUseLocalScope()) 2802 printer.printTopLevelOperation(this); 2803 else 2804 printer.print(this); 2805 } 2806 2807 void Operation::dump() { 2808 print(llvm::errs(), OpPrintingFlags().useLocalScope()); 2809 llvm::errs() << "\n"; 2810 } 2811 2812 void Block::print(raw_ostream &os) { 2813 Operation *parentOp = getParentOp(); 2814 if (!parentOp) { 2815 os << "<<UNLINKED BLOCK>>\n"; 2816 return; 2817 } 2818 // Get the top-level op. 2819 while (auto *nextOp = parentOp->getParentOp()) 2820 parentOp = nextOp; 2821 2822 AsmState state(parentOp); 2823 print(os, state); 2824 } 2825 void Block::print(raw_ostream &os, AsmState &state) { 2826 OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); 2827 } 2828 2829 void Block::dump() { print(llvm::errs()); } 2830 2831 /// Print out the name of the block without printing its body. 2832 void Block::printAsOperand(raw_ostream &os, bool printType) { 2833 Operation *parentOp = getParentOp(); 2834 if (!parentOp) { 2835 os << "<<UNLINKED BLOCK>>\n"; 2836 return; 2837 } 2838 AsmState state(parentOp); 2839 printAsOperand(os, state); 2840 } 2841 void Block::printAsOperand(raw_ostream &os, AsmState &state) { 2842 OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); 2843 printer.printBlockName(this); 2844 } 2845