1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the MLIR AsmPrinter class, which is used to implement 10 // the various print() methods on the core IR objects. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/AffineExpr.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/AsmState.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/IntegerSet.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/Operation.h" 27 #include "mlir/IR/SubElementInterfaces.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/DenseMap.h" 30 #include "llvm/ADT/MapVector.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/ScopeExit.h" 33 #include "llvm/ADT/ScopedHashTable.h" 34 #include "llvm/ADT/SetVector.h" 35 #include "llvm/ADT/SmallString.h" 36 #include "llvm/ADT/StringExtras.h" 37 #include "llvm/ADT/StringSet.h" 38 #include "llvm/ADT/TypeSwitch.h" 39 #include "llvm/Support/CommandLine.h" 40 #include "llvm/Support/Endian.h" 41 #include "llvm/Support/Regex.h" 42 #include "llvm/Support/SaveAndRestore.h" 43 44 #include <tuple> 45 46 using namespace mlir; 47 using namespace mlir::detail; 48 49 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 50 51 void OperationName::dump() const { print(llvm::errs()); } 52 53 //===--------------------------------------------------------------------===// 54 // AsmParser 55 //===--------------------------------------------------------------------===// 56 57 AsmParser::~AsmParser() = default; 58 DialectAsmParser::~DialectAsmParser() = default; 59 OpAsmParser::~OpAsmParser() = default; 60 61 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } 62 63 //===----------------------------------------------------------------------===// 64 // DialectAsmPrinter 65 //===----------------------------------------------------------------------===// 66 67 DialectAsmPrinter::~DialectAsmPrinter() = default; 68 69 //===----------------------------------------------------------------------===// 70 // OpAsmPrinter 71 //===----------------------------------------------------------------------===// 72 73 OpAsmPrinter::~OpAsmPrinter() = default; 74 75 void OpAsmPrinter::printFunctionalType(Operation *op) { 76 auto &os = getStream(); 77 os << '('; 78 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { 79 // Print the types of null values as <<NULL TYPE>>. 80 *this << (operand ? operand.getType() : Type()); 81 }); 82 os << ") -> "; 83 84 // Print the result list. We don't parenthesize single result types unless 85 // it is a function (avoiding a grammar ambiguity). 86 bool wrapped = op->getNumResults() != 1; 87 if (!wrapped && op->getResult(0).getType() && 88 op->getResult(0).getType().isa<FunctionType>()) 89 wrapped = true; 90 91 if (wrapped) 92 os << '('; 93 94 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { 95 // Print the types of null values as <<NULL TYPE>>. 96 *this << (result ? result.getType() : Type()); 97 }); 98 99 if (wrapped) 100 os << ')'; 101 } 102 103 //===----------------------------------------------------------------------===// 104 // Operation OpAsm interface. 105 //===----------------------------------------------------------------------===// 106 107 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 108 #include "mlir/IR/OpAsmInterface.cpp.inc" 109 110 //===----------------------------------------------------------------------===// 111 // OpPrintingFlags 112 //===----------------------------------------------------------------------===// 113 114 namespace { 115 /// This struct contains command line options that can be used to initialize 116 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need 117 /// for global command line options. 118 struct AsmPrinterOptions { 119 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ 120 "mlir-print-elementsattrs-with-hex-if-larger", 121 llvm::cl::desc( 122 "Print DenseElementsAttrs with a hex string that have " 123 "more elements than the given upper limit (use -1 to disable)")}; 124 125 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ 126 "mlir-elide-elementsattrs-if-larger", 127 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " 128 "more elements than the given upper limit")}; 129 130 llvm::cl::opt<bool> printDebugInfoOpt{ 131 "mlir-print-debuginfo", llvm::cl::init(false), 132 llvm::cl::desc("Print debug info in MLIR output")}; 133 134 llvm::cl::opt<bool> printPrettyDebugInfoOpt{ 135 "mlir-pretty-debuginfo", llvm::cl::init(false), 136 llvm::cl::desc("Print pretty debug info in MLIR output")}; 137 138 // Use the generic op output form in the operation printer even if the custom 139 // form is defined. 140 llvm::cl::opt<bool> printGenericOpFormOpt{ 141 "mlir-print-op-generic", llvm::cl::init(false), 142 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; 143 144 llvm::cl::opt<bool> printLocalScopeOpt{ 145 "mlir-print-local-scope", llvm::cl::init(false), 146 llvm::cl::desc("Print with local scope and inline information (eliding " 147 "aliases for attributes, types, and locations")}; 148 }; 149 } // namespace 150 151 static llvm::ManagedStatic<AsmPrinterOptions> clOptions; 152 153 /// Register a set of useful command-line options that can be used to configure 154 /// various flags within the AsmPrinter. 155 void mlir::registerAsmPrinterCLOptions() { 156 // Make sure that the options struct has been initialized. 157 *clOptions; 158 } 159 160 /// Initialize the printing flags with default supplied by the cl::opts above. 161 OpPrintingFlags::OpPrintingFlags() 162 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), 163 printGenericOpFormFlag(false), printLocalScope(false) { 164 // Initialize based upon command line options, if they are available. 165 if (!clOptions.isConstructed()) 166 return; 167 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) 168 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; 169 printDebugInfoFlag = clOptions->printDebugInfoOpt; 170 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; 171 printGenericOpFormFlag = clOptions->printGenericOpFormOpt; 172 printLocalScope = clOptions->printLocalScopeOpt; 173 } 174 175 /// Enable the elision of large elements attributes, by printing a '...' 176 /// instead of the element data, when the number of elements is greater than 177 /// `largeElementLimit`. Note: The IR generated with this option is not 178 /// parsable. 179 OpPrintingFlags & 180 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { 181 elementsAttrElementLimit = largeElementLimit; 182 return *this; 183 } 184 185 /// Enable printing of debug information. If 'prettyForm' is set to true, 186 /// debug information is printed in a more readable 'pretty' form. 187 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { 188 printDebugInfoFlag = true; 189 printDebugInfoPrettyFormFlag = prettyForm; 190 return *this; 191 } 192 193 /// Always print operations in the generic form. 194 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { 195 printGenericOpFormFlag = true; 196 return *this; 197 } 198 199 /// Use local scope when printing the operation. This allows for using the 200 /// printer in a more localized and thread-safe setting, but may not necessarily 201 /// be identical of what the IR will look like when dumping the full module. 202 OpPrintingFlags &OpPrintingFlags::useLocalScope() { 203 printLocalScope = true; 204 return *this; 205 } 206 207 /// Return if the given ElementsAttr should be elided. 208 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { 209 return elementsAttrElementLimit.hasValue() && 210 *elementsAttrElementLimit < int64_t(attr.getNumElements()) && 211 !attr.isa<SplatElementsAttr>(); 212 } 213 214 /// Return the size limit for printing large ElementsAttr. 215 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { 216 return elementsAttrElementLimit; 217 } 218 219 /// Return if debug information should be printed. 220 bool OpPrintingFlags::shouldPrintDebugInfo() const { 221 return printDebugInfoFlag; 222 } 223 224 /// Return if debug information should be printed in the pretty form. 225 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { 226 return printDebugInfoPrettyFormFlag; 227 } 228 229 /// Return if operations should be printed in the generic form. 230 bool OpPrintingFlags::shouldPrintGenericOpForm() const { 231 return printGenericOpFormFlag; 232 } 233 234 /// Return if the printer should use local scope when dumping the IR. 235 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } 236 237 /// Returns true if an ElementsAttr with the given number of elements should be 238 /// printed with hex. 239 static bool shouldPrintElementsAttrWithHex(int64_t numElements) { 240 // Check to see if a command line option was provided for the limit. 241 if (clOptions.isConstructed()) { 242 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) { 243 // -1 is used to disable hex printing. 244 if (clOptions->printElementsAttrWithHexIfLarger == -1) 245 return false; 246 return numElements > clOptions->printElementsAttrWithHexIfLarger; 247 } 248 } 249 250 // Otherwise, default to printing with hex if the number of elements is >100. 251 return numElements > 100; 252 } 253 254 //===----------------------------------------------------------------------===// 255 // NewLineCounter 256 //===----------------------------------------------------------------------===// 257 258 namespace { 259 /// This class is a simple formatter that emits a new line when inputted into a 260 /// stream, that enables counting the number of newlines emitted. This class 261 /// should be used whenever emitting newlines in the printer. 262 struct NewLineCounter { 263 unsigned curLine = 1; 264 }; 265 266 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { 267 ++newLine.curLine; 268 return os << '\n'; 269 } 270 } // namespace 271 272 //===----------------------------------------------------------------------===// 273 // AliasInitializer 274 //===----------------------------------------------------------------------===// 275 276 namespace { 277 /// This class represents a specific instance of a symbol Alias. 278 class SymbolAlias { 279 public: 280 SymbolAlias(StringRef name, bool isDeferrable) 281 : name(name), suffixIndex(0), hasSuffixIndex(false), 282 isDeferrable(isDeferrable) {} 283 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) 284 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true), 285 isDeferrable(isDeferrable) {} 286 287 /// Print this alias to the given stream. 288 void print(raw_ostream &os) const { 289 os << name; 290 if (hasSuffixIndex) 291 os << suffixIndex; 292 } 293 294 /// Returns true if this alias supports deferred resolution when parsing. 295 bool canBeDeferred() const { return isDeferrable; } 296 297 private: 298 /// The main name of the alias. 299 StringRef name; 300 /// The optional suffix index of the alias, if multiple aliases had the same 301 /// name. 302 uint32_t suffixIndex : 30; 303 /// A flag indicating whether this alias has a suffix or not. 304 bool hasSuffixIndex : 1; 305 /// A flag indicating whether this alias may be deferred or not. 306 bool isDeferrable : 1; 307 }; 308 309 /// This class represents a utility that initializes the set of attribute and 310 /// type aliases, without the need to store the extra information within the 311 /// main AliasState class or pass it around via function arguments. 312 class AliasInitializer { 313 public: 314 AliasInitializer( 315 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, 316 llvm::BumpPtrAllocator &aliasAllocator) 317 : interfaces(interfaces), aliasAllocator(aliasAllocator), 318 aliasOS(aliasBuffer) {} 319 320 void initialize(Operation *op, const OpPrintingFlags &printerFlags, 321 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 322 llvm::MapVector<Type, SymbolAlias> &typeToAlias); 323 324 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is 325 /// set to true if the originator of this attribute can resolve the alias 326 /// after parsing has completed (e.g. in the case of operation locations). 327 void visit(Attribute attr, bool canBeDeferred = false); 328 329 /// Visit the given type to see if it has an alias. 330 void visit(Type type); 331 332 private: 333 /// Try to generate an alias for the provided symbol. If an alias is 334 /// generated, the provided alias mapping and reverse mapping are updated. 335 /// Returns success if an alias was generated, failure otherwise. 336 template <typename T> 337 LogicalResult 338 generateAlias(T symbol, 339 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol); 340 341 /// The set of asm interfaces within the context. 342 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 343 344 /// Mapping between an alias and the set of symbols mapped to it. 345 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr; 346 llvm::MapVector<StringRef, std::vector<Type>> aliasToType; 347 348 /// An allocator used for alias names. 349 llvm::BumpPtrAllocator &aliasAllocator; 350 351 /// The set of visited attributes. 352 DenseSet<Attribute> visitedAttributes; 353 354 /// The set of attributes that have aliases *and* can be deferred. 355 DenseSet<Attribute> deferrableAttributes; 356 357 /// The set of visited types. 358 DenseSet<Type> visitedTypes; 359 360 /// Storage and stream used when generating an alias. 361 SmallString<32> aliasBuffer; 362 llvm::raw_svector_ostream aliasOS; 363 }; 364 365 /// This class implements a dummy OpAsmPrinter that doesn't print any output, 366 /// and merely collects the attributes and types that *would* be printed in a 367 /// normal print invocation so that we can generate proper aliases. This allows 368 /// for us to generate aliases only for the attributes and types that would be 369 /// in the output, and trims down unnecessary output. 370 class DummyAliasOperationPrinter : private OpAsmPrinter { 371 public: 372 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, 373 AliasInitializer &initializer) 374 : printerFlags(printerFlags), initializer(initializer) {} 375 376 /// Print the given operation. 377 void print(Operation *op) { 378 // Visit the operation location. 379 if (printerFlags.shouldPrintDebugInfo()) 380 initializer.visit(op->getLoc(), /*canBeDeferred=*/true); 381 382 // If requested, always print the generic form. 383 if (!printerFlags.shouldPrintGenericOpForm()) { 384 // Check to see if this is a known operation. If so, use the registered 385 // custom printer hook. 386 if (auto opInfo = op->getRegisteredInfo()) { 387 opInfo->printAssembly(op, *this, /*defaultDialect=*/""); 388 return; 389 } 390 } 391 392 // Otherwise print with the generic assembly form. 393 printGenericOp(op); 394 } 395 396 private: 397 /// Print the given operation in the generic form. 398 void printGenericOp(Operation *op, bool printOpName = true) override { 399 // Consider nested operations for aliases. 400 if (op->getNumRegions() != 0) { 401 for (Region ®ion : op->getRegions()) 402 printRegion(region, /*printEntryBlockArgs=*/true, 403 /*printBlockTerminators=*/true); 404 } 405 406 // Visit all the types used in the operation. 407 for (Type type : op->getOperandTypes()) 408 printType(type); 409 for (Type type : op->getResultTypes()) 410 printType(type); 411 412 // Consider the attributes of the operation for aliases. 413 for (const NamedAttribute &attr : op->getAttrs()) 414 printAttribute(attr.getValue()); 415 } 416 417 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 418 /// block are not printed. If 'printBlockTerminator' is false, the terminator 419 /// operation of the block is not printed. 420 void print(Block *block, bool printBlockArgs = true, 421 bool printBlockTerminator = true) { 422 // Consider the types of the block arguments for aliases if 'printBlockArgs' 423 // is set to true. 424 if (printBlockArgs) { 425 for (BlockArgument arg : block->getArguments()) { 426 printType(arg.getType()); 427 428 // Visit the argument location. 429 if (printerFlags.shouldPrintDebugInfo()) 430 // TODO: Allow deferring argument locations. 431 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 432 } 433 } 434 435 // Consider the operations within this block, ignoring the terminator if 436 // requested. 437 bool hasTerminator = 438 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 439 auto range = llvm::make_range( 440 block->begin(), 441 std::prev(block->end(), 442 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 443 for (Operation &op : range) 444 print(&op); 445 } 446 447 /// Print the given region. 448 void printRegion(Region ®ion, bool printEntryBlockArgs, 449 bool printBlockTerminators, 450 bool printEmptyBlock = false) override { 451 if (region.empty()) 452 return; 453 454 auto *entryBlock = ®ion.front(); 455 print(entryBlock, printEntryBlockArgs, printBlockTerminators); 456 for (Block &b : llvm::drop_begin(region, 1)) 457 print(&b); 458 } 459 460 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, 461 bool omitType) override { 462 printType(arg.getType()); 463 // Visit the argument location. 464 if (printerFlags.shouldPrintDebugInfo()) 465 // TODO: Allow deferring argument locations. 466 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); 467 } 468 469 /// Consider the given type to be printed for an alias. 470 void printType(Type type) override { initializer.visit(type); } 471 472 /// Consider the given attribute to be printed for an alias. 473 void printAttribute(Attribute attr) override { initializer.visit(attr); } 474 void printAttributeWithoutType(Attribute attr) override { 475 printAttribute(attr); 476 } 477 LogicalResult printAlias(Attribute attr) override { 478 initializer.visit(attr); 479 return success(); 480 } 481 LogicalResult printAlias(Type type) override { 482 initializer.visit(type); 483 return success(); 484 } 485 486 /// Print the given set of attributes with names not included within 487 /// 'elidedAttrs'. 488 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 489 ArrayRef<StringRef> elidedAttrs = {}) override { 490 if (attrs.empty()) 491 return; 492 if (elidedAttrs.empty()) { 493 for (const NamedAttribute &attr : attrs) 494 printAttribute(attr.getValue()); 495 return; 496 } 497 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 498 elidedAttrs.end()); 499 for (const NamedAttribute &attr : attrs) 500 if (!elidedAttrsSet.contains(attr.getName().strref())) 501 printAttribute(attr.getValue()); 502 } 503 void printOptionalAttrDictWithKeyword( 504 ArrayRef<NamedAttribute> attrs, 505 ArrayRef<StringRef> elidedAttrs = {}) override { 506 printOptionalAttrDict(attrs, elidedAttrs); 507 } 508 509 /// Return a null stream as the output stream, this will ignore any data fed 510 /// to it. 511 raw_ostream &getStream() const override { return os; } 512 513 /// The following are hooks of `OpAsmPrinter` that are not necessary for 514 /// determining potential aliases. 515 void printFloat(const APFloat &value) override {} 516 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} 517 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} 518 void printNewline() override {} 519 void printOperand(Value) override {} 520 void printOperand(Value, raw_ostream &os) override { 521 // Users expect the output string to have at least the prefixed % to signal 522 // a value name. To maintain this invariant, emit a name even if it is 523 // guaranteed to go unused. 524 os << "%"; 525 } 526 void printKeywordOrString(StringRef) override {} 527 void printSymbolName(StringRef) override {} 528 void printSuccessor(Block *) override {} 529 void printSuccessorAndUseList(Block *, ValueRange) override {} 530 void shadowRegionArgs(Region &, ValueRange) override {} 531 532 /// The printer flags to use when determining potential aliases. 533 const OpPrintingFlags &printerFlags; 534 535 /// The initializer to use when identifying aliases. 536 AliasInitializer &initializer; 537 538 /// A dummy output stream. 539 mutable llvm::raw_null_ostream os; 540 }; 541 } // namespace 542 543 /// Sanitize the given name such that it can be used as a valid identifier. If 544 /// the string needs to be modified in any way, the provided buffer is used to 545 /// store the new copy, 546 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, 547 StringRef allowedPunctChars = "$._-", 548 bool allowTrailingDigit = true) { 549 assert(!name.empty() && "Shouldn't have an empty name here"); 550 551 auto copyNameToBuffer = [&] { 552 for (char ch : name) { 553 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) 554 buffer.push_back(ch); 555 else if (ch == ' ') 556 buffer.push_back('_'); 557 else 558 buffer.append(llvm::utohexstr((unsigned char)ch)); 559 } 560 }; 561 562 // Check to see if this name is valid. If it starts with a digit, then it 563 // could conflict with the autogenerated numeric ID's, so add an underscore 564 // prefix to avoid problems. 565 if (isdigit(name[0])) { 566 buffer.push_back('_'); 567 copyNameToBuffer(); 568 return buffer; 569 } 570 571 // If the name ends with a trailing digit, add a '_' to avoid potential 572 // conflicts with autogenerated ID's. 573 if (!allowTrailingDigit && isdigit(name.back())) { 574 copyNameToBuffer(); 575 buffer.push_back('_'); 576 return buffer; 577 } 578 579 // Check to see that the name consists of only valid identifier characters. 580 for (char ch : name) { 581 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { 582 copyNameToBuffer(); 583 return buffer; 584 } 585 } 586 587 // If there are no invalid characters, return the original name. 588 return name; 589 } 590 591 /// Given a collection of aliases and symbols, initialize a mapping from a 592 /// symbol to a given alias. 593 template <typename T> 594 static void 595 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol, 596 llvm::MapVector<T, SymbolAlias> &symbolToAlias, 597 DenseSet<T> *deferrableAliases = nullptr) { 598 std::vector<std::pair<StringRef, std::vector<T>>> aliases = 599 aliasToSymbol.takeVector(); 600 llvm::array_pod_sort(aliases.begin(), aliases.end(), 601 [](const auto *lhs, const auto *rhs) { 602 return lhs->first.compare(rhs->first); 603 }); 604 605 for (auto &it : aliases) { 606 // If there is only one instance for this alias, use the name directly. 607 if (it.second.size() == 1) { 608 T symbol = it.second.front(); 609 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 610 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)}); 611 continue; 612 } 613 // Otherwise, add the index to the name. 614 for (int i = 0, e = it.second.size(); i < e; ++i) { 615 T symbol = it.second[i]; 616 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 617 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)}); 618 } 619 } 620 } 621 622 void AliasInitializer::initialize( 623 Operation *op, const OpPrintingFlags &printerFlags, 624 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 625 llvm::MapVector<Type, SymbolAlias> &typeToAlias) { 626 // Use a dummy printer when walking the IR so that we can collect the 627 // attributes/types that will actually be used during printing when 628 // considering aliases. 629 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); 630 aliasPrinter.print(op); 631 632 // Initialize the aliases sorted by name. 633 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); 634 initializeAliases(aliasToType, typeToAlias); 635 } 636 637 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { 638 if (!visitedAttributes.insert(attr).second) { 639 // If this attribute already has an alias and this instance can't be 640 // deferred, make sure that the alias isn't deferred. 641 if (!canBeDeferred) 642 deferrableAttributes.erase(attr); 643 return; 644 } 645 646 // Try to generate an alias for this attribute. 647 if (succeeded(generateAlias(attr, aliasToAttr))) { 648 if (canBeDeferred) 649 deferrableAttributes.insert(attr); 650 return; 651 } 652 653 // Check for any sub elements. 654 if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) { 655 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 656 [&](Type type) { visit(type); }); 657 } 658 } 659 660 void AliasInitializer::visit(Type type) { 661 if (!visitedTypes.insert(type).second) 662 return; 663 664 // Try to generate an alias for this type. 665 if (succeeded(generateAlias(type, aliasToType))) 666 return; 667 668 // Check for any sub elements. 669 if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) { 670 subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, 671 [&](Type type) { visit(type); }); 672 } 673 } 674 675 template <typename T> 676 LogicalResult AliasInitializer::generateAlias( 677 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) { 678 SmallString<32> nameBuffer; 679 for (const auto &interface : interfaces) { 680 OpAsmDialectInterface::AliasResult result = 681 interface.getAlias(symbol, aliasOS); 682 if (result == OpAsmDialectInterface::AliasResult::NoAlias) 683 continue; 684 nameBuffer = std::move(aliasBuffer); 685 assert(!nameBuffer.empty() && "expected valid alias name"); 686 if (result == OpAsmDialectInterface::AliasResult::FinalAlias) 687 break; 688 } 689 690 if (nameBuffer.empty()) 691 return failure(); 692 693 SmallString<16> tempBuffer; 694 StringRef name = 695 sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", 696 /*allowTrailingDigit=*/false); 697 name = name.copy(aliasAllocator); 698 aliasToSymbol[name].push_back(symbol); 699 return success(); 700 } 701 702 //===----------------------------------------------------------------------===// 703 // AliasState 704 //===----------------------------------------------------------------------===// 705 706 namespace { 707 /// This class manages the state for type and attribute aliases. 708 class AliasState { 709 public: 710 // Initialize the internal aliases. 711 void 712 initialize(Operation *op, const OpPrintingFlags &printerFlags, 713 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 714 715 /// Get an alias for the given attribute if it has one and print it in `os`. 716 /// Returns success if an alias was printed, failure otherwise. 717 LogicalResult getAlias(Attribute attr, raw_ostream &os) const; 718 719 /// Get an alias for the given type if it has one and print it in `os`. 720 /// Returns success if an alias was printed, failure otherwise. 721 LogicalResult getAlias(Type ty, raw_ostream &os) const; 722 723 /// Print all of the referenced aliases that can not be resolved in a deferred 724 /// manner. 725 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 726 printAliases(os, newLine, /*isDeferred=*/false); 727 } 728 729 /// Print all of the referenced aliases that support deferred resolution. 730 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 731 printAliases(os, newLine, /*isDeferred=*/true); 732 } 733 734 private: 735 /// Print all of the referenced aliases that support the provided resolution 736 /// behavior. 737 void printAliases(raw_ostream &os, NewLineCounter &newLine, 738 bool isDeferred) const; 739 740 /// Mapping between attribute and alias. 741 llvm::MapVector<Attribute, SymbolAlias> attrToAlias; 742 /// Mapping between type and alias. 743 llvm::MapVector<Type, SymbolAlias> typeToAlias; 744 745 /// An allocator used for alias names. 746 llvm::BumpPtrAllocator aliasAllocator; 747 }; 748 } // namespace 749 750 void AliasState::initialize( 751 Operation *op, const OpPrintingFlags &printerFlags, 752 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 753 AliasInitializer initializer(interfaces, aliasAllocator); 754 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); 755 } 756 757 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { 758 auto it = attrToAlias.find(attr); 759 if (it == attrToAlias.end()) 760 return failure(); 761 it->second.print(os << '#'); 762 return success(); 763 } 764 765 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { 766 auto it = typeToAlias.find(ty); 767 if (it == typeToAlias.end()) 768 return failure(); 769 770 it->second.print(os << '!'); 771 return success(); 772 } 773 774 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine, 775 bool isDeferred) const { 776 auto filterFn = [=](const auto &aliasIt) { 777 return aliasIt.second.canBeDeferred() == isDeferred; 778 }; 779 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) { 780 it.second.print(os << '#'); 781 os << " = " << it.first << newLine; 782 } 783 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) { 784 it.second.print(os << '!'); 785 os << " = type " << it.first << newLine; 786 } 787 } 788 789 //===----------------------------------------------------------------------===// 790 // SSANameState 791 //===----------------------------------------------------------------------===// 792 793 namespace { 794 /// Info about block printing: a number which is its position in the visitation 795 /// order, and a name that is used to print reference to it, e.g. ^bb42. 796 struct BlockInfo { 797 int ordering; 798 StringRef name; 799 }; 800 801 /// This class manages the state of SSA value names. 802 class SSANameState { 803 public: 804 /// A sentinel value used for values with names set. 805 enum : unsigned { NameSentinel = ~0U }; 806 807 SSANameState(Operation *op, const OpPrintingFlags &printerFlags); 808 809 /// Print the SSA identifier for the given value to 'stream'. If 810 /// 'printResultNo' is true, it also presents the result number ('#' number) 811 /// of this value. 812 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; 813 814 /// Return the result indices for each of the result groups registered by this 815 /// operation, or empty if none exist. 816 ArrayRef<int> getOpResultGroups(Operation *op); 817 818 /// Get the info for the given block. 819 BlockInfo getBlockInfo(Block *block); 820 821 /// Renumber the arguments for the specified region to the same names as the 822 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for 823 /// details. 824 void shadowRegionArgs(Region ®ion, ValueRange namesToUse); 825 826 private: 827 /// Number the SSA values within the given IR unit. 828 void numberValuesInRegion(Region ®ion); 829 void numberValuesInBlock(Block &block); 830 void numberValuesInOp(Operation &op); 831 832 /// Given a result of an operation 'result', find the result group head 833 /// 'lookupValue' and the result of 'result' within that group in 834 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group 835 /// has more than 1 result. 836 void getResultIDAndNumber(OpResult result, Value &lookupValue, 837 Optional<int> &lookupResultNo) const; 838 839 /// Set a special value name for the given value. 840 void setValueName(Value value, StringRef name); 841 842 /// Uniques the given value name within the printer. If the given name 843 /// conflicts, it is automatically renamed. 844 StringRef uniqueValueName(StringRef name); 845 846 /// This is the value ID for each SSA value. If this returns NameSentinel, 847 /// then the valueID has an entry in valueNames. 848 DenseMap<Value, unsigned> valueIDs; 849 DenseMap<Value, StringRef> valueNames; 850 851 /// This is a map of operations that contain multiple named result groups, 852 /// i.e. there may be multiple names for the results of the operation. The 853 /// value of this map are the result numbers that start a result group. 854 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; 855 856 /// This maps blocks to there visitation number in the current region as well 857 /// as the string representing their name. 858 DenseMap<Block *, BlockInfo> blockNames; 859 860 /// This keeps track of all of the non-numeric names that are in flight, 861 /// allowing us to check for duplicates. 862 /// Note: the value of the map is unused. 863 llvm::ScopedHashTable<StringRef, char> usedNames; 864 llvm::BumpPtrAllocator usedNameAllocator; 865 866 /// This is the next value ID to assign in numbering. 867 unsigned nextValueID = 0; 868 /// This is the next ID to assign to a region entry block argument. 869 unsigned nextArgumentID = 0; 870 /// This is the next ID to assign when a name conflict is detected. 871 unsigned nextConflictID = 0; 872 873 /// These are the printing flags. They control, eg., whether to print in 874 /// generic form. 875 OpPrintingFlags printerFlags; 876 }; 877 } // namespace 878 879 SSANameState::SSANameState( 880 Operation *op, const OpPrintingFlags &printerFlags) 881 : printerFlags(printerFlags) { 882 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); 883 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); 884 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); 885 886 // The naming context includes `nextValueID`, `nextArgumentID`, 887 // `nextConflictID` and `usedNames` scoped HashTable. This information is 888 // carried from the parent region. 889 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; 890 using NamingContext = 891 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; 892 893 // Allocator for UsedNamesScopeTy 894 llvm::BumpPtrAllocator allocator; 895 896 // Add a scope for the top level operation. 897 auto *topLevelNamesScope = 898 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); 899 900 SmallVector<NamingContext, 8> nameContext; 901 for (Region ®ion : op->getRegions()) 902 nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, 903 nextConflictID, topLevelNamesScope)); 904 905 numberValuesInOp(*op); 906 907 while (!nameContext.empty()) { 908 Region *region; 909 UsedNamesScopeTy *parentScope; 910 std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) = 911 nameContext.pop_back_val(); 912 913 // When we switch from one subtree to another, pop the scopes(needless) 914 // until the parent scope. 915 while (usedNames.getCurScope() != parentScope) { 916 usedNames.getCurScope()->~UsedNamesScopeTy(); 917 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && 918 "top level parentScope must be a nullptr"); 919 } 920 921 // Add a scope for the current region. 922 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) 923 UsedNamesScopeTy(usedNames); 924 925 numberValuesInRegion(*region); 926 927 for (Operation &op : region->getOps()) 928 for (Region ®ion : op.getRegions()) 929 nameContext.push_back(std::make_tuple(®ion, nextValueID, 930 nextArgumentID, nextConflictID, 931 curNamesScope)); 932 } 933 934 // Manually remove all the scopes. 935 while (usedNames.getCurScope() != nullptr) 936 usedNames.getCurScope()->~UsedNamesScopeTy(); 937 } 938 939 void SSANameState::printValueID(Value value, bool printResultNo, 940 raw_ostream &stream) const { 941 if (!value) { 942 stream << "<<NULL VALUE>>"; 943 return; 944 } 945 946 Optional<int> resultNo; 947 auto lookupValue = value; 948 949 // If this is an operation result, collect the head lookup value of the result 950 // group and the result number of 'result' within that group. 951 if (OpResult result = value.dyn_cast<OpResult>()) 952 getResultIDAndNumber(result, lookupValue, resultNo); 953 954 auto it = valueIDs.find(lookupValue); 955 if (it == valueIDs.end()) { 956 stream << "<<UNKNOWN SSA VALUE>>"; 957 return; 958 } 959 960 stream << '%'; 961 if (it->second != NameSentinel) { 962 stream << it->second; 963 } else { 964 auto nameIt = valueNames.find(lookupValue); 965 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 966 stream << nameIt->second; 967 } 968 969 if (resultNo.hasValue() && printResultNo) 970 stream << '#' << resultNo; 971 } 972 973 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { 974 auto it = opResultGroups.find(op); 975 return it == opResultGroups.end() ? ArrayRef<int>() : it->second; 976 } 977 978 BlockInfo SSANameState::getBlockInfo(Block *block) { 979 auto it = blockNames.find(block); 980 BlockInfo invalidBlock{-1, "INVALIDBLOCK"}; 981 return it != blockNames.end() ? it->second : invalidBlock; 982 } 983 984 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { 985 assert(!region.empty() && "cannot shadow arguments of an empty region"); 986 assert(region.getNumArguments() == namesToUse.size() && 987 "incorrect number of names passed in"); 988 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 989 "only KnownIsolatedFromAbove ops can shadow names"); 990 991 SmallVector<char, 16> nameStr; 992 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 993 auto nameToUse = namesToUse[i]; 994 if (nameToUse == nullptr) 995 continue; 996 auto nameToReplace = region.getArgument(i); 997 998 nameStr.clear(); 999 llvm::raw_svector_ostream nameStream(nameStr); 1000 printValueID(nameToUse, /*printResultNo=*/true, nameStream); 1001 1002 // Entry block arguments should already have a pretty "arg" name. 1003 assert(valueIDs[nameToReplace] == NameSentinel); 1004 1005 // Use the name without the leading %. 1006 auto name = StringRef(nameStream.str()).drop_front(); 1007 1008 // Overwrite the name. 1009 valueNames[nameToReplace] = name.copy(usedNameAllocator); 1010 } 1011 } 1012 1013 void SSANameState::numberValuesInRegion(Region ®ion) { 1014 auto setBlockArgNameFn = [&](Value arg, StringRef name) { 1015 assert(!valueIDs.count(arg) && "arg numbered multiple times"); 1016 assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion && 1017 "arg not defined in current region"); 1018 setValueName(arg, name); 1019 }; 1020 1021 if (!printerFlags.shouldPrintGenericOpForm()) { 1022 if (Operation *op = region.getParentOp()) { 1023 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) 1024 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); 1025 } 1026 } 1027 1028 // Number the values within this region in a breadth-first order. 1029 unsigned nextBlockID = 0; 1030 for (auto &block : region) { 1031 // Each block gets a unique ID, and all of the operations within it get 1032 // numbered as well. 1033 auto blockInfoIt = blockNames.insert({&block, {-1, ""}}); 1034 if (blockInfoIt.second) { 1035 // This block hasn't been named through `getAsmBlockArgumentNames`, use 1036 // default `^bbNNN` format. 1037 std::string name; 1038 llvm::raw_string_ostream(name) << "^bb" << nextBlockID; 1039 blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator); 1040 } 1041 blockInfoIt.first->second.ordering = nextBlockID++; 1042 1043 numberValuesInBlock(block); 1044 } 1045 } 1046 1047 void SSANameState::numberValuesInBlock(Block &block) { 1048 // Number the block arguments. We give entry block arguments a special name 1049 // 'arg'. 1050 bool isEntryBlock = block.isEntryBlock(); 1051 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); 1052 llvm::raw_svector_ostream specialName(specialNameBuffer); 1053 for (auto arg : block.getArguments()) { 1054 if (valueIDs.count(arg)) 1055 continue; 1056 if (isEntryBlock) { 1057 specialNameBuffer.resize(strlen("arg")); 1058 specialName << nextArgumentID++; 1059 } 1060 setValueName(arg, specialName.str()); 1061 } 1062 1063 // Number the operations in this block. 1064 for (auto &op : block) 1065 numberValuesInOp(op); 1066 } 1067 1068 void SSANameState::numberValuesInOp(Operation &op) { 1069 // Function used to set the special result names for the operation. 1070 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); 1071 auto setResultNameFn = [&](Value result, StringRef name) { 1072 assert(!valueIDs.count(result) && "result numbered multiple times"); 1073 assert(result.getDefiningOp() == &op && "result not defined by 'op'"); 1074 setValueName(result, name); 1075 1076 // Record the result number for groups not anchored at 0. 1077 if (int resultNo = result.cast<OpResult>().getResultNumber()) 1078 resultGroups.push_back(resultNo); 1079 }; 1080 // Operations can customize the printing of block names in OpAsmOpInterface. 1081 auto setBlockNameFn = [&](Block *block, StringRef name) { 1082 assert(block->getParentOp() == &op && 1083 "getAsmBlockArgumentNames callback invoked on a block not directly " 1084 "nested under the current operation"); 1085 assert(!blockNames.count(block) && "block numbered multiple times"); 1086 SmallString<16> tmpBuffer{"^"}; 1087 name = sanitizeIdentifier(name, tmpBuffer); 1088 if (name.data() != tmpBuffer.data()) { 1089 tmpBuffer.append(name); 1090 name = tmpBuffer.str(); 1091 } 1092 name = name.copy(usedNameAllocator); 1093 blockNames[block] = {-1, name}; 1094 }; 1095 1096 if (!printerFlags.shouldPrintGenericOpForm()) { 1097 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { 1098 asmInterface.getAsmBlockNames(setBlockNameFn); 1099 asmInterface.getAsmResultNames(setResultNameFn); 1100 } 1101 } 1102 1103 unsigned numResults = op.getNumResults(); 1104 if (numResults == 0) 1105 return; 1106 Value resultBegin = op.getResult(0); 1107 1108 // If the first result wasn't numbered, give it a default number. 1109 if (valueIDs.try_emplace(resultBegin, nextValueID).second) 1110 ++nextValueID; 1111 1112 // If this operation has multiple result groups, mark it. 1113 if (resultGroups.size() != 1) { 1114 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); 1115 opResultGroups.try_emplace(&op, std::move(resultGroups)); 1116 } 1117 } 1118 1119 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, 1120 Optional<int> &lookupResultNo) const { 1121 Operation *owner = result.getOwner(); 1122 if (owner->getNumResults() == 1) 1123 return; 1124 int resultNo = result.getResultNumber(); 1125 1126 // If this operation has multiple result groups, we will need to find the 1127 // one corresponding to this result. 1128 auto resultGroupIt = opResultGroups.find(owner); 1129 if (resultGroupIt == opResultGroups.end()) { 1130 // If not, just use the first result. 1131 lookupResultNo = resultNo; 1132 lookupValue = owner->getResult(0); 1133 return; 1134 } 1135 1136 // Find the correct index using a binary search, as the groups are ordered. 1137 ArrayRef<int> resultGroups = resultGroupIt->second; 1138 const auto *it = llvm::upper_bound(resultGroups, resultNo); 1139 int groupResultNo = 0, groupSize = 0; 1140 1141 // If there are no smaller elements, the last result group is the lookup. 1142 if (it == resultGroups.end()) { 1143 groupResultNo = resultGroups.back(); 1144 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); 1145 } else { 1146 // Otherwise, the previous element is the lookup. 1147 groupResultNo = *std::prev(it); 1148 groupSize = *it - groupResultNo; 1149 } 1150 1151 // We only record the result number for a group of size greater than 1. 1152 if (groupSize != 1) 1153 lookupResultNo = resultNo - groupResultNo; 1154 lookupValue = owner->getResult(groupResultNo); 1155 } 1156 1157 void SSANameState::setValueName(Value value, StringRef name) { 1158 // If the name is empty, the value uses the default numbering. 1159 if (name.empty()) { 1160 valueIDs[value] = nextValueID++; 1161 return; 1162 } 1163 1164 valueIDs[value] = NameSentinel; 1165 valueNames[value] = uniqueValueName(name); 1166 } 1167 1168 StringRef SSANameState::uniqueValueName(StringRef name) { 1169 SmallString<16> tmpBuffer; 1170 name = sanitizeIdentifier(name, tmpBuffer); 1171 1172 // Check to see if this name is already unique. 1173 if (!usedNames.count(name)) { 1174 name = name.copy(usedNameAllocator); 1175 } else { 1176 // Otherwise, we had a conflict - probe until we find a unique name. This 1177 // is guaranteed to terminate (and usually in a single iteration) because it 1178 // generates new names by incrementing nextConflictID. 1179 SmallString<64> probeName(name); 1180 probeName.push_back('_'); 1181 while (true) { 1182 probeName += llvm::utostr(nextConflictID++); 1183 if (!usedNames.count(probeName)) { 1184 name = probeName.str().copy(usedNameAllocator); 1185 break; 1186 } 1187 probeName.resize(name.size() + 1); 1188 } 1189 } 1190 1191 usedNames.insert(name, char()); 1192 return name; 1193 } 1194 1195 //===----------------------------------------------------------------------===// 1196 // AsmState 1197 //===----------------------------------------------------------------------===// 1198 1199 namespace mlir { 1200 namespace detail { 1201 class AsmStateImpl { 1202 public: 1203 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, 1204 AsmState::LocationMap *locationMap) 1205 : interfaces(op->getContext()), nameState(op, printerFlags), 1206 printerFlags(printerFlags), locationMap(locationMap) {} 1207 1208 /// Initialize the alias state to enable the printing of aliases. 1209 void initializeAliases(Operation *op) { 1210 aliasState.initialize(op, printerFlags, interfaces); 1211 } 1212 1213 /// Get the state used for aliases. 1214 AliasState &getAliasState() { return aliasState; } 1215 1216 /// Get the state used for SSA names. 1217 SSANameState &getSSANameState() { return nameState; } 1218 1219 /// Get the printer flags. 1220 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } 1221 1222 /// Register the location, line and column, within the buffer that the given 1223 /// operation was printed at. 1224 void registerOperationLocation(Operation *op, unsigned line, unsigned col) { 1225 if (locationMap) 1226 (*locationMap)[op] = std::make_pair(line, col); 1227 } 1228 1229 private: 1230 /// Collection of OpAsm interfaces implemented in the context. 1231 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 1232 1233 /// The state used for attribute and type aliases. 1234 AliasState aliasState; 1235 1236 /// The state used for SSA value names. 1237 SSANameState nameState; 1238 1239 /// Flags that control op output. 1240 OpPrintingFlags printerFlags; 1241 1242 /// An optional location map to be populated. 1243 AsmState::LocationMap *locationMap; 1244 }; 1245 } // namespace detail 1246 } // namespace mlir 1247 1248 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, 1249 LocationMap *locationMap) 1250 : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {} 1251 AsmState::~AsmState() = default; 1252 1253 const OpPrintingFlags &AsmState::getPrinterFlags() const { 1254 return impl->getPrinterFlags(); 1255 } 1256 1257 //===----------------------------------------------------------------------===// 1258 // AsmPrinter::Impl 1259 //===----------------------------------------------------------------------===// 1260 1261 namespace mlir { 1262 class AsmPrinter::Impl { 1263 public: 1264 Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None, 1265 AsmStateImpl *state = nullptr) 1266 : os(os), printerFlags(flags), state(state) {} 1267 explicit Impl(Impl &other) 1268 : Impl(other.os, other.printerFlags, other.state) {} 1269 1270 /// Returns the output stream of the printer. 1271 raw_ostream &getStream() { return os; } 1272 1273 template <typename Container, typename UnaryFunctor> 1274 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { 1275 llvm::interleaveComma(c, os, eachFn); 1276 } 1277 1278 /// This enum describes the different kinds of elision for the type of an 1279 /// attribute when printing it. 1280 enum class AttrTypeElision { 1281 /// The type must not be elided, 1282 Never, 1283 /// The type may be elided when it matches the default used in the parser 1284 /// (for example i64 is the default for integer attributes). 1285 May, 1286 /// The type must be elided. 1287 Must 1288 }; 1289 1290 /// Print the given attribute. 1291 void printAttribute(Attribute attr, 1292 AttrTypeElision typeElision = AttrTypeElision::Never); 1293 1294 /// Print the alias for the given attribute, return failure if no alias could 1295 /// be printed. 1296 LogicalResult printAlias(Attribute attr); 1297 1298 void printType(Type type); 1299 1300 /// Print the alias for the given type, return failure if no alias could 1301 /// be printed. 1302 LogicalResult printAlias(Type type); 1303 1304 /// Print the given location to the stream. If `allowAlias` is true, this 1305 /// allows for the internal location to use an attribute alias. 1306 void printLocation(LocationAttr loc, bool allowAlias = false); 1307 1308 void printAffineMap(AffineMap map); 1309 void 1310 printAffineExpr(AffineExpr expr, 1311 function_ref<void(unsigned, bool)> printValueName = nullptr); 1312 void printAffineConstraint(AffineExpr expr, bool isEq); 1313 void printIntegerSet(IntegerSet set); 1314 1315 protected: 1316 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1317 ArrayRef<StringRef> elidedAttrs = {}, 1318 bool withKeyword = false); 1319 void printNamedAttribute(NamedAttribute attr); 1320 void printTrailingLocation(Location loc, bool allowAlias = true); 1321 void printLocationInternal(LocationAttr loc, bool pretty = false); 1322 1323 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1324 /// used instead of individual elements when the elements attr is large. 1325 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); 1326 1327 /// Print a dense string elements attribute. 1328 void printDenseStringElementsAttr(DenseStringElementsAttr attr); 1329 1330 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1331 /// used instead of individual elements when the elements attr is large. 1332 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1333 bool allowHex); 1334 1335 void printDialectAttribute(Attribute attr); 1336 void printDialectType(Type type); 1337 1338 /// This enum is used to represent the binding strength of the enclosing 1339 /// context that an AffineExprStorage is being printed in, so we can 1340 /// intelligently produce parens. 1341 enum class BindingStrength { 1342 Weak, // + and - 1343 Strong, // All other binary operators. 1344 }; 1345 void printAffineExprInternal( 1346 AffineExpr expr, BindingStrength enclosingTightness, 1347 function_ref<void(unsigned, bool)> printValueName = nullptr); 1348 1349 /// The output stream for the printer. 1350 raw_ostream &os; 1351 1352 /// A set of flags to control the printer's behavior. 1353 OpPrintingFlags printerFlags; 1354 1355 /// An optional printer state for the module. 1356 AsmStateImpl *state; 1357 1358 /// A tracker for the number of new lines emitted during printing. 1359 NewLineCounter newLine; 1360 }; 1361 } // namespace mlir 1362 1363 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { 1364 // Check to see if we are printing debug information. 1365 if (!printerFlags.shouldPrintDebugInfo()) 1366 return; 1367 1368 os << " "; 1369 printLocation(loc, /*allowAlias=*/allowAlias); 1370 } 1371 1372 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) { 1373 TypeSwitch<LocationAttr>(loc) 1374 .Case<OpaqueLoc>([&](OpaqueLoc loc) { 1375 printLocationInternal(loc.getFallbackLocation(), pretty); 1376 }) 1377 .Case<UnknownLoc>([&](UnknownLoc loc) { 1378 if (pretty) 1379 os << "[unknown]"; 1380 else 1381 os << "unknown"; 1382 }) 1383 .Case<FileLineColLoc>([&](FileLineColLoc loc) { 1384 if (pretty) { 1385 os << loc.getFilename().getValue(); 1386 } else { 1387 os << "\""; 1388 printEscapedString(loc.getFilename(), os); 1389 os << "\""; 1390 } 1391 os << ':' << loc.getLine() << ':' << loc.getColumn(); 1392 }) 1393 .Case<NameLoc>([&](NameLoc loc) { 1394 os << '\"'; 1395 printEscapedString(loc.getName(), os); 1396 os << '\"'; 1397 1398 // Print the child if it isn't unknown. 1399 auto childLoc = loc.getChildLoc(); 1400 if (!childLoc.isa<UnknownLoc>()) { 1401 os << '('; 1402 printLocationInternal(childLoc, pretty); 1403 os << ')'; 1404 } 1405 }) 1406 .Case<CallSiteLoc>([&](CallSiteLoc loc) { 1407 Location caller = loc.getCaller(); 1408 Location callee = loc.getCallee(); 1409 if (!pretty) 1410 os << "callsite("; 1411 printLocationInternal(callee, pretty); 1412 if (pretty) { 1413 if (callee.isa<NameLoc>()) { 1414 if (caller.isa<FileLineColLoc>()) { 1415 os << " at "; 1416 } else { 1417 os << newLine << " at "; 1418 } 1419 } else { 1420 os << newLine << " at "; 1421 } 1422 } else { 1423 os << " at "; 1424 } 1425 printLocationInternal(caller, pretty); 1426 if (!pretty) 1427 os << ")"; 1428 }) 1429 .Case<FusedLoc>([&](FusedLoc loc) { 1430 if (!pretty) 1431 os << "fused"; 1432 if (Attribute metadata = loc.getMetadata()) 1433 os << '<' << metadata << '>'; 1434 os << '['; 1435 interleave( 1436 loc.getLocations(), 1437 [&](Location loc) { printLocationInternal(loc, pretty); }, 1438 [&]() { os << ", "; }); 1439 os << ']'; 1440 }); 1441 } 1442 1443 /// Print a floating point value in a way that the parser will be able to 1444 /// round-trip losslessly. 1445 static void printFloatValue(const APFloat &apValue, raw_ostream &os) { 1446 // We would like to output the FP constant value in exponential notation, 1447 // but we cannot do this if doing so will lose precision. Check here to 1448 // make sure that we only output it in exponential format if we can parse 1449 // the value back and get the same value. 1450 bool isInf = apValue.isInfinity(); 1451 bool isNaN = apValue.isNaN(); 1452 if (!isInf && !isNaN) { 1453 SmallString<128> strValue; 1454 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, 1455 /*TruncateZero=*/false); 1456 1457 // Check to make sure that the stringized number is not some string like 1458 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 1459 // that the string matches the "[-+]?[0-9]" regex. 1460 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 1461 ((strValue[0] == '-' || strValue[0] == '+') && 1462 (strValue[1] >= '0' && strValue[1] <= '9'))) && 1463 "[-+]?[0-9] regex does not match!"); 1464 1465 // Parse back the stringized version and check that the value is equal 1466 // (i.e., there is no precision loss). 1467 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 1468 os << strValue; 1469 return; 1470 } 1471 1472 // If it is not, use the default format of APFloat instead of the 1473 // exponential notation. 1474 strValue.clear(); 1475 apValue.toString(strValue); 1476 1477 // Make sure that we can parse the default form as a float. 1478 if (strValue.str().contains('.')) { 1479 os << strValue; 1480 return; 1481 } 1482 } 1483 1484 // Print special values in hexadecimal format. The sign bit should be included 1485 // in the literal. 1486 SmallVector<char, 16> str; 1487 APInt apInt = apValue.bitcastToAPInt(); 1488 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 1489 /*formatAsCLiteral=*/true); 1490 os << str; 1491 } 1492 1493 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { 1494 if (printerFlags.shouldPrintDebugInfoPrettyForm()) 1495 return printLocationInternal(loc, /*pretty=*/true); 1496 1497 os << "loc("; 1498 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) 1499 printLocationInternal(loc); 1500 os << ')'; 1501 } 1502 1503 /// Returns true if the given dialect symbol data is simple enough to print in 1504 /// the pretty form, i.e. without the enclosing "". 1505 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 1506 // The name must start with an identifier. 1507 if (symName.empty() || !isalpha(symName.front())) 1508 return false; 1509 1510 // Ignore all the characters that are valid in an identifier in the symbol 1511 // name. 1512 symName = symName.drop_while( 1513 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); 1514 if (symName.empty()) 1515 return true; 1516 1517 // If we got to an unexpected character, then it must be a <>. Check those 1518 // recursively. 1519 if (symName.front() != '<' || symName.back() != '>') 1520 return false; 1521 1522 SmallVector<char, 8> nestedPunctuation; 1523 do { 1524 // If we ran out of characters, then we had a punctuation mismatch. 1525 if (symName.empty()) 1526 return false; 1527 1528 auto c = symName.front(); 1529 symName = symName.drop_front(); 1530 1531 switch (c) { 1532 // We never allow null characters. This is an EOF indicator for the lexer 1533 // which we could handle, but isn't important for any known dialect. 1534 case '\0': 1535 return false; 1536 case '<': 1537 case '[': 1538 case '(': 1539 case '{': 1540 nestedPunctuation.push_back(c); 1541 continue; 1542 case '-': 1543 // Treat `->` as a special token. 1544 if (!symName.empty() && symName.front() == '>') { 1545 symName = symName.drop_front(); 1546 continue; 1547 } 1548 break; 1549 // Reject types with mismatched brackets. 1550 case '>': 1551 if (nestedPunctuation.pop_back_val() != '<') 1552 return false; 1553 break; 1554 case ']': 1555 if (nestedPunctuation.pop_back_val() != '[') 1556 return false; 1557 break; 1558 case ')': 1559 if (nestedPunctuation.pop_back_val() != '(') 1560 return false; 1561 break; 1562 case '}': 1563 if (nestedPunctuation.pop_back_val() != '{') 1564 return false; 1565 break; 1566 default: 1567 continue; 1568 } 1569 1570 // We're done when the punctuation is fully matched. 1571 } while (!nestedPunctuation.empty()); 1572 1573 // If there were extra characters, then we failed. 1574 return symName.empty(); 1575 } 1576 1577 /// Print the given dialect symbol to the stream. 1578 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 1579 StringRef dialectName, StringRef symString) { 1580 os << symPrefix << dialectName; 1581 1582 // If this symbol name is simple enough, print it directly in pretty form, 1583 // otherwise, we print it as an escaped string. 1584 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 1585 os << '.' << symString; 1586 return; 1587 } 1588 1589 os << "<\""; 1590 llvm::printEscapedString(symString, os); 1591 os << "\">"; 1592 } 1593 1594 /// Returns true if the given string can be represented as a bare identifier. 1595 static bool isBareIdentifier(StringRef name) { 1596 // By making this unsigned, the value passed in to isalnum will always be 1597 // in the range 0-255. This is important when building with MSVC because 1598 // its implementation will assert. This situation can arise when dealing 1599 // with UTF-8 multibyte characters. 1600 if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) 1601 return false; 1602 return llvm::all_of(name.drop_front(), [](unsigned char c) { 1603 return isalnum(c) || c == '_' || c == '$' || c == '.'; 1604 }); 1605 } 1606 1607 /// Print the given string as a keyword, or a quoted and escaped string if it 1608 /// has any special or non-printable characters in it. 1609 static void printKeywordOrString(StringRef keyword, raw_ostream &os) { 1610 // If it can be represented as a bare identifier, write it directly. 1611 if (isBareIdentifier(keyword)) { 1612 os << keyword; 1613 return; 1614 } 1615 1616 // Otherwise, output the keyword wrapped in quotes with proper escaping. 1617 os << "\""; 1618 printEscapedString(keyword, os); 1619 os << '"'; 1620 } 1621 1622 /// Print the given string as a symbol reference. A symbol reference is 1623 /// represented as a string prefixed with '@'. The reference is surrounded with 1624 /// ""'s and escaped if it has any special or non-printable characters in it. 1625 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { 1626 assert(!symbolRef.empty() && "expected valid symbol reference"); 1627 os << '@'; 1628 printKeywordOrString(symbolRef, os); 1629 } 1630 1631 // Print out a valid ElementsAttr that is succinct and can represent any 1632 // potential shape/type, for use when eliding a large ElementsAttr. 1633 // 1634 // We choose to use an opaque ElementsAttr literal with conspicuous content to 1635 // hopefully alert readers to the fact that this has been elided. 1636 // 1637 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will 1638 // accept the string "elided". The first string must be a registered dialect 1639 // name and the latter must be a hex constant. 1640 static void printElidedElementsAttr(raw_ostream &os) { 1641 os << R"(opaque<"elided_large_const", "0xDEADBEEF">)"; 1642 } 1643 1644 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { 1645 return success(state && succeeded(state->getAliasState().getAlias(attr, os))); 1646 } 1647 1648 LogicalResult AsmPrinter::Impl::printAlias(Type type) { 1649 return success(state && succeeded(state->getAliasState().getAlias(type, os))); 1650 } 1651 1652 void AsmPrinter::Impl::printAttribute(Attribute attr, 1653 AttrTypeElision typeElision) { 1654 if (!attr) { 1655 os << "<<NULL ATTRIBUTE>>"; 1656 return; 1657 } 1658 1659 // Try to print an alias for this attribute. 1660 if (succeeded(printAlias(attr))) 1661 return; 1662 1663 if (!isa<BuiltinDialect>(attr.getDialect())) 1664 return printDialectAttribute(attr); 1665 1666 auto attrType = attr.getType(); 1667 if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) { 1668 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 1669 opaqueAttr.getAttrData()); 1670 } else if (attr.isa<UnitAttr>()) { 1671 os << "unit"; 1672 return; 1673 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) { 1674 os << '{'; 1675 interleaveComma(dictAttr.getValue(), 1676 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1677 os << '}'; 1678 1679 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) { 1680 if (attrType.isSignlessInteger(1)) { 1681 os << (intAttr.getValue().getBoolValue() ? "true" : "false"); 1682 1683 // Boolean integer attributes always elides the type. 1684 return; 1685 } 1686 1687 // Only print attributes as unsigned if they are explicitly unsigned or are 1688 // signless 1-bit values. Indexes, signed values, and multi-bit signless 1689 // values print as signed. 1690 bool isUnsigned = 1691 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); 1692 intAttr.getValue().print(os, !isUnsigned); 1693 1694 // IntegerAttr elides the type if I64. 1695 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) 1696 return; 1697 1698 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) { 1699 printFloatValue(floatAttr.getValue(), os); 1700 1701 // FloatAttr elides the type if F64. 1702 if (typeElision == AttrTypeElision::May && attrType.isF64()) 1703 return; 1704 1705 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) { 1706 os << '"'; 1707 printEscapedString(strAttr.getValue(), os); 1708 os << '"'; 1709 1710 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 1711 os << '['; 1712 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { 1713 printAttribute(attr, AttrTypeElision::May); 1714 }); 1715 os << ']'; 1716 1717 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) { 1718 os << "affine_map<"; 1719 affineMapAttr.getValue().print(os); 1720 os << '>'; 1721 1722 // AffineMap always elides the type. 1723 return; 1724 1725 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) { 1726 os << "affine_set<"; 1727 integerSetAttr.getValue().print(os); 1728 os << '>'; 1729 1730 // IntegerSet always elides the type. 1731 return; 1732 1733 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 1734 printType(typeAttr.getValue()); 1735 1736 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) { 1737 printSymbolReference(refAttr.getRootReference().getValue(), os); 1738 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { 1739 os << "::"; 1740 printSymbolReference(nestedRef.getValue(), os); 1741 } 1742 1743 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) { 1744 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { 1745 printElidedElementsAttr(os); 1746 } else { 1747 os << "opaque<" << opaqueAttr.getDialect() << ", \"0x" 1748 << llvm::toHex(opaqueAttr.getValue()) << "\">"; 1749 } 1750 1751 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) { 1752 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { 1753 printElidedElementsAttr(os); 1754 } else { 1755 os << "dense<"; 1756 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); 1757 os << '>'; 1758 } 1759 1760 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) { 1761 if (printerFlags.shouldElideElementsAttr(strEltAttr)) { 1762 printElidedElementsAttr(os); 1763 } else { 1764 os << "dense<"; 1765 printDenseStringElementsAttr(strEltAttr); 1766 os << '>'; 1767 } 1768 1769 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) { 1770 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || 1771 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { 1772 printElidedElementsAttr(os); 1773 } else { 1774 os << "sparse<"; 1775 DenseIntElementsAttr indices = sparseEltAttr.getIndices(); 1776 if (indices.getNumElements() != 0) { 1777 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); 1778 os << ", "; 1779 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); 1780 } 1781 os << '>'; 1782 } 1783 1784 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) { 1785 printLocation(locAttr); 1786 } 1787 // Don't print the type if we must elide it, or if it is a None type. 1788 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { 1789 os << " : "; 1790 printType(attrType); 1791 } 1792 } 1793 1794 /// Print the integer element of a DenseElementsAttr. 1795 static void printDenseIntElement(const APInt &value, raw_ostream &os, 1796 bool isSigned) { 1797 if (value.getBitWidth() == 1) 1798 os << (value.getBoolValue() ? "true" : "false"); 1799 else 1800 value.print(os, isSigned); 1801 } 1802 1803 static void 1804 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, 1805 function_ref<void(unsigned)> printEltFn) { 1806 // Special case for 0-d and splat tensors. 1807 if (isSplat) 1808 return printEltFn(0); 1809 1810 // Special case for degenerate tensors. 1811 auto numElements = type.getNumElements(); 1812 if (numElements == 0) 1813 return; 1814 1815 // We use a mixed-radix counter to iterate through the shape. When we bump a 1816 // non-least-significant digit, we emit a close bracket. When we next emit an 1817 // element we re-open all closed brackets. 1818 1819 // The mixed-radix counter, with radices in 'shape'. 1820 int64_t rank = type.getRank(); 1821 SmallVector<unsigned, 4> counter(rank, 0); 1822 // The number of brackets that have been opened and not closed. 1823 unsigned openBrackets = 0; 1824 1825 auto shape = type.getShape(); 1826 auto bumpCounter = [&] { 1827 // Bump the least significant digit. 1828 ++counter[rank - 1]; 1829 // Iterate backwards bubbling back the increment. 1830 for (unsigned i = rank - 1; i > 0; --i) 1831 if (counter[i] >= shape[i]) { 1832 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 1833 counter[i] = 0; 1834 ++counter[i - 1]; 1835 --openBrackets; 1836 os << ']'; 1837 } 1838 }; 1839 1840 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 1841 if (idx != 0) 1842 os << ", "; 1843 while (openBrackets++ < rank) 1844 os << '['; 1845 openBrackets = rank; 1846 printEltFn(idx); 1847 bumpCounter(); 1848 } 1849 while (openBrackets-- > 0) 1850 os << ']'; 1851 } 1852 1853 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, 1854 bool allowHex) { 1855 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) 1856 return printDenseStringElementsAttr(stringAttr); 1857 1858 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(), 1859 allowHex); 1860 } 1861 1862 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( 1863 DenseIntOrFPElementsAttr attr, bool allowHex) { 1864 auto type = attr.getType(); 1865 auto elementType = type.getElementType(); 1866 1867 // Check to see if we should format this attribute as a hex string. 1868 auto numElements = type.getNumElements(); 1869 if (!attr.isSplat() && allowHex && 1870 shouldPrintElementsAttrWithHex(numElements)) { 1871 ArrayRef<char> rawData = attr.getRawData(); 1872 if (llvm::support::endian::system_endianness() == 1873 llvm::support::endianness::big) { 1874 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE 1875 // machines. It is converted here to print in LE format. 1876 SmallVector<char, 64> outDataVec(rawData.size()); 1877 MutableArrayRef<char> convRawData(outDataVec); 1878 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1879 rawData, convRawData, type); 1880 os << '"' << "0x" 1881 << llvm::toHex(StringRef(convRawData.data(), convRawData.size())) 1882 << "\""; 1883 } else { 1884 os << '"' << "0x" 1885 << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\""; 1886 } 1887 1888 return; 1889 } 1890 1891 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1892 Type complexElementType = complexTy.getElementType(); 1893 // Note: The if and else below had a common lambda function which invoked 1894 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 1895 // and hence was replaced. 1896 if (complexElementType.isa<IntegerType>()) { 1897 bool isSigned = !complexElementType.isUnsignedInteger(); 1898 auto valueIt = attr.value_begin<std::complex<APInt>>(); 1899 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1900 auto complexValue = *(valueIt + index); 1901 os << "("; 1902 printDenseIntElement(complexValue.real(), os, isSigned); 1903 os << ","; 1904 printDenseIntElement(complexValue.imag(), os, isSigned); 1905 os << ")"; 1906 }); 1907 } else { 1908 auto valueIt = attr.value_begin<std::complex<APFloat>>(); 1909 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1910 auto complexValue = *(valueIt + index); 1911 os << "("; 1912 printFloatValue(complexValue.real(), os); 1913 os << ","; 1914 printFloatValue(complexValue.imag(), os); 1915 os << ")"; 1916 }); 1917 } 1918 } else if (elementType.isIntOrIndex()) { 1919 bool isSigned = !elementType.isUnsignedInteger(); 1920 auto valueIt = attr.value_begin<APInt>(); 1921 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1922 printDenseIntElement(*(valueIt + index), os, isSigned); 1923 }); 1924 } else { 1925 assert(elementType.isa<FloatType>() && "unexpected element type"); 1926 auto valueIt = attr.value_begin<APFloat>(); 1927 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1928 printFloatValue(*(valueIt + index), os); 1929 }); 1930 } 1931 } 1932 1933 void AsmPrinter::Impl::printDenseStringElementsAttr( 1934 DenseStringElementsAttr attr) { 1935 ArrayRef<StringRef> data = attr.getRawStringData(); 1936 auto printFn = [&](unsigned index) { 1937 os << "\""; 1938 printEscapedString(data[index], os); 1939 os << "\""; 1940 }; 1941 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); 1942 } 1943 1944 void AsmPrinter::Impl::printType(Type type) { 1945 if (!type) { 1946 os << "<<NULL TYPE>>"; 1947 return; 1948 } 1949 1950 // Try to print an alias for this type. 1951 if (state && succeeded(state->getAliasState().getAlias(type, os))) 1952 return; 1953 1954 TypeSwitch<Type>(type) 1955 .Case<OpaqueType>([&](OpaqueType opaqueTy) { 1956 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 1957 opaqueTy.getTypeData()); 1958 }) 1959 .Case<IndexType>([&](Type) { os << "index"; }) 1960 .Case<BFloat16Type>([&](Type) { os << "bf16"; }) 1961 .Case<Float16Type>([&](Type) { os << "f16"; }) 1962 .Case<Float32Type>([&](Type) { os << "f32"; }) 1963 .Case<Float64Type>([&](Type) { os << "f64"; }) 1964 .Case<Float80Type>([&](Type) { os << "f80"; }) 1965 .Case<Float128Type>([&](Type) { os << "f128"; }) 1966 .Case<IntegerType>([&](IntegerType integerTy) { 1967 if (integerTy.isSigned()) 1968 os << 's'; 1969 else if (integerTy.isUnsigned()) 1970 os << 'u'; 1971 os << 'i' << integerTy.getWidth(); 1972 }) 1973 .Case<FunctionType>([&](FunctionType funcTy) { 1974 os << '('; 1975 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); 1976 os << ") -> "; 1977 ArrayRef<Type> results = funcTy.getResults(); 1978 if (results.size() == 1 && !results[0].isa<FunctionType>()) { 1979 printType(results[0]); 1980 } else { 1981 os << '('; 1982 interleaveComma(results, [&](Type ty) { printType(ty); }); 1983 os << ')'; 1984 } 1985 }) 1986 .Case<VectorType>([&](VectorType vectorTy) { 1987 os << "vector<"; 1988 auto vShape = vectorTy.getShape(); 1989 unsigned lastDim = vShape.size(); 1990 unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims(); 1991 unsigned dimIdx = 0; 1992 for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++) 1993 os << vShape[dimIdx] << 'x'; 1994 if (vectorTy.isScalable()) { 1995 os << '['; 1996 unsigned secondToLastDim = lastDim - 1; 1997 for (; dimIdx < secondToLastDim; dimIdx++) 1998 os << vShape[dimIdx] << 'x'; 1999 os << vShape[dimIdx] << "]x"; 2000 } 2001 printType(vectorTy.getElementType()); 2002 os << '>'; 2003 }) 2004 .Case<RankedTensorType>([&](RankedTensorType tensorTy) { 2005 os << "tensor<"; 2006 for (int64_t dim : tensorTy.getShape()) { 2007 if (ShapedType::isDynamic(dim)) 2008 os << '?'; 2009 else 2010 os << dim; 2011 os << 'x'; 2012 } 2013 printType(tensorTy.getElementType()); 2014 // Only print the encoding attribute value if set. 2015 if (tensorTy.getEncoding()) { 2016 os << ", "; 2017 printAttribute(tensorTy.getEncoding()); 2018 } 2019 os << '>'; 2020 }) 2021 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { 2022 os << "tensor<*x"; 2023 printType(tensorTy.getElementType()); 2024 os << '>'; 2025 }) 2026 .Case<MemRefType>([&](MemRefType memrefTy) { 2027 os << "memref<"; 2028 for (int64_t dim : memrefTy.getShape()) { 2029 if (ShapedType::isDynamic(dim)) 2030 os << '?'; 2031 else 2032 os << dim; 2033 os << 'x'; 2034 } 2035 printType(memrefTy.getElementType()); 2036 if (!memrefTy.getLayout().isIdentity()) { 2037 os << ", "; 2038 printAttribute(memrefTy.getLayout(), AttrTypeElision::May); 2039 } 2040 // Only print the memory space if it is the non-default one. 2041 if (memrefTy.getMemorySpace()) { 2042 os << ", "; 2043 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2044 } 2045 os << '>'; 2046 }) 2047 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { 2048 os << "memref<*x"; 2049 printType(memrefTy.getElementType()); 2050 // Only print the memory space if it is the non-default one. 2051 if (memrefTy.getMemorySpace()) { 2052 os << ", "; 2053 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); 2054 } 2055 os << '>'; 2056 }) 2057 .Case<ComplexType>([&](ComplexType complexTy) { 2058 os << "complex<"; 2059 printType(complexTy.getElementType()); 2060 os << '>'; 2061 }) 2062 .Case<TupleType>([&](TupleType tupleTy) { 2063 os << "tuple<"; 2064 interleaveComma(tupleTy.getTypes(), 2065 [&](Type type) { printType(type); }); 2066 os << '>'; 2067 }) 2068 .Case<NoneType>([&](Type) { os << "none"; }) 2069 .Default([&](Type type) { return printDialectType(type); }); 2070 } 2071 2072 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2073 ArrayRef<StringRef> elidedAttrs, 2074 bool withKeyword) { 2075 // If there are no attributes, then there is nothing to be done. 2076 if (attrs.empty()) 2077 return; 2078 2079 // Functor used to print a filtered attribute list. 2080 auto printFilteredAttributesFn = [&](auto filteredAttrs) { 2081 // Print the 'attributes' keyword if necessary. 2082 if (withKeyword) 2083 os << " attributes"; 2084 2085 // Otherwise, print them all out in braces. 2086 os << " {"; 2087 interleaveComma(filteredAttrs, 2088 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 2089 os << '}'; 2090 }; 2091 2092 // If no attributes are elided, we can directly print with no filtering. 2093 if (elidedAttrs.empty()) 2094 return printFilteredAttributesFn(attrs); 2095 2096 // Otherwise, filter out any attributes that shouldn't be included. 2097 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 2098 elidedAttrs.end()); 2099 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 2100 return !elidedAttrsSet.contains(attr.getName().strref()); 2101 }); 2102 if (!filteredAttrs.empty()) 2103 printFilteredAttributesFn(filteredAttrs); 2104 } 2105 2106 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { 2107 // Print the name without quotes if possible. 2108 ::printKeywordOrString(attr.getName().strref(), os); 2109 2110 // Pretty printing elides the attribute value for unit attributes. 2111 if (attr.getValue().isa<UnitAttr>()) 2112 return; 2113 2114 os << " = "; 2115 printAttribute(attr.getValue()); 2116 } 2117 2118 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { 2119 auto &dialect = attr.getDialect(); 2120 2121 // Ask the dialect to serialize the attribute to a string. 2122 std::string attrName; 2123 { 2124 llvm::raw_string_ostream attrNameStr(attrName); 2125 Impl subPrinter(attrNameStr, printerFlags, state); 2126 DialectAsmPrinter printer(subPrinter); 2127 dialect.printAttribute(attr, printer); 2128 } 2129 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 2130 } 2131 2132 void AsmPrinter::Impl::printDialectType(Type type) { 2133 auto &dialect = type.getDialect(); 2134 2135 // Ask the dialect to serialize the type to a string. 2136 std::string typeName; 2137 { 2138 llvm::raw_string_ostream typeNameStr(typeName); 2139 Impl subPrinter(typeNameStr, printerFlags, state); 2140 DialectAsmPrinter printer(subPrinter); 2141 dialect.printType(type, printer); 2142 } 2143 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 2144 } 2145 2146 //===--------------------------------------------------------------------===// 2147 // AsmPrinter 2148 //===--------------------------------------------------------------------===// 2149 2150 AsmPrinter::~AsmPrinter() = default; 2151 2152 raw_ostream &AsmPrinter::getStream() const { 2153 assert(impl && "expected AsmPrinter::getStream to be overriden"); 2154 return impl->getStream(); 2155 } 2156 2157 /// Print the given floating point value in a stablized form. 2158 void AsmPrinter::printFloat(const APFloat &value) { 2159 assert(impl && "expected AsmPrinter::printFloat to be overriden"); 2160 printFloatValue(value, impl->getStream()); 2161 } 2162 2163 void AsmPrinter::printType(Type type) { 2164 assert(impl && "expected AsmPrinter::printType to be overriden"); 2165 impl->printType(type); 2166 } 2167 2168 void AsmPrinter::printAttribute(Attribute attr) { 2169 assert(impl && "expected AsmPrinter::printAttribute to be overriden"); 2170 impl->printAttribute(attr); 2171 } 2172 2173 LogicalResult AsmPrinter::printAlias(Attribute attr) { 2174 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2175 return impl->printAlias(attr); 2176 } 2177 2178 LogicalResult AsmPrinter::printAlias(Type type) { 2179 assert(impl && "expected AsmPrinter::printAlias to be overriden"); 2180 return impl->printAlias(type); 2181 } 2182 2183 void AsmPrinter::printAttributeWithoutType(Attribute attr) { 2184 assert(impl && 2185 "expected AsmPrinter::printAttributeWithoutType to be overriden"); 2186 impl->printAttribute(attr, Impl::AttrTypeElision::Must); 2187 } 2188 2189 void AsmPrinter::printKeywordOrString(StringRef keyword) { 2190 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); 2191 ::printKeywordOrString(keyword, impl->getStream()); 2192 } 2193 2194 void AsmPrinter::printSymbolName(StringRef symbolRef) { 2195 assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); 2196 ::printSymbolReference(symbolRef, impl->getStream()); 2197 } 2198 2199 //===----------------------------------------------------------------------===// 2200 // Affine expressions and maps 2201 //===----------------------------------------------------------------------===// 2202 2203 void AsmPrinter::Impl::printAffineExpr( 2204 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { 2205 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 2206 } 2207 2208 void AsmPrinter::Impl::printAffineExprInternal( 2209 AffineExpr expr, BindingStrength enclosingTightness, 2210 function_ref<void(unsigned, bool)> printValueName) { 2211 const char *binopSpelling = nullptr; 2212 switch (expr.getKind()) { 2213 case AffineExprKind::SymbolId: { 2214 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); 2215 if (printValueName) 2216 printValueName(pos, /*isSymbol=*/true); 2217 else 2218 os << 's' << pos; 2219 return; 2220 } 2221 case AffineExprKind::DimId: { 2222 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 2223 if (printValueName) 2224 printValueName(pos, /*isSymbol=*/false); 2225 else 2226 os << 'd' << pos; 2227 return; 2228 } 2229 case AffineExprKind::Constant: 2230 os << expr.cast<AffineConstantExpr>().getValue(); 2231 return; 2232 case AffineExprKind::Add: 2233 binopSpelling = " + "; 2234 break; 2235 case AffineExprKind::Mul: 2236 binopSpelling = " * "; 2237 break; 2238 case AffineExprKind::FloorDiv: 2239 binopSpelling = " floordiv "; 2240 break; 2241 case AffineExprKind::CeilDiv: 2242 binopSpelling = " ceildiv "; 2243 break; 2244 case AffineExprKind::Mod: 2245 binopSpelling = " mod "; 2246 break; 2247 } 2248 2249 auto binOp = expr.cast<AffineBinaryOpExpr>(); 2250 AffineExpr lhsExpr = binOp.getLHS(); 2251 AffineExpr rhsExpr = binOp.getRHS(); 2252 2253 // Handle tightly binding binary operators. 2254 if (binOp.getKind() != AffineExprKind::Add) { 2255 if (enclosingTightness == BindingStrength::Strong) 2256 os << '('; 2257 2258 // Pretty print multiplication with -1. 2259 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); 2260 if (rhsConst && binOp.getKind() == AffineExprKind::Mul && 2261 rhsConst.getValue() == -1) { 2262 os << "-"; 2263 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2264 if (enclosingTightness == BindingStrength::Strong) 2265 os << ')'; 2266 return; 2267 } 2268 2269 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2270 2271 os << binopSpelling; 2272 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 2273 2274 if (enclosingTightness == BindingStrength::Strong) 2275 os << ')'; 2276 return; 2277 } 2278 2279 // Print out special "pretty" forms for add. 2280 if (enclosingTightness == BindingStrength::Strong) 2281 os << '('; 2282 2283 // Pretty print addition to a product that has a negative operand as a 2284 // subtraction. 2285 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { 2286 if (rhs.getKind() == AffineExprKind::Mul) { 2287 AffineExpr rrhsExpr = rhs.getRHS(); 2288 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { 2289 if (rrhs.getValue() == -1) { 2290 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2291 printValueName); 2292 os << " - "; 2293 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 2294 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2295 printValueName); 2296 } else { 2297 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 2298 printValueName); 2299 } 2300 2301 if (enclosingTightness == BindingStrength::Strong) 2302 os << ')'; 2303 return; 2304 } 2305 2306 if (rrhs.getValue() < -1) { 2307 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2308 printValueName); 2309 os << " - "; 2310 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2311 printValueName); 2312 os << " * " << -rrhs.getValue(); 2313 if (enclosingTightness == BindingStrength::Strong) 2314 os << ')'; 2315 return; 2316 } 2317 } 2318 } 2319 } 2320 2321 // Pretty print addition to a negative number as a subtraction. 2322 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { 2323 if (rhsConst.getValue() < 0) { 2324 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2325 os << " - " << -rhsConst.getValue(); 2326 if (enclosingTightness == BindingStrength::Strong) 2327 os << ')'; 2328 return; 2329 } 2330 } 2331 2332 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2333 2334 os << " + "; 2335 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 2336 2337 if (enclosingTightness == BindingStrength::Strong) 2338 os << ')'; 2339 } 2340 2341 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { 2342 printAffineExprInternal(expr, BindingStrength::Weak); 2343 isEq ? os << " == 0" : os << " >= 0"; 2344 } 2345 2346 void AsmPrinter::Impl::printAffineMap(AffineMap map) { 2347 // Dimension identifiers. 2348 os << '('; 2349 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 2350 os << 'd' << i << ", "; 2351 if (map.getNumDims() >= 1) 2352 os << 'd' << map.getNumDims() - 1; 2353 os << ')'; 2354 2355 // Symbolic identifiers. 2356 if (map.getNumSymbols() != 0) { 2357 os << '['; 2358 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 2359 os << 's' << i << ", "; 2360 if (map.getNumSymbols() >= 1) 2361 os << 's' << map.getNumSymbols() - 1; 2362 os << ']'; 2363 } 2364 2365 // Result affine expressions. 2366 os << " -> ("; 2367 interleaveComma(map.getResults(), 2368 [&](AffineExpr expr) { printAffineExpr(expr); }); 2369 os << ')'; 2370 } 2371 2372 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { 2373 // Dimension identifiers. 2374 os << '('; 2375 for (unsigned i = 1; i < set.getNumDims(); ++i) 2376 os << 'd' << i - 1 << ", "; 2377 if (set.getNumDims() >= 1) 2378 os << 'd' << set.getNumDims() - 1; 2379 os << ')'; 2380 2381 // Symbolic identifiers. 2382 if (set.getNumSymbols() != 0) { 2383 os << '['; 2384 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 2385 os << 's' << i << ", "; 2386 if (set.getNumSymbols() >= 1) 2387 os << 's' << set.getNumSymbols() - 1; 2388 os << ']'; 2389 } 2390 2391 // Print constraints. 2392 os << " : ("; 2393 int numConstraints = set.getNumConstraints(); 2394 for (int i = 1; i < numConstraints; ++i) { 2395 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 2396 os << ", "; 2397 } 2398 if (numConstraints >= 1) 2399 printAffineConstraint(set.getConstraint(numConstraints - 1), 2400 set.isEq(numConstraints - 1)); 2401 os << ')'; 2402 } 2403 2404 //===----------------------------------------------------------------------===// 2405 // OperationPrinter 2406 //===----------------------------------------------------------------------===// 2407 2408 namespace { 2409 /// This class contains the logic for printing operations, regions, and blocks. 2410 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { 2411 public: 2412 using Impl = AsmPrinter::Impl; 2413 using Impl::printType; 2414 2415 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) 2416 : Impl(os, state.getPrinterFlags(), &state), 2417 OpAsmPrinter(static_cast<Impl &>(*this)) {} 2418 2419 /// Print the given top-level operation. 2420 void printTopLevelOperation(Operation *op); 2421 2422 /// Print the given operation with its indent and location. 2423 void print(Operation *op); 2424 /// Print the bare location, not including indentation/location/etc. 2425 void printOperation(Operation *op); 2426 /// Print the given operation in the generic form. 2427 void printGenericOp(Operation *op, bool printOpName) override; 2428 2429 /// Print the name of the given block. 2430 void printBlockName(Block *block); 2431 2432 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 2433 /// block are not printed. If 'printBlockTerminator' is false, the terminator 2434 /// operation of the block is not printed. 2435 void print(Block *block, bool printBlockArgs = true, 2436 bool printBlockTerminator = true); 2437 2438 /// Print the ID of the given value, optionally with its result number. 2439 void printValueID(Value value, bool printResultNo = true, 2440 raw_ostream *streamOverride = nullptr) const; 2441 2442 //===--------------------------------------------------------------------===// 2443 // OpAsmPrinter methods 2444 //===--------------------------------------------------------------------===// 2445 2446 /// Print a newline and indent the printer to the start of the current 2447 /// operation. 2448 void printNewline() override { 2449 os << newLine; 2450 os.indent(currentIndent); 2451 } 2452 2453 /// Print a block argument in the usual format of: 2454 /// %ssaName : type {attr1=42} loc("here") 2455 /// where location printing is controlled by the standard internal option. 2456 /// You may pass omitType=true to not print a type, and pass an empty 2457 /// attribute list if you don't care for attributes. 2458 void printRegionArgument(BlockArgument arg, 2459 ArrayRef<NamedAttribute> argAttrs = {}, 2460 bool omitType = false) override; 2461 2462 /// Print the ID for the given value. 2463 void printOperand(Value value) override { printValueID(value); } 2464 void printOperand(Value value, raw_ostream &os) override { 2465 printValueID(value, /*printResultNo=*/true, &os); 2466 } 2467 2468 /// Print an optional attribute dictionary with a given set of elided values. 2469 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2470 ArrayRef<StringRef> elidedAttrs = {}) override { 2471 Impl::printOptionalAttrDict(attrs, elidedAttrs); 2472 } 2473 void printOptionalAttrDictWithKeyword( 2474 ArrayRef<NamedAttribute> attrs, 2475 ArrayRef<StringRef> elidedAttrs = {}) override { 2476 Impl::printOptionalAttrDict(attrs, elidedAttrs, 2477 /*withKeyword=*/true); 2478 } 2479 2480 /// Print the given successor. 2481 void printSuccessor(Block *successor) override; 2482 2483 /// Print an operation successor with the operands used for the block 2484 /// arguments. 2485 void printSuccessorAndUseList(Block *successor, 2486 ValueRange succOperands) override; 2487 2488 /// Print the given region. 2489 void printRegion(Region ®ion, bool printEntryBlockArgs, 2490 bool printBlockTerminators, bool printEmptyBlock) override; 2491 2492 /// Renumber the arguments for the specified region to the same names as the 2493 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 2494 /// operations. If any entry in namesToUse is null, the corresponding 2495 /// argument name is left alone. 2496 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { 2497 state->getSSANameState().shadowRegionArgs(region, namesToUse); 2498 } 2499 2500 /// Print the given affine map with the symbol and dimension operands printed 2501 /// inline with the map. 2502 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2503 ValueRange operands) override; 2504 2505 /// Print the given affine expression with the symbol and dimension operands 2506 /// printed inline with the expression. 2507 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, 2508 ValueRange symOperands) override; 2509 2510 private: 2511 // Contains the stack of default dialects to use when printing regions. 2512 // A new dialect is pushed to the stack before parsing regions nested under an 2513 // operation implementing `OpAsmOpInterface`, and popped when done. At the 2514 // top-level we start with "builtin" as the default, so that the top-level 2515 // `module` operation prints as-is. 2516 SmallVector<StringRef> defaultDialectStack{"builtin"}; 2517 2518 /// The number of spaces used for indenting nested operations. 2519 const static unsigned indentWidth = 2; 2520 2521 // This is the current indentation level for nested structures. 2522 unsigned currentIndent = 0; 2523 }; 2524 } // namespace 2525 2526 void OperationPrinter::printTopLevelOperation(Operation *op) { 2527 // Output the aliases at the top level that can't be deferred. 2528 state->getAliasState().printNonDeferredAliases(os, newLine); 2529 2530 // Print the module. 2531 print(op); 2532 os << newLine; 2533 2534 // Output the aliases at the top level that can be deferred. 2535 state->getAliasState().printDeferredAliases(os, newLine); 2536 } 2537 2538 /// Print a block argument in the usual format of: 2539 /// %ssaName : type {attr1=42} loc("here") 2540 /// where location printing is controlled by the standard internal option. 2541 /// You may pass omitType=true to not print a type, and pass an empty 2542 /// attribute list if you don't care for attributes. 2543 void OperationPrinter::printRegionArgument(BlockArgument arg, 2544 ArrayRef<NamedAttribute> argAttrs, 2545 bool omitType) { 2546 printOperand(arg); 2547 if (!omitType) { 2548 os << ": "; 2549 printType(arg.getType()); 2550 } 2551 printOptionalAttrDict(argAttrs); 2552 // TODO: We should allow location aliases on block arguments. 2553 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2554 } 2555 2556 void OperationPrinter::print(Operation *op) { 2557 // Track the location of this operation. 2558 state->registerOperationLocation(op, newLine.curLine, currentIndent); 2559 2560 os.indent(currentIndent); 2561 printOperation(op); 2562 printTrailingLocation(op->getLoc()); 2563 } 2564 2565 void OperationPrinter::printOperation(Operation *op) { 2566 if (size_t numResults = op->getNumResults()) { 2567 auto printResultGroup = [&](size_t resultNo, size_t resultCount) { 2568 printValueID(op->getResult(resultNo), /*printResultNo=*/false); 2569 if (resultCount > 1) 2570 os << ':' << resultCount; 2571 }; 2572 2573 // Check to see if this operation has multiple result groups. 2574 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); 2575 if (!resultGroups.empty()) { 2576 // Interleave the groups excluding the last one, this one will be handled 2577 // separately. 2578 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { 2579 printResultGroup(resultGroups[i], 2580 resultGroups[i + 1] - resultGroups[i]); 2581 }); 2582 os << ", "; 2583 printResultGroup(resultGroups.back(), numResults - resultGroups.back()); 2584 2585 } else { 2586 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); 2587 } 2588 2589 os << " = "; 2590 } 2591 2592 // If requested, always print the generic form. 2593 if (!printerFlags.shouldPrintGenericOpForm()) { 2594 // Check to see if this is a known operation. If so, use the registered 2595 // custom printer hook. 2596 if (auto opInfo = op->getRegisteredInfo()) { 2597 opInfo->printAssembly(op, *this, defaultDialectStack.back()); 2598 return; 2599 } 2600 // Otherwise try to dispatch to the dialect, if available. 2601 if (Dialect *dialect = op->getDialect()) { 2602 if (auto opPrinter = dialect->getOperationPrinter(op)) { 2603 // Print the op name first. 2604 StringRef name = op->getName().getStringRef(); 2605 name.consume_front((defaultDialectStack.back() + ".").str()); 2606 printEscapedString(name, os); 2607 // Print the rest of the op now. 2608 opPrinter(op, *this); 2609 return; 2610 } 2611 } 2612 } 2613 2614 // Otherwise print with the generic assembly form. 2615 printGenericOp(op, /*printOpName=*/true); 2616 } 2617 2618 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { 2619 if (printOpName) { 2620 os << '"'; 2621 printEscapedString(op->getName().getStringRef(), os); 2622 os << '"'; 2623 } 2624 os << '('; 2625 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); 2626 os << ')'; 2627 2628 // For terminators, print the list of successors and their operands. 2629 if (op->getNumSuccessors() != 0) { 2630 os << '['; 2631 interleaveComma(op->getSuccessors(), 2632 [&](Block *successor) { printBlockName(successor); }); 2633 os << ']'; 2634 } 2635 2636 // Print regions. 2637 if (op->getNumRegions() != 0) { 2638 os << " ("; 2639 interleaveComma(op->getRegions(), [&](Region ®ion) { 2640 printRegion(region, /*printEntryBlockArgs=*/true, 2641 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); 2642 }); 2643 os << ')'; 2644 } 2645 2646 auto attrs = op->getAttrs(); 2647 printOptionalAttrDict(attrs); 2648 2649 // Print the type signature of the operation. 2650 os << " : "; 2651 printFunctionalType(op); 2652 } 2653 2654 void OperationPrinter::printBlockName(Block *block) { 2655 os << state->getSSANameState().getBlockInfo(block).name; 2656 } 2657 2658 void OperationPrinter::print(Block *block, bool printBlockArgs, 2659 bool printBlockTerminator) { 2660 // Print the block label and argument list if requested. 2661 if (printBlockArgs) { 2662 os.indent(currentIndent); 2663 printBlockName(block); 2664 2665 // Print the argument list if non-empty. 2666 if (!block->args_empty()) { 2667 os << '('; 2668 interleaveComma(block->getArguments(), [&](BlockArgument arg) { 2669 printValueID(arg); 2670 os << ": "; 2671 printType(arg.getType()); 2672 // TODO: We should allow location aliases on block arguments. 2673 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); 2674 }); 2675 os << ')'; 2676 } 2677 os << ':'; 2678 2679 // Print out some context information about the predecessors of this block. 2680 if (!block->getParent()) { 2681 os << " // block is not in a region!"; 2682 } else if (block->hasNoPredecessors()) { 2683 if (!block->isEntryBlock()) 2684 os << " // no predecessors"; 2685 } else if (auto *pred = block->getSinglePredecessor()) { 2686 os << " // pred: "; 2687 printBlockName(pred); 2688 } else { 2689 // We want to print the predecessors in a stable order, not in 2690 // whatever order the use-list is in, so gather and sort them. 2691 SmallVector<BlockInfo, 4> predIDs; 2692 for (auto *pred : block->getPredecessors()) 2693 predIDs.push_back(state->getSSANameState().getBlockInfo(pred)); 2694 llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { 2695 return lhs.ordering < rhs.ordering; 2696 }); 2697 2698 os << " // " << predIDs.size() << " preds: "; 2699 2700 interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; }); 2701 } 2702 os << newLine; 2703 } 2704 2705 currentIndent += indentWidth; 2706 bool hasTerminator = 2707 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); 2708 auto range = llvm::make_range( 2709 block->begin(), 2710 std::prev(block->end(), 2711 (!hasTerminator || printBlockTerminator) ? 0 : 1)); 2712 for (auto &op : range) { 2713 print(&op); 2714 os << newLine; 2715 } 2716 currentIndent -= indentWidth; 2717 } 2718 2719 void OperationPrinter::printValueID(Value value, bool printResultNo, 2720 raw_ostream *streamOverride) const { 2721 state->getSSANameState().printValueID(value, printResultNo, 2722 streamOverride ? *streamOverride : os); 2723 } 2724 2725 void OperationPrinter::printSuccessor(Block *successor) { 2726 printBlockName(successor); 2727 } 2728 2729 void OperationPrinter::printSuccessorAndUseList(Block *successor, 2730 ValueRange succOperands) { 2731 printBlockName(successor); 2732 if (succOperands.empty()) 2733 return; 2734 2735 os << '('; 2736 interleaveComma(succOperands, 2737 [this](Value operand) { printValueID(operand); }); 2738 os << " : "; 2739 interleaveComma(succOperands, 2740 [this](Value operand) { printType(operand.getType()); }); 2741 os << ')'; 2742 } 2743 2744 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, 2745 bool printBlockTerminators, 2746 bool printEmptyBlock) { 2747 os << "{" << newLine; 2748 if (!region.empty()) { 2749 auto restoreDefaultDialect = 2750 llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); }); 2751 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) 2752 defaultDialectStack.push_back(iface.getDefaultDialect()); 2753 else 2754 defaultDialectStack.push_back(""); 2755 2756 auto *entryBlock = ®ion.front(); 2757 // Force printing the block header if printEmptyBlock is set and the block 2758 // is empty or if printEntryBlockArgs is set and there are arguments to 2759 // print. 2760 bool shouldAlwaysPrintBlockHeader = 2761 (printEmptyBlock && entryBlock->empty()) || 2762 (printEntryBlockArgs && entryBlock->getNumArguments() != 0); 2763 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); 2764 for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) 2765 print(&b); 2766 } 2767 os.indent(currentIndent) << "}"; 2768 } 2769 2770 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2771 ValueRange operands) { 2772 AffineMap map = mapAttr.getValue(); 2773 unsigned numDims = map.getNumDims(); 2774 auto printValueName = [&](unsigned pos, bool isSymbol) { 2775 unsigned index = isSymbol ? numDims + pos : pos; 2776 assert(index < operands.size()); 2777 if (isSymbol) 2778 os << "symbol("; 2779 printValueID(operands[index]); 2780 if (isSymbol) 2781 os << ')'; 2782 }; 2783 2784 interleaveComma(map.getResults(), [&](AffineExpr expr) { 2785 printAffineExpr(expr, printValueName); 2786 }); 2787 } 2788 2789 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, 2790 ValueRange dimOperands, 2791 ValueRange symOperands) { 2792 auto printValueName = [&](unsigned pos, bool isSymbol) { 2793 if (!isSymbol) 2794 return printValueID(dimOperands[pos]); 2795 os << "symbol("; 2796 printValueID(symOperands[pos]); 2797 os << ')'; 2798 }; 2799 printAffineExpr(expr, printValueName); 2800 } 2801 2802 //===----------------------------------------------------------------------===// 2803 // print and dump methods 2804 //===----------------------------------------------------------------------===// 2805 2806 void Attribute::print(raw_ostream &os) const { 2807 AsmPrinter::Impl(os).printAttribute(*this); 2808 } 2809 2810 void Attribute::dump() const { 2811 print(llvm::errs()); 2812 llvm::errs() << "\n"; 2813 } 2814 2815 void Type::print(raw_ostream &os) const { 2816 AsmPrinter::Impl(os).printType(*this); 2817 } 2818 2819 void Type::dump() const { print(llvm::errs()); } 2820 2821 void AffineMap::dump() const { 2822 print(llvm::errs()); 2823 llvm::errs() << "\n"; 2824 } 2825 2826 void IntegerSet::dump() const { 2827 print(llvm::errs()); 2828 llvm::errs() << "\n"; 2829 } 2830 2831 void AffineExpr::print(raw_ostream &os) const { 2832 if (!expr) { 2833 os << "<<NULL AFFINE EXPR>>"; 2834 return; 2835 } 2836 AsmPrinter::Impl(os).printAffineExpr(*this); 2837 } 2838 2839 void AffineExpr::dump() const { 2840 print(llvm::errs()); 2841 llvm::errs() << "\n"; 2842 } 2843 2844 void AffineMap::print(raw_ostream &os) const { 2845 if (!map) { 2846 os << "<<NULL AFFINE MAP>>"; 2847 return; 2848 } 2849 AsmPrinter::Impl(os).printAffineMap(*this); 2850 } 2851 2852 void IntegerSet::print(raw_ostream &os) const { 2853 AsmPrinter::Impl(os).printIntegerSet(*this); 2854 } 2855 2856 void Value::print(raw_ostream &os) { 2857 if (!impl) { 2858 os << "<<NULL VALUE>>"; 2859 return; 2860 } 2861 2862 if (auto *op = getDefiningOp()) 2863 return op->print(os); 2864 // TODO: Improve BlockArgument print'ing. 2865 BlockArgument arg = this->cast<BlockArgument>(); 2866 os << "<block argument> of type '" << arg.getType() 2867 << "' at index: " << arg.getArgNumber(); 2868 } 2869 void Value::print(raw_ostream &os, AsmState &state) { 2870 if (!impl) { 2871 os << "<<NULL VALUE>>"; 2872 return; 2873 } 2874 2875 if (auto *op = getDefiningOp()) 2876 return op->print(os, state); 2877 2878 // TODO: Improve BlockArgument print'ing. 2879 BlockArgument arg = this->cast<BlockArgument>(); 2880 os << "<block argument> of type '" << arg.getType() 2881 << "' at index: " << arg.getArgNumber(); 2882 } 2883 2884 void Value::dump() { 2885 print(llvm::errs()); 2886 llvm::errs() << "\n"; 2887 } 2888 2889 void Value::printAsOperand(raw_ostream &os, AsmState &state) { 2890 // TODO: This doesn't necessarily capture all potential cases. 2891 // Currently, region arguments can be shadowed when printing the main 2892 // operation. If the IR hasn't been printed, this will produce the old SSA 2893 // name and not the shadowed name. 2894 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, 2895 os); 2896 } 2897 2898 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { 2899 // If this is a top level operation, we also print aliases. 2900 if (!getParent() && !printerFlags.shouldUseLocalScope()) { 2901 AsmState state(this, printerFlags); 2902 state.getImpl().initializeAliases(this); 2903 print(os, state); 2904 return; 2905 } 2906 2907 // Find the operation to number from based upon the provided flags. 2908 Operation *op = this; 2909 bool shouldUseLocalScope = printerFlags.shouldUseLocalScope(); 2910 do { 2911 // If we are printing local scope, stop at the first operation that is 2912 // isolated from above. 2913 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 2914 break; 2915 2916 // Otherwise, traverse up to the next parent. 2917 Operation *parentOp = op->getParentOp(); 2918 if (!parentOp) 2919 break; 2920 op = parentOp; 2921 } while (true); 2922 2923 AsmState state(op, printerFlags); 2924 print(os, state); 2925 } 2926 void Operation::print(raw_ostream &os, AsmState &state) { 2927 OperationPrinter printer(os, state.getImpl()); 2928 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) 2929 printer.printTopLevelOperation(this); 2930 else 2931 printer.print(this); 2932 } 2933 2934 void Operation::dump() { 2935 print(llvm::errs(), OpPrintingFlags().useLocalScope()); 2936 llvm::errs() << "\n"; 2937 } 2938 2939 void Block::print(raw_ostream &os) { 2940 Operation *parentOp = getParentOp(); 2941 if (!parentOp) { 2942 os << "<<UNLINKED BLOCK>>\n"; 2943 return; 2944 } 2945 // Get the top-level op. 2946 while (auto *nextOp = parentOp->getParentOp()) 2947 parentOp = nextOp; 2948 2949 AsmState state(parentOp); 2950 print(os, state); 2951 } 2952 void Block::print(raw_ostream &os, AsmState &state) { 2953 OperationPrinter(os, state.getImpl()).print(this); 2954 } 2955 2956 void Block::dump() { print(llvm::errs()); } 2957 2958 /// Print out the name of the block without printing its body. 2959 void Block::printAsOperand(raw_ostream &os, bool printType) { 2960 Operation *parentOp = getParentOp(); 2961 if (!parentOp) { 2962 os << "<<UNLINKED BLOCK>>\n"; 2963 return; 2964 } 2965 AsmState state(parentOp); 2966 printAsOperand(os, state); 2967 } 2968 void Block::printAsOperand(raw_ostream &os, AsmState &state) { 2969 OperationPrinter printer(os, state.getImpl()); 2970 printer.printBlockName(this); 2971 } 2972