1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the MLIR AsmPrinter class, which is used to implement 10 // the various print() methods on the core IR objects. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/AffineExpr.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/AsmState.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/Operation.h" 25 #include "llvm/ADT/APFloat.h" 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/ADT/MapVector.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/ScopedHashTable.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/ADT/SmallString.h" 32 #include "llvm/ADT/StringExtras.h" 33 #include "llvm/ADT/StringSet.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/Regex.h" 37 #include "llvm/Support/SaveAndRestore.h" 38 using namespace mlir; 39 using namespace mlir::detail; 40 41 void Identifier::print(raw_ostream &os) const { os << str(); } 42 43 void Identifier::dump() const { print(llvm::errs()); } 44 45 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 46 47 void OperationName::dump() const { print(llvm::errs()); } 48 49 DialectAsmPrinter::~DialectAsmPrinter() {} 50 51 //===--------------------------------------------------------------------===// 52 // OpAsmPrinter 53 //===--------------------------------------------------------------------===// 54 55 OpAsmPrinter::~OpAsmPrinter() {} 56 57 void OpAsmPrinter::printFunctionalType(Operation *op) { 58 auto &os = getStream(); 59 os << '('; 60 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { 61 // Print the types of null values as <<NULL TYPE>>. 62 *this << (operand ? operand.getType() : Type()); 63 }); 64 os << ") -> "; 65 66 // Print the result list. We don't parenthesize single result types unless 67 // it is a function (avoiding a grammar ambiguity). 68 bool wrapped = op->getNumResults() != 1; 69 if (!wrapped && op->getResult(0).getType() && 70 op->getResult(0).getType().isa<FunctionType>()) 71 wrapped = true; 72 73 if (wrapped) 74 os << '('; 75 76 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { 77 // Print the types of null values as <<NULL TYPE>>. 78 *this << (result ? result.getType() : Type()); 79 }); 80 81 if (wrapped) 82 os << ')'; 83 } 84 85 //===--------------------------------------------------------------------===// 86 // Operation OpAsm interface. 87 //===--------------------------------------------------------------------===// 88 89 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 90 #include "mlir/IR/OpAsmInterface.cpp.inc" 91 92 //===----------------------------------------------------------------------===// 93 // OpPrintingFlags 94 //===----------------------------------------------------------------------===// 95 96 namespace { 97 /// This struct contains command line options that can be used to initialize 98 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need 99 /// for global command line options. 100 struct AsmPrinterOptions { 101 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ 102 "mlir-print-elementsattrs-with-hex-if-larger", 103 llvm::cl::desc( 104 "Print DenseElementsAttrs with a hex string that have " 105 "more elements than the given upper limit (use -1 to disable)")}; 106 107 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ 108 "mlir-elide-elementsattrs-if-larger", 109 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " 110 "more elements than the given upper limit")}; 111 112 llvm::cl::opt<bool> printDebugInfoOpt{ 113 "mlir-print-debuginfo", llvm::cl::init(false), 114 llvm::cl::desc("Print debug info in MLIR output")}; 115 116 llvm::cl::opt<bool> printPrettyDebugInfoOpt{ 117 "mlir-pretty-debuginfo", llvm::cl::init(false), 118 llvm::cl::desc("Print pretty debug info in MLIR output")}; 119 120 // Use the generic op output form in the operation printer even if the custom 121 // form is defined. 122 llvm::cl::opt<bool> printGenericOpFormOpt{ 123 "mlir-print-op-generic", llvm::cl::init(false), 124 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; 125 126 llvm::cl::opt<bool> printLocalScopeOpt{ 127 "mlir-print-local-scope", llvm::cl::init(false), 128 llvm::cl::desc("Print assuming in local scope by default"), 129 llvm::cl::Hidden}; 130 }; 131 } // end anonymous namespace 132 133 static llvm::ManagedStatic<AsmPrinterOptions> clOptions; 134 135 /// Register a set of useful command-line options that can be used to configure 136 /// various flags within the AsmPrinter. 137 void mlir::registerAsmPrinterCLOptions() { 138 // Make sure that the options struct has been initialized. 139 *clOptions; 140 } 141 142 /// Initialize the printing flags with default supplied by the cl::opts above. 143 OpPrintingFlags::OpPrintingFlags() 144 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), 145 printGenericOpFormFlag(false), printLocalScope(false) { 146 // Initialize based upon command line options, if they are available. 147 if (!clOptions.isConstructed()) 148 return; 149 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) 150 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; 151 printDebugInfoFlag = clOptions->printDebugInfoOpt; 152 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; 153 printGenericOpFormFlag = clOptions->printGenericOpFormOpt; 154 printLocalScope = clOptions->printLocalScopeOpt; 155 } 156 157 /// Enable the elision of large elements attributes, by printing a '...' 158 /// instead of the element data, when the number of elements is greater than 159 /// `largeElementLimit`. Note: The IR generated with this option is not 160 /// parsable. 161 OpPrintingFlags & 162 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { 163 elementsAttrElementLimit = largeElementLimit; 164 return *this; 165 } 166 167 /// Enable printing of debug information. If 'prettyForm' is set to true, 168 /// debug information is printed in a more readable 'pretty' form. 169 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { 170 printDebugInfoFlag = true; 171 printDebugInfoPrettyFormFlag = prettyForm; 172 return *this; 173 } 174 175 /// Always print operations in the generic form. 176 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { 177 printGenericOpFormFlag = true; 178 return *this; 179 } 180 181 /// Use local scope when printing the operation. This allows for using the 182 /// printer in a more localized and thread-safe setting, but may not necessarily 183 /// be identical of what the IR will look like when dumping the full module. 184 OpPrintingFlags &OpPrintingFlags::useLocalScope() { 185 printLocalScope = true; 186 return *this; 187 } 188 189 /// Return if the given ElementsAttr should be elided. 190 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { 191 return elementsAttrElementLimit.hasValue() && 192 *elementsAttrElementLimit < int64_t(attr.getNumElements()) && 193 !attr.isa<SplatElementsAttr>(); 194 } 195 196 /// Return the size limit for printing large ElementsAttr. 197 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { 198 return elementsAttrElementLimit; 199 } 200 201 /// Return if debug information should be printed. 202 bool OpPrintingFlags::shouldPrintDebugInfo() const { 203 return printDebugInfoFlag; 204 } 205 206 /// Return if debug information should be printed in the pretty form. 207 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { 208 return printDebugInfoPrettyFormFlag; 209 } 210 211 /// Return if operations should be printed in the generic form. 212 bool OpPrintingFlags::shouldPrintGenericOpForm() const { 213 return printGenericOpFormFlag; 214 } 215 216 /// Return if the printer should use local scope when dumping the IR. 217 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } 218 219 /// Returns true if an ElementsAttr with the given number of elements should be 220 /// printed with hex. 221 static bool shouldPrintElementsAttrWithHex(int64_t numElements) { 222 // Check to see if a command line option was provided for the limit. 223 if (clOptions.isConstructed()) { 224 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) { 225 // -1 is used to disable hex printing. 226 if (clOptions->printElementsAttrWithHexIfLarger == -1) 227 return false; 228 return numElements > clOptions->printElementsAttrWithHexIfLarger; 229 } 230 } 231 232 // Otherwise, default to printing with hex if the number of elements is >100. 233 return numElements > 100; 234 } 235 236 //===----------------------------------------------------------------------===// 237 // NewLineCounter 238 //===----------------------------------------------------------------------===// 239 240 namespace { 241 /// This class is a simple formatter that emits a new line when inputted into a 242 /// stream, that enables counting the number of newlines emitted. This class 243 /// should be used whenever emitting newlines in the printer. 244 struct NewLineCounter { 245 unsigned curLine = 1; 246 }; 247 } // end anonymous namespace 248 249 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { 250 ++newLine.curLine; 251 return os << '\n'; 252 } 253 254 //===----------------------------------------------------------------------===// 255 // AliasInitializer 256 //===----------------------------------------------------------------------===// 257 258 namespace { 259 /// This class represents a specific instance of a symbol Alias. 260 class SymbolAlias { 261 public: 262 SymbolAlias(StringRef name, bool isDeferrable) 263 : name(name), suffixIndex(0), hasSuffixIndex(false), 264 isDeferrable(isDeferrable) {} 265 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) 266 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true), 267 isDeferrable(isDeferrable) {} 268 269 /// Print this alias to the given stream. 270 void print(raw_ostream &os) const { 271 os << name; 272 if (hasSuffixIndex) 273 os << suffixIndex; 274 } 275 276 /// Returns true if this alias supports deferred resolution when parsing. 277 bool canBeDeferred() const { return isDeferrable; } 278 279 private: 280 /// The main name of the alias. 281 StringRef name; 282 /// The optional suffix index of the alias, if multiple aliases had the same 283 /// name. 284 uint32_t suffixIndex : 30; 285 /// A flag indicating whether this alias has a suffix or not. 286 bool hasSuffixIndex : 1; 287 /// A flag indicating whether this alias may be deferred or not. 288 bool isDeferrable : 1; 289 }; 290 291 /// This class represents a utility that initializes the set of attribute and 292 /// type aliases, without the need to store the extra information within the 293 /// main AliasState class or pass it around via function arguments. 294 class AliasInitializer { 295 public: 296 AliasInitializer( 297 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, 298 llvm::BumpPtrAllocator &aliasAllocator) 299 : interfaces(interfaces), aliasAllocator(aliasAllocator), 300 aliasOS(aliasBuffer) {} 301 302 void initialize(Operation *op, const OpPrintingFlags &printerFlags, 303 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 304 llvm::MapVector<Type, SymbolAlias> &typeToAlias); 305 306 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is 307 /// set to true if the originator of this attribute can resolve the alias 308 /// after parsing has completed (e.g. in the case of operation locations). 309 void visit(Attribute attr, bool canBeDeferred = false); 310 311 /// Visit the given type to see if it has an alias. 312 void visit(Type type); 313 314 private: 315 /// Try to generate an alias for the provided symbol. If an alias is 316 /// generated, the provided alias mapping and reverse mapping are updated. 317 /// Returns success if an alias was generated, failure otherwise. 318 template <typename T> 319 LogicalResult 320 generateAlias(T symbol, 321 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol); 322 323 /// The set of asm interfaces within the context. 324 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; 325 326 /// Mapping between an alias and the set of symbols mapped to it. 327 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr; 328 llvm::MapVector<StringRef, std::vector<Type>> aliasToType; 329 330 /// An allocator used for alias names. 331 llvm::BumpPtrAllocator &aliasAllocator; 332 333 /// The set of visited attributes. 334 DenseSet<Attribute> visitedAttributes; 335 336 /// The set of attributes that have aliases *and* can be deferred. 337 DenseSet<Attribute> deferrableAttributes; 338 339 /// The set of visited types. 340 DenseSet<Type> visitedTypes; 341 342 /// Storage and stream used when generating an alias. 343 SmallString<32> aliasBuffer; 344 llvm::raw_svector_ostream aliasOS; 345 }; 346 347 /// This class implements a dummy OpAsmPrinter that doesn't print any output, 348 /// and merely collects the attributes and types that *would* be printed in a 349 /// normal print invocation so that we can generate proper aliases. This allows 350 /// for us to generate aliases only for the attributes and types that would be 351 /// in the output, and trims down unnecessary output. 352 class DummyAliasOperationPrinter : private OpAsmPrinter { 353 public: 354 explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags, 355 AliasInitializer &initializer) 356 : printerFlags(flags), initializer(initializer) {} 357 358 /// Print the given operation. 359 void print(Operation *op) { 360 // Visit the operation location. 361 if (printerFlags.shouldPrintDebugInfo()) 362 initializer.visit(op->getLoc(), /*canBeDeferred=*/true); 363 364 // If requested, always print the generic form. 365 if (!printerFlags.shouldPrintGenericOpForm()) { 366 // Check to see if this is a known operation. If so, use the registered 367 // custom printer hook. 368 if (auto *opInfo = op->getAbstractOperation()) { 369 opInfo->printAssembly(op, *this); 370 return; 371 } 372 } 373 374 // Otherwise print with the generic assembly form. 375 printGenericOp(op); 376 } 377 378 private: 379 /// Print the given operation in the generic form. 380 void printGenericOp(Operation *op) override { 381 // Consider nested operations for aliases. 382 if (op->getNumRegions() != 0) { 383 for (Region ®ion : op->getRegions()) 384 printRegion(region, /*printEntryBlockArgs=*/true, 385 /*printBlockTerminators=*/true); 386 } 387 388 // Visit all the types used in the operation. 389 for (Type type : op->getOperandTypes()) 390 printType(type); 391 for (Type type : op->getResultTypes()) 392 printType(type); 393 394 // Consider the attributes of the operation for aliases. 395 for (const NamedAttribute &attr : op->getAttrs()) 396 printAttribute(attr.second); 397 } 398 399 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 400 /// block are not printed. If 'printBlockTerminator' is false, the terminator 401 /// operation of the block is not printed. 402 void print(Block *block, bool printBlockArgs = true, 403 bool printBlockTerminator = true) { 404 // Consider the types of the block arguments for aliases if 'printBlockArgs' 405 // is set to true. 406 if (printBlockArgs) { 407 for (Type type : block->getArgumentTypes()) 408 printType(type); 409 } 410 411 // Consider the operations within this block, ignoring the terminator if 412 // requested. 413 auto range = llvm::make_range( 414 block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1)); 415 for (Operation &op : range) 416 print(&op); 417 } 418 419 /// Print the given region. 420 void printRegion(Region ®ion, bool printEntryBlockArgs, 421 bool printBlockTerminators) override { 422 if (region.empty()) 423 return; 424 425 auto *entryBlock = ®ion.front(); 426 print(entryBlock, printEntryBlockArgs, printBlockTerminators); 427 for (Block &b : llvm::drop_begin(region, 1)) 428 print(&b); 429 } 430 431 /// Consider the given type to be printed for an alias. 432 void printType(Type type) override { initializer.visit(type); } 433 434 /// Consider the given attribute to be printed for an alias. 435 void printAttribute(Attribute attr) override { initializer.visit(attr); } 436 void printAttributeWithoutType(Attribute attr) override { 437 printAttribute(attr); 438 } 439 440 /// Print the given set of attributes with names not included within 441 /// 'elidedAttrs'. 442 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 443 ArrayRef<StringRef> elidedAttrs = {}) override { 444 if (attrs.empty()) 445 return; 446 if (elidedAttrs.empty()) { 447 for (const NamedAttribute &attr : attrs) 448 printAttribute(attr.second); 449 return; 450 } 451 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 452 elidedAttrs.end()); 453 for (const NamedAttribute &attr : attrs) 454 if (!elidedAttrsSet.contains(attr.first.strref())) 455 printAttribute(attr.second); 456 } 457 void printOptionalAttrDictWithKeyword( 458 ArrayRef<NamedAttribute> attrs, 459 ArrayRef<StringRef> elidedAttrs = {}) override { 460 printOptionalAttrDict(attrs, elidedAttrs); 461 } 462 463 /// Return 'nulls' as the output stream, this will ignore any data fed to it. 464 raw_ostream &getStream() const override { return llvm::nulls(); } 465 466 /// The following are hooks of `OpAsmPrinter` that are not necessary for 467 /// determining potential aliases. 468 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} 469 void printNewline() override {} 470 void printOperand(Value) override {} 471 void printOperand(Value, raw_ostream &os) override { 472 // Users expect the output string to have at least the prefixed % to signal 473 // a value name. To maintain this invariant, emit a name even if it is 474 // guaranteed to go unused. 475 os << "%"; 476 } 477 void printSymbolName(StringRef) override {} 478 void printSuccessor(Block *) override {} 479 void printSuccessorAndUseList(Block *, ValueRange) override {} 480 void shadowRegionArgs(Region &, ValueRange) override {} 481 482 /// The printer flags to use when determining potential aliases. 483 const OpPrintingFlags &printerFlags; 484 485 /// The initializer to use when identifying aliases. 486 AliasInitializer &initializer; 487 }; 488 } // end anonymous namespace 489 490 /// Sanitize the given name such that it can be used as a valid identifier. If 491 /// the string needs to be modified in any way, the provided buffer is used to 492 /// store the new copy, 493 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, 494 StringRef allowedPunctChars = "$._-", 495 bool allowTrailingDigit = true) { 496 assert(!name.empty() && "Shouldn't have an empty name here"); 497 498 auto copyNameToBuffer = [&] { 499 for (char ch : name) { 500 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) 501 buffer.push_back(ch); 502 else if (ch == ' ') 503 buffer.push_back('_'); 504 else 505 buffer.append(llvm::utohexstr((unsigned char)ch)); 506 } 507 }; 508 509 // Check to see if this name is valid. If it starts with a digit, then it 510 // could conflict with the autogenerated numeric ID's, so add an underscore 511 // prefix to avoid problems. 512 if (isdigit(name[0])) { 513 buffer.push_back('_'); 514 copyNameToBuffer(); 515 return buffer; 516 } 517 518 // If the name ends with a trailing digit, add a '_' to avoid potential 519 // conflicts with autogenerated ID's. 520 if (!allowTrailingDigit && isdigit(name.back())) { 521 copyNameToBuffer(); 522 buffer.push_back('_'); 523 return buffer; 524 } 525 526 // Check to see that the name consists of only valid identifier characters. 527 for (char ch : name) { 528 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { 529 copyNameToBuffer(); 530 return buffer; 531 } 532 } 533 534 // If there are no invalid characters, return the original name. 535 return name; 536 } 537 538 /// Given a collection of aliases and symbols, initialize a mapping from a 539 /// symbol to a given alias. 540 template <typename T> 541 static void 542 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol, 543 llvm::MapVector<T, SymbolAlias> &symbolToAlias, 544 DenseSet<T> *deferrableAliases = nullptr) { 545 std::vector<std::pair<StringRef, std::vector<T>>> aliases = 546 aliasToSymbol.takeVector(); 547 llvm::array_pod_sort(aliases.begin(), aliases.end(), 548 [](const auto *lhs, const auto *rhs) { 549 return lhs->first.compare(rhs->first); 550 }); 551 552 for (auto &it : aliases) { 553 // If there is only one instance for this alias, use the name directly. 554 if (it.second.size() == 1) { 555 T symbol = it.second.front(); 556 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 557 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)}); 558 continue; 559 } 560 // Otherwise, add the index to the name. 561 for (int i = 0, e = it.second.size(); i < e; ++i) { 562 T symbol = it.second[i]; 563 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); 564 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)}); 565 } 566 } 567 } 568 569 void AliasInitializer::initialize( 570 Operation *op, const OpPrintingFlags &printerFlags, 571 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias, 572 llvm::MapVector<Type, SymbolAlias> &typeToAlias) { 573 // Use a dummy printer when walking the IR so that we can collect the 574 // attributes/types that will actually be used during printing when 575 // considering aliases. 576 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); 577 aliasPrinter.print(op); 578 579 // Initialize the aliases sorted by name. 580 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); 581 initializeAliases(aliasToType, typeToAlias); 582 } 583 584 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { 585 if (!visitedAttributes.insert(attr).second) { 586 // If this attribute already has an alias and this instance can't be 587 // deferred, make sure that the alias isn't deferred. 588 if (!canBeDeferred) 589 deferrableAttributes.erase(attr); 590 return; 591 } 592 593 // Try to generate an alias for this attribute. 594 if (succeeded(generateAlias(attr, aliasToAttr))) { 595 if (canBeDeferred) 596 deferrableAttributes.insert(attr); 597 return; 598 } 599 600 if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 601 for (Attribute element : arrayAttr.getValue()) 602 visit(element); 603 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) { 604 for (const NamedAttribute &attr : dictAttr) 605 visit(attr.second); 606 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 607 visit(typeAttr.getValue()); 608 } 609 } 610 611 void AliasInitializer::visit(Type type) { 612 if (!visitedTypes.insert(type).second) 613 return; 614 615 // Try to generate an alias for this type. 616 if (succeeded(generateAlias(type, aliasToType))) 617 return; 618 619 // Visit several subtypes that contain types or attributes. 620 if (auto funcType = type.dyn_cast<FunctionType>()) { 621 // Visit input and result types for functions. 622 for (auto input : funcType.getInputs()) 623 visit(input); 624 for (auto result : funcType.getResults()) 625 visit(result); 626 } else if (auto shapedType = type.dyn_cast<ShapedType>()) { 627 visit(shapedType.getElementType()); 628 629 // Visit affine maps in memref type. 630 if (auto memref = type.dyn_cast<MemRefType>()) 631 for (auto map : memref.getAffineMaps()) 632 visit(AffineMapAttr::get(map)); 633 } 634 } 635 636 template <typename T> 637 LogicalResult AliasInitializer::generateAlias( 638 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) { 639 SmallString<16> tempBuffer; 640 for (const auto &interface : interfaces) { 641 if (failed(interface.getAlias(symbol, aliasOS))) 642 continue; 643 StringRef name = aliasOS.str(); 644 assert(!name.empty() && "expected valid alias name"); 645 name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-", 646 /*allowTrailingDigit=*/false); 647 name = name.copy(aliasAllocator); 648 649 aliasToSymbol[name].push_back(symbol); 650 aliasBuffer.clear(); 651 return success(); 652 } 653 return failure(); 654 } 655 656 //===----------------------------------------------------------------------===// 657 // AliasState 658 //===----------------------------------------------------------------------===// 659 660 namespace { 661 /// This class manages the state for type and attribute aliases. 662 class AliasState { 663 public: 664 // Initialize the internal aliases. 665 void 666 initialize(Operation *op, const OpPrintingFlags &printerFlags, 667 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 668 669 /// Get an alias for the given attribute if it has one and print it in `os`. 670 /// Returns success if an alias was printed, failure otherwise. 671 LogicalResult getAlias(Attribute attr, raw_ostream &os) const; 672 673 /// Get an alias for the given type if it has one and print it in `os`. 674 /// Returns success if an alias was printed, failure otherwise. 675 LogicalResult getAlias(Type ty, raw_ostream &os) const; 676 677 /// Print all of the referenced aliases that can not be resolved in a deferred 678 /// manner. 679 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 680 printAliases(os, newLine, /*isDeferred=*/false); 681 } 682 683 /// Print all of the referenced aliases that support deferred resolution. 684 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { 685 printAliases(os, newLine, /*isDeferred=*/true); 686 } 687 688 private: 689 /// Print all of the referenced aliases that support the provided resolution 690 /// behavior. 691 void printAliases(raw_ostream &os, NewLineCounter &newLine, 692 bool isDeferred) const; 693 694 /// Mapping between attribute and alias. 695 llvm::MapVector<Attribute, SymbolAlias> attrToAlias; 696 /// Mapping between type and alias. 697 llvm::MapVector<Type, SymbolAlias> typeToAlias; 698 699 /// An allocator used for alias names. 700 llvm::BumpPtrAllocator aliasAllocator; 701 }; 702 } // end anonymous namespace 703 704 void AliasState::initialize( 705 Operation *op, const OpPrintingFlags &printerFlags, 706 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 707 AliasInitializer initializer(interfaces, aliasAllocator); 708 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); 709 } 710 711 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { 712 auto it = attrToAlias.find(attr); 713 if (it == attrToAlias.end()) 714 return failure(); 715 it->second.print(os << '#'); 716 return success(); 717 } 718 719 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { 720 auto it = typeToAlias.find(ty); 721 if (it == typeToAlias.end()) 722 return failure(); 723 724 it->second.print(os << '!'); 725 return success(); 726 } 727 728 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine, 729 bool isDeferred) const { 730 auto filterFn = [=](const auto &aliasIt) { 731 return aliasIt.second.canBeDeferred() == isDeferred; 732 }; 733 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) { 734 it.second.print(os << '#'); 735 os << " = " << it.first << newLine; 736 } 737 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) { 738 it.second.print(os << '!'); 739 os << " = type " << it.first << newLine; 740 } 741 } 742 743 //===----------------------------------------------------------------------===// 744 // SSANameState 745 //===----------------------------------------------------------------------===// 746 747 namespace { 748 /// This class manages the state of SSA value names. 749 class SSANameState { 750 public: 751 /// A sentinel value used for values with names set. 752 enum : unsigned { NameSentinel = ~0U }; 753 754 SSANameState(Operation *op, 755 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 756 757 /// Print the SSA identifier for the given value to 'stream'. If 758 /// 'printResultNo' is true, it also presents the result number ('#' number) 759 /// of this value. 760 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; 761 762 /// Return the result indices for each of the result groups registered by this 763 /// operation, or empty if none exist. 764 ArrayRef<int> getOpResultGroups(Operation *op); 765 766 /// Get the ID for the given block. 767 unsigned getBlockID(Block *block); 768 769 /// Renumber the arguments for the specified region to the same names as the 770 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for 771 /// details. 772 void shadowRegionArgs(Region ®ion, ValueRange namesToUse); 773 774 private: 775 /// Number the SSA values within the given IR unit. 776 void numberValuesInRegion( 777 Region ®ion, 778 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 779 void numberValuesInBlock( 780 Block &block, 781 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 782 void numberValuesInOp( 783 Operation &op, 784 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); 785 786 /// Given a result of an operation 'result', find the result group head 787 /// 'lookupValue' and the result of 'result' within that group in 788 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group 789 /// has more than 1 result. 790 void getResultIDAndNumber(OpResult result, Value &lookupValue, 791 Optional<int> &lookupResultNo) const; 792 793 /// Set a special value name for the given value. 794 void setValueName(Value value, StringRef name); 795 796 /// Uniques the given value name within the printer. If the given name 797 /// conflicts, it is automatically renamed. 798 StringRef uniqueValueName(StringRef name); 799 800 /// This is the value ID for each SSA value. If this returns NameSentinel, 801 /// then the valueID has an entry in valueNames. 802 DenseMap<Value, unsigned> valueIDs; 803 DenseMap<Value, StringRef> valueNames; 804 805 /// This is a map of operations that contain multiple named result groups, 806 /// i.e. there may be multiple names for the results of the operation. The 807 /// value of this map are the result numbers that start a result group. 808 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; 809 810 /// This is the block ID for each block in the current. 811 DenseMap<Block *, unsigned> blockIDs; 812 813 /// This keeps track of all of the non-numeric names that are in flight, 814 /// allowing us to check for duplicates. 815 /// Note: the value of the map is unused. 816 llvm::ScopedHashTable<StringRef, char> usedNames; 817 llvm::BumpPtrAllocator usedNameAllocator; 818 819 /// This is the next value ID to assign in numbering. 820 unsigned nextValueID = 0; 821 /// This is the next ID to assign to a region entry block argument. 822 unsigned nextArgumentID = 0; 823 /// This is the next ID to assign when a name conflict is detected. 824 unsigned nextConflictID = 0; 825 }; 826 } // end anonymous namespace 827 828 SSANameState::SSANameState( 829 Operation *op, 830 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 831 llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); 832 numberValuesInOp(*op, interfaces); 833 834 for (auto ®ion : op->getRegions()) 835 numberValuesInRegion(region, interfaces); 836 } 837 838 void SSANameState::printValueID(Value value, bool printResultNo, 839 raw_ostream &stream) const { 840 if (!value) { 841 stream << "<<NULL>>"; 842 return; 843 } 844 845 Optional<int> resultNo; 846 auto lookupValue = value; 847 848 // If this is an operation result, collect the head lookup value of the result 849 // group and the result number of 'result' within that group. 850 if (OpResult result = value.dyn_cast<OpResult>()) 851 getResultIDAndNumber(result, lookupValue, resultNo); 852 853 auto it = valueIDs.find(lookupValue); 854 if (it == valueIDs.end()) { 855 stream << "<<UNKNOWN SSA VALUE>>"; 856 return; 857 } 858 859 stream << '%'; 860 if (it->second != NameSentinel) { 861 stream << it->second; 862 } else { 863 auto nameIt = valueNames.find(lookupValue); 864 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 865 stream << nameIt->second; 866 } 867 868 if (resultNo.hasValue() && printResultNo) 869 stream << '#' << resultNo; 870 } 871 872 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { 873 auto it = opResultGroups.find(op); 874 return it == opResultGroups.end() ? ArrayRef<int>() : it->second; 875 } 876 877 unsigned SSANameState::getBlockID(Block *block) { 878 auto it = blockIDs.find(block); 879 return it != blockIDs.end() ? it->second : NameSentinel; 880 } 881 882 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { 883 assert(!region.empty() && "cannot shadow arguments of an empty region"); 884 assert(region.getNumArguments() == namesToUse.size() && 885 "incorrect number of names passed in"); 886 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && 887 "only KnownIsolatedFromAbove ops can shadow names"); 888 889 SmallVector<char, 16> nameStr; 890 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 891 auto nameToUse = namesToUse[i]; 892 if (nameToUse == nullptr) 893 continue; 894 auto nameToReplace = region.getArgument(i); 895 896 nameStr.clear(); 897 llvm::raw_svector_ostream nameStream(nameStr); 898 printValueID(nameToUse, /*printResultNo=*/true, nameStream); 899 900 // Entry block arguments should already have a pretty "arg" name. 901 assert(valueIDs[nameToReplace] == NameSentinel); 902 903 // Use the name without the leading %. 904 auto name = StringRef(nameStream.str()).drop_front(); 905 906 // Overwrite the name. 907 valueNames[nameToReplace] = name.copy(usedNameAllocator); 908 } 909 } 910 911 void SSANameState::numberValuesInRegion( 912 Region ®ion, 913 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 914 // Save the current value ids to allow for numbering values in sibling regions 915 // the same. 916 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); 917 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); 918 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); 919 920 // Push a new used names scope. 921 llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); 922 923 // Number the values within this region in a breadth-first order. 924 unsigned nextBlockID = 0; 925 for (auto &block : region) { 926 // Each block gets a unique ID, and all of the operations within it get 927 // numbered as well. 928 blockIDs[&block] = nextBlockID++; 929 numberValuesInBlock(block, interfaces); 930 } 931 932 // After that we traverse the nested regions. 933 // TODO: Rework this loop to not use recursion. 934 for (auto &block : region) { 935 for (auto &op : block) 936 for (auto &nestedRegion : op.getRegions()) 937 numberValuesInRegion(nestedRegion, interfaces); 938 } 939 } 940 941 void SSANameState::numberValuesInBlock( 942 Block &block, 943 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 944 auto setArgNameFn = [&](Value arg, StringRef name) { 945 assert(!valueIDs.count(arg) && "arg numbered multiple times"); 946 assert(arg.cast<BlockArgument>().getOwner() == &block && 947 "arg not defined in 'block'"); 948 setValueName(arg, name); 949 }; 950 951 bool isEntryBlock = block.isEntryBlock(); 952 if (isEntryBlock) { 953 if (auto *op = block.getParentOp()) { 954 if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) 955 asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); 956 } 957 } 958 959 // Number the block arguments. We give entry block arguments a special name 960 // 'arg'. 961 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); 962 llvm::raw_svector_ostream specialName(specialNameBuffer); 963 for (auto arg : block.getArguments()) { 964 if (valueIDs.count(arg)) 965 continue; 966 if (isEntryBlock) { 967 specialNameBuffer.resize(strlen("arg")); 968 specialName << nextArgumentID++; 969 } 970 setValueName(arg, specialName.str()); 971 } 972 973 // Number the operations in this block. 974 for (auto &op : block) 975 numberValuesInOp(op, interfaces); 976 } 977 978 void SSANameState::numberValuesInOp( 979 Operation &op, 980 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { 981 unsigned numResults = op.getNumResults(); 982 if (numResults == 0) 983 return; 984 Value resultBegin = op.getResult(0); 985 986 // Function used to set the special result names for the operation. 987 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); 988 auto setResultNameFn = [&](Value result, StringRef name) { 989 assert(!valueIDs.count(result) && "result numbered multiple times"); 990 assert(result.getDefiningOp() == &op && "result not defined by 'op'"); 991 setValueName(result, name); 992 993 // Record the result number for groups not anchored at 0. 994 if (int resultNo = result.cast<OpResult>().getResultNumber()) 995 resultGroups.push_back(resultNo); 996 }; 997 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) 998 asmInterface.getAsmResultNames(setResultNameFn); 999 else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) 1000 asmInterface->getAsmResultNames(&op, setResultNameFn); 1001 1002 // If the first result wasn't numbered, give it a default number. 1003 if (valueIDs.try_emplace(resultBegin, nextValueID).second) 1004 ++nextValueID; 1005 1006 // If this operation has multiple result groups, mark it. 1007 if (resultGroups.size() != 1) { 1008 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); 1009 opResultGroups.try_emplace(&op, std::move(resultGroups)); 1010 } 1011 } 1012 1013 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, 1014 Optional<int> &lookupResultNo) const { 1015 Operation *owner = result.getOwner(); 1016 if (owner->getNumResults() == 1) 1017 return; 1018 int resultNo = result.getResultNumber(); 1019 1020 // If this operation has multiple result groups, we will need to find the 1021 // one corresponding to this result. 1022 auto resultGroupIt = opResultGroups.find(owner); 1023 if (resultGroupIt == opResultGroups.end()) { 1024 // If not, just use the first result. 1025 lookupResultNo = resultNo; 1026 lookupValue = owner->getResult(0); 1027 return; 1028 } 1029 1030 // Find the correct index using a binary search, as the groups are ordered. 1031 ArrayRef<int> resultGroups = resultGroupIt->second; 1032 auto it = llvm::upper_bound(resultGroups, resultNo); 1033 int groupResultNo = 0, groupSize = 0; 1034 1035 // If there are no smaller elements, the last result group is the lookup. 1036 if (it == resultGroups.end()) { 1037 groupResultNo = resultGroups.back(); 1038 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); 1039 } else { 1040 // Otherwise, the previous element is the lookup. 1041 groupResultNo = *std::prev(it); 1042 groupSize = *it - groupResultNo; 1043 } 1044 1045 // We only record the result number for a group of size greater than 1. 1046 if (groupSize != 1) 1047 lookupResultNo = resultNo - groupResultNo; 1048 lookupValue = owner->getResult(groupResultNo); 1049 } 1050 1051 void SSANameState::setValueName(Value value, StringRef name) { 1052 // If the name is empty, the value uses the default numbering. 1053 if (name.empty()) { 1054 valueIDs[value] = nextValueID++; 1055 return; 1056 } 1057 1058 valueIDs[value] = NameSentinel; 1059 valueNames[value] = uniqueValueName(name); 1060 } 1061 1062 StringRef SSANameState::uniqueValueName(StringRef name) { 1063 SmallString<16> tmpBuffer; 1064 name = sanitizeIdentifier(name, tmpBuffer); 1065 1066 // Check to see if this name is already unique. 1067 if (!usedNames.count(name)) { 1068 name = name.copy(usedNameAllocator); 1069 } else { 1070 // Otherwise, we had a conflict - probe until we find a unique name. This 1071 // is guaranteed to terminate (and usually in a single iteration) because it 1072 // generates new names by incrementing nextConflictID. 1073 SmallString<64> probeName(name); 1074 probeName.push_back('_'); 1075 while (true) { 1076 probeName += llvm::utostr(nextConflictID++); 1077 if (!usedNames.count(probeName)) { 1078 name = StringRef(probeName).copy(usedNameAllocator); 1079 break; 1080 } 1081 probeName.resize(name.size() + 1); 1082 } 1083 } 1084 1085 usedNames.insert(name, char()); 1086 return name; 1087 } 1088 1089 //===----------------------------------------------------------------------===// 1090 // AsmState 1091 //===----------------------------------------------------------------------===// 1092 1093 namespace mlir { 1094 namespace detail { 1095 class AsmStateImpl { 1096 public: 1097 explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap) 1098 : interfaces(op->getContext()), nameState(op, interfaces), 1099 locationMap(locationMap) {} 1100 1101 /// Initialize the alias state to enable the printing of aliases. 1102 void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) { 1103 aliasState.initialize(op, printerFlags, interfaces); 1104 } 1105 1106 /// Get an instance of the OpAsmDialectInterface for the given dialect, or 1107 /// null if one wasn't registered. 1108 const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { 1109 return interfaces.getInterfaceFor(dialect); 1110 } 1111 1112 /// Get the state used for aliases. 1113 AliasState &getAliasState() { return aliasState; } 1114 1115 /// Get the state used for SSA names. 1116 SSANameState &getSSANameState() { return nameState; } 1117 1118 /// Register the location, line and column, within the buffer that the given 1119 /// operation was printed at. 1120 void registerOperationLocation(Operation *op, unsigned line, unsigned col) { 1121 if (locationMap) 1122 (*locationMap)[op] = std::make_pair(line, col); 1123 } 1124 1125 private: 1126 /// Collection of OpAsm interfaces implemented in the context. 1127 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 1128 1129 /// The state used for attribute and type aliases. 1130 AliasState aliasState; 1131 1132 /// The state used for SSA value names. 1133 SSANameState nameState; 1134 1135 /// An optional location map to be populated. 1136 AsmState::LocationMap *locationMap; 1137 }; 1138 } // end namespace detail 1139 } // end namespace mlir 1140 1141 AsmState::AsmState(Operation *op, LocationMap *locationMap) 1142 : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {} 1143 AsmState::~AsmState() {} 1144 1145 //===----------------------------------------------------------------------===// 1146 // ModulePrinter 1147 //===----------------------------------------------------------------------===// 1148 1149 namespace { 1150 class ModulePrinter { 1151 public: 1152 ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, 1153 AsmStateImpl *state = nullptr) 1154 : os(os), printerFlags(flags), state(state) {} 1155 explicit ModulePrinter(ModulePrinter &printer) 1156 : os(printer.os), printerFlags(printer.printerFlags), 1157 state(printer.state) {} 1158 1159 /// Returns the output stream of the printer. 1160 raw_ostream &getStream() { return os; } 1161 1162 template <typename Container, typename UnaryFunctor> 1163 inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { 1164 llvm::interleaveComma(c, os, each_fn); 1165 } 1166 1167 /// This enum describes the different kinds of elision for the type of an 1168 /// attribute when printing it. 1169 enum class AttrTypeElision { 1170 /// The type must not be elided, 1171 Never, 1172 /// The type may be elided when it matches the default used in the parser 1173 /// (for example i64 is the default for integer attributes). 1174 May, 1175 /// The type must be elided. 1176 Must 1177 }; 1178 1179 /// Print the given attribute. 1180 void printAttribute(Attribute attr, 1181 AttrTypeElision typeElision = AttrTypeElision::Never); 1182 1183 void printType(Type type); 1184 1185 /// Print the given location to the stream. If `allowAlias` is true, this 1186 /// allows for the internal location to use an attribute alias. 1187 void printLocation(LocationAttr loc, bool allowAlias = false); 1188 1189 void printAffineMap(AffineMap map); 1190 void 1191 printAffineExpr(AffineExpr expr, 1192 function_ref<void(unsigned, bool)> printValueName = nullptr); 1193 void printAffineConstraint(AffineExpr expr, bool isEq); 1194 void printIntegerSet(IntegerSet set); 1195 1196 protected: 1197 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1198 ArrayRef<StringRef> elidedAttrs = {}, 1199 bool withKeyword = false); 1200 void printNamedAttribute(NamedAttribute attr); 1201 void printTrailingLocation(Location loc); 1202 void printLocationInternal(LocationAttr loc, bool pretty = false); 1203 1204 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1205 /// used instead of individual elements when the elements attr is large. 1206 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); 1207 1208 /// Print a dense string elements attribute. 1209 void printDenseStringElementsAttr(DenseStringElementsAttr attr); 1210 1211 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is 1212 /// used instead of individual elements when the elements attr is large. 1213 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1214 bool allowHex); 1215 1216 void printDialectAttribute(Attribute attr); 1217 void printDialectType(Type type); 1218 1219 /// This enum is used to represent the binding strength of the enclosing 1220 /// context that an AffineExprStorage is being printed in, so we can 1221 /// intelligently produce parens. 1222 enum class BindingStrength { 1223 Weak, // + and - 1224 Strong, // All other binary operators. 1225 }; 1226 void printAffineExprInternal( 1227 AffineExpr expr, BindingStrength enclosingTightness, 1228 function_ref<void(unsigned, bool)> printValueName = nullptr); 1229 1230 /// The output stream for the printer. 1231 raw_ostream &os; 1232 1233 /// A set of flags to control the printer's behavior. 1234 OpPrintingFlags printerFlags; 1235 1236 /// An optional printer state for the module. 1237 AsmStateImpl *state; 1238 1239 /// A tracker for the number of new lines emitted during printing. 1240 NewLineCounter newLine; 1241 }; 1242 } // end anonymous namespace 1243 1244 void ModulePrinter::printTrailingLocation(Location loc) { 1245 // Check to see if we are printing debug information. 1246 if (!printerFlags.shouldPrintDebugInfo()) 1247 return; 1248 1249 os << " "; 1250 printLocation(loc, /*allowAlias=*/true); 1251 } 1252 1253 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { 1254 TypeSwitch<LocationAttr>(loc) 1255 .Case<OpaqueLoc>([&](OpaqueLoc loc) { 1256 printLocationInternal(loc.getFallbackLocation(), pretty); 1257 }) 1258 .Case<UnknownLoc>([&](UnknownLoc loc) { 1259 if (pretty) 1260 os << "[unknown]"; 1261 else 1262 os << "unknown"; 1263 }) 1264 .Case<FileLineColLoc>([&](FileLineColLoc loc) { 1265 if (pretty) { 1266 os << loc.getFilename(); 1267 } else { 1268 os << "\""; 1269 printEscapedString(loc.getFilename(), os); 1270 os << "\""; 1271 } 1272 os << ':' << loc.getLine() << ':' << loc.getColumn(); 1273 }) 1274 .Case<NameLoc>([&](NameLoc loc) { 1275 os << '\"'; 1276 printEscapedString(loc.getName(), os); 1277 os << '\"'; 1278 1279 // Print the child if it isn't unknown. 1280 auto childLoc = loc.getChildLoc(); 1281 if (!childLoc.isa<UnknownLoc>()) { 1282 os << '('; 1283 printLocationInternal(childLoc, pretty); 1284 os << ')'; 1285 } 1286 }) 1287 .Case<CallSiteLoc>([&](CallSiteLoc loc) { 1288 Location caller = loc.getCaller(); 1289 Location callee = loc.getCallee(); 1290 if (!pretty) 1291 os << "callsite("; 1292 printLocationInternal(callee, pretty); 1293 if (pretty) { 1294 if (callee.isa<NameLoc>()) { 1295 if (caller.isa<FileLineColLoc>()) { 1296 os << " at "; 1297 } else { 1298 os << newLine << " at "; 1299 } 1300 } else { 1301 os << newLine << " at "; 1302 } 1303 } else { 1304 os << " at "; 1305 } 1306 printLocationInternal(caller, pretty); 1307 if (!pretty) 1308 os << ")"; 1309 }) 1310 .Case<FusedLoc>([&](FusedLoc loc) { 1311 if (!pretty) 1312 os << "fused"; 1313 if (Attribute metadata = loc.getMetadata()) 1314 os << '<' << metadata << '>'; 1315 os << '['; 1316 interleave( 1317 loc.getLocations(), 1318 [&](Location loc) { printLocationInternal(loc, pretty); }, 1319 [&]() { os << ", "; }); 1320 os << ']'; 1321 }); 1322 } 1323 1324 /// Print a floating point value in a way that the parser will be able to 1325 /// round-trip losslessly. 1326 static void printFloatValue(const APFloat &apValue, raw_ostream &os) { 1327 // We would like to output the FP constant value in exponential notation, 1328 // but we cannot do this if doing so will lose precision. Check here to 1329 // make sure that we only output it in exponential format if we can parse 1330 // the value back and get the same value. 1331 bool isInf = apValue.isInfinity(); 1332 bool isNaN = apValue.isNaN(); 1333 if (!isInf && !isNaN) { 1334 SmallString<128> strValue; 1335 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, 1336 /*TruncateZero=*/false); 1337 1338 // Check to make sure that the stringized number is not some string like 1339 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 1340 // that the string matches the "[-+]?[0-9]" regex. 1341 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 1342 ((strValue[0] == '-' || strValue[0] == '+') && 1343 (strValue[1] >= '0' && strValue[1] <= '9'))) && 1344 "[-+]?[0-9] regex does not match!"); 1345 1346 // Parse back the stringized version and check that the value is equal 1347 // (i.e., there is no precision loss). 1348 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 1349 os << strValue; 1350 return; 1351 } 1352 1353 // If it is not, use the default format of APFloat instead of the 1354 // exponential notation. 1355 strValue.clear(); 1356 apValue.toString(strValue); 1357 1358 // Make sure that we can parse the default form as a float. 1359 if (StringRef(strValue).contains('.')) { 1360 os << strValue; 1361 return; 1362 } 1363 } 1364 1365 // Print special values in hexadecimal format. The sign bit should be included 1366 // in the literal. 1367 SmallVector<char, 16> str; 1368 APInt apInt = apValue.bitcastToAPInt(); 1369 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 1370 /*formatAsCLiteral=*/true); 1371 os << str; 1372 } 1373 1374 void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) { 1375 if (printerFlags.shouldPrintDebugInfoPrettyForm()) 1376 return printLocationInternal(loc, /*pretty=*/true); 1377 1378 os << "loc("; 1379 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) 1380 printLocationInternal(loc); 1381 os << ')'; 1382 } 1383 1384 /// Returns true if the given dialect symbol data is simple enough to print in 1385 /// the pretty form, i.e. without the enclosing "". 1386 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 1387 // The name must start with an identifier. 1388 if (symName.empty() || !isalpha(symName.front())) 1389 return false; 1390 1391 // Ignore all the characters that are valid in an identifier in the symbol 1392 // name. 1393 symName = symName.drop_while( 1394 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); 1395 if (symName.empty()) 1396 return true; 1397 1398 // If we got to an unexpected character, then it must be a <>. Check those 1399 // recursively. 1400 if (symName.front() != '<' || symName.back() != '>') 1401 return false; 1402 1403 SmallVector<char, 8> nestedPunctuation; 1404 do { 1405 // If we ran out of characters, then we had a punctuation mismatch. 1406 if (symName.empty()) 1407 return false; 1408 1409 auto c = symName.front(); 1410 symName = symName.drop_front(); 1411 1412 switch (c) { 1413 // We never allow null characters. This is an EOF indicator for the lexer 1414 // which we could handle, but isn't important for any known dialect. 1415 case '\0': 1416 return false; 1417 case '<': 1418 case '[': 1419 case '(': 1420 case '{': 1421 nestedPunctuation.push_back(c); 1422 continue; 1423 case '-': 1424 // Treat `->` as a special token. 1425 if (!symName.empty() && symName.front() == '>') { 1426 symName = symName.drop_front(); 1427 continue; 1428 } 1429 break; 1430 // Reject types with mismatched brackets. 1431 case '>': 1432 if (nestedPunctuation.pop_back_val() != '<') 1433 return false; 1434 break; 1435 case ']': 1436 if (nestedPunctuation.pop_back_val() != '[') 1437 return false; 1438 break; 1439 case ')': 1440 if (nestedPunctuation.pop_back_val() != '(') 1441 return false; 1442 break; 1443 case '}': 1444 if (nestedPunctuation.pop_back_val() != '{') 1445 return false; 1446 break; 1447 default: 1448 continue; 1449 } 1450 1451 // We're done when the punctuation is fully matched. 1452 } while (!nestedPunctuation.empty()); 1453 1454 // If there were extra characters, then we failed. 1455 return symName.empty(); 1456 } 1457 1458 /// Print the given dialect symbol to the stream. 1459 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 1460 StringRef dialectName, StringRef symString) { 1461 os << symPrefix << dialectName; 1462 1463 // If this symbol name is simple enough, print it directly in pretty form, 1464 // otherwise, we print it as an escaped string. 1465 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 1466 os << '.' << symString; 1467 return; 1468 } 1469 1470 // TODO: escape the symbol name, it could contain " characters. 1471 os << "<\"" << symString << "\">"; 1472 } 1473 1474 /// Returns true if the given string can be represented as a bare identifier. 1475 static bool isBareIdentifier(StringRef name) { 1476 assert(!name.empty() && "invalid name"); 1477 1478 // By making this unsigned, the value passed in to isalnum will always be 1479 // in the range 0-255. This is important when building with MSVC because 1480 // its implementation will assert. This situation can arise when dealing 1481 // with UTF-8 multibyte characters. 1482 unsigned char firstChar = static_cast<unsigned char>(name[0]); 1483 if (!isalpha(firstChar) && firstChar != '_') 1484 return false; 1485 return llvm::all_of(name.drop_front(), [](unsigned char c) { 1486 return isalnum(c) || c == '_' || c == '$' || c == '.'; 1487 }); 1488 } 1489 1490 /// Print the given string as a symbol reference. A symbol reference is 1491 /// represented as a string prefixed with '@'. The reference is surrounded with 1492 /// ""'s and escaped if it has any special or non-printable characters in it. 1493 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { 1494 assert(!symbolRef.empty() && "expected valid symbol reference"); 1495 1496 // If the symbol can be represented as a bare identifier, write it directly. 1497 if (isBareIdentifier(symbolRef)) { 1498 os << '@' << symbolRef; 1499 return; 1500 } 1501 1502 // Otherwise, output the reference wrapped in quotes with proper escaping. 1503 os << "@\""; 1504 printEscapedString(symbolRef, os); 1505 os << '"'; 1506 } 1507 1508 // Print out a valid ElementsAttr that is succinct and can represent any 1509 // potential shape/type, for use when eliding a large ElementsAttr. 1510 // 1511 // We choose to use an opaque ElementsAttr literal with conspicuous content to 1512 // hopefully alert readers to the fact that this has been elided. 1513 // 1514 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will 1515 // accept the string "elided". The first string must be a registered dialect 1516 // name and the latter must be a hex constant. 1517 static void printElidedElementsAttr(raw_ostream &os) { 1518 os << R"(opaque<"", "0xDEADBEEF">)"; 1519 } 1520 1521 void ModulePrinter::printAttribute(Attribute attr, 1522 AttrTypeElision typeElision) { 1523 if (!attr) { 1524 os << "<<NULL ATTRIBUTE>>"; 1525 return; 1526 } 1527 1528 // Try to print an alias for this attribute. 1529 if (state && succeeded(state->getAliasState().getAlias(attr, os))) 1530 return; 1531 1532 auto attrType = attr.getType(); 1533 if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) { 1534 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 1535 opaqueAttr.getAttrData()); 1536 } else if (attr.isa<UnitAttr>()) { 1537 os << "unit"; 1538 return; 1539 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) { 1540 os << '{'; 1541 interleaveComma(dictAttr.getValue(), 1542 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1543 os << '}'; 1544 1545 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) { 1546 if (attrType.isSignlessInteger(1)) { 1547 os << (intAttr.getValue().getBoolValue() ? "true" : "false"); 1548 1549 // Boolean integer attributes always elides the type. 1550 return; 1551 } 1552 1553 // Only print attributes as unsigned if they are explicitly unsigned or are 1554 // signless 1-bit values. Indexes, signed values, and multi-bit signless 1555 // values print as signed. 1556 bool isUnsigned = 1557 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); 1558 intAttr.getValue().print(os, !isUnsigned); 1559 1560 // IntegerAttr elides the type if I64. 1561 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) 1562 return; 1563 1564 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) { 1565 printFloatValue(floatAttr.getValue(), os); 1566 1567 // FloatAttr elides the type if F64. 1568 if (typeElision == AttrTypeElision::May && attrType.isF64()) 1569 return; 1570 1571 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) { 1572 os << '"'; 1573 printEscapedString(strAttr.getValue(), os); 1574 os << '"'; 1575 1576 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 1577 os << '['; 1578 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { 1579 printAttribute(attr, AttrTypeElision::May); 1580 }); 1581 os << ']'; 1582 1583 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) { 1584 os << "affine_map<"; 1585 affineMapAttr.getValue().print(os); 1586 os << '>'; 1587 1588 // AffineMap always elides the type. 1589 return; 1590 1591 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) { 1592 os << "affine_set<"; 1593 integerSetAttr.getValue().print(os); 1594 os << '>'; 1595 1596 // IntegerSet always elides the type. 1597 return; 1598 1599 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 1600 printType(typeAttr.getValue()); 1601 1602 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) { 1603 printSymbolReference(refAttr.getRootReference(), os); 1604 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { 1605 os << "::"; 1606 printSymbolReference(nestedRef.getValue(), os); 1607 } 1608 1609 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) { 1610 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { 1611 printElidedElementsAttr(os); 1612 } else { 1613 os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", "; 1614 os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">"; 1615 } 1616 1617 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) { 1618 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { 1619 printElidedElementsAttr(os); 1620 } else { 1621 os << "dense<"; 1622 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); 1623 os << '>'; 1624 } 1625 1626 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) { 1627 if (printerFlags.shouldElideElementsAttr(strEltAttr)) { 1628 printElidedElementsAttr(os); 1629 } else { 1630 os << "dense<"; 1631 printDenseStringElementsAttr(strEltAttr); 1632 os << '>'; 1633 } 1634 1635 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) { 1636 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || 1637 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { 1638 printElidedElementsAttr(os); 1639 } else { 1640 os << "sparse<"; 1641 DenseIntElementsAttr indices = sparseEltAttr.getIndices(); 1642 if (indices.getNumElements() != 0) { 1643 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); 1644 os << ", "; 1645 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); 1646 } 1647 os << '>'; 1648 } 1649 1650 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) { 1651 printLocation(locAttr); 1652 1653 } else { 1654 return printDialectAttribute(attr); 1655 } 1656 1657 // Don't print the type if we must elide it, or if it is a None type. 1658 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { 1659 os << " : "; 1660 printType(attrType); 1661 } 1662 } 1663 1664 /// Print the integer element of a DenseElementsAttr. 1665 static void printDenseIntElement(const APInt &value, raw_ostream &os, 1666 bool isSigned) { 1667 if (value.getBitWidth() == 1) 1668 os << (value.getBoolValue() ? "true" : "false"); 1669 else 1670 value.print(os, isSigned); 1671 } 1672 1673 static void 1674 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, 1675 function_ref<void(unsigned)> printEltFn) { 1676 // Special case for 0-d and splat tensors. 1677 if (isSplat) 1678 return printEltFn(0); 1679 1680 // Special case for degenerate tensors. 1681 auto numElements = type.getNumElements(); 1682 if (numElements == 0) 1683 return; 1684 1685 // We use a mixed-radix counter to iterate through the shape. When we bump a 1686 // non-least-significant digit, we emit a close bracket. When we next emit an 1687 // element we re-open all closed brackets. 1688 1689 // The mixed-radix counter, with radices in 'shape'. 1690 int64_t rank = type.getRank(); 1691 SmallVector<unsigned, 4> counter(rank, 0); 1692 // The number of brackets that have been opened and not closed. 1693 unsigned openBrackets = 0; 1694 1695 auto shape = type.getShape(); 1696 auto bumpCounter = [&] { 1697 // Bump the least significant digit. 1698 ++counter[rank - 1]; 1699 // Iterate backwards bubbling back the increment. 1700 for (unsigned i = rank - 1; i > 0; --i) 1701 if (counter[i] >= shape[i]) { 1702 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 1703 counter[i] = 0; 1704 ++counter[i - 1]; 1705 --openBrackets; 1706 os << ']'; 1707 } 1708 }; 1709 1710 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 1711 if (idx != 0) 1712 os << ", "; 1713 while (openBrackets++ < rank) 1714 os << '['; 1715 openBrackets = rank; 1716 printEltFn(idx); 1717 bumpCounter(); 1718 } 1719 while (openBrackets-- > 0) 1720 os << ']'; 1721 } 1722 1723 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, 1724 bool allowHex) { 1725 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) 1726 return printDenseStringElementsAttr(stringAttr); 1727 1728 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(), 1729 allowHex); 1730 } 1731 1732 void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, 1733 bool allowHex) { 1734 auto type = attr.getType(); 1735 auto elementType = type.getElementType(); 1736 1737 // Check to see if we should format this attribute as a hex string. 1738 auto numElements = type.getNumElements(); 1739 if (!attr.isSplat() && allowHex && 1740 shouldPrintElementsAttrWithHex(numElements)) { 1741 ArrayRef<char> rawData = attr.getRawData(); 1742 if (llvm::support::endian::system_endianness() == 1743 llvm::support::endianness::big) { 1744 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE 1745 // machines. It is converted here to print in LE format. 1746 SmallVector<char, 64> outDataVec(rawData.size()); 1747 MutableArrayRef<char> convRawData(outDataVec); 1748 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1749 rawData, convRawData, type); 1750 os << '"' << "0x" 1751 << llvm::toHex(StringRef(convRawData.data(), convRawData.size())) 1752 << "\""; 1753 } else { 1754 os << '"' << "0x" 1755 << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\""; 1756 } 1757 1758 return; 1759 } 1760 1761 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { 1762 Type complexElementType = complexTy.getElementType(); 1763 // Note: The if and else below had a common lambda function which invoked 1764 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 1765 // and hence was replaced. 1766 if (complexElementType.isa<IntegerType>()) { 1767 bool isSigned = !complexElementType.isUnsignedInteger(); 1768 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1769 auto complexValue = *(attr.getComplexIntValues().begin() + index); 1770 os << "("; 1771 printDenseIntElement(complexValue.real(), os, isSigned); 1772 os << ","; 1773 printDenseIntElement(complexValue.imag(), os, isSigned); 1774 os << ")"; 1775 }); 1776 } else { 1777 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1778 auto complexValue = *(attr.getComplexFloatValues().begin() + index); 1779 os << "("; 1780 printFloatValue(complexValue.real(), os); 1781 os << ","; 1782 printFloatValue(complexValue.imag(), os); 1783 os << ")"; 1784 }); 1785 } 1786 } else if (elementType.isIntOrIndex()) { 1787 bool isSigned = !elementType.isUnsignedInteger(); 1788 auto intValues = attr.getIntValues(); 1789 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1790 printDenseIntElement(*(intValues.begin() + index), os, isSigned); 1791 }); 1792 } else { 1793 assert(elementType.isa<FloatType>() && "unexpected element type"); 1794 auto floatValues = attr.getFloatValues(); 1795 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { 1796 printFloatValue(*(floatValues.begin() + index), os); 1797 }); 1798 } 1799 } 1800 1801 void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { 1802 ArrayRef<StringRef> data = attr.getRawStringData(); 1803 auto printFn = [&](unsigned index) { 1804 os << "\""; 1805 printEscapedString(data[index], os); 1806 os << "\""; 1807 }; 1808 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); 1809 } 1810 1811 void ModulePrinter::printType(Type type) { 1812 if (!type) { 1813 os << "<<NULL TYPE>>"; 1814 return; 1815 } 1816 1817 // Try to print an alias for this type. 1818 if (state && succeeded(state->getAliasState().getAlias(type, os))) 1819 return; 1820 1821 TypeSwitch<Type>(type) 1822 .Case<OpaqueType>([&](OpaqueType opaqueTy) { 1823 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 1824 opaqueTy.getTypeData()); 1825 }) 1826 .Case<IndexType>([&](Type) { os << "index"; }) 1827 .Case<BFloat16Type>([&](Type) { os << "bf16"; }) 1828 .Case<Float16Type>([&](Type) { os << "f16"; }) 1829 .Case<Float32Type>([&](Type) { os << "f32"; }) 1830 .Case<Float64Type>([&](Type) { os << "f64"; }) 1831 .Case<Float80Type>([&](Type) { os << "f80"; }) 1832 .Case<Float128Type>([&](Type) { os << "f128"; }) 1833 .Case<IntegerType>([&](IntegerType integerTy) { 1834 if (integerTy.isSigned()) 1835 os << 's'; 1836 else if (integerTy.isUnsigned()) 1837 os << 'u'; 1838 os << 'i' << integerTy.getWidth(); 1839 }) 1840 .Case<FunctionType>([&](FunctionType funcTy) { 1841 os << '('; 1842 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); 1843 os << ") -> "; 1844 ArrayRef<Type> results = funcTy.getResults(); 1845 if (results.size() == 1 && !results[0].isa<FunctionType>()) { 1846 os << results[0]; 1847 } else { 1848 os << '('; 1849 interleaveComma(results, [&](Type ty) { printType(ty); }); 1850 os << ')'; 1851 } 1852 }) 1853 .Case<VectorType>([&](VectorType vectorTy) { 1854 os << "vector<"; 1855 for (int64_t dim : vectorTy.getShape()) 1856 os << dim << 'x'; 1857 os << vectorTy.getElementType() << '>'; 1858 }) 1859 .Case<RankedTensorType>([&](RankedTensorType tensorTy) { 1860 os << "tensor<"; 1861 for (int64_t dim : tensorTy.getShape()) { 1862 if (ShapedType::isDynamic(dim)) 1863 os << '?'; 1864 else 1865 os << dim; 1866 os << 'x'; 1867 } 1868 os << tensorTy.getElementType() << '>'; 1869 }) 1870 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { 1871 os << "tensor<*x"; 1872 printType(tensorTy.getElementType()); 1873 os << '>'; 1874 }) 1875 .Case<MemRefType>([&](MemRefType memrefTy) { 1876 os << "memref<"; 1877 for (int64_t dim : memrefTy.getShape()) { 1878 if (ShapedType::isDynamic(dim)) 1879 os << '?'; 1880 else 1881 os << dim; 1882 os << 'x'; 1883 } 1884 printType(memrefTy.getElementType()); 1885 for (auto map : memrefTy.getAffineMaps()) { 1886 os << ", "; 1887 printAttribute(AffineMapAttr::get(map)); 1888 } 1889 // Only print the memory space if it is the non-default one. 1890 if (memrefTy.getMemorySpaceAsInt()) 1891 os << ", " << memrefTy.getMemorySpaceAsInt(); 1892 os << '>'; 1893 }) 1894 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { 1895 os << "memref<*x"; 1896 printType(memrefTy.getElementType()); 1897 // Only print the memory space if it is the non-default one. 1898 if (memrefTy.getMemorySpaceAsInt()) 1899 os << ", " << memrefTy.getMemorySpaceAsInt(); 1900 os << '>'; 1901 }) 1902 .Case<ComplexType>([&](ComplexType complexTy) { 1903 os << "complex<"; 1904 printType(complexTy.getElementType()); 1905 os << '>'; 1906 }) 1907 .Case<TupleType>([&](TupleType tupleTy) { 1908 os << "tuple<"; 1909 interleaveComma(tupleTy.getTypes(), 1910 [&](Type type) { printType(type); }); 1911 os << '>'; 1912 }) 1913 .Case<NoneType>([&](Type) { os << "none"; }) 1914 .Default([&](Type type) { return printDialectType(type); }); 1915 } 1916 1917 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1918 ArrayRef<StringRef> elidedAttrs, 1919 bool withKeyword) { 1920 // If there are no attributes, then there is nothing to be done. 1921 if (attrs.empty()) 1922 return; 1923 1924 // Functor used to print a filtered attribute list. 1925 auto printFilteredAttributesFn = [&](auto filteredAttrs) { 1926 // Print the 'attributes' keyword if necessary. 1927 if (withKeyword) 1928 os << " attributes"; 1929 1930 // Otherwise, print them all out in braces. 1931 os << " {"; 1932 interleaveComma(filteredAttrs, 1933 [&](NamedAttribute attr) { printNamedAttribute(attr); }); 1934 os << '}'; 1935 }; 1936 1937 // If no attributes are elided, we can directly print with no filtering. 1938 if (elidedAttrs.empty()) 1939 return printFilteredAttributesFn(attrs); 1940 1941 // Otherwise, filter out any attributes that shouldn't be included. 1942 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), 1943 elidedAttrs.end()); 1944 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 1945 return !elidedAttrsSet.contains(attr.first.strref()); 1946 }); 1947 if (!filteredAttrs.empty()) 1948 printFilteredAttributesFn(filteredAttrs); 1949 } 1950 1951 void ModulePrinter::printNamedAttribute(NamedAttribute attr) { 1952 if (isBareIdentifier(attr.first)) { 1953 os << attr.first; 1954 } else { 1955 os << '"'; 1956 printEscapedString(attr.first.strref(), os); 1957 os << '"'; 1958 } 1959 1960 // Pretty printing elides the attribute value for unit attributes. 1961 if (attr.second.isa<UnitAttr>()) 1962 return; 1963 1964 os << " = "; 1965 printAttribute(attr.second); 1966 } 1967 1968 //===----------------------------------------------------------------------===// 1969 // CustomDialectAsmPrinter 1970 //===----------------------------------------------------------------------===// 1971 1972 namespace { 1973 /// This class provides the main specialization of the DialectAsmPrinter that is 1974 /// used to provide support for print attributes and types. This hooks allows 1975 /// for dialects to hook into the main ModulePrinter. 1976 struct CustomDialectAsmPrinter : public DialectAsmPrinter { 1977 public: 1978 CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} 1979 ~CustomDialectAsmPrinter() override {} 1980 1981 raw_ostream &getStream() const override { return printer.getStream(); } 1982 1983 /// Print the given attribute to the stream. 1984 void printAttribute(Attribute attr) override { printer.printAttribute(attr); } 1985 1986 /// Print the given floating point value in a stablized form. 1987 void printFloat(const APFloat &value) override { 1988 printFloatValue(value, getStream()); 1989 } 1990 1991 /// Print the given type to the stream. 1992 void printType(Type type) override { printer.printType(type); } 1993 1994 /// The main module printer. 1995 ModulePrinter &printer; 1996 }; 1997 } // end anonymous namespace 1998 1999 void ModulePrinter::printDialectAttribute(Attribute attr) { 2000 auto &dialect = attr.getDialect(); 2001 2002 // Ask the dialect to serialize the attribute to a string. 2003 std::string attrName; 2004 { 2005 llvm::raw_string_ostream attrNameStr(attrName); 2006 ModulePrinter subPrinter(attrNameStr, printerFlags, state); 2007 CustomDialectAsmPrinter printer(subPrinter); 2008 dialect.printAttribute(attr, printer); 2009 } 2010 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 2011 } 2012 2013 void ModulePrinter::printDialectType(Type type) { 2014 auto &dialect = type.getDialect(); 2015 2016 // Ask the dialect to serialize the type to a string. 2017 std::string typeName; 2018 { 2019 llvm::raw_string_ostream typeNameStr(typeName); 2020 ModulePrinter subPrinter(typeNameStr, printerFlags, state); 2021 CustomDialectAsmPrinter printer(subPrinter); 2022 dialect.printType(type, printer); 2023 } 2024 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 2025 } 2026 2027 //===----------------------------------------------------------------------===// 2028 // Affine expressions and maps 2029 //===----------------------------------------------------------------------===// 2030 2031 void ModulePrinter::printAffineExpr( 2032 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { 2033 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 2034 } 2035 2036 void ModulePrinter::printAffineExprInternal( 2037 AffineExpr expr, BindingStrength enclosingTightness, 2038 function_ref<void(unsigned, bool)> printValueName) { 2039 const char *binopSpelling = nullptr; 2040 switch (expr.getKind()) { 2041 case AffineExprKind::SymbolId: { 2042 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); 2043 if (printValueName) 2044 printValueName(pos, /*isSymbol=*/true); 2045 else 2046 os << 's' << pos; 2047 return; 2048 } 2049 case AffineExprKind::DimId: { 2050 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 2051 if (printValueName) 2052 printValueName(pos, /*isSymbol=*/false); 2053 else 2054 os << 'd' << pos; 2055 return; 2056 } 2057 case AffineExprKind::Constant: 2058 os << expr.cast<AffineConstantExpr>().getValue(); 2059 return; 2060 case AffineExprKind::Add: 2061 binopSpelling = " + "; 2062 break; 2063 case AffineExprKind::Mul: 2064 binopSpelling = " * "; 2065 break; 2066 case AffineExprKind::FloorDiv: 2067 binopSpelling = " floordiv "; 2068 break; 2069 case AffineExprKind::CeilDiv: 2070 binopSpelling = " ceildiv "; 2071 break; 2072 case AffineExprKind::Mod: 2073 binopSpelling = " mod "; 2074 break; 2075 } 2076 2077 auto binOp = expr.cast<AffineBinaryOpExpr>(); 2078 AffineExpr lhsExpr = binOp.getLHS(); 2079 AffineExpr rhsExpr = binOp.getRHS(); 2080 2081 // Handle tightly binding binary operators. 2082 if (binOp.getKind() != AffineExprKind::Add) { 2083 if (enclosingTightness == BindingStrength::Strong) 2084 os << '('; 2085 2086 // Pretty print multiplication with -1. 2087 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); 2088 if (rhsConst && binOp.getKind() == AffineExprKind::Mul && 2089 rhsConst.getValue() == -1) { 2090 os << "-"; 2091 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2092 if (enclosingTightness == BindingStrength::Strong) 2093 os << ')'; 2094 return; 2095 } 2096 2097 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 2098 2099 os << binopSpelling; 2100 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 2101 2102 if (enclosingTightness == BindingStrength::Strong) 2103 os << ')'; 2104 return; 2105 } 2106 2107 // Print out special "pretty" forms for add. 2108 if (enclosingTightness == BindingStrength::Strong) 2109 os << '('; 2110 2111 // Pretty print addition to a product that has a negative operand as a 2112 // subtraction. 2113 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { 2114 if (rhs.getKind() == AffineExprKind::Mul) { 2115 AffineExpr rrhsExpr = rhs.getRHS(); 2116 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { 2117 if (rrhs.getValue() == -1) { 2118 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2119 printValueName); 2120 os << " - "; 2121 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 2122 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2123 printValueName); 2124 } else { 2125 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 2126 printValueName); 2127 } 2128 2129 if (enclosingTightness == BindingStrength::Strong) 2130 os << ')'; 2131 return; 2132 } 2133 2134 if (rrhs.getValue() < -1) { 2135 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 2136 printValueName); 2137 os << " - "; 2138 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 2139 printValueName); 2140 os << " * " << -rrhs.getValue(); 2141 if (enclosingTightness == BindingStrength::Strong) 2142 os << ')'; 2143 return; 2144 } 2145 } 2146 } 2147 } 2148 2149 // Pretty print addition to a negative number as a subtraction. 2150 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { 2151 if (rhsConst.getValue() < 0) { 2152 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2153 os << " - " << -rhsConst.getValue(); 2154 if (enclosingTightness == BindingStrength::Strong) 2155 os << ')'; 2156 return; 2157 } 2158 } 2159 2160 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 2161 2162 os << " + "; 2163 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 2164 2165 if (enclosingTightness == BindingStrength::Strong) 2166 os << ')'; 2167 } 2168 2169 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { 2170 printAffineExprInternal(expr, BindingStrength::Weak); 2171 isEq ? os << " == 0" : os << " >= 0"; 2172 } 2173 2174 void ModulePrinter::printAffineMap(AffineMap map) { 2175 // Dimension identifiers. 2176 os << '('; 2177 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 2178 os << 'd' << i << ", "; 2179 if (map.getNumDims() >= 1) 2180 os << 'd' << map.getNumDims() - 1; 2181 os << ')'; 2182 2183 // Symbolic identifiers. 2184 if (map.getNumSymbols() != 0) { 2185 os << '['; 2186 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 2187 os << 's' << i << ", "; 2188 if (map.getNumSymbols() >= 1) 2189 os << 's' << map.getNumSymbols() - 1; 2190 os << ']'; 2191 } 2192 2193 // Result affine expressions. 2194 os << " -> ("; 2195 interleaveComma(map.getResults(), 2196 [&](AffineExpr expr) { printAffineExpr(expr); }); 2197 os << ')'; 2198 } 2199 2200 void ModulePrinter::printIntegerSet(IntegerSet set) { 2201 // Dimension identifiers. 2202 os << '('; 2203 for (unsigned i = 1; i < set.getNumDims(); ++i) 2204 os << 'd' << i - 1 << ", "; 2205 if (set.getNumDims() >= 1) 2206 os << 'd' << set.getNumDims() - 1; 2207 os << ')'; 2208 2209 // Symbolic identifiers. 2210 if (set.getNumSymbols() != 0) { 2211 os << '['; 2212 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 2213 os << 's' << i << ", "; 2214 if (set.getNumSymbols() >= 1) 2215 os << 's' << set.getNumSymbols() - 1; 2216 os << ']'; 2217 } 2218 2219 // Print constraints. 2220 os << " : ("; 2221 int numConstraints = set.getNumConstraints(); 2222 for (int i = 1; i < numConstraints; ++i) { 2223 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 2224 os << ", "; 2225 } 2226 if (numConstraints >= 1) 2227 printAffineConstraint(set.getConstraint(numConstraints - 1), 2228 set.isEq(numConstraints - 1)); 2229 os << ')'; 2230 } 2231 2232 //===----------------------------------------------------------------------===// 2233 // OperationPrinter 2234 //===----------------------------------------------------------------------===// 2235 2236 namespace { 2237 /// This class contains the logic for printing operations, regions, and blocks. 2238 class OperationPrinter : public ModulePrinter, private OpAsmPrinter { 2239 public: 2240 explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, 2241 AsmStateImpl &state) 2242 : ModulePrinter(os, flags, &state) {} 2243 2244 /// Print the given top-level operation. 2245 void printTopLevelOperation(Operation *op); 2246 2247 /// Print the given operation with its indent and location. 2248 void print(Operation *op); 2249 /// Print the bare location, not including indentation/location/etc. 2250 void printOperation(Operation *op); 2251 /// Print the given operation in the generic form. 2252 void printGenericOp(Operation *op) override; 2253 2254 /// Print the name of the given block. 2255 void printBlockName(Block *block); 2256 2257 /// Print the given block. If 'printBlockArgs' is false, the arguments of the 2258 /// block are not printed. If 'printBlockTerminator' is false, the terminator 2259 /// operation of the block is not printed. 2260 void print(Block *block, bool printBlockArgs = true, 2261 bool printBlockTerminator = true); 2262 2263 /// Print the ID of the given value, optionally with its result number. 2264 void printValueID(Value value, bool printResultNo = true, 2265 raw_ostream *streamOverride = nullptr) const; 2266 2267 //===--------------------------------------------------------------------===// 2268 // OpAsmPrinter methods 2269 //===--------------------------------------------------------------------===// 2270 2271 /// Return the current stream of the printer. 2272 raw_ostream &getStream() const override { return os; } 2273 2274 /// Print a newline and indent the printer to the start of the current 2275 /// operation. 2276 void printNewline() override { 2277 os << newLine; 2278 os.indent(currentIndent); 2279 } 2280 2281 /// Print the given type. 2282 void printType(Type type) override { ModulePrinter::printType(type); } 2283 2284 /// Print the given attribute. 2285 void printAttribute(Attribute attr) override { 2286 ModulePrinter::printAttribute(attr); 2287 } 2288 2289 /// Print the given attribute without its type. The corresponding parser must 2290 /// provide a valid type for the attribute. 2291 void printAttributeWithoutType(Attribute attr) override { 2292 ModulePrinter::printAttribute(attr, AttrTypeElision::Must); 2293 } 2294 2295 /// Print the ID for the given value. 2296 void printOperand(Value value) override { printValueID(value); } 2297 void printOperand(Value value, raw_ostream &os) override { 2298 printValueID(value, /*printResultNo=*/true, &os); 2299 } 2300 2301 /// Print an optional attribute dictionary with a given set of elided values. 2302 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 2303 ArrayRef<StringRef> elidedAttrs = {}) override { 2304 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); 2305 } 2306 void printOptionalAttrDictWithKeyword( 2307 ArrayRef<NamedAttribute> attrs, 2308 ArrayRef<StringRef> elidedAttrs = {}) override { 2309 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, 2310 /*withKeyword=*/true); 2311 } 2312 2313 /// Print the given successor. 2314 void printSuccessor(Block *successor) override; 2315 2316 /// Print an operation successor with the operands used for the block 2317 /// arguments. 2318 void printSuccessorAndUseList(Block *successor, 2319 ValueRange succOperands) override; 2320 2321 /// Print the given region. 2322 void printRegion(Region ®ion, bool printEntryBlockArgs, 2323 bool printBlockTerminators) override; 2324 2325 /// Renumber the arguments for the specified region to the same names as the 2326 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 2327 /// operations. If any entry in namesToUse is null, the corresponding 2328 /// argument name is left alone. 2329 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { 2330 state->getSSANameState().shadowRegionArgs(region, namesToUse); 2331 } 2332 2333 /// Print the given affine map with the symbol and dimension operands printed 2334 /// inline with the map. 2335 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2336 ValueRange operands) override; 2337 2338 /// Print the given string as a symbol reference. 2339 void printSymbolName(StringRef symbolRef) override { 2340 ::printSymbolReference(symbolRef, os); 2341 } 2342 2343 private: 2344 /// The number of spaces used for indenting nested operations. 2345 const static unsigned indentWidth = 2; 2346 2347 // This is the current indentation level for nested structures. 2348 unsigned currentIndent = 0; 2349 }; 2350 } // end anonymous namespace 2351 2352 void OperationPrinter::printTopLevelOperation(Operation *op) { 2353 // Output the aliases at the top level that can't be deferred. 2354 state->getAliasState().printNonDeferredAliases(os, newLine); 2355 2356 // Print the module. 2357 print(op); 2358 os << newLine; 2359 2360 // Output the aliases at the top level that can be deferred. 2361 state->getAliasState().printDeferredAliases(os, newLine); 2362 } 2363 2364 void OperationPrinter::print(Operation *op) { 2365 // Track the location of this operation. 2366 state->registerOperationLocation(op, newLine.curLine, currentIndent); 2367 2368 os.indent(currentIndent); 2369 printOperation(op); 2370 printTrailingLocation(op->getLoc()); 2371 } 2372 2373 void OperationPrinter::printOperation(Operation *op) { 2374 if (size_t numResults = op->getNumResults()) { 2375 auto printResultGroup = [&](size_t resultNo, size_t resultCount) { 2376 printValueID(op->getResult(resultNo), /*printResultNo=*/false); 2377 if (resultCount > 1) 2378 os << ':' << resultCount; 2379 }; 2380 2381 // Check to see if this operation has multiple result groups. 2382 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); 2383 if (!resultGroups.empty()) { 2384 // Interleave the groups excluding the last one, this one will be handled 2385 // separately. 2386 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { 2387 printResultGroup(resultGroups[i], 2388 resultGroups[i + 1] - resultGroups[i]); 2389 }); 2390 os << ", "; 2391 printResultGroup(resultGroups.back(), numResults - resultGroups.back()); 2392 2393 } else { 2394 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); 2395 } 2396 2397 os << " = "; 2398 } 2399 2400 // If requested, always print the generic form. 2401 if (!printerFlags.shouldPrintGenericOpForm()) { 2402 // Check to see if this is a known operation. If so, use the registered 2403 // custom printer hook. 2404 if (auto *opInfo = op->getAbstractOperation()) { 2405 opInfo->printAssembly(op, *this); 2406 return; 2407 } 2408 } 2409 2410 // Otherwise print with the generic assembly form. 2411 printGenericOp(op); 2412 } 2413 2414 void OperationPrinter::printGenericOp(Operation *op) { 2415 os << '"'; 2416 printEscapedString(op->getName().getStringRef(), os); 2417 os << "\"("; 2418 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); 2419 os << ')'; 2420 2421 // For terminators, print the list of successors and their operands. 2422 if (op->getNumSuccessors() != 0) { 2423 os << '['; 2424 interleaveComma(op->getSuccessors(), 2425 [&](Block *successor) { printBlockName(successor); }); 2426 os << ']'; 2427 } 2428 2429 // Print regions. 2430 if (op->getNumRegions() != 0) { 2431 os << " ("; 2432 interleaveComma(op->getRegions(), [&](Region ®ion) { 2433 printRegion(region, /*printEntryBlockArgs=*/true, 2434 /*printBlockTerminators=*/true); 2435 }); 2436 os << ')'; 2437 } 2438 2439 auto attrs = op->getAttrs(); 2440 printOptionalAttrDict(attrs); 2441 2442 // Print the type signature of the operation. 2443 os << " : "; 2444 printFunctionalType(op); 2445 } 2446 2447 void OperationPrinter::printBlockName(Block *block) { 2448 auto id = state->getSSANameState().getBlockID(block); 2449 if (id != SSANameState::NameSentinel) 2450 os << "^bb" << id; 2451 else 2452 os << "^INVALIDBLOCK"; 2453 } 2454 2455 void OperationPrinter::print(Block *block, bool printBlockArgs, 2456 bool printBlockTerminator) { 2457 // Print the block label and argument list if requested. 2458 if (printBlockArgs) { 2459 os.indent(currentIndent); 2460 printBlockName(block); 2461 2462 // Print the argument list if non-empty. 2463 if (!block->args_empty()) { 2464 os << '('; 2465 interleaveComma(block->getArguments(), [&](BlockArgument arg) { 2466 printValueID(arg); 2467 os << ": "; 2468 printType(arg.getType()); 2469 }); 2470 os << ')'; 2471 } 2472 os << ':'; 2473 2474 // Print out some context information about the predecessors of this block. 2475 if (!block->getParent()) { 2476 os << " // block is not in a region!"; 2477 } else if (block->hasNoPredecessors()) { 2478 os << " // no predecessors"; 2479 } else if (auto *pred = block->getSinglePredecessor()) { 2480 os << " // pred: "; 2481 printBlockName(pred); 2482 } else { 2483 // We want to print the predecessors in increasing numeric order, not in 2484 // whatever order the use-list is in, so gather and sort them. 2485 SmallVector<std::pair<unsigned, Block *>, 4> predIDs; 2486 for (auto *pred : block->getPredecessors()) 2487 predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); 2488 llvm::array_pod_sort(predIDs.begin(), predIDs.end()); 2489 2490 os << " // " << predIDs.size() << " preds: "; 2491 2492 interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) { 2493 printBlockName(pred.second); 2494 }); 2495 } 2496 os << newLine; 2497 } 2498 2499 currentIndent += indentWidth; 2500 auto range = llvm::make_range( 2501 block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1)); 2502 for (auto &op : range) { 2503 print(&op); 2504 os << newLine; 2505 } 2506 currentIndent -= indentWidth; 2507 } 2508 2509 void OperationPrinter::printValueID(Value value, bool printResultNo, 2510 raw_ostream *streamOverride) const { 2511 state->getSSANameState().printValueID(value, printResultNo, 2512 streamOverride ? *streamOverride : os); 2513 } 2514 2515 void OperationPrinter::printSuccessor(Block *successor) { 2516 printBlockName(successor); 2517 } 2518 2519 void OperationPrinter::printSuccessorAndUseList(Block *successor, 2520 ValueRange succOperands) { 2521 printBlockName(successor); 2522 if (succOperands.empty()) 2523 return; 2524 2525 os << '('; 2526 interleaveComma(succOperands, 2527 [this](Value operand) { printValueID(operand); }); 2528 os << " : "; 2529 interleaveComma(succOperands, 2530 [this](Value operand) { printType(operand.getType()); }); 2531 os << ')'; 2532 } 2533 2534 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, 2535 bool printBlockTerminators) { 2536 os << " {" << newLine; 2537 if (!region.empty()) { 2538 auto *entryBlock = ®ion.front(); 2539 print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, 2540 printBlockTerminators); 2541 for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) 2542 print(&b); 2543 } 2544 os.indent(currentIndent) << "}"; 2545 } 2546 2547 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, 2548 ValueRange operands) { 2549 AffineMap map = mapAttr.getValue(); 2550 unsigned numDims = map.getNumDims(); 2551 auto printValueName = [&](unsigned pos, bool isSymbol) { 2552 unsigned index = isSymbol ? numDims + pos : pos; 2553 assert(index < operands.size()); 2554 if (isSymbol) 2555 os << "symbol("; 2556 printValueID(operands[index]); 2557 if (isSymbol) 2558 os << ')'; 2559 }; 2560 2561 interleaveComma(map.getResults(), [&](AffineExpr expr) { 2562 printAffineExpr(expr, printValueName); 2563 }); 2564 } 2565 2566 //===----------------------------------------------------------------------===// 2567 // print and dump methods 2568 //===----------------------------------------------------------------------===// 2569 2570 void Attribute::print(raw_ostream &os) const { 2571 ModulePrinter(os).printAttribute(*this); 2572 } 2573 2574 void Attribute::dump() const { 2575 print(llvm::errs()); 2576 llvm::errs() << "\n"; 2577 } 2578 2579 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); } 2580 2581 void Type::dump() { print(llvm::errs()); } 2582 2583 void AffineMap::dump() const { 2584 print(llvm::errs()); 2585 llvm::errs() << "\n"; 2586 } 2587 2588 void IntegerSet::dump() const { 2589 print(llvm::errs()); 2590 llvm::errs() << "\n"; 2591 } 2592 2593 void AffineExpr::print(raw_ostream &os) const { 2594 if (!expr) { 2595 os << "<<NULL AFFINE EXPR>>"; 2596 return; 2597 } 2598 ModulePrinter(os).printAffineExpr(*this); 2599 } 2600 2601 void AffineExpr::dump() const { 2602 print(llvm::errs()); 2603 llvm::errs() << "\n"; 2604 } 2605 2606 void AffineMap::print(raw_ostream &os) const { 2607 if (!map) { 2608 os << "<<NULL AFFINE MAP>>"; 2609 return; 2610 } 2611 ModulePrinter(os).printAffineMap(*this); 2612 } 2613 2614 void IntegerSet::print(raw_ostream &os) const { 2615 ModulePrinter(os).printIntegerSet(*this); 2616 } 2617 2618 void Value::print(raw_ostream &os) { 2619 if (auto *op = getDefiningOp()) 2620 return op->print(os); 2621 // TODO: Improve this. 2622 BlockArgument arg = this->cast<BlockArgument>(); 2623 os << "<block argument> of type '" << arg.getType() 2624 << "' at index: " << arg.getArgNumber() << '\n'; 2625 } 2626 void Value::print(raw_ostream &os, AsmState &state) { 2627 if (auto *op = getDefiningOp()) 2628 return op->print(os, state); 2629 2630 // TODO: Improve this. 2631 BlockArgument arg = this->cast<BlockArgument>(); 2632 os << "<block argument> of type '" << arg.getType() 2633 << "' at index: " << arg.getArgNumber() << '\n'; 2634 } 2635 2636 void Value::dump() { 2637 print(llvm::errs()); 2638 llvm::errs() << "\n"; 2639 } 2640 2641 void Value::printAsOperand(raw_ostream &os, AsmState &state) { 2642 // TODO: This doesn't necessarily capture all potential cases. 2643 // Currently, region arguments can be shadowed when printing the main 2644 // operation. If the IR hasn't been printed, this will produce the old SSA 2645 // name and not the shadowed name. 2646 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, 2647 os); 2648 } 2649 2650 void Operation::print(raw_ostream &os, OpPrintingFlags flags) { 2651 // If this is a top level operation, we also print aliases. 2652 if (!getParent() && !flags.shouldUseLocalScope()) { 2653 AsmState state(this); 2654 state.getImpl().initializeAliases(this, flags); 2655 print(os, state, flags); 2656 return; 2657 } 2658 2659 // Find the operation to number from based upon the provided flags. 2660 Operation *op = this; 2661 bool shouldUseLocalScope = flags.shouldUseLocalScope(); 2662 do { 2663 // If we are printing local scope, stop at the first operation that is 2664 // isolated from above. 2665 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 2666 break; 2667 2668 // Otherwise, traverse up to the next parent. 2669 Operation *parentOp = op->getParentOp(); 2670 if (!parentOp) 2671 break; 2672 op = parentOp; 2673 } while (true); 2674 2675 AsmState state(op); 2676 print(os, state, flags); 2677 } 2678 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { 2679 OperationPrinter printer(os, flags, state.getImpl()); 2680 if (!getParent() && !flags.shouldUseLocalScope()) 2681 printer.printTopLevelOperation(this); 2682 else 2683 printer.print(this); 2684 } 2685 2686 void Operation::dump() { 2687 print(llvm::errs(), OpPrintingFlags().useLocalScope()); 2688 llvm::errs() << "\n"; 2689 } 2690 2691 void Block::print(raw_ostream &os) { 2692 Operation *parentOp = getParentOp(); 2693 if (!parentOp) { 2694 os << "<<UNLINKED BLOCK>>\n"; 2695 return; 2696 } 2697 // Get the top-level op. 2698 while (auto *nextOp = parentOp->getParentOp()) 2699 parentOp = nextOp; 2700 2701 AsmState state(parentOp); 2702 print(os, state); 2703 } 2704 void Block::print(raw_ostream &os, AsmState &state) { 2705 OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); 2706 } 2707 2708 void Block::dump() { print(llvm::errs()); } 2709 2710 /// Print out the name of the block without printing its body. 2711 void Block::printAsOperand(raw_ostream &os, bool printType) { 2712 Operation *parentOp = getParentOp(); 2713 if (!parentOp) { 2714 os << "<<UNLINKED BLOCK>>\n"; 2715 return; 2716 } 2717 AsmState state(parentOp); 2718 printAsOperand(os, state); 2719 } 2720 void Block::printAsOperand(raw_ostream &os, AsmState &state) { 2721 OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); 2722 printer.printBlockName(this); 2723 } 2724