1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the MLIR AsmPrinter class, which is used to implement 10 // the various print() methods on the core IR objects. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/AffineExpr.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/AsmState.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/IntegerSet.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/Operation.h" 27 #include "mlir/IR/SubElementInterfaces.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/DenseMap.h" 30 #include "llvm/ADT/MapVector.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/ScopeExit.h" 33 #include "llvm/ADT/ScopedHashTable.h" 34 #include "llvm/ADT/SetVector.h" 35 #include "llvm/ADT/SmallString.h" 36 #include "llvm/ADT/StringExtras.h" 37 #include "llvm/ADT/StringSet.h" 38 #include "llvm/ADT/TypeSwitch.h" 39 #include "llvm/Support/CommandLine.h" 40 #include "llvm/Support/Endian.h" 41 #include "llvm/Support/Regex.h" 42 #include "llvm/Support/SaveAndRestore.h" 43 44 #include <tuple> 45 46 using namespace mlir; 47 using namespace mlir::detail; 48 49 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 50 51 void OperationName::dump() const { print(llvm::errs()); } 52 53 //===--------------------------------------------------------------------===// 54 // AsmParser 55 //===--------------------------------------------------------------------===// 56 57 AsmParser::~AsmParser() {} 58 DialectAsmParser::~DialectAsmParser() {} 59 OpAsmParser::~OpAsmParser() {} 60 61 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } 62 63 //===----------------------------------------------------------------------===// 64 // DialectAsmPrinter 65 //===----------------------------------------------------------------------===// 66 67 DialectAsmPrinter::~DialectAsmPrinter() {} 68 69 //===----------------------------------------------------------------------===// 70 // OpAsmPrinter 71 //===----------------------------------------------------------------------===// 72 73 OpAsmPrinter::~OpAsmPrinter() {} 74 75 void OpAsmPrinter::printFunctionalType(Operation *op) { 76 auto &os = getStream(); 77 os << '('; 78 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { 79 // Print the types of null values as <<NULL TYPE>>. 80 *this << (operand ? operand.getType() : Type()); 81 }); 82 os << ") -> "; 83 84 // Print the result list. We don't parenthesize single result types unless 85 // it is a function (avoiding a grammar ambiguity). 86 bool wrapped = op->getNumResults() != 1; 87 if (!wrapped && op->getResult(0).getType() && 88 op->getResult(0).getType().isa<FunctionType>()) 89 wrapped = true; 90 91 if (wrapped) 92 os << '('; 93 94 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { 95 // Print the types of null values as <<NULL TYPE>>. 96 *this << (result ? result.getType() : Type()); 97 }); 98 99 if (wrapped) 100 os << ')'; 101 } 102 103 //===----------------------------------------------------------------------===// 104 // Operation OpAsm interface. 105 //===----------------------------------------------------------------------===// 106 107 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 108 #include "mlir/IR/OpAsmInterface.cpp.inc" 109 110 //===----------------------------------------------------------------------===// 111 // OpPrintingFlags 112 //===----------------------------------------------------------------------===// 113 114 namespace { 115 /// This struct contains command line options that can be used to initialize 116 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need 117 /// for global command line options. 118 struct AsmPrinterOptions { 119 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ 120 "mlir-print-elementsattrs-with-hex-if-larger", 121 llvm::cl::desc( 122 "Print DenseElementsAttrs with a hex string that have " 123 "more elements than the given upper limit (use -1 to disable)")}; 124 125 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ 126 "mlir-elide-elementsattrs-if-larger", 127 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " 128 "more elements than the given upper limit")}; 129 130 llvm::cl::opt<bool> printDebugInfoOpt{ 131 "mlir-print-debuginfo", llvm::cl::init(false), 132 llvm::cl::desc("Print debug info in MLIR output")}; 133 134 llvm::cl::opt<bool> printPrettyDebugInfoOpt{ 135 "mlir-pretty-debuginfo", llvm::cl::init(false), 136 llvm::cl::desc("Print pretty debug info in MLIR output")}; 137 138 // Use the generic op output form in the operation printer even if the custom 139 // form is defined. 140 llvm::cl::opt<bool> printGenericOpFormOpt{ 141 "mlir-print-op-generic", llvm::cl::init(false), 142 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; 143 144 llvm::cl::opt<bool> printLocalScopeOpt{ 145 "mlir-print-local-scope", llvm::cl::init(false), 146 llvm::cl::desc("Print assuming in local scope by default"), 147 llvm::cl::Hidden}; 148 }; 149 } // end anonymous namespace 150 151 static llvm::ManagedStatic<AsmPrinterOptions> clOptions; 152 153 /// Register a set of useful command-line options that can be used to configure 154 /// various flags within the AsmPrinter. 155 void mlir::registerAsmPrinterCLOptions() { 156 // Make sure that the options struct has been initialized. 157 *clOptions; 158 } 159 160 /// Initialize the printing flags with default supplied by the cl::opts above. 161 OpPrintingFlags::OpPrintingFlags() 162 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), 163 printGenericOpFormFlag(false), printLocalScope(false) { 164 // Initialize based upon command line options, if they are available. 165 if (!clOptions.isConstructed()) 166 return; 167 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) 168 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; 169 printDebugInfoFlag = clOptions->printDebugInfoOpt; 170 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; 171 printGenericOpFormFlag = clOptions->printGenericOpFormOpt; 172 printLocalScope = clOptions->printLocalScopeOpt; 173 } 174 175 /// Enable the elision of large elements attributes, by printing a '...' 176 /// instead of the element data, when the number of elements is greater than 177 /// `largeElementLimit`. Note: The IR generated with this option is not 178 /// parsable. 179 OpPrintingFlags & 180 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { 181 elementsAttrElementLimit = largeElementLimit; 182 return *this; 183 } 184 185 /// Enable printing of debug information. If 'prettyForm' is set to true, 186 /// debug information is printed in a more readable 'pretty' form. 187 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { 188 printDebugInfoFlag = true; 189 printDebugInfoPrettyFormFlag = prettyForm; 190 return *this; 191 } 192 193 /// Always print operations in the generic form. 194 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { 195 printGenericOpFormFlag = true; 196 return *this; 197 } 198 199 /// Use local scope when printing the operation. This allows for using the 200 /// printer in a more localized and thread-safe setting, but may not necessarily 201 /// be identical of what the IR will look like when dumping the full module. 202 OpPrintingFlags &OpPrintingFlags::useLocalScope() { 203 printLocalScope = true; 204 return *this; 205 } 206 207 /// Return if the given ElementsAttr should be elided. 208 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { 209 return elementsAttrElementLimit.hasValue() && 210 *elementsAttrElementLimit < int64_t(attr.getNumElements()) && 211 !attr.isa<SplatElementsAttr>(); 212 } 213 214 /// Return the size limit for printing large ElementsAttr. 215 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { 216 return elementsAttrElementLimit; 217 } 218 219 /// Return if debug information should be printed. 220 bool OpPrintingFlags::shouldPrintDebugInfo() const { 221 return printDebugInfoFlag; 222 } 223 224 /// Return if debug information should be printed in the pretty form. 225 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { 226 return printDebugInfoPrettyFormFlag; 227 } 228 229 /// Return if operations should be printed in the generic form. 230 bool OpPrintingFlags::shouldPrintGenericOpForm() const { 231 return printGenericOpFormFlag; 232 } 233 234 /// Return if the printer should use local scope when dumping the IR. 235 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } 236 237 /// Returns true if an ElementsAttr with the given number of elements should be 238 /// printed with hex. 239 static bool shouldPrintElementsAttrWithHex(int64_t numElements) { 240 // Check to see if a command line option was provided for the limit. 241 if (clOptions.isConstructed()) { 242 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) { 243 // -1 is used to disable hex printing. 244 if (clOptions->printElementsAttrWithHexIfLarger == -1) 245 return false; 246 return numElements > clOptions->printElementsAttrWithHexIfLarger; 247 } 248 } 249 250 // Otherwise, default to printing with hex if the number of elements is >100. 251 return numElements > 100; 252 } 253 254 //===----------------------------------------------------------------------===// 255 // NewLineCounter 256 //===----------------------------------------------------------------------===// 257 258 namespace { 259 /// This class is a simple formatter that emits a new line when inputted into a 260 /// stream, that enables counting the number of newlines emitted. This class 261 /// should be used whenever emitting newlines in the printer. 262 struct NewLineCounter { 263 unsigned curLine = 1; 264 }; 265 266 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { 267 ++newLine.curLine; 268 return os << '\n'; 269 } 270 } // end anonymous namespace 271 272 //===----------------------------------------------------------------------===// 273 // AliasInitializer 274 //===----------------------------------------------------------------------===// 275 276 namespace { 277 /// This class represents a specific instance of a symbol Alias. 278 class SymbolAlias { 279 public: 280 SymbolAlias(StringRef name, bool isDeferrable) 281 : name(name), suffixIndex(0), hasSuffixIndex(false), 282 isDeferrable(isDeferrable) {} 283 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) 284 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true), 285 isDeferrable(isDeferrable) {} 286 287 /// Print this alias to the given stream. 288 void print(raw_ostream &os) const { 289 os << name; 290 if (hasSuffixIndex) 291 os << suffixIndex; 292 } 293 294 /// Returns true if this alias supports deferred resolution when parsing. 295 bool canBeDeferred() const { return isDeferrable; } 296 297 private: 298 /// The main name of the alias. 299 StringRef name; 300 /// The optional suffix index of the alias, if multiple aliases had the same 301 /// name. 302 uint32_t suffixIndex : 30; 303 /// A flag indicating whether this alias has a suffix or not. 304 bool hasSuffixIndex : 1; 305 /// A flag indicating whether this alias may be deferred or not. 306 bool isDeferrable : 1; 307 }; 308 309 /// This class represents a utility that initializes the set of attribute and 310 /// type aliases, without the need to store the extra information within the 311 /// main AliasState class or pass it around via function arguments. 312 class AliasInitializer { 313 public: 314 AliasInitializer( 315 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, 316 llvm::BumpPtrAllocator &aliasAllocator) 317 : interfaces(interfaces), aliasAllocator(aliasAllocator), 318 aliasOS(aliasBuffer) {} 319 320 void initialize(Operation *op, const OpPrintingFlags &printerFlags, 321 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 322 llvm::MapVector<Type, SymbolAlias> &typeToAlias); 323 324 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is 325 /// set to true if the originator of this attribute can resolve the alias 326 /// after parsing has completed (e.g. in the case of operation locations). 327 void visit(Attribute attr, bool canBeDeferred = false); 328 329 /// Visit the given type to see if it has an alias. 330 void visit(Type type); 331 332 private: 333 /// Try to generate an alias for the provided symbol. If an alias is 334 /// generated, the provided alias mapping and reverse mapping are updated. 335 /// Returns success if an alias was generated, failure otherwise. 336 template <typename T> 337 LogicalResult 338 generateAlias(T symbol, 339 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol); 340 341 /// The set of asm interfaces within the context. 342 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 343 344 /// Mapping between an alias and the set of symbols mapped to it. 345 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr; 346 llvm::MapVector<StringRef, std::vector<Type>> aliasToType; 347 348 /// An allocator used for alias names. 349 llvm::BumpPtrAllocator &aliasAllocator; 350 351 /// The set of visited attributes. 352 DenseSet<Attribute> visitedAttributes; 353 354 /// The set of attributes that have aliases *and* can be deferred. 355 DenseSet<Attribute> deferrableAttributes; 356 357 /// The set of visited types. 358 DenseSet<Type> visitedTypes; 359 360 /// Storage and stream used when generating an alias. 361 SmallString<32> aliasBuffer; 362 llvm::raw_svector_ostream aliasOS; 363 }; 364 365 /// This class implements a dummy OpAsmPrinter that doesn't print any output, 366 /// and merely collects the attributes and types that *would* be printed in a 367 /// normal print invocation so that we can generate proper aliases. This allows 368 /// for us to generate aliases only for the attributes and types that would be 369 /// in the output, and trims down unnecessary output. 370 class DummyAliasOperationPrinter : private OpAsmPrinter { 371 public: 372 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, 373 AliasInitializer &initializer) 374 : printerFlags(printerFlags), initializer(initializer) {} 375 376 /// Print the given operation. 377 void print(Operation *op) { 378 // Visit the operation location. 379 if (printerFlags.shouldPrintDebugInfo()) 380 initializer.visit(op->getLoc(), /*canBeDeferred=*/true); 381 382 // If requested, always print the generic form. 383 if (!printerFlags.shouldPrintGenericOpForm()) { 384 // Check to see if this is a known operation. If so, use the registered 385 // custom printer hook. 386 if (auto opInfo = op->getRegisteredInfo()) { 387 opInfo->printAssembly(op, *this, /*defaultDialect=*/""); 388 return; 389 } 390 } 391 392 // Otherwise print with the generic assembly form. 393 printGenericOp(op); 394 } 395 396 private: 397 /// Print the given operation in the generic form. 398 void printGenericOp(Operation *op, bool printOpName = true) override { 399 // Consider nested operations for aliases. 400 if (op->getNumRegions() != 0) { 401 for (Region ®ion : op->getRegions()) 402 printRegion(region, /*printEntryBlockArgs=*/true, 403 /*printBlockTerminators=*/true); 404 } 405 406 // Visit all the types used in the operation. 407 for (Type type : op->getOperandTypes()) 408 printType(type); 409 for (Type type : op->getResultTypes()) 410 printType(type); 411 412 // Consider the attributes of the operation for aliases. 413 for (const NamedAttribute &attr : op->getAttrs()) 414 printAttribute(attr.getValue()); 415 } 416 417 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 418 /// block are not printed. If 'printBlockTerminator' is false, the terminator 419 /// operation of the block is not printed. 420 void print(Block *block, bool printBlockArgs = true, 421 bool printBlockTerminator = true) { 422 // Consider the types of the block arguments for aliases if 'printBlockArgs' 423 // is set to true. 424 if (printBlockArgs) { 425 for (BlockArgument arg : block->getArguments()) { 426 printType(arg.getType()); 427 428 // Visit the argument location. 429 if (printerFlags.shouldPrintDebugInfo()) 430 // TODO: Allow deferring argument locations. 431 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 432 } 433 } 434 435 // Consider the operations within this block, ignoring the terminator if 436 // requested. 437 bool hasTerminator = 438 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 439 auto range = llvm::make_range( 440 block->begin(), 441 std::prev(block->end(), 442 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 443 for (Operation &op : range) 444 print(&op); 445 } 446 447 /// Print the given region. 448 void printRegion(Region ®ion, bool printEntryBlockArgs, 449 bool printBlockTerminators, 450 bool printEmptyBlock = false) override { 451 if (region.empty()) 452 return; 453 454 auto *entryBlock = ®ion.front(); 455 print(entryBlock, printEntryBlockArgs, printBlockTerminators); 456 for (Block &b : llvm::drop_begin(region, 1)) 457 print(&b); 458 } 459 460 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, 461 bool omitType) override { 462 printType(arg.getType()); 463 // Visit the argument location. 464 if (printerFlags.shouldPrintDebugInfo()) 465 // TODO: Allow deferring argument locations. 466 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 467 } 468 469 /// Consider the given type to be printed for an alias. 470 void printType(Type type) override { initializer.visit(type); } 471 472 /// Consider the given attribute to be printed for an alias. 473 void printAttribute(Attribute attr) override { initializer.visit(attr); } 474 void printAttributeWithoutType(Attribute attr) override { 475 printAttribute(attr); 476 } 477 478 /// Print the given set of attributes with names not included within 479 /// 'elidedAttrs'. 480 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 481 ArrayRef<StringRef> elidedAttrs = {}) override { 482 if (attrs.empty()) 483 return; 484 if (elidedAttrs.empty()) { 485 for (const NamedAttribute &attr : attrs) 486 printAttribute(attr.getValue()); 487 return; 488 } 489 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 490 elidedAttrs.end()); 491 for (const NamedAttribute &attr : attrs) 492 if (!elidedAttrsSet.contains(attr.getName().strref())) 493 printAttribute(attr.getValue()); 494 } 495 void printOptionalAttrDictWithKeyword( 496 ArrayRef<NamedAttribute> attrs, 497 ArrayRef<StringRef> elidedAttrs = {}) override { 498 printOptionalAttrDict(attrs, elidedAttrs); 499 } 500 501 /// Return a null stream as the output stream, this will ignore any data fed 502 /// to it. 503 raw_ostream &getStream() const override { return os; } 504 505 /// The following are hooks of `OpAsmPrinter` that are not necessary for 506 /// determining potential aliases. 507 void printFloat(const APFloat &value) override {} 508 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} 509 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} 510 void printNewline() override {} 511 void printOperand(Value) override {} 512 void printOperand(Value, raw_ostream &os) override { 513 // Users expect the output string to have at least the prefixed % to signal 514 // a value name. To maintain this invariant, emit a name even if it is 515 // guaranteed to go unused. 516 os << "%"; 517 } 518 void printKeywordOrString(StringRef) override {} 519 void printSymbolName(StringRef) override {} 520 void printSuccessor(Block *) override {} 521 void printSuccessorAndUseList(Block *, ValueRange) override {} 522 void shadowRegionArgs(Region &, ValueRange) override {} 523 524 /// The printer flags to use when determining potential aliases. 525 const OpPrintingFlags &printerFlags; 526 527 /// The initializer to use when identifying aliases. 528 AliasInitializer &initializer; 529 530 /// A dummy output stream. 531 mutable llvm::raw_null_ostream os; 532 }; 533 } // end anonymous namespace 534 535 /// Sanitize the given name such that it can be used as a valid identifier. If 536 /// the string needs to be modified in any way, the provided buffer is used to 537 /// store the new copy, 538 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, 539 StringRef allowedPunctChars = "$._-", 540 bool allowTrailingDigit = true) { 541 assert(!name.empty() && "Shouldn't have an empty name here"); 542 543 auto copyNameToBuffer = [&] { 544 for (char ch : name) { 545 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) 546 buffer.push_back(ch); 547 else if (ch == ' ') 548 buffer.push_back('_'); 549 else 550 buffer.append(llvm::utohexstr((unsigned char)ch)); 551 } 552 }; 553 554 // Check to see if this name is valid. If it starts with a digit, then it 555 // could conflict with the autogenerated numeric ID's, so add an underscore 556 // prefix to avoid problems. 557 if (isdigit(name[0])) { 558 buffer.push_back('_'); 559 copyNameToBuffer(); 560 return buffer; 561 } 562 563 // If the name ends with a trailing digit, add a '_' to avoid potential 564 // conflicts with autogenerated ID's. 565 if (!allowTrailingDigit && isdigit(name.back())) { 566 copyNameToBuffer(); 567 buffer.push_back('_'); 568 return buffer; 569 } 570 571 // Check to see that the name consists of only valid identifier characters. 572 for (char ch : name) { 573 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { 574 copyNameToBuffer(); 575 return buffer; 576 } 577 } 578 579 // If there are no invalid characters, return the original name. 580 return name; 581 } 582 583 /// Given a collection of aliases and symbols, initialize a mapping from a 584 /// symbol to a given alias. 585 template <typename T> 586 static void 587 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol, 588 llvm::MapVector<T, SymbolAlias> &symbolToAlias, 589 DenseSet<T> *deferrableAliases = nullptr) { 590 std::vector<std::pair<StringRef, std::vector<T>>> aliases = 591 aliasToSymbol.takeVector(); 592 llvm::array_pod_sort(aliases.begin(), aliases.end(), 593 [](const auto *lhs, const auto *rhs) { 594 return lhs->first.compare(rhs->first); 595 }); 596 597 for (auto &it : aliases) { 598 // If there is only one instance for this alias, use the name directly. 599 if (it.second.size() == 1) { 600 T symbol = it.second.front(); 601 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 602 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)}); 603 continue; 604 } 605 // Otherwise, add the index to the name. 606 for (int i = 0, e = it.second.size(); i < e; ++i) { 607 T symbol = it.second[i]; 608 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 609 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)}); 610 } 611 } 612 } 613 614 void AliasInitializer::initialize( 615 Operation *op, const OpPrintingFlags &printerFlags, 616 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 617 llvm::MapVector<Type, SymbolAlias> &typeToAlias) { 618 // Use a dummy printer when walking the IR so that we can collect the 619 // attributes/types that will actually be used during printing when 620 // considering aliases. 621 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); 622 aliasPrinter.print(op); 623 624 // Initialize the aliases sorted by name. 625 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); 626 initializeAliases(aliasToType, typeToAlias); 627 } 628 629 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { 630 if (!visitedAttributes.insert(attr).second) { 631 // If this attribute already has an alias and this instance can't be 632 // deferred, make sure that the alias isn't deferred. 633 if (!canBeDeferred) 634 deferrableAttributes.erase(attr); 635 return; 636 } 637 638 // Try to generate an alias for this attribute. 639 if (succeeded(generateAlias(attr, aliasToAttr))) { 640 if (canBeDeferred) 641 deferrableAttributes.insert(attr); 642 return; 643 } 644 645 // Check for any sub elements. 646 if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) { 647 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 648 [&](Type type) { visit(type); }); 649 } 650 } 651 652 void AliasInitializer::visit(Type type) { 653 if (!visitedTypes.insert(type).second) 654 return; 655 656 // Try to generate an alias for this type. 657 if (succeeded(generateAlias(type, aliasToType))) 658 return; 659 660 // Check for any sub elements. 661 if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) { 662 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 663 [&](Type type) { visit(type); }); 664 } 665 } 666 667 template <typename T> 668 LogicalResult AliasInitializer::generateAlias( 669 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) { 670 SmallString<32> nameBuffer; 671 for (const auto &interface : interfaces) { 672 OpAsmDialectInterface::AliasResult result = 673 interface.getAlias(symbol, aliasOS); 674 if (result == OpAsmDialectInterface::AliasResult::NoAlias) 675 continue; 676 nameBuffer = std::move(aliasBuffer); 677 assert(!nameBuffer.empty() && "expected valid alias name"); 678 if (result == OpAsmDialectInterface::AliasResult::FinalAlias) 679 break; 680 } 681 682 if (nameBuffer.empty()) 683 return failure(); 684 685 SmallString<16> tempBuffer; 686 StringRef name = 687 sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", 688 /*allowTrailingDigit=*/false); 689 name = name.copy(aliasAllocator); 690 aliasToSymbol[name].push_back(symbol); 691 return success(); 692 } 693 694 //===----------------------------------------------------------------------===// 695 // AliasState 696 //===----------------------------------------------------------------------===// 697 698 namespace { 699 /// This class manages the state for type and attribute aliases. 700 class AliasState { 701 public: 702 // Initialize the internal aliases. 703 void 704 initialize(Operation *op, const OpPrintingFlags &printerFlags, 705 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 706 707 /// Get an alias for the given attribute if it has one and print it in `os`. 708 /// Returns success if an alias was printed, failure otherwise. 709 LogicalResult getAlias(Attribute attr, raw_ostream &os) const; 710 711 /// Get an alias for the given type if it has one and print it in `os`. 712 /// Returns success if an alias was printed, failure otherwise. 713 LogicalResult getAlias(Type ty, raw_ostream &os) const; 714 715 /// Print all of the referenced aliases that can not be resolved in a deferred 716 /// manner. 717 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 718 printAliases(os, newLine, /*isDeferred=*/false); 719 } 720 721 /// Print all of the referenced aliases that support deferred resolution. 722 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 723 printAliases(os, newLine, /*isDeferred=*/true); 724 } 725 726 private: 727 /// Print all of the referenced aliases that support the provided resolution 728 /// behavior. 729 void printAliases(raw_ostream &os, NewLineCounter &newLine, 730 bool isDeferred) const; 731 732 /// Mapping between attribute and alias. 733 llvm::MapVector<Attribute, SymbolAlias> attrToAlias; 734 /// Mapping between type and alias. 735 llvm::MapVector<Type, SymbolAlias> typeToAlias; 736 737 /// An allocator used for alias names. 738 llvm::BumpPtrAllocator aliasAllocator; 739 }; 740 } // end anonymous namespace 741 742 void AliasState::initialize( 743 Operation *op, const OpPrintingFlags &printerFlags, 744 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 745 AliasInitializer initializer(interfaces, aliasAllocator); 746 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); 747 } 748 749 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { 750 auto it = attrToAlias.find(attr); 751 if (it == attrToAlias.end()) 752 return failure(); 753 it->second.print(os << '#'); 754 return success(); 755 } 756 757 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { 758 auto it = typeToAlias.find(ty); 759 if (it == typeToAlias.end()) 760 return failure(); 761 762 it->second.print(os << '!'); 763 return success(); 764 } 765 766 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine, 767 bool isDeferred) const { 768 auto filterFn = [=](const auto &aliasIt) { 769 return aliasIt.second.canBeDeferred() == isDeferred; 770 }; 771 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) { 772 it.second.print(os << '#'); 773 os << " = " << it.first << newLine; 774 } 775 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) { 776 it.second.print(os << '!'); 777 os << " = type " << it.first << newLine; 778 } 779 } 780 781 //===----------------------------------------------------------------------===// 782 // SSANameState 783 //===----------------------------------------------------------------------===// 784 785 namespace { 786 /// This class manages the state of SSA value names. 787 class SSANameState { 788 public: 789 /// A sentinel value used for values with names set. 790 enum : unsigned { NameSentinel = ~0U }; 791 792 SSANameState(Operation *op, const OpPrintingFlags &printerFlags, 793 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 794 795 /// Print the SSA identifier for the given value to 'stream'. If 796 /// 'printResultNo' is true, it also presents the result number ('#' number) 797 /// of this value. 798 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; 799 800 /// Return the result indices for each of the result groups registered by this 801 /// operation, or empty if none exist. 802 ArrayRef<int> getOpResultGroups(Operation *op); 803 804 /// Get the ID for the given block. 805 unsigned getBlockID(Block *block); 806 807 /// Renumber the arguments for the specified region to the same names as the 808 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for 809 /// details. 810 void shadowRegionArgs(Region ®ion, ValueRange namesToUse); 811 812 private: 813 /// Number the SSA values within the given IR unit. 814 void numberValuesInRegion(Region ®ion); 815 void numberValuesInBlock(Block &block); 816 void numberValuesInOp(Operation &op); 817 818 /// Given a result of an operation 'result', find the result group head 819 /// 'lookupValue' and the result of 'result' within that group in 820 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group 821 /// has more than 1 result. 822 void getResultIDAndNumber(OpResult result, Value &lookupValue, 823 Optional<int> &lookupResultNo) const; 824 825 /// Set a special value name for the given value. 826 void setValueName(Value value, StringRef name); 827 828 /// Uniques the given value name within the printer. If the given name 829 /// conflicts, it is automatically renamed. 830 StringRef uniqueValueName(StringRef name); 831 832 /// This is the value ID for each SSA value. If this returns NameSentinel, 833 /// then the valueID has an entry in valueNames. 834 DenseMap<Value, unsigned> valueIDs; 835 DenseMap<Value, StringRef> valueNames; 836 837 /// This is a map of operations that contain multiple named result groups, 838 /// i.e. there may be multiple names for the results of the operation. The 839 /// value of this map are the result numbers that start a result group. 840 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; 841 842 /// This is the block ID for each block in the current. 843 DenseMap<Block *, unsigned> blockIDs; 844 845 /// This keeps track of all of the non-numeric names that are in flight, 846 /// allowing us to check for duplicates. 847 /// Note: the value of the map is unused. 848 llvm::ScopedHashTable<StringRef, char> usedNames; 849 llvm::BumpPtrAllocator usedNameAllocator; 850 851 /// This is the next value ID to assign in numbering. 852 unsigned nextValueID = 0; 853 /// This is the next ID to assign to a region entry block argument. 854 unsigned nextArgumentID = 0; 855 /// This is the next ID to assign when a name conflict is detected. 856 unsigned nextConflictID = 0; 857 858 /// These are the printing flags. They control, eg., whether to print in 859 /// generic form. 860 OpPrintingFlags printerFlags; 861 862 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 863 }; 864 } // end anonymous namespace 865 866 SSANameState::SSANameState( 867 Operation *op, const OpPrintingFlags &printerFlags, 868 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) 869 : printerFlags(printerFlags), interfaces(interfaces) { 870 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); 871 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); 872 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); 873 874 // The naming context includes `nextValueID`, `nextArgumentID`, 875 // `nextConflictID` and `usedNames` scoped HashTable. This information is 876 // carried from the parent region. 877 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; 878 using NamingContext = 879 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; 880 881 // Allocator for UsedNamesScopeTy 882 llvm::BumpPtrAllocator allocator; 883 884 // Add a scope for the top level operation. 885 auto *topLevelNamesScope = 886 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); 887 888 SmallVector<NamingContext, 8> nameContext; 889 for (Region ®ion : op->getRegions()) 890 nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, 891 nextConflictID, topLevelNamesScope)); 892 893 numberValuesInOp(*op); 894 895 while (!nameContext.empty()) { 896 Region *region; 897 UsedNamesScopeTy *parentScope; 898 std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) = 899 nameContext.pop_back_val(); 900 901 // When we switch from one subtree to another, pop the scopes(needless) 902 // until the parent scope. 903 while (usedNames.getCurScope() != parentScope) { 904 usedNames.getCurScope()->~UsedNamesScopeTy(); 905 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && 906 "top level parentScope must be a nullptr"); 907 } 908 909 // Add a scope for the current region. 910 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) 911 UsedNamesScopeTy(usedNames); 912 913 numberValuesInRegion(*region); 914 915 for (Operation &op : region->getOps()) 916 for (Region ®ion : op.getRegions()) 917 nameContext.push_back(std::make_tuple(®ion, nextValueID, 918 nextArgumentID, nextConflictID, 919 curNamesScope)); 920 } 921 922 // Manually remove all the scopes. 923 while (usedNames.getCurScope() != nullptr) 924 usedNames.getCurScope()->~UsedNamesScopeTy(); 925 } 926 927 void SSANameState::printValueID(Value value, bool printResultNo, 928 raw_ostream &stream) const { 929 if (!value) { 930 stream << "<<NULL>>"; 931 return; 932 } 933 934 Optional<int> resultNo; 935 auto lookupValue = value; 936 937 // If this is an operation result, collect the head lookup value of the result 938 // group and the result number of 'result' within that group. 939 if (OpResult result = value.dyn_cast<OpResult>()) 940 getResultIDAndNumber(result, lookupValue, resultNo); 941 942 auto it = valueIDs.find(lookupValue); 943 if (it == valueIDs.end()) { 944 stream << "<<UNKNOWN SSA VALUE>>"; 945 return; 946 } 947 948 stream << '%'; 949 if (it->second != NameSentinel) { 950 stream << it->second; 951 } else { 952 auto nameIt = valueNames.find(lookupValue); 953 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 954 stream << nameIt->second; 955 } 956 957 if (resultNo.hasValue() && printResultNo) 958 stream << '#' << resultNo; 959 } 960 961 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { 962 auto it = opResultGroups.find(op); 963 return it == opResultGroups.end() ? ArrayRef<int>() : it->second; 964 } 965 966 unsigned SSANameState::getBlockID(Block *block) { 967 auto it = blockIDs.find(block); 968 return it != blockIDs.end() ? it->second : NameSentinel; 969 } 970 971 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { 972 assert(!region.empty() && "cannot shadow arguments of an empty region"); 973 assert(region.getNumArguments() == namesToUse.size() && 974 "incorrect number of names passed in"); 975 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 976 "only KnownIsolatedFromAbove ops can shadow names"); 977 978 SmallVector<char, 16> nameStr; 979 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 980 auto nameToUse = namesToUse[i]; 981 if (nameToUse == nullptr) 982 continue; 983 auto nameToReplace = region.getArgument(i); 984 985 nameStr.clear(); 986 llvm::raw_svector_ostream nameStream(nameStr); 987 printValueID(nameToUse, /*printResultNo=*/true, nameStream); 988 989 // Entry block arguments should already have a pretty "arg" name. 990 assert(valueIDs[nameToReplace] == NameSentinel); 991 992 // Use the name without the leading %. 993 auto name = StringRef(nameStream.str()).drop_front(); 994 995 // Overwrite the name. 996 valueNames[nameToReplace] = name.copy(usedNameAllocator); 997 } 998 } 999 1000 void SSANameState::numberValuesInRegion(Region ®ion) { 1001 // Number the values within this region in a breadth-first order. 1002 unsigned nextBlockID = 0; 1003 for (auto &block : region) { 1004 // Each block gets a unique ID, and all of the operations within it get 1005 // numbered as well. 1006 blockIDs[&block] = nextBlockID++; 1007 numberValuesInBlock(block); 1008 } 1009 } 1010 1011 void SSANameState::numberValuesInBlock(Block &block) { 1012 auto setArgNameFn = [&](Value arg, StringRef name) { 1013 assert(!valueIDs.count(arg) && "arg numbered multiple times"); 1014 assert(arg.cast<BlockArgument>().getOwner() == &block && 1015 "arg not defined in 'block'"); 1016 setValueName(arg, name); 1017 }; 1018 1019 bool isEntryBlock = block.isEntryBlock(); 1020 if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) { 1021 if (auto *op = block.getParentOp()) { 1022 if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) 1023 asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); 1024 } 1025 } 1026 1027 // Number the block arguments. We give entry block arguments a special name 1028 // 'arg'. 1029 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); 1030 llvm::raw_svector_ostream specialName(specialNameBuffer); 1031 for (auto arg : block.getArguments()) { 1032 if (valueIDs.count(arg)) 1033 continue; 1034 if (isEntryBlock) { 1035 specialNameBuffer.resize(strlen("arg")); 1036 specialName << nextArgumentID++; 1037 } 1038 setValueName(arg, specialName.str()); 1039 } 1040 1041 // Number the operations in this block. 1042 for (auto &op : block) 1043 numberValuesInOp(op); 1044 } 1045 1046 void SSANameState::numberValuesInOp(Operation &op) { 1047 unsigned numResults = op.getNumResults(); 1048 if (numResults == 0) 1049 return; 1050 Value resultBegin = op.getResult(0); 1051 1052 // Function used to set the special result names for the operation. 1053 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); 1054 auto setResultNameFn = [&](Value result, StringRef name) { 1055 assert(!valueIDs.count(result) && "result numbered multiple times"); 1056 assert(result.getDefiningOp() == &op && "result not defined by 'op'"); 1057 setValueName(result, name); 1058 1059 // Record the result number for groups not anchored at 0. 1060 if (int resultNo = result.cast<OpResult>().getResultNumber()) 1061 resultGroups.push_back(resultNo); 1062 }; 1063 if (!printerFlags.shouldPrintGenericOpForm()) { 1064 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) 1065 asmInterface.getAsmResultNames(setResultNameFn); 1066 else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) 1067 asmInterface->getAsmResultNames(&op, setResultNameFn); 1068 } 1069 1070 // If the first result wasn't numbered, give it a default number. 1071 if (valueIDs.try_emplace(resultBegin, nextValueID).second) 1072 ++nextValueID; 1073 1074 // If this operation has multiple result groups, mark it. 1075 if (resultGroups.size() != 1) { 1076 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); 1077 opResultGroups.try_emplace(&op, std::move(resultGroups)); 1078 } 1079 } 1080 1081 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, 1082 Optional<int> &lookupResultNo) const { 1083 Operation *owner = result.getOwner(); 1084 if (owner->getNumResults() == 1) 1085 return; 1086 int resultNo = result.getResultNumber(); 1087 1088 // If this operation has multiple result groups, we will need to find the 1089 // one corresponding to this result. 1090 auto resultGroupIt = opResultGroups.find(owner); 1091 if (resultGroupIt == opResultGroups.end()) { 1092 // If not, just use the first result. 1093 lookupResultNo = resultNo; 1094 lookupValue = owner->getResult(0); 1095 return; 1096 } 1097 1098 // Find the correct index using a binary search, as the groups are ordered. 1099 ArrayRef<int> resultGroups = resultGroupIt->second; 1100 auto it = llvm::upper_bound(resultGroups, resultNo); 1101 int groupResultNo = 0, groupSize = 0; 1102 1103 // If there are no smaller elements, the last result group is the lookup. 1104 if (it == resultGroups.end()) { 1105 groupResultNo = resultGroups.back(); 1106 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); 1107 } else { 1108 // Otherwise, the previous element is the lookup. 1109 groupResultNo = *std::prev(it); 1110 groupSize = *it - groupResultNo; 1111 } 1112 1113 // We only record the result number for a group of size greater than 1. 1114 if (groupSize != 1) 1115 lookupResultNo = resultNo - groupResultNo; 1116 lookupValue = owner->getResult(groupResultNo); 1117 } 1118 1119 void SSANameState::setValueName(Value value, StringRef name) { 1120 // If the name is empty, the value uses the default numbering. 1121 if (name.empty()) { 1122 valueIDs[value] = nextValueID++; 1123 return; 1124 } 1125 1126 valueIDs[value] = NameSentinel; 1127 valueNames[value] = uniqueValueName(name); 1128 } 1129 1130 StringRef SSANameState::uniqueValueName(StringRef name) { 1131 SmallString<16> tmpBuffer; 1132 name = sanitizeIdentifier(name, tmpBuffer); 1133 1134 // Check to see if this name is already unique. 1135 if (!usedNames.count(name)) { 1136 name = name.copy(usedNameAllocator); 1137 } else { 1138 // Otherwise, we had a conflict - probe until we find a unique name. This 1139 // is guaranteed to terminate (and usually in a single iteration) because it 1140 // generates new names by incrementing nextConflictID. 1141 SmallString<64> probeName(name); 1142 probeName.push_back('_'); 1143 while (true) { 1144 probeName += llvm::utostr(nextConflictID++); 1145 if (!usedNames.count(probeName)) { 1146 name = probeName.str().copy(usedNameAllocator); 1147 break; 1148 } 1149 probeName.resize(name.size() + 1); 1150 } 1151 } 1152 1153 usedNames.insert(name, char()); 1154 return name; 1155 } 1156 1157 //===----------------------------------------------------------------------===// 1158 // AsmState 1159 //===----------------------------------------------------------------------===// 1160 1161 namespace mlir { 1162 namespace detail { 1163 class AsmStateImpl { 1164 public: 1165 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, 1166 AsmState::LocationMap *locationMap) 1167 : interfaces(op->getContext()), nameState(op, printerFlags, interfaces), 1168 printerFlags(printerFlags), locationMap(locationMap) {} 1169 1170 /// Initialize the alias state to enable the printing of aliases. 1171 void initializeAliases(Operation *op) { 1172 aliasState.initialize(op, printerFlags, interfaces); 1173 } 1174 1175 /// Get an instance of the OpAsmDialectInterface for the given dialect, or 1176 /// null if one wasn't registered. 1177 const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { 1178 return interfaces.getInterfaceFor(dialect); 1179 } 1180 1181 /// Get the state used for aliases. 1182 AliasState &getAliasState() { return aliasState; } 1183 1184 /// Get the state used for SSA names. 1185 SSANameState &getSSANameState() { return nameState; } 1186 1187 /// Register the location, line and column, within the buffer that the given 1188 /// operation was printed at. 1189 void registerOperationLocation(Operation *op, unsigned line, unsigned col) { 1190 if (locationMap) 1191 (*locationMap)[op] = std::make_pair(line, col); 1192 } 1193 1194 private: 1195 /// Collection of OpAsm interfaces implemented in the context. 1196 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 1197 1198 /// The state used for attribute and type aliases. 1199 AliasState aliasState; 1200 1201 /// The state used for SSA value names. 1202 SSANameState nameState; 1203 1204 /// Flags that control op output. 1205 OpPrintingFlags printerFlags; 1206 1207 /// An optional location map to be populated. 1208 AsmState::LocationMap *locationMap; 1209 }; 1210 } // end namespace detail 1211 } // end namespace mlir 1212 1213 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, 1214 LocationMap *locationMap) 1215 : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {} 1216 AsmState::~AsmState() {} 1217 1218 //===----------------------------------------------------------------------===// 1219 // AsmPrinter::Impl 1220 //===----------------------------------------------------------------------===// 1221 1222 namespace mlir { 1223 class AsmPrinter::Impl { 1224 public: 1225 Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None, 1226 AsmStateImpl *state = nullptr) 1227 : os(os), printerFlags(flags), state(state) {} 1228 explicit Impl(Impl &other) 1229 : Impl(other.os, other.printerFlags, other.state) {} 1230 1231 /// Returns the output stream of the printer. 1232 raw_ostream &getStream() { return os; } 1233 1234 template <typename Container, typename UnaryFunctor> 1235 inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { 1236 llvm::interleaveComma(c, os, each_fn); 1237 } 1238 1239 /// This enum describes the different kinds of elision for the type of an 1240 /// attribute when printing it. 1241 enum class AttrTypeElision { 1242 /// The type must not be elided, 1243 Never, 1244 /// The type may be elided when it matches the default used in the parser 1245 /// (for example i64 is the default for integer attributes). 1246 May, 1247 /// The type must be elided. 1248 Must 1249 }; 1250 1251 /// Print the given attribute. 1252 void printAttribute(Attribute attr, 1253 AttrTypeElision typeElision = AttrTypeElision::Never); 1254 1255 void printType(Type type); 1256 1257 /// Print the given location to the stream. If `allowAlias` is true, this 1258 /// allows for the internal location to use an attribute alias. 1259 void printLocation(LocationAttr loc, bool allowAlias = false); 1260 1261 void printAffineMap(AffineMap map); 1262 void 1263 printAffineExpr(AffineExpr expr, 1264 function_ref<void(unsigned, bool)> printValueName = nullptr); 1265 void printAffineConstraint(AffineExpr expr, bool isEq); 1266 void printIntegerSet(IntegerSet set); 1267 1268 protected: 1269 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1270 ArrayRef<StringRef> elidedAttrs = {}, 1271 bool withKeyword = false); 1272 void printNamedAttribute(NamedAttribute attr); 1273 void printTrailingLocation(Location loc, bool allowAlias = true); 1274 void printLocationInternal(LocationAttr loc, bool pretty = false); 1275 1276 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1277 /// used instead of individual elements when the elements attr is large. 1278 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); 1279 1280 /// Print a dense string elements attribute. 1281 void printDenseStringElementsAttr(DenseStringElementsAttr attr); 1282 1283 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1284 /// used instead of individual elements when the elements attr is large. 1285 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1286 bool allowHex); 1287 1288 void printDialectAttribute(Attribute attr); 1289 void printDialectType(Type type); 1290 1291 /// This enum is used to represent the binding strength of the enclosing 1292 /// context that an AffineExprStorage is being printed in, so we can 1293 /// intelligently produce parens. 1294 enum class BindingStrength { 1295 Weak, // + and - 1296 Strong, // All other binary operators. 1297 }; 1298 void printAffineExprInternal( 1299 AffineExpr expr, BindingStrength enclosingTightness, 1300 function_ref<void(unsigned, bool)> printValueName = nullptr); 1301 1302 /// The output stream for the printer. 1303 raw_ostream &os; 1304 1305 /// A set of flags to control the printer's behavior. 1306 OpPrintingFlags printerFlags; 1307 1308 /// An optional printer state for the module. 1309 AsmStateImpl *state; 1310 1311 /// A tracker for the number of new lines emitted during printing. 1312 NewLineCounter newLine; 1313 }; 1314 } // namespace mlir 1315 1316 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { 1317 // Check to see if we are printing debug information. 1318 if (!printerFlags.shouldPrintDebugInfo()) 1319 return; 1320 1321 os << " "; 1322 printLocation(loc, /*allowAlias=*/allowAlias); 1323 } 1324 1325 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) { 1326 TypeSwitch<LocationAttr>(loc) 1327 .Case<OpaqueLoc>([&](OpaqueLoc loc) { 1328 printLocationInternal(loc.getFallbackLocation(), pretty); 1329 }) 1330 .Case<UnknownLoc>([&](UnknownLoc loc) { 1331 if (pretty) 1332 os << "[unknown]"; 1333 else 1334 os << "unknown"; 1335 }) 1336 .Case<FileLineColLoc>([&](FileLineColLoc loc) { 1337 if (pretty) { 1338 os << loc.getFilename().getValue(); 1339 } else { 1340 os << "\""; 1341 printEscapedString(loc.getFilename(), os); 1342 os << "\""; 1343 } 1344 os << ':' << loc.getLine() << ':' << loc.getColumn(); 1345 }) 1346 .Case<NameLoc>([&](NameLoc loc) { 1347 os << '\"'; 1348 printEscapedString(loc.getName(), os); 1349 os << '\"'; 1350 1351 // Print the child if it isn't unknown. 1352 auto childLoc = loc.getChildLoc(); 1353 if (!childLoc.isa<UnknownLoc>()) { 1354 os << '('; 1355 printLocationInternal(childLoc, pretty); 1356 os << ')'; 1357 } 1358 }) 1359 .Case<CallSiteLoc>([&](CallSiteLoc loc) { 1360 Location caller = loc.getCaller(); 1361 Location callee = loc.getCallee(); 1362 if (!pretty) 1363 os << "callsite("; 1364 printLocationInternal(callee, pretty); 1365 if (pretty) { 1366 if (callee.isa<NameLoc>()) { 1367 if (caller.isa<FileLineColLoc>()) { 1368 os << " at "; 1369 } else { 1370 os << newLine << " at "; 1371 } 1372 } else { 1373 os << newLine << " at "; 1374 } 1375 } else { 1376 os << " at "; 1377 } 1378 printLocationInternal(caller, pretty); 1379 if (!pretty) 1380 os << ")"; 1381 }) 1382 .Case<FusedLoc>([&](FusedLoc loc) { 1383 if (!pretty) 1384 os << "fused"; 1385 if (Attribute metadata = loc.getMetadata()) 1386 os << '<' << metadata << '>'; 1387 os << '['; 1388 interleave( 1389 loc.getLocations(), 1390 [&](Location loc) { printLocationInternal(loc, pretty); }, 1391 [&]() { os << ", "; }); 1392 os << ']'; 1393 }); 1394 } 1395 1396 /// Print a floating point value in a way that the parser will be able to 1397 /// round-trip losslessly. 1398 static void printFloatValue(const APFloat &apValue, raw_ostream &os) { 1399 // We would like to output the FP constant value in exponential notation, 1400 // but we cannot do this if doing so will lose precision. Check here to 1401 // make sure that we only output it in exponential format if we can parse 1402 // the value back and get the same value. 1403 bool isInf = apValue.isInfinity(); 1404 bool isNaN = apValue.isNaN(); 1405 if (!isInf && !isNaN) { 1406 SmallString<128> strValue; 1407 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, 1408 /*TruncateZero=*/false); 1409 1410 // Check to make sure that the stringized number is not some string like 1411 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 1412 // that the string matches the "[-+]?[0-9]" regex. 1413 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 1414 ((strValue[0] == '-' || strValue[0] == '+') && 1415 (strValue[1] >= '0' && strValue[1] <= '9'))) && 1416 "[-+]?[0-9] regex does not match!"); 1417 1418 // Parse back the stringized version and check that the value is equal 1419 // (i.e., there is no precision loss). 1420 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 1421 os << strValue; 1422 return; 1423 } 1424 1425 // If it is not, use the default format of APFloat instead of the 1426 // exponential notation. 1427 strValue.clear(); 1428 apValue.toString(strValue); 1429 1430 // Make sure that we can parse the default form as a float. 1431 if (strValue.str().contains('.')) { 1432 os << strValue; 1433 return; 1434 } 1435 } 1436 1437 // Print special values in hexadecimal format. The sign bit should be included 1438 // in the literal. 1439 SmallVector<char, 16> str; 1440 APInt apInt = apValue.bitcastToAPInt(); 1441 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 1442 /*formatAsCLiteral=*/true); 1443 os << str; 1444 } 1445 1446 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { 1447 if (printerFlags.shouldPrintDebugInfoPrettyForm()) 1448 return printLocationInternal(loc, /*pretty=*/true); 1449 1450 os << "loc("; 1451 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) 1452 printLocationInternal(loc); 1453 os << ')'; 1454 } 1455 1456 /// Returns true if the given dialect symbol data is simple enough to print in 1457 /// the pretty form, i.e. without the enclosing "". 1458 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 1459 // The name must start with an identifier. 1460 if (symName.empty() || !isalpha(symName.front())) 1461 return false; 1462 1463 // Ignore all the characters that are valid in an identifier in the symbol 1464 // name. 1465 symName = symName.drop_while( 1466 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); 1467 if (symName.empty()) 1468 return true; 1469 1470 // If we got to an unexpected character, then it must be a <>. Check those 1471 // recursively. 1472 if (symName.front() != '<' || symName.back() != '>') 1473 return false; 1474 1475 SmallVector<char, 8> nestedPunctuation; 1476 do { 1477 // If we ran out of characters, then we had a punctuation mismatch. 1478 if (symName.empty()) 1479 return false; 1480 1481 auto c = symName.front(); 1482 symName = symName.drop_front(); 1483 1484 switch (c) { 1485 // We never allow null characters. This is an EOF indicator for the lexer 1486 // which we could handle, but isn't important for any known dialect. 1487 case '\0': 1488 return false; 1489 case '<': 1490 case '[': 1491 case '(': 1492 case '{': 1493 nestedPunctuation.push_back(c); 1494 continue; 1495 case '-': 1496 // Treat `->` as a special token. 1497 if (!symName.empty() && symName.front() == '>') { 1498 symName = symName.drop_front(); 1499 continue; 1500 } 1501 break; 1502 // Reject types with mismatched brackets. 1503 case '>': 1504 if (nestedPunctuation.pop_back_val() != '<') 1505 return false; 1506 break; 1507 case ']': 1508 if (nestedPunctuation.pop_back_val() != '[') 1509 return false; 1510 break; 1511 case ')': 1512 if (nestedPunctuation.pop_back_val() != '(') 1513 return false; 1514 break; 1515 case '}': 1516 if (nestedPunctuation.pop_back_val() != '{') 1517 return false; 1518 break; 1519 default: 1520 continue; 1521 } 1522 1523 // We're done when the punctuation is fully matched. 1524 } while (!nestedPunctuation.empty()); 1525 1526 // If there were extra characters, then we failed. 1527 return symName.empty(); 1528 } 1529 1530 /// Print the given dialect symbol to the stream. 1531 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 1532 StringRef dialectName, StringRef symString) { 1533 os << symPrefix << dialectName; 1534 1535 // If this symbol name is simple enough, print it directly in pretty form, 1536 // otherwise, we print it as an escaped string. 1537 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 1538 os << '.' << symString; 1539 return; 1540 } 1541 1542 os << "<\""; 1543 llvm::printEscapedString(symString, os); 1544 os << "\">"; 1545 } 1546 1547 /// Returns true if the given string can be represented as a bare identifier. 1548 static bool isBareIdentifier(StringRef name) { 1549 // By making this unsigned, the value passed in to isalnum will always be 1550 // in the range 0-255. This is important when building with MSVC because 1551 // its implementation will assert. This situation can arise when dealing 1552 // with UTF-8 multibyte characters. 1553 if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) 1554 return false; 1555 return llvm::all_of(name.drop_front(), [](unsigned char c) { 1556 return isalnum(c) || c == '_' || c == '$' || c == '.'; 1557 }); 1558 } 1559 1560 /// Print the given string as a keyword, or a quoted and escaped string if it 1561 /// has any special or non-printable characters in it. 1562 static void printKeywordOrString(StringRef keyword, raw_ostream &os) { 1563 // If it can be represented as a bare identifier, write it directly. 1564 if (isBareIdentifier(keyword)) { 1565 os << keyword; 1566 return; 1567 } 1568 1569 // Otherwise, output the keyword wrapped in quotes with proper escaping. 1570 os << "\""; 1571 printEscapedString(keyword, os); 1572 os << '"'; 1573 } 1574 1575 /// Print the given string as a symbol reference. A symbol reference is 1576 /// represented as a string prefixed with '@'. The reference is surrounded with 1577 /// ""'s and escaped if it has any special or non-printable characters in it. 1578 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { 1579 assert(!symbolRef.empty() && "expected valid symbol reference"); 1580 os << '@'; 1581 printKeywordOrString(symbolRef, os); 1582 } 1583 1584 // Print out a valid ElementsAttr that is succinct and can represent any 1585 // potential shape/type, for use when eliding a large ElementsAttr. 1586 // 1587 // We choose to use an opaque ElementsAttr literal with conspicuous content to 1588 // hopefully alert readers to the fact that this has been elided. 1589 // 1590 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will 1591 // accept the string "elided". The first string must be a registered dialect 1592 // name and the latter must be a hex constant. 1593 static void printElidedElementsAttr(raw_ostream &os) { 1594 os << R"(opaque<"_", "0xDEADBEEF">)"; 1595 } 1596 1597 void AsmPrinter::Impl::printAttribute(Attribute attr, 1598 AttrTypeElision typeElision) { 1599 if (!attr) { 1600 os << "<<NULL ATTRIBUTE>>"; 1601 return; 1602 } 1603 1604 // Try to print an alias for this attribute. 1605 if (state && succeeded(state->getAliasState().getAlias(attr, os))) 1606 return; 1607 1608 if (!isa<BuiltinDialect>(attr.getDialect())) 1609 return printDialectAttribute(attr); 1610 1611 auto attrType = attr.getType(); 1612 if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) { 1613 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 1614 opaqueAttr.getAttrData()); 1615 } else if (attr.isa<UnitAttr>()) { 1616 os << "unit"; 1617 return; 1618 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) { 1619 os << '{'; 1620 interleaveComma(dictAttr.getValue(), 1621 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1622 os << '}'; 1623 1624 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) { 1625 if (attrType.isSignlessInteger(1)) { 1626 os << (intAttr.getValue().getBoolValue() ? "true" : "false"); 1627 1628 // Boolean integer attributes always elides the type. 1629 return; 1630 } 1631 1632 // Only print attributes as unsigned if they are explicitly unsigned or are 1633 // signless 1-bit values. Indexes, signed values, and multi-bit signless 1634 // values print as signed. 1635 bool isUnsigned = 1636 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); 1637 intAttr.getValue().print(os, !isUnsigned); 1638 1639 // IntegerAttr elides the type if I64. 1640 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) 1641 return; 1642 1643 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) { 1644 printFloatValue(floatAttr.getValue(), os); 1645 1646 // FloatAttr elides the type if F64. 1647 if (typeElision == AttrTypeElision::May && attrType.isF64()) 1648 return; 1649 1650 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) { 1651 os << '"'; 1652 printEscapedString(strAttr.getValue(), os); 1653 os << '"'; 1654 1655 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 1656 os << '['; 1657 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { 1658 printAttribute(attr, AttrTypeElision::May); 1659 }); 1660 os << ']'; 1661 1662 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) { 1663 os << "affine_map<"; 1664 affineMapAttr.getValue().print(os); 1665 os << '>'; 1666 1667 // AffineMap always elides the type. 1668 return; 1669 1670 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) { 1671 os << "affine_set<"; 1672 integerSetAttr.getValue().print(os); 1673 os << '>'; 1674 1675 // IntegerSet always elides the type. 1676 return; 1677 1678 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 1679 printType(typeAttr.getValue()); 1680 1681 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) { 1682 printSymbolReference(refAttr.getRootReference().getValue(), os); 1683 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { 1684 os << "::"; 1685 printSymbolReference(nestedRef.getValue(), os); 1686 } 1687 1688 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) { 1689 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { 1690 printElidedElementsAttr(os); 1691 } else { 1692 os << "opaque<" << opaqueAttr.getDialect() << ", \"0x" 1693 << llvm::toHex(opaqueAttr.getValue()) << "\">"; 1694 } 1695 1696 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) { 1697 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { 1698 printElidedElementsAttr(os); 1699 } else { 1700 os << "dense<"; 1701 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); 1702 os << '>'; 1703 } 1704 1705 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) { 1706 if (printerFlags.shouldElideElementsAttr(strEltAttr)) { 1707 printElidedElementsAttr(os); 1708 } else { 1709 os << "dense<"; 1710 printDenseStringElementsAttr(strEltAttr); 1711 os << '>'; 1712 } 1713 1714 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) { 1715 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || 1716 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { 1717 printElidedElementsAttr(os); 1718 } else { 1719 os << "sparse<"; 1720 DenseIntElementsAttr indices = sparseEltAttr.getIndices(); 1721 if (indices.getNumElements() != 0) { 1722 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); 1723 os << ", "; 1724 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); 1725 } 1726 os << '>'; 1727 } 1728 1729 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) { 1730 printLocation(locAttr); 1731 } 1732 // Don't print the type if we must elide it, or if it is a None type. 1733 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { 1734 os << " : "; 1735 printType(attrType); 1736 } 1737 } 1738 1739 /// Print the integer element of a DenseElementsAttr. 1740 static void printDenseIntElement(const APInt &value, raw_ostream &os, 1741 bool isSigned) { 1742 if (value.getBitWidth() == 1) 1743 os << (value.getBoolValue() ? "true" : "false"); 1744 else 1745 value.print(os, isSigned); 1746 } 1747 1748 static void 1749 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, 1750 function_ref<void(unsigned)> printEltFn) { 1751 // Special case for 0-d and splat tensors. 1752 if (isSplat) 1753 return printEltFn(0); 1754 1755 // Special case for degenerate tensors. 1756 auto numElements = type.getNumElements(); 1757 if (numElements == 0) 1758 return; 1759 1760 // We use a mixed-radix counter to iterate through the shape. When we bump a 1761 // non-least-significant digit, we emit a close bracket. When we next emit an 1762 // element we re-open all closed brackets. 1763 1764 // The mixed-radix counter, with radices in 'shape'. 1765 int64_t rank = type.getRank(); 1766 SmallVector<unsigned, 4> counter(rank, 0); 1767 // The number of brackets that have been opened and not closed. 1768 unsigned openBrackets = 0; 1769 1770 auto shape = type.getShape(); 1771 auto bumpCounter = [&] { 1772 // Bump the least significant digit. 1773 ++counter[rank - 1]; 1774 // Iterate backwards bubbling back the increment. 1775 for (unsigned i = rank - 1; i > 0; --i) 1776 if (counter[i] >= shape[i]) { 1777 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 1778 counter[i] = 0; 1779 ++counter[i - 1]; 1780 --openBrackets; 1781 os << ']'; 1782 } 1783 }; 1784 1785 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 1786 if (idx != 0) 1787 os << ", "; 1788 while (openBrackets++ < rank) 1789 os << '['; 1790 openBrackets = rank; 1791 printEltFn(idx); 1792 bumpCounter(); 1793 } 1794 while (openBrackets-- > 0) 1795 os << ']'; 1796 } 1797 1798 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, 1799 bool allowHex) { 1800 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) 1801 return printDenseStringElementsAttr(stringAttr); 1802 1803 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(), 1804 allowHex); 1805 } 1806 1807 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( 1808 DenseIntOrFPElementsAttr attr, bool allowHex) { 1809 auto type = attr.getType(); 1810 auto elementType = type.getElementType(); 1811 1812 // Check to see if we should format this attribute as a hex string. 1813 auto numElements = type.getNumElements(); 1814 if (!attr.isSplat() && allowHex && 1815 shouldPrintElementsAttrWithHex(numElements)) { 1816 ArrayRef<char> rawData = attr.getRawData(); 1817 if (llvm::support::endian::system_endianness() == 1818 llvm::support::endianness::big) { 1819 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE 1820 // machines. It is converted here to print in LE format. 1821 SmallVector<char, 64> outDataVec(rawData.size()); 1822 MutableArrayRef<char> convRawData(outDataVec); 1823 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1824 rawData, convRawData, type); 1825 os << '"' << "0x" 1826 << llvm::toHex(StringRef(convRawData.data(), convRawData.size())) 1827 << "\""; 1828 } else { 1829 os << '"' << "0x" 1830 << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\""; 1831 } 1832 1833 return; 1834 } 1835 1836 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1837 Type complexElementType = complexTy.getElementType(); 1838 // Note: The if and else below had a common lambda function which invoked 1839 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 1840 // and hence was replaced. 1841 if (complexElementType.isa<IntegerType>()) { 1842 bool isSigned = !complexElementType.isUnsignedInteger(); 1843 auto valueIt = attr.value_begin<std::complex<APInt>>(); 1844 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1845 auto complexValue = *(valueIt + index); 1846 os << "("; 1847 printDenseIntElement(complexValue.real(), os, isSigned); 1848 os << ","; 1849 printDenseIntElement(complexValue.imag(), os, isSigned); 1850 os << ")"; 1851 }); 1852 } else { 1853 auto valueIt = attr.value_begin<std::complex<APFloat>>(); 1854 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1855 auto complexValue = *(valueIt + index); 1856 os << "("; 1857 printFloatValue(complexValue.real(), os); 1858 os << ","; 1859 printFloatValue(complexValue.imag(), os); 1860 os << ")"; 1861 }); 1862 } 1863 } else if (elementType.isIntOrIndex()) { 1864 bool isSigned = !elementType.isUnsignedInteger(); 1865 auto valueIt = attr.value_begin<APInt>(); 1866 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1867 printDenseIntElement(*(valueIt + index), os, isSigned); 1868 }); 1869 } else { 1870 assert(elementType.isa<FloatType>() && "unexpected element type"); 1871 auto valueIt = attr.value_begin<APFloat>(); 1872 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1873 printFloatValue(*(valueIt + index), os); 1874 }); 1875 } 1876 } 1877 1878 void AsmPrinter::Impl::printDenseStringElementsAttr( 1879 DenseStringElementsAttr attr) { 1880 ArrayRef<StringRef> data = attr.getRawStringData(); 1881 auto printFn = [&](unsigned index) { 1882 os << "\""; 1883 printEscapedString(data[index], os); 1884 os << "\""; 1885 }; 1886 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); 1887 } 1888 1889 void AsmPrinter::Impl::printType(Type type) { 1890 if (!type) { 1891 os << "<<NULL TYPE>>"; 1892 return; 1893 } 1894 1895 // Try to print an alias for this type. 1896 if (state && succeeded(state->getAliasState().getAlias(type, os))) 1897 return; 1898 1899 TypeSwitch<Type>(type) 1900 .Case<OpaqueType>([&](OpaqueType opaqueTy) { 1901 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 1902 opaqueTy.getTypeData()); 1903 }) 1904 .Case<IndexType>([&](Type) { os << "index"; }) 1905 .Case<BFloat16Type>([&](Type) { os << "bf16"; }) 1906 .Case<Float16Type>([&](Type) { os << "f16"; }) 1907 .Case<Float32Type>([&](Type) { os << "f32"; }) 1908 .Case<Float64Type>([&](Type) { os << "f64"; }) 1909 .Case<Float80Type>([&](Type) { os << "f80"; }) 1910 .Case<Float128Type>([&](Type) { os << "f128"; }) 1911 .Case<IntegerType>([&](IntegerType integerTy) { 1912 if (integerTy.isSigned()) 1913 os << 's'; 1914 else if (integerTy.isUnsigned()) 1915 os << 'u'; 1916 os << 'i' << integerTy.getWidth(); 1917 }) 1918 .Case<FunctionType>([&](FunctionType funcTy) { 1919 os << '('; 1920 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); 1921 os << ") -> "; 1922 ArrayRef<Type> results = funcTy.getResults(); 1923 if (results.size() == 1 && !results[0].isa<FunctionType>()) { 1924 printType(results[0]); 1925 } else { 1926 os << '('; 1927 interleaveComma(results, [&](Type ty) { printType(ty); }); 1928 os << ')'; 1929 } 1930 }) 1931 .Case<VectorType>([&](VectorType vectorTy) { 1932 os << "vector<"; 1933 for (int64_t dim : vectorTy.getShape()) 1934 os << dim << 'x'; 1935 printType(vectorTy.getElementType()); 1936 os << '>'; 1937 }) 1938 .Case<RankedTensorType>([&](RankedTensorType tensorTy) { 1939 os << "tensor<"; 1940 for (int64_t dim : tensorTy.getShape()) { 1941 if (ShapedType::isDynamic(dim)) 1942 os << '?'; 1943 else 1944 os << dim; 1945 os << 'x'; 1946 } 1947 printType(tensorTy.getElementType()); 1948 // Only print the encoding attribute value if set. 1949 if (tensorTy.getEncoding()) { 1950 os << ", "; 1951 printAttribute(tensorTy.getEncoding()); 1952 } 1953 os << '>'; 1954 }) 1955 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { 1956 os << "tensor<*x"; 1957 printType(tensorTy.getElementType()); 1958 os << '>'; 1959 }) 1960 .Case<MemRefType>([&](MemRefType memrefTy) { 1961 os << "memref<"; 1962 for (int64_t dim : memrefTy.getShape()) { 1963 if (ShapedType::isDynamic(dim)) 1964 os << '?'; 1965 else 1966 os << dim; 1967 os << 'x'; 1968 } 1969 printType(memrefTy.getElementType()); 1970 if (!memrefTy.getLayout().isIdentity()) { 1971 os << ", "; 1972 printAttribute(memrefTy.getLayout(), AttrTypeElision::May); 1973 } 1974 // Only print the memory space if it is the non-default one. 1975 if (memrefTy.getMemorySpace()) { 1976 os << ", "; 1977 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 1978 } 1979 os << '>'; 1980 }) 1981 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { 1982 os << "memref<*x"; 1983 printType(memrefTy.getElementType()); 1984 // Only print the memory space if it is the non-default one. 1985 if (memrefTy.getMemorySpace()) { 1986 os << ", "; 1987 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 1988 } 1989 os << '>'; 1990 }) 1991 .Case<ComplexType>([&](ComplexType complexTy) { 1992 os << "complex<"; 1993 printType(complexTy.getElementType()); 1994 os << '>'; 1995 }) 1996 .Case<TupleType>([&](TupleType tupleTy) { 1997 os << "tuple<"; 1998 interleaveComma(tupleTy.getTypes(), 1999 [&](Type type) { printType(type); }); 2000 os << '>'; 2001 }) 2002 .Case<NoneType>([&](Type) { os << "none"; }) 2003 .Default([&](Type type) { return printDialectType(type); }); 2004 } 2005 2006 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2007 ArrayRef<StringRef> elidedAttrs, 2008 bool withKeyword) { 2009 // If there are no attributes, then there is nothing to be done. 2010 if (attrs.empty()) 2011 return; 2012 2013 // Functor used to print a filtered attribute list. 2014 auto printFilteredAttributesFn = [&](auto filteredAttrs) { 2015 // Print the 'attributes' keyword if necessary. 2016 if (withKeyword) 2017 os << " attributes"; 2018 2019 // Otherwise, print them all out in braces. 2020 os << " {"; 2021 interleaveComma(filteredAttrs, 2022 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 2023 os << '}'; 2024 }; 2025 2026 // If no attributes are elided, we can directly print with no filtering. 2027 if (elidedAttrs.empty()) 2028 return printFilteredAttributesFn(attrs); 2029 2030 // Otherwise, filter out any attributes that shouldn't be included. 2031 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 2032 elidedAttrs.end()); 2033 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 2034 return !elidedAttrsSet.contains(attr.getName().strref()); 2035 }); 2036 if (!filteredAttrs.empty()) 2037 printFilteredAttributesFn(filteredAttrs); 2038 } 2039 2040 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { 2041 // Print the name without quotes if possible. 2042 ::printKeywordOrString(attr.getName().strref(), os); 2043 2044 // Pretty printing elides the attribute value for unit attributes. 2045 if (attr.getValue().isa<UnitAttr>()) 2046 return; 2047 2048 os << " = "; 2049 printAttribute(attr.getValue()); 2050 } 2051 2052 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { 2053 auto &dialect = attr.getDialect(); 2054 2055 // Ask the dialect to serialize the attribute to a string. 2056 std::string attrName; 2057 { 2058 llvm::raw_string_ostream attrNameStr(attrName); 2059 Impl subPrinter(attrNameStr, printerFlags, state); 2060 DialectAsmPrinter printer(subPrinter); 2061 dialect.printAttribute(attr, printer); 2062 } 2063 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 2064 } 2065 2066 void AsmPrinter::Impl::printDialectType(Type type) { 2067 auto &dialect = type.getDialect(); 2068 2069 // Ask the dialect to serialize the type to a string. 2070 std::string typeName; 2071 { 2072 llvm::raw_string_ostream typeNameStr(typeName); 2073 Impl subPrinter(typeNameStr, printerFlags, state); 2074 DialectAsmPrinter printer(subPrinter); 2075 dialect.printType(type, printer); 2076 } 2077 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 2078 } 2079 2080 //===--------------------------------------------------------------------===// 2081 // AsmPrinter 2082 //===--------------------------------------------------------------------===// 2083 2084 AsmPrinter::~AsmPrinter() {} 2085 2086 raw_ostream &AsmPrinter::getStream() const { 2087 assert(impl && "expected AsmPrinter::getStream to be overriden"); 2088 return impl->getStream(); 2089 } 2090 2091 /// Print the given floating point value in a stablized form. 2092 void AsmPrinter::printFloat(const APFloat &value) { 2093 assert(impl && "expected AsmPrinter::printFloat to be overriden"); 2094 printFloatValue(value, impl->getStream()); 2095 } 2096 2097 void AsmPrinter::printType(Type type) { 2098 assert(impl && "expected AsmPrinter::printType to be overriden"); 2099 impl->printType(type); 2100 } 2101 2102 void AsmPrinter::printAttribute(Attribute attr) { 2103 assert(impl && "expected AsmPrinter::printAttribute to be overriden"); 2104 impl->printAttribute(attr); 2105 } 2106 2107 void AsmPrinter::printAttributeWithoutType(Attribute attr) { 2108 assert(impl && 2109 "expected AsmPrinter::printAttributeWithoutType to be overriden"); 2110 impl->printAttribute(attr, Impl::AttrTypeElision::Must); 2111 } 2112 2113 void AsmPrinter::printKeywordOrString(StringRef keyword) { 2114 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); 2115 ::printKeywordOrString(keyword, impl->getStream()); 2116 } 2117 2118 void AsmPrinter::printSymbolName(StringRef symbolRef) { 2119 assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); 2120 ::printSymbolReference(symbolRef, impl->getStream()); 2121 } 2122 2123 //===----------------------------------------------------------------------===// 2124 // Affine expressions and maps 2125 //===----------------------------------------------------------------------===// 2126 2127 void AsmPrinter::Impl::printAffineExpr( 2128 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { 2129 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 2130 } 2131 2132 void AsmPrinter::Impl::printAffineExprInternal( 2133 AffineExpr expr, BindingStrength enclosingTightness, 2134 function_ref<void(unsigned, bool)> printValueName) { 2135 const char *binopSpelling = nullptr; 2136 switch (expr.getKind()) { 2137 case AffineExprKind::SymbolId: { 2138 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); 2139 if (printValueName) 2140 printValueName(pos, /*isSymbol=*/true); 2141 else 2142 os << 's' << pos; 2143 return; 2144 } 2145 case AffineExprKind::DimId: { 2146 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 2147 if (printValueName) 2148 printValueName(pos, /*isSymbol=*/false); 2149 else 2150 os << 'd' << pos; 2151 return; 2152 } 2153 case AffineExprKind::Constant: 2154 os << expr.cast<AffineConstantExpr>().getValue(); 2155 return; 2156 case AffineExprKind::Add: 2157 binopSpelling = " + "; 2158 break; 2159 case AffineExprKind::Mul: 2160 binopSpelling = " * "; 2161 break; 2162 case AffineExprKind::FloorDiv: 2163 binopSpelling = " floordiv "; 2164 break; 2165 case AffineExprKind::CeilDiv: 2166 binopSpelling = " ceildiv "; 2167 break; 2168 case AffineExprKind::Mod: 2169 binopSpelling = " mod "; 2170 break; 2171 } 2172 2173 auto binOp = expr.cast<AffineBinaryOpExpr>(); 2174 AffineExpr lhsExpr = binOp.getLHS(); 2175 AffineExpr rhsExpr = binOp.getRHS(); 2176 2177 // Handle tightly binding binary operators. 2178 if (binOp.getKind() != AffineExprKind::Add) { 2179 if (enclosingTightness == BindingStrength::Strong) 2180 os << '('; 2181 2182 // Pretty print multiplication with -1. 2183 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); 2184 if (rhsConst && binOp.getKind() == AffineExprKind::Mul && 2185 rhsConst.getValue() == -1) { 2186 os << "-"; 2187 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2188 if (enclosingTightness == BindingStrength::Strong) 2189 os << ')'; 2190 return; 2191 } 2192 2193 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2194 2195 os << binopSpelling; 2196 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 2197 2198 if (enclosingTightness == BindingStrength::Strong) 2199 os << ')'; 2200 return; 2201 } 2202 2203 // Print out special "pretty" forms for add. 2204 if (enclosingTightness == BindingStrength::Strong) 2205 os << '('; 2206 2207 // Pretty print addition to a product that has a negative operand as a 2208 // subtraction. 2209 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { 2210 if (rhs.getKind() == AffineExprKind::Mul) { 2211 AffineExpr rrhsExpr = rhs.getRHS(); 2212 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { 2213 if (rrhs.getValue() == -1) { 2214 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2215 printValueName); 2216 os << " - "; 2217 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 2218 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2219 printValueName); 2220 } else { 2221 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 2222 printValueName); 2223 } 2224 2225 if (enclosingTightness == BindingStrength::Strong) 2226 os << ')'; 2227 return; 2228 } 2229 2230 if (rrhs.getValue() < -1) { 2231 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2232 printValueName); 2233 os << " - "; 2234 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2235 printValueName); 2236 os << " * " << -rrhs.getValue(); 2237 if (enclosingTightness == BindingStrength::Strong) 2238 os << ')'; 2239 return; 2240 } 2241 } 2242 } 2243 } 2244 2245 // Pretty print addition to a negative number as a subtraction. 2246 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { 2247 if (rhsConst.getValue() < 0) { 2248 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2249 os << " - " << -rhsConst.getValue(); 2250 if (enclosingTightness == BindingStrength::Strong) 2251 os << ')'; 2252 return; 2253 } 2254 } 2255 2256 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2257 2258 os << " + "; 2259 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 2260 2261 if (enclosingTightness == BindingStrength::Strong) 2262 os << ')'; 2263 } 2264 2265 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { 2266 printAffineExprInternal(expr, BindingStrength::Weak); 2267 isEq ? os << " == 0" : os << " >= 0"; 2268 } 2269 2270 void AsmPrinter::Impl::printAffineMap(AffineMap map) { 2271 // Dimension identifiers. 2272 os << '('; 2273 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 2274 os << 'd' << i << ", "; 2275 if (map.getNumDims() >= 1) 2276 os << 'd' << map.getNumDims() - 1; 2277 os << ')'; 2278 2279 // Symbolic identifiers. 2280 if (map.getNumSymbols() != 0) { 2281 os << '['; 2282 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 2283 os << 's' << i << ", "; 2284 if (map.getNumSymbols() >= 1) 2285 os << 's' << map.getNumSymbols() - 1; 2286 os << ']'; 2287 } 2288 2289 // Result affine expressions. 2290 os << " -> ("; 2291 interleaveComma(map.getResults(), 2292 [&](AffineExpr expr) { printAffineExpr(expr); }); 2293 os << ')'; 2294 } 2295 2296 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { 2297 // Dimension identifiers. 2298 os << '('; 2299 for (unsigned i = 1; i < set.getNumDims(); ++i) 2300 os << 'd' << i - 1 << ", "; 2301 if (set.getNumDims() >= 1) 2302 os << 'd' << set.getNumDims() - 1; 2303 os << ')'; 2304 2305 // Symbolic identifiers. 2306 if (set.getNumSymbols() != 0) { 2307 os << '['; 2308 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 2309 os << 's' << i << ", "; 2310 if (set.getNumSymbols() >= 1) 2311 os << 's' << set.getNumSymbols() - 1; 2312 os << ']'; 2313 } 2314 2315 // Print constraints. 2316 os << " : ("; 2317 int numConstraints = set.getNumConstraints(); 2318 for (int i = 1; i < numConstraints; ++i) { 2319 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 2320 os << ", "; 2321 } 2322 if (numConstraints >= 1) 2323 printAffineConstraint(set.getConstraint(numConstraints - 1), 2324 set.isEq(numConstraints - 1)); 2325 os << ')'; 2326 } 2327 2328 //===----------------------------------------------------------------------===// 2329 // OperationPrinter 2330 //===----------------------------------------------------------------------===// 2331 2332 namespace { 2333 /// This class contains the logic for printing operations, regions, and blocks. 2334 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { 2335 public: 2336 using Impl = AsmPrinter::Impl; 2337 using Impl::printType; 2338 2339 explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, 2340 AsmStateImpl &state) 2341 : Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {} 2342 2343 /// Print the given top-level operation. 2344 void printTopLevelOperation(Operation *op); 2345 2346 /// Print the given operation with its indent and location. 2347 void print(Operation *op); 2348 /// Print the bare location, not including indentation/location/etc. 2349 void printOperation(Operation *op); 2350 /// Print the given operation in the generic form. 2351 void printGenericOp(Operation *op, bool printOpName) override; 2352 2353 /// Print the name of the given block. 2354 void printBlockName(Block *block); 2355 2356 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 2357 /// block are not printed. If 'printBlockTerminator' is false, the terminator 2358 /// operation of the block is not printed. 2359 void print(Block *block, bool printBlockArgs = true, 2360 bool printBlockTerminator = true); 2361 2362 /// Print the ID of the given value, optionally with its result number. 2363 void printValueID(Value value, bool printResultNo = true, 2364 raw_ostream *streamOverride = nullptr) const; 2365 2366 //===--------------------------------------------------------------------===// 2367 // OpAsmPrinter methods 2368 //===--------------------------------------------------------------------===// 2369 2370 /// Print a newline and indent the printer to the start of the current 2371 /// operation. 2372 void printNewline() override { 2373 os << newLine; 2374 os.indent(currentIndent); 2375 } 2376 2377 /// Print a block argument in the usual format of: 2378 /// %ssaName : type {attr1=42} loc("here") 2379 /// where location printing is controlled by the standard internal option. 2380 /// You may pass omitType=true to not print a type, and pass an empty 2381 /// attribute list if you don't care for attributes. 2382 void printRegionArgument(BlockArgument arg, 2383 ArrayRef<NamedAttribute> argAttrs = {}, 2384 bool omitType = false) override; 2385 2386 /// Print the ID for the given value. 2387 void printOperand(Value value) override { printValueID(value); } 2388 void printOperand(Value value, raw_ostream &os) override { 2389 printValueID(value, /*printResultNo=*/true, &os); 2390 } 2391 2392 /// Print an optional attribute dictionary with a given set of elided values. 2393 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2394 ArrayRef<StringRef> elidedAttrs = {}) override { 2395 Impl::printOptionalAttrDict(attrs, elidedAttrs); 2396 } 2397 void printOptionalAttrDictWithKeyword( 2398 ArrayRef<NamedAttribute> attrs, 2399 ArrayRef<StringRef> elidedAttrs = {}) override { 2400 Impl::printOptionalAttrDict(attrs, elidedAttrs, 2401 /*withKeyword=*/true); 2402 } 2403 2404 /// Print the given successor. 2405 void printSuccessor(Block *successor) override; 2406 2407 /// Print an operation successor with the operands used for the block 2408 /// arguments. 2409 void printSuccessorAndUseList(Block *successor, 2410 ValueRange succOperands) override; 2411 2412 /// Print the given region. 2413 void printRegion(Region ®ion, bool printEntryBlockArgs, 2414 bool printBlockTerminators, bool printEmptyBlock) override; 2415 2416 /// Renumber the arguments for the specified region to the same names as the 2417 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 2418 /// operations. If any entry in namesToUse is null, the corresponding 2419 /// argument name is left alone. 2420 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { 2421 state->getSSANameState().shadowRegionArgs(region, namesToUse); 2422 } 2423 2424 /// Print the given affine map with the symbol and dimension operands printed 2425 /// inline with the map. 2426 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2427 ValueRange operands) override; 2428 2429 /// Print the given affine expression with the symbol and dimension operands 2430 /// printed inline with the expression. 2431 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, 2432 ValueRange symOperands) override; 2433 2434 private: 2435 // Contains the stack of default dialects to use when printing regions. 2436 // A new dialect is pushed to the stack before parsing regions nested under an 2437 // operation implementing `OpAsmOpInterface`, and popped when done. At the 2438 // top-level we start with "builtin" as the default, so that the top-level 2439 // `module` operation prints as-is. 2440 SmallVector<StringRef> defaultDialectStack{"builtin"}; 2441 2442 /// The number of spaces used for indenting nested operations. 2443 const static unsigned indentWidth = 2; 2444 2445 // This is the current indentation level for nested structures. 2446 unsigned currentIndent = 0; 2447 }; 2448 } // end anonymous namespace 2449 2450 void OperationPrinter::printTopLevelOperation(Operation *op) { 2451 // Output the aliases at the top level that can't be deferred. 2452 state->getAliasState().printNonDeferredAliases(os, newLine); 2453 2454 // Print the module. 2455 print(op); 2456 os << newLine; 2457 2458 // Output the aliases at the top level that can be deferred. 2459 state->getAliasState().printDeferredAliases(os, newLine); 2460 } 2461 2462 /// Print a block argument in the usual format of: 2463 /// %ssaName : type {attr1=42} loc("here") 2464 /// where location printing is controlled by the standard internal option. 2465 /// You may pass omitType=true to not print a type, and pass an empty 2466 /// attribute list if you don't care for attributes. 2467 void OperationPrinter::printRegionArgument(BlockArgument arg, 2468 ArrayRef<NamedAttribute> argAttrs, 2469 bool omitType) { 2470 printOperand(arg); 2471 if (!omitType) { 2472 os << ": "; 2473 printType(arg.getType()); 2474 } 2475 printOptionalAttrDict(argAttrs); 2476 // TODO: We should allow location aliases on block arguments. 2477 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2478 } 2479 2480 void OperationPrinter::print(Operation *op) { 2481 // Track the location of this operation. 2482 state->registerOperationLocation(op, newLine.curLine, currentIndent); 2483 2484 os.indent(currentIndent); 2485 printOperation(op); 2486 printTrailingLocation(op->getLoc()); 2487 } 2488 2489 void OperationPrinter::printOperation(Operation *op) { 2490 if (size_t numResults = op->getNumResults()) { 2491 auto printResultGroup = [&](size_t resultNo, size_t resultCount) { 2492 printValueID(op->getResult(resultNo), /*printResultNo=*/false); 2493 if (resultCount > 1) 2494 os << ':' << resultCount; 2495 }; 2496 2497 // Check to see if this operation has multiple result groups. 2498 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); 2499 if (!resultGroups.empty()) { 2500 // Interleave the groups excluding the last one, this one will be handled 2501 // separately. 2502 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { 2503 printResultGroup(resultGroups[i], 2504 resultGroups[i + 1] - resultGroups[i]); 2505 }); 2506 os << ", "; 2507 printResultGroup(resultGroups.back(), numResults - resultGroups.back()); 2508 2509 } else { 2510 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); 2511 } 2512 2513 os << " = "; 2514 } 2515 2516 // If requested, always print the generic form. 2517 if (!printerFlags.shouldPrintGenericOpForm()) { 2518 // Check to see if this is a known operation. If so, use the registered 2519 // custom printer hook. 2520 if (auto opInfo = op->getRegisteredInfo()) { 2521 opInfo->printAssembly(op, *this, defaultDialectStack.back()); 2522 return; 2523 } 2524 // Otherwise try to dispatch to the dialect, if available. 2525 if (Dialect *dialect = op->getDialect()) { 2526 if (auto opPrinter = dialect->getOperationPrinter(op)) { 2527 // Print the op name first. 2528 StringRef name = op->getName().getStringRef(); 2529 name.consume_front((defaultDialectStack.back() + ".").str()); 2530 printEscapedString(name, os); 2531 // Print the rest of the op now. 2532 opPrinter(op, *this); 2533 return; 2534 } 2535 } 2536 } 2537 2538 // Otherwise print with the generic assembly form. 2539 printGenericOp(op, /*printOpName=*/true); 2540 } 2541 2542 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { 2543 if (printOpName) { 2544 os << '"'; 2545 printEscapedString(op->getName().getStringRef(), os); 2546 os << '"'; 2547 } 2548 os << '('; 2549 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); 2550 os << ')'; 2551 2552 // For terminators, print the list of successors and their operands. 2553 if (op->getNumSuccessors() != 0) { 2554 os << '['; 2555 interleaveComma(op->getSuccessors(), 2556 [&](Block *successor) { printBlockName(successor); }); 2557 os << ']'; 2558 } 2559 2560 // Print regions. 2561 if (op->getNumRegions() != 0) { 2562 os << " ("; 2563 interleaveComma(op->getRegions(), [&](Region ®ion) { 2564 printRegion(region, /*printEntryBlockArgs=*/true, 2565 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); 2566 }); 2567 os << ')'; 2568 } 2569 2570 auto attrs = op->getAttrs(); 2571 printOptionalAttrDict(attrs); 2572 2573 // Print the type signature of the operation. 2574 os << " : "; 2575 printFunctionalType(op); 2576 } 2577 2578 void OperationPrinter::printBlockName(Block *block) { 2579 auto id = state->getSSANameState().getBlockID(block); 2580 if (id != SSANameState::NameSentinel) 2581 os << "^bb" << id; 2582 else 2583 os << "^INVALIDBLOCK"; 2584 } 2585 2586 void OperationPrinter::print(Block *block, bool printBlockArgs, 2587 bool printBlockTerminator) { 2588 // Print the block label and argument list if requested. 2589 if (printBlockArgs) { 2590 os.indent(currentIndent); 2591 printBlockName(block); 2592 2593 // Print the argument list if non-empty. 2594 if (!block->args_empty()) { 2595 os << '('; 2596 interleaveComma(block->getArguments(), [&](BlockArgument arg) { 2597 printValueID(arg); 2598 os << ": "; 2599 printType(arg.getType()); 2600 // TODO: We should allow location aliases on block arguments. 2601 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2602 }); 2603 os << ')'; 2604 } 2605 os << ':'; 2606 2607 // Print out some context information about the predecessors of this block. 2608 if (!block->getParent()) { 2609 os << " // block is not in a region!"; 2610 } else if (block->hasNoPredecessors()) { 2611 os << " // no predecessors"; 2612 } else if (auto *pred = block->getSinglePredecessor()) { 2613 os << " // pred: "; 2614 printBlockName(pred); 2615 } else { 2616 // We want to print the predecessors in increasing numeric order, not in 2617 // whatever order the use-list is in, so gather and sort them. 2618 SmallVector<std::pair<unsigned, Block *>, 4> predIDs; 2619 for (auto *pred : block->getPredecessors()) 2620 predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); 2621 llvm::array_pod_sort(predIDs.begin(), predIDs.end()); 2622 2623 os << " // " << predIDs.size() << " preds: "; 2624 2625 interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) { 2626 printBlockName(pred.second); 2627 }); 2628 } 2629 os << newLine; 2630 } 2631 2632 currentIndent += indentWidth; 2633 bool hasTerminator = 2634 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 2635 auto range = llvm::make_range( 2636 block->begin(), 2637 std::prev(block->end(), 2638 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 2639 for (auto &op : range) { 2640 print(&op); 2641 os << newLine; 2642 } 2643 currentIndent -= indentWidth; 2644 } 2645 2646 void OperationPrinter::printValueID(Value value, bool printResultNo, 2647 raw_ostream *streamOverride) const { 2648 state->getSSANameState().printValueID(value, printResultNo, 2649 streamOverride ? *streamOverride : os); 2650 } 2651 2652 void OperationPrinter::printSuccessor(Block *successor) { 2653 printBlockName(successor); 2654 } 2655 2656 void OperationPrinter::printSuccessorAndUseList(Block *successor, 2657 ValueRange succOperands) { 2658 printBlockName(successor); 2659 if (succOperands.empty()) 2660 return; 2661 2662 os << '('; 2663 interleaveComma(succOperands, 2664 [this](Value operand) { printValueID(operand); }); 2665 os << " : "; 2666 interleaveComma(succOperands, 2667 [this](Value operand) { printType(operand.getType()); }); 2668 os << ')'; 2669 } 2670 2671 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, 2672 bool printBlockTerminators, 2673 bool printEmptyBlock) { 2674 os << " {" << newLine; 2675 if (!region.empty()) { 2676 auto restoreDefaultDialect = 2677 llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); }); 2678 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) 2679 defaultDialectStack.push_back(iface.getDefaultDialect()); 2680 else 2681 defaultDialectStack.push_back(""); 2682 2683 auto *entryBlock = ®ion.front(); 2684 // Force printing the block header if printEmptyBlock is set and the block 2685 // is empty or if printEntryBlockArgs is set and there are arguments to 2686 // print. 2687 bool shouldAlwaysPrintBlockHeader = 2688 (printEmptyBlock && entryBlock->empty()) || 2689 (printEntryBlockArgs && entryBlock->getNumArguments() != 0); 2690 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); 2691 for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) 2692 print(&b); 2693 } 2694 os.indent(currentIndent) << "}"; 2695 } 2696 2697 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2698 ValueRange operands) { 2699 AffineMap map = mapAttr.getValue(); 2700 unsigned numDims = map.getNumDims(); 2701 auto printValueName = [&](unsigned pos, bool isSymbol) { 2702 unsigned index = isSymbol ? numDims + pos : pos; 2703 assert(index < operands.size()); 2704 if (isSymbol) 2705 os << "symbol("; 2706 printValueID(operands[index]); 2707 if (isSymbol) 2708 os << ')'; 2709 }; 2710 2711 interleaveComma(map.getResults(), [&](AffineExpr expr) { 2712 printAffineExpr(expr, printValueName); 2713 }); 2714 } 2715 2716 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, 2717 ValueRange dimOperands, 2718 ValueRange symOperands) { 2719 auto printValueName = [&](unsigned pos, bool isSymbol) { 2720 if (!isSymbol) 2721 return printValueID(dimOperands[pos]); 2722 os << "symbol("; 2723 printValueID(symOperands[pos]); 2724 os << ')'; 2725 }; 2726 printAffineExpr(expr, printValueName); 2727 } 2728 2729 //===----------------------------------------------------------------------===// 2730 // print and dump methods 2731 //===----------------------------------------------------------------------===// 2732 2733 void Attribute::print(raw_ostream &os) const { 2734 AsmPrinter::Impl(os).printAttribute(*this); 2735 } 2736 2737 void Attribute::dump() const { 2738 print(llvm::errs()); 2739 llvm::errs() << "\n"; 2740 } 2741 2742 void Type::print(raw_ostream &os) const { 2743 AsmPrinter::Impl(os).printType(*this); 2744 } 2745 2746 void Type::dump() const { print(llvm::errs()); } 2747 2748 void AffineMap::dump() const { 2749 print(llvm::errs()); 2750 llvm::errs() << "\n"; 2751 } 2752 2753 void IntegerSet::dump() const { 2754 print(llvm::errs()); 2755 llvm::errs() << "\n"; 2756 } 2757 2758 void AffineExpr::print(raw_ostream &os) const { 2759 if (!expr) { 2760 os << "<<NULL AFFINE EXPR>>"; 2761 return; 2762 } 2763 AsmPrinter::Impl(os).printAffineExpr(*this); 2764 } 2765 2766 void AffineExpr::dump() const { 2767 print(llvm::errs()); 2768 llvm::errs() << "\n"; 2769 } 2770 2771 void AffineMap::print(raw_ostream &os) const { 2772 if (!map) { 2773 os << "<<NULL AFFINE MAP>>"; 2774 return; 2775 } 2776 AsmPrinter::Impl(os).printAffineMap(*this); 2777 } 2778 2779 void IntegerSet::print(raw_ostream &os) const { 2780 AsmPrinter::Impl(os).printIntegerSet(*this); 2781 } 2782 2783 void Value::print(raw_ostream &os) { 2784 if (auto *op = getDefiningOp()) 2785 return op->print(os); 2786 // TODO: Improve BlockArgument print'ing. 2787 BlockArgument arg = this->cast<BlockArgument>(); 2788 os << "<block argument> of type '" << arg.getType() 2789 << "' at index: " << arg.getArgNumber(); 2790 } 2791 void Value::print(raw_ostream &os, AsmState &state) { 2792 if (auto *op = getDefiningOp()) 2793 return op->print(os, state); 2794 2795 // TODO: Improve BlockArgument print'ing. 2796 BlockArgument arg = this->cast<BlockArgument>(); 2797 os << "<block argument> of type '" << arg.getType() 2798 << "' at index: " << arg.getArgNumber(); 2799 } 2800 2801 void Value::dump() { 2802 print(llvm::errs()); 2803 llvm::errs() << "\n"; 2804 } 2805 2806 void Value::printAsOperand(raw_ostream &os, AsmState &state) { 2807 // TODO: This doesn't necessarily capture all potential cases. 2808 // Currently, region arguments can be shadowed when printing the main 2809 // operation. If the IR hasn't been printed, this will produce the old SSA 2810 // name and not the shadowed name. 2811 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, 2812 os); 2813 } 2814 2815 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { 2816 // If this is a top level operation, we also print aliases. 2817 if (!getParent() && !printerFlags.shouldUseLocalScope()) { 2818 AsmState state(this, printerFlags); 2819 state.getImpl().initializeAliases(this); 2820 print(os, state, printerFlags); 2821 return; 2822 } 2823 2824 // Find the operation to number from based upon the provided flags. 2825 Operation *op = this; 2826 bool shouldUseLocalScope = printerFlags.shouldUseLocalScope(); 2827 do { 2828 // If we are printing local scope, stop at the first operation that is 2829 // isolated from above. 2830 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 2831 break; 2832 2833 // Otherwise, traverse up to the next parent. 2834 Operation *parentOp = op->getParentOp(); 2835 if (!parentOp) 2836 break; 2837 op = parentOp; 2838 } while (true); 2839 2840 AsmState state(op, printerFlags); 2841 print(os, state, printerFlags); 2842 } 2843 void Operation::print(raw_ostream &os, AsmState &state, 2844 const OpPrintingFlags &flags) { 2845 OperationPrinter printer(os, flags, state.getImpl()); 2846 if (!getParent() && !flags.shouldUseLocalScope()) 2847 printer.printTopLevelOperation(this); 2848 else 2849 printer.print(this); 2850 } 2851 2852 void Operation::dump() { 2853 print(llvm::errs(), OpPrintingFlags().useLocalScope()); 2854 llvm::errs() << "\n"; 2855 } 2856 2857 void Block::print(raw_ostream &os) { 2858 Operation *parentOp = getParentOp(); 2859 if (!parentOp) { 2860 os << "<<UNLINKED BLOCK>>\n"; 2861 return; 2862 } 2863 // Get the top-level op. 2864 while (auto *nextOp = parentOp->getParentOp()) 2865 parentOp = nextOp; 2866 2867 AsmState state(parentOp); 2868 print(os, state); 2869 } 2870 void Block::print(raw_ostream &os, AsmState &state) { 2871 OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); 2872 } 2873 2874 void Block::dump() { print(llvm::errs()); } 2875 2876 /// Print out the name of the block without printing its body. 2877 void Block::printAsOperand(raw_ostream &os, bool printType) { 2878 Operation *parentOp = getParentOp(); 2879 if (!parentOp) { 2880 os << "<<UNLINKED BLOCK>>\n"; 2881 return; 2882 } 2883 AsmState state(parentOp); 2884 printAsOperand(os, state); 2885 } 2886 void Block::printAsOperand(raw_ostream &os, AsmState &state) { 2887 OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); 2888 printer.printBlockName(this); 2889 } 2890