1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/DialectConversion.h" 10 #include "mlir/IR/Block.h" 11 #include "mlir/IR/BlockAndValueMapping.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/Rewrite/PatternApplicator.h" 15 #include "mlir/Transforms/Utils.h" 16 #include "llvm/ADT/SetVector.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/FormatVariadic.h" 20 #include "llvm/Support/SaveAndRestore.h" 21 #include "llvm/Support/ScopedPrinter.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 #define DEBUG_TYPE "dialect-conversion" 27 28 /// Recursively collect all of the operations to convert from within 'region'. 29 /// If 'target' is nonnull, operations that are recursively legal have their 30 /// regions pre-filtered to avoid considering them for legalization. 31 static LogicalResult 32 computeConversionSet(iterator_range<Region::iterator> region, 33 Location regionLoc, std::vector<Operation *> &toConvert, 34 ConversionTarget *target = nullptr) { 35 if (llvm::empty(region)) 36 return success(); 37 38 // Traverse starting from the entry block. 39 SmallVector<Block *, 16> worklist(1, &*region.begin()); 40 DenseSet<Block *> visitedBlocks; 41 visitedBlocks.insert(worklist.front()); 42 while (!worklist.empty()) { 43 Block *block = worklist.pop_back_val(); 44 45 // Compute the conversion set of each of the nested operations. 46 for (Operation &op : *block) { 47 toConvert.emplace_back(&op); 48 49 // Don't check this operation's children for conversion if the operation 50 // is recursively legal. 51 auto legalityInfo = target ? target->isLegal(&op) 52 : Optional<ConversionTarget::LegalOpDetails>(); 53 if (legalityInfo && legalityInfo->isRecursivelyLegal) 54 continue; 55 for (auto ®ion : op.getRegions()) { 56 if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), 57 toConvert, target))) 58 return failure(); 59 } 60 } 61 62 // Recurse to children that haven't been visited. 63 for (Block *succ : block->getSuccessors()) 64 if (visitedBlocks.insert(succ).second) 65 worklist.push_back(succ); 66 } 67 68 // Check that all blocks in the region were visited. 69 if (llvm::any_of(llvm::drop_begin(region, 1), 70 [&](Block &block) { return !visitedBlocks.count(&block); })) 71 return emitError(regionLoc, "unreachable blocks were not converted"); 72 return success(); 73 } 74 75 /// A utility function to log a successful result for the given reason. 76 template <typename... Args> 77 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 78 LLVM_DEBUG({ 79 os.unindent(); 80 os.startLine() << "} -> SUCCESS"; 81 if (!fmt.empty()) 82 os.getOStream() << " : " 83 << llvm::formatv(fmt.data(), std::forward<Args>(args)...); 84 os.getOStream() << "\n"; 85 }); 86 } 87 88 /// A utility function to log a failure result for the given reason. 89 template <typename... Args> 90 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 91 LLVM_DEBUG({ 92 os.unindent(); 93 os.startLine() << "} -> FAILURE : " 94 << llvm::formatv(fmt.data(), std::forward<Args>(args)...) 95 << "\n"; 96 }); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // ConversionValueMapping 101 //===----------------------------------------------------------------------===// 102 103 namespace { 104 /// This class wraps a BlockAndValueMapping to provide recursive lookup 105 /// functionality, i.e. we will traverse if the mapped value also has a mapping. 106 struct ConversionValueMapping { 107 /// Lookup a mapped value within the map. If a mapping for the provided value 108 /// does not exist then return the provided value. If `desiredType` is 109 /// non-null, returns the most recently mapped value with that type. If an 110 /// operand of that type does not exist, defaults to normal behavior. 111 Value lookupOrDefault(Value from, Type desiredType = nullptr) const; 112 113 /// Lookup a mapped value within the map, or return null if a mapping does not 114 /// exist. If a mapping exists, this follows the same behavior of 115 /// `lookupOrDefault`. 116 Value lookupOrNull(Value from) const; 117 118 /// Map a value to the one provided. 119 void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); } 120 121 /// Drop the last mapping for the given value. 122 void erase(Value value) { mapping.erase(value); } 123 124 private: 125 /// Current value mappings. 126 BlockAndValueMapping mapping; 127 }; 128 } // end anonymous namespace 129 130 Value ConversionValueMapping::lookupOrDefault(Value from, 131 Type desiredType) const { 132 // If there was no desired type, simply find the leaf value. 133 if (!desiredType) { 134 // If this value had a valid mapping, unmap that value as well in the case 135 // that it was also replaced. 136 while (auto mappedValue = mapping.lookupOrNull(from)) 137 from = mappedValue; 138 return from; 139 } 140 141 // Otherwise, try to find the deepest value that has the desired type. 142 Value desiredValue; 143 do { 144 if (from.getType() == desiredType) 145 desiredValue = from; 146 147 Value mappedValue = mapping.lookupOrNull(from); 148 if (!mappedValue) 149 break; 150 from = mappedValue; 151 } while (true); 152 153 // If the desired value was found use it, otherwise default to the leaf value. 154 return desiredValue ? desiredValue : from; 155 } 156 157 Value ConversionValueMapping::lookupOrNull(Value from) const { 158 Value result = lookupOrDefault(from); 159 return result == from ? nullptr : result; 160 } 161 162 //===----------------------------------------------------------------------===// 163 // ArgConverter 164 //===----------------------------------------------------------------------===// 165 namespace { 166 /// This class provides a simple interface for converting the types of block 167 /// arguments. This is done by creating a new block that contains the new legal 168 /// types and extracting the block that contains the old illegal types to allow 169 /// for undoing pending rewrites in the case of failure. 170 struct ArgConverter { 171 ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {} 172 173 /// This structure contains the information pertaining to an argument that has 174 /// been converted. 175 struct ConvertedArgInfo { 176 ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, 177 Value castValue = nullptr) 178 : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} 179 180 /// The start index of in the new argument list that contains arguments that 181 /// replace the original. 182 unsigned newArgIdx; 183 184 /// The number of arguments that replaced the original argument. 185 unsigned newArgSize; 186 187 /// The cast value that was created to cast from the new arguments to the 188 /// old. This only used if 'newArgSize' > 1. 189 Value castValue; 190 }; 191 192 /// This structure contains information pertaining to a block that has had its 193 /// signature converted. 194 struct ConvertedBlockInfo { 195 ConvertedBlockInfo(Block *origBlock, TypeConverter &converter) 196 : origBlock(origBlock), converter(&converter) {} 197 198 /// The original block that was requested to have its signature converted. 199 Block *origBlock; 200 201 /// The conversion information for each of the arguments. The information is 202 /// None if the argument was dropped during conversion. 203 SmallVector<Optional<ConvertedArgInfo>, 1> argInfo; 204 205 /// The type converter used to convert the arguments. 206 TypeConverter *converter; 207 }; 208 209 /// Return if the signature of the given block has already been converted. 210 bool hasBeenConverted(Block *block) const { 211 return conversionInfo.count(block) || convertedBlocks.count(block); 212 } 213 214 /// Set the type converter to use for the given region. 215 void setConverter(Region *region, TypeConverter *typeConverter) { 216 assert(typeConverter && "expected valid type converter"); 217 regionToConverter[region] = typeConverter; 218 } 219 220 /// Return the type converter to use for the given region, or null if there 221 /// isn't one. 222 TypeConverter *getConverter(Region *region) { 223 return regionToConverter.lookup(region); 224 } 225 226 //===--------------------------------------------------------------------===// 227 // Rewrite Application 228 //===--------------------------------------------------------------------===// 229 230 /// Erase any rewrites registered for the blocks within the given operation 231 /// which is about to be removed. This merely drops the rewrites without 232 /// undoing them. 233 void notifyOpRemoved(Operation *op); 234 235 /// Cleanup and undo any generated conversions for the arguments of block. 236 /// This method replaces the new block with the original, reverting the IR to 237 /// its original state. 238 void discardRewrites(Block *block); 239 240 /// Fully replace uses of the old arguments with the new. 241 void applyRewrites(ConversionValueMapping &mapping); 242 243 /// Materialize any necessary conversions for converted arguments that have 244 /// live users, using the provided `findLiveUser` to search for a user that 245 /// survives the conversion process. 246 LogicalResult 247 materializeLiveConversions(ConversionValueMapping &mapping, 248 OpBuilder &builder, 249 function_ref<Operation *(Value)> findLiveUser); 250 251 //===--------------------------------------------------------------------===// 252 // Conversion 253 //===--------------------------------------------------------------------===// 254 255 /// Attempt to convert the signature of the given block, if successful a new 256 /// block is returned containing the new arguments. Returns `block` if it did 257 /// not require conversion. 258 FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter, 259 ConversionValueMapping &mapping); 260 261 /// Apply the given signature conversion on the given block. The new block 262 /// containing the updated signature is returned. If no conversions were 263 /// necessary, e.g. if the block has no arguments, `block` is returned. 264 /// `converter` is used to generate any necessary cast operations that 265 /// translate between the origin argument types and those specified in the 266 /// signature conversion. 267 Block *applySignatureConversion( 268 Block *block, TypeConverter &converter, 269 TypeConverter::SignatureConversion &signatureConversion, 270 ConversionValueMapping &mapping); 271 272 /// Insert a new conversion into the cache. 273 void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); 274 275 /// A collection of blocks that have had their arguments converted. This is a 276 /// map from the new replacement block, back to the original block. 277 llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo; 278 279 /// The set of original blocks that were converted. 280 DenseSet<Block *> convertedBlocks; 281 282 /// A mapping from valid regions, to those containing the original blocks of a 283 /// conversion. 284 DenseMap<Region *, std::unique_ptr<Region>> regionMapping; 285 286 /// A mapping of regions to type converters that should be used when 287 /// converting the arguments of blocks within that region. 288 DenseMap<Region *, TypeConverter *> regionToConverter; 289 290 /// The pattern rewriter to use when materializing conversions. 291 PatternRewriter &rewriter; 292 }; 293 } // end anonymous namespace 294 295 //===----------------------------------------------------------------------===// 296 // Rewrite Application 297 298 void ArgConverter::notifyOpRemoved(Operation *op) { 299 if (conversionInfo.empty()) 300 return; 301 302 for (Region ®ion : op->getRegions()) { 303 for (Block &block : region) { 304 // Drop any rewrites from within. 305 for (Operation &nestedOp : block) 306 if (nestedOp.getNumRegions()) 307 notifyOpRemoved(&nestedOp); 308 309 // Check if this block was converted. 310 auto it = conversionInfo.find(&block); 311 if (it == conversionInfo.end()) 312 continue; 313 314 // Drop all uses of the original arguments and delete the original block. 315 Block *origBlock = it->second.origBlock; 316 for (BlockArgument arg : origBlock->getArguments()) 317 arg.dropAllUses(); 318 conversionInfo.erase(it); 319 } 320 } 321 } 322 323 void ArgConverter::discardRewrites(Block *block) { 324 auto it = conversionInfo.find(block); 325 if (it == conversionInfo.end()) 326 return; 327 Block *origBlock = it->second.origBlock; 328 329 // Drop all uses of the new block arguments and replace uses of the new block. 330 for (int i = block->getNumArguments() - 1; i >= 0; --i) 331 block->getArgument(i).dropAllUses(); 332 block->replaceAllUsesWith(origBlock); 333 334 // Move the operations back the original block and the delete the new block. 335 origBlock->getOperations().splice(origBlock->end(), block->getOperations()); 336 origBlock->moveBefore(block); 337 block->erase(); 338 339 convertedBlocks.erase(origBlock); 340 conversionInfo.erase(it); 341 } 342 343 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { 344 for (auto &info : conversionInfo) { 345 ConvertedBlockInfo &blockInfo = info.second; 346 Block *origBlock = blockInfo.origBlock; 347 348 // Process the remapping for each of the original arguments. 349 for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { 350 Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i]; 351 BlockArgument origArg = origBlock->getArgument(i); 352 353 // Handle the case of a 1->0 value mapping. 354 if (!argInfo) { 355 if (Value newArg = mapping.lookupOrNull(origArg)) 356 origArg.replaceAllUsesWith(newArg); 357 continue; 358 } 359 360 // Otherwise this is a 1->1+ value mapping. 361 Value castValue = argInfo->castValue; 362 assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); 363 364 // If the argument is still used, replace it with the generated cast. 365 if (!origArg.use_empty()) 366 origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue)); 367 } 368 } 369 } 370 371 LogicalResult ArgConverter::materializeLiveConversions( 372 ConversionValueMapping &mapping, OpBuilder &builder, 373 function_ref<Operation *(Value)> findLiveUser) { 374 for (auto &info : conversionInfo) { 375 Block *newBlock = info.first; 376 ConvertedBlockInfo &blockInfo = info.second; 377 Block *origBlock = blockInfo.origBlock; 378 379 // Process the remapping for each of the original arguments. 380 for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { 381 // FIXME: We should run the below checks even if the type conversion was 382 // 1->N, but a lot of existing lowering rely on the block argument being 383 // blindly replaced. Those usages should be updated, and this if should be 384 // removed. 385 if (blockInfo.argInfo[i]) 386 continue; 387 388 // If the type of this argument changed and the argument is still live, we 389 // need to materialize a conversion. 390 BlockArgument origArg = origBlock->getArgument(i); 391 auto argReplacementValue = mapping.lookupOrDefault(origArg); 392 bool isDroppedArg = argReplacementValue == origArg; 393 if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg) 394 continue; 395 Operation *liveUser = findLiveUser(origArg); 396 if (!liveUser) 397 continue; 398 399 if (OpResult result = argReplacementValue.dyn_cast<OpResult>()) 400 rewriter.setInsertionPointAfter(result.getOwner()); 401 else 402 rewriter.setInsertionPointToStart(newBlock); 403 Value newArg = blockInfo.converter->materializeSourceConversion( 404 rewriter, origArg.getLoc(), origArg.getType(), 405 isDroppedArg ? ValueRange() : ValueRange(argReplacementValue)); 406 if (!newArg) { 407 InFlightDiagnostic diag = 408 emitError(origArg.getLoc()) 409 << "failed to materialize conversion for block argument #" << i 410 << " that remained live after conversion, type was " 411 << origArg.getType(); 412 if (!isDroppedArg) 413 diag << ", with target type " << argReplacementValue.getType(); 414 diag.attachNote(liveUser->getLoc()) 415 << "see existing live user here: " << *liveUser; 416 return failure(); 417 } 418 mapping.map(origArg, newArg); 419 } 420 } 421 return success(); 422 } 423 424 //===----------------------------------------------------------------------===// 425 // Conversion 426 427 FailureOr<Block *> 428 ArgConverter::convertSignature(Block *block, TypeConverter &converter, 429 ConversionValueMapping &mapping) { 430 // Check if the block was already converted. If the block is detached, 431 // conservatively assume it is going to be deleted. 432 if (hasBeenConverted(block) || !block->getParent()) 433 return block; 434 435 // Try to convert the signature for the block with the provided converter. 436 if (auto conversion = converter.convertBlockSignature(block)) 437 return applySignatureConversion(block, converter, *conversion, mapping); 438 return failure(); 439 } 440 441 Block *ArgConverter::applySignatureConversion( 442 Block *block, TypeConverter &converter, 443 TypeConverter::SignatureConversion &signatureConversion, 444 ConversionValueMapping &mapping) { 445 // If no arguments are being changed or added, there is nothing to do. 446 unsigned origArgCount = block->getNumArguments(); 447 auto convertedTypes = signatureConversion.getConvertedTypes(); 448 if (origArgCount == 0 && convertedTypes.empty()) 449 return block; 450 451 // Split the block at the beginning to get a new block to use for the updated 452 // signature. 453 Block *newBlock = block->splitBlock(block->begin()); 454 block->replaceAllUsesWith(newBlock); 455 456 SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes)); 457 ArrayRef<Value> newArgs(newArgRange); 458 459 // Remap each of the original arguments as determined by the signature 460 // conversion. 461 ConvertedBlockInfo info(block, converter); 462 info.argInfo.resize(origArgCount); 463 464 OpBuilder::InsertionGuard guard(rewriter); 465 rewriter.setInsertionPointToStart(newBlock); 466 for (unsigned i = 0; i != origArgCount; ++i) { 467 auto inputMap = signatureConversion.getInputMapping(i); 468 if (!inputMap) 469 continue; 470 BlockArgument origArg = block->getArgument(i); 471 472 // If inputMap->replacementValue is not nullptr, then the argument is 473 // dropped and a replacement value is provided to be the remappedValue. 474 if (inputMap->replacementValue) { 475 assert(inputMap->size == 0 && 476 "invalid to provide a replacement value when the argument isn't " 477 "dropped"); 478 mapping.map(origArg, inputMap->replacementValue); 479 continue; 480 } 481 482 // Otherwise, this is a 1->1+ mapping. Call into the provided type converter 483 // to pack the new values. For 1->1 mappings, if there is no materialization 484 // provided, use the argument directly instead. 485 auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); 486 Value newArg = converter.materializeArgumentConversion( 487 rewriter, origArg.getLoc(), origArg.getType(), replArgs); 488 if (!newArg) { 489 assert(replArgs.size() == 1 && 490 "couldn't materialize the result of 1->N conversion"); 491 newArg = replArgs.front(); 492 } 493 mapping.map(origArg, newArg); 494 info.argInfo[i] = 495 ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); 496 } 497 498 // Remove the original block from the region and return the new one. 499 insertConversion(newBlock, std::move(info)); 500 return newBlock; 501 } 502 503 void ArgConverter::insertConversion(Block *newBlock, 504 ConvertedBlockInfo &&info) { 505 // Get a region to insert the old block. 506 Region *region = newBlock->getParent(); 507 std::unique_ptr<Region> &mappedRegion = regionMapping[region]; 508 if (!mappedRegion) 509 mappedRegion = std::make_unique<Region>(region->getParentOp()); 510 511 // Move the original block to the mapped region and emplace the conversion. 512 mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), 513 info.origBlock->getIterator()); 514 convertedBlocks.insert(info.origBlock); 515 conversionInfo.insert({newBlock, std::move(info)}); 516 } 517 518 //===----------------------------------------------------------------------===// 519 // Rewriter and Translation State 520 //===----------------------------------------------------------------------===// 521 namespace { 522 /// This class contains a snapshot of the current conversion rewriter state. 523 /// This is useful when saving and undoing a set of rewrites. 524 struct RewriterState { 525 RewriterState(unsigned numCreatedOps, unsigned numReplacements, 526 unsigned numArgReplacements, unsigned numBlockActions, 527 unsigned numIgnoredOperations, unsigned numRootUpdates) 528 : numCreatedOps(numCreatedOps), numReplacements(numReplacements), 529 numArgReplacements(numArgReplacements), 530 numBlockActions(numBlockActions), 531 numIgnoredOperations(numIgnoredOperations), 532 numRootUpdates(numRootUpdates) {} 533 534 /// The current number of created operations. 535 unsigned numCreatedOps; 536 537 /// The current number of replacements queued. 538 unsigned numReplacements; 539 540 /// The current number of argument replacements queued. 541 unsigned numArgReplacements; 542 543 /// The current number of block actions performed. 544 unsigned numBlockActions; 545 546 /// The current number of ignored operations. 547 unsigned numIgnoredOperations; 548 549 /// The current number of operations that were updated in place. 550 unsigned numRootUpdates; 551 }; 552 553 /// The state of an operation that was updated by a pattern in-place. This 554 /// contains all of the necessary information to reconstruct an operation that 555 /// was updated in place. 556 class OperationTransactionState { 557 public: 558 OperationTransactionState() = default; 559 OperationTransactionState(Operation *op) 560 : op(op), loc(op->getLoc()), attrs(op->getMutableAttrDict()), 561 operands(op->operand_begin(), op->operand_end()), 562 successors(op->successor_begin(), op->successor_end()) {} 563 564 /// Discard the transaction state and reset the state of the original 565 /// operation. 566 void resetOperation() const { 567 op->setLoc(loc); 568 op->setAttrs(attrs); 569 op->setOperands(operands); 570 for (auto it : llvm::enumerate(successors)) 571 op->setSuccessor(it.value(), it.index()); 572 } 573 574 /// Return the original operation of this state. 575 Operation *getOperation() const { return op; } 576 577 private: 578 Operation *op; 579 LocationAttr loc; 580 MutableDictionaryAttr attrs; 581 SmallVector<Value, 8> operands; 582 SmallVector<Block *, 2> successors; 583 }; 584 585 /// This class represents one requested operation replacement via 'replaceOp' or 586 /// 'eraseOp`. 587 struct OpReplacement { 588 OpReplacement() = default; 589 OpReplacement(TypeConverter *converter) : converter(converter) {} 590 591 /// An optional type converter that can be used to materialize conversions 592 /// between the new and old values if necessary. 593 TypeConverter *converter = nullptr; 594 }; 595 596 /// The kind of the block action performed during the rewrite. Actions can be 597 /// undone if the conversion fails. 598 enum class BlockActionKind { 599 Create, 600 Erase, 601 Merge, 602 Move, 603 Split, 604 TypeConversion 605 }; 606 607 /// Original position of the given block in its parent region. During undo 608 /// actions, the block needs to be placed after `insertAfterBlock`. 609 struct BlockPosition { 610 Region *region; 611 Block *insertAfterBlock; 612 }; 613 614 /// Information needed to undo the merge actions. 615 /// - the source block, and 616 /// - the Operation that was the last operation in the dest block before the 617 /// merge (could be null if the dest block was empty). 618 struct MergeInfo { 619 Block *sourceBlock; 620 Operation *destBlockLastInst; 621 }; 622 623 /// The storage class for an undoable block action (one of BlockActionKind), 624 /// contains the information necessary to undo this action. 625 struct BlockAction { 626 static BlockAction getCreate(Block *block) { 627 return {BlockActionKind::Create, block, {}}; 628 } 629 static BlockAction getErase(Block *block, BlockPosition originalPosition) { 630 return {BlockActionKind::Erase, block, {originalPosition}}; 631 } 632 static BlockAction getMerge(Block *block, Block *sourceBlock) { 633 BlockAction action{BlockActionKind::Merge, block, {}}; 634 action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()}; 635 return action; 636 } 637 static BlockAction getMove(Block *block, BlockPosition originalPosition) { 638 return {BlockActionKind::Move, block, {originalPosition}}; 639 } 640 static BlockAction getSplit(Block *block, Block *originalBlock) { 641 BlockAction action{BlockActionKind::Split, block, {}}; 642 action.originalBlock = originalBlock; 643 return action; 644 } 645 static BlockAction getTypeConversion(Block *block) { 646 return BlockAction{BlockActionKind::TypeConversion, block, {}}; 647 } 648 649 // The action kind. 650 BlockActionKind kind; 651 652 // A pointer to the block that was created by the action. 653 Block *block; 654 655 union { 656 // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and 657 // contains a pointer to the region that originally contained the block as 658 // well as the position of the block in that region. 659 BlockPosition originalPosition; 660 // In use if kind == BlockActionKind::Split and contains a pointer to the 661 // block that was split into two parts. 662 Block *originalBlock; 663 // In use if kind == BlockActionKind::Merge, and contains the information 664 // needed to undo the merge. 665 MergeInfo mergeInfo; 666 }; 667 }; 668 } // end anonymous namespace 669 670 //===----------------------------------------------------------------------===// 671 // ConversionPatternRewriterImpl 672 //===----------------------------------------------------------------------===// 673 namespace mlir { 674 namespace detail { 675 struct ConversionPatternRewriterImpl { 676 ConversionPatternRewriterImpl(PatternRewriter &rewriter) 677 : argConverter(rewriter) {} 678 679 /// Cleanup and destroy any generated rewrite operations. This method is 680 /// invoked when the conversion process fails. 681 void discardRewrites(); 682 683 /// Apply all requested operation rewrites. This method is invoked when the 684 /// conversion process succeeds. 685 void applyRewrites(); 686 687 //===--------------------------------------------------------------------===// 688 // State Management 689 //===--------------------------------------------------------------------===// 690 691 /// Return the current state of the rewriter. 692 RewriterState getCurrentState(); 693 694 /// Reset the state of the rewriter to a previously saved point. 695 void resetState(RewriterState state); 696 697 /// Erase any blocks that were unlinked from their regions and stored in block 698 /// actions. 699 void eraseDanglingBlocks(); 700 701 /// Undo the block actions (motions, splits) one by one in reverse order until 702 /// "numActionsToKeep" actions remains. 703 void undoBlockActions(unsigned numActionsToKeep = 0); 704 705 /// Remap the given operands to those with potentially different types. The 706 /// provided type converter is used to ensure that the remapped types are 707 /// legal. Returns success if the operands could be remapped, failure 708 /// otherwise. 709 LogicalResult remapValues(Location loc, PatternRewriter &rewriter, 710 TypeConverter *converter, 711 Operation::operand_range operands, 712 SmallVectorImpl<Value> &remapped); 713 714 /// Returns true if the given operation is ignored, and does not need to be 715 /// converted. 716 bool isOpIgnored(Operation *op) const; 717 718 /// Recursively marks the nested operations under 'op' as ignored. This 719 /// removes them from being considered for legalization. 720 void markNestedOpsIgnored(Operation *op); 721 722 //===--------------------------------------------------------------------===// 723 // Type Conversion 724 //===--------------------------------------------------------------------===// 725 726 /// Convert the signature of the given block. 727 FailureOr<Block *> convertBlockSignature( 728 Block *block, TypeConverter &converter, 729 TypeConverter::SignatureConversion *conversion = nullptr); 730 731 /// Apply a signature conversion on the given region. 732 Block * 733 applySignatureConversion(Region *region, 734 TypeConverter::SignatureConversion &conversion); 735 736 /// Convert the types of block arguments within the given region. 737 FailureOr<Block *> 738 convertRegionTypes(Region *region, TypeConverter &converter, 739 TypeConverter::SignatureConversion *entryConversion); 740 741 //===--------------------------------------------------------------------===// 742 // Rewriter Notification Hooks 743 //===--------------------------------------------------------------------===// 744 745 /// PatternRewriter hook for replacing the results of an operation. 746 void notifyOpReplaced(Operation *op, ValueRange newValues); 747 748 /// Notifies that a block is about to be erased. 749 void notifyBlockIsBeingErased(Block *block); 750 751 /// Notifies that a block was created. 752 void notifyCreatedBlock(Block *block); 753 754 /// Notifies that a block was split. 755 void notifySplitBlock(Block *block, Block *continuation); 756 757 /// Notifies that `block` is being merged with `srcBlock`. 758 void notifyBlocksBeingMerged(Block *block, Block *srcBlock); 759 760 /// Notifies that the blocks of a region are about to be moved. 761 void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, 762 Region::iterator before); 763 764 /// Notifies that the blocks of a region were cloned into another. 765 void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks, 766 Location origRegionLoc); 767 768 /// Notifies that a pattern match failed for the given reason. 769 LogicalResult 770 notifyMatchFailure(Location loc, 771 function_ref<void(Diagnostic &)> reasonCallback); 772 773 //===--------------------------------------------------------------------===// 774 // State 775 //===--------------------------------------------------------------------===// 776 777 // Mapping between replaced values that differ in type. This happens when 778 // replacing a value with one of a different type. 779 ConversionValueMapping mapping; 780 781 /// Utility used to convert block arguments. 782 ArgConverter argConverter; 783 784 /// Ordered vector of all of the newly created operations during conversion. 785 std::vector<Operation *> createdOps; 786 787 /// Ordered map of requested operation replacements. 788 llvm::MapVector<Operation *, OpReplacement> replacements; 789 790 /// Ordered vector of any requested block argument replacements. 791 SmallVector<BlockArgument, 4> argReplacements; 792 793 /// Ordered list of block operations (creations, splits, motions). 794 SmallVector<BlockAction, 4> blockActions; 795 796 /// A set of operations that should no longer be considered for legalization, 797 /// but were not directly replace/erased/etc. by a pattern. These are 798 /// generally child operations of other operations who were 799 /// replaced/erased/etc. This is not meant to be an exhaustive list of all 800 /// operations, but the minimal set that can be used to detect if a given 801 /// operation should be `ignored`. For example, we may add the operations that 802 /// define non-empty regions to the set, but not any of the others. This 803 /// simplifies the amount of memory needed as we can query if the parent 804 /// operation was ignored. 805 llvm::SetVector<Operation *> ignoredOps; 806 807 /// A transaction state for each of operations that were updated in-place. 808 SmallVector<OperationTransactionState, 4> rootUpdates; 809 810 /// A vector of indices into `replacements` of operations that were replaced 811 /// with values with different result types than the original operation, e.g. 812 /// 1->N conversion of some kind. 813 SmallVector<unsigned, 4> operationsWithChangedResults; 814 815 /// A default type converter, used when block conversions do not have one 816 /// explicitly provided. 817 TypeConverter defaultTypeConverter; 818 819 /// The current conversion pattern that is being rewritten, or nullptr if 820 /// called from outside of a conversion pattern rewrite. 821 const ConversionPattern *currentConversionPattern = nullptr; 822 823 #ifndef NDEBUG 824 /// A set of operations that have pending updates. This tracking isn't 825 /// strictly necessary, and is thus only active during debug builds for extra 826 /// verification. 827 SmallPtrSet<Operation *, 1> pendingRootUpdates; 828 829 /// A logger used to emit diagnostics during the conversion process. 830 llvm::ScopedPrinter logger{llvm::dbgs()}; 831 #endif 832 }; 833 } // end namespace detail 834 } // end namespace mlir 835 836 /// Detach any operations nested in the given operation from their parent 837 /// blocks, and erase the given operation. This can be used when the nested 838 /// operations are scheduled for erasure themselves, so deleting the regions of 839 /// the given operation together with their content would result in double-free. 840 /// This happens, for example, when rolling back op creation in the reverse 841 /// order and if the nested ops were created before the parent op. This function 842 /// does not need to collect nested ops recursively because it is expected to 843 /// also be called for each nested op when it is about to be deleted. 844 static void detachNestedAndErase(Operation *op) { 845 for (Region ®ion : op->getRegions()) { 846 for (Block &block : region.getBlocks()) { 847 while (!block.getOperations().empty()) 848 block.getOperations().remove(block.getOperations().begin()); 849 block.dropAllDefinedValueUses(); 850 } 851 } 852 op->erase(); 853 } 854 855 void ConversionPatternRewriterImpl::discardRewrites() { 856 // Reset any operations that were updated in place. 857 for (auto &state : rootUpdates) 858 state.resetOperation(); 859 860 undoBlockActions(); 861 862 // Remove any newly created ops. 863 for (auto *op : llvm::reverse(createdOps)) 864 detachNestedAndErase(op); 865 } 866 867 void ConversionPatternRewriterImpl::applyRewrites() { 868 // Apply all of the rewrites replacements requested during conversion. 869 for (auto &repl : replacements) { 870 for (OpResult result : repl.first->getResults()) 871 if (Value newValue = mapping.lookupOrNull(result)) 872 result.replaceAllUsesWith(newValue); 873 874 // If this operation defines any regions, drop any pending argument 875 // rewrites. 876 if (repl.first->getNumRegions()) 877 argConverter.notifyOpRemoved(repl.first); 878 } 879 880 // Apply all of the requested argument replacements. 881 for (BlockArgument arg : argReplacements) { 882 Value repl = mapping.lookupOrDefault(arg); 883 if (repl.isa<BlockArgument>()) { 884 arg.replaceAllUsesWith(repl); 885 continue; 886 } 887 888 // If the replacement value is an operation, we check to make sure that we 889 // don't replace uses that are within the parent operation of the 890 // replacement value. 891 Operation *replOp = repl.cast<OpResult>().getOwner(); 892 Block *replBlock = replOp->getBlock(); 893 arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { 894 Operation *user = operand.getOwner(); 895 return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); 896 }); 897 } 898 899 // In a second pass, erase all of the replaced operations in reverse. This 900 // allows processing nested operations before their parent region is 901 // destroyed. 902 for (auto &repl : llvm::reverse(replacements)) 903 repl.first->erase(); 904 905 argConverter.applyRewrites(mapping); 906 907 // Now that the ops have been erased, also erase dangling blocks. 908 eraseDanglingBlocks(); 909 } 910 911 //===----------------------------------------------------------------------===// 912 // State Management 913 914 RewriterState ConversionPatternRewriterImpl::getCurrentState() { 915 return RewriterState(createdOps.size(), replacements.size(), 916 argReplacements.size(), blockActions.size(), 917 ignoredOps.size(), rootUpdates.size()); 918 } 919 920 void ConversionPatternRewriterImpl::resetState(RewriterState state) { 921 // Reset any operations that were updated in place. 922 for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) 923 rootUpdates[i].resetOperation(); 924 rootUpdates.resize(state.numRootUpdates); 925 926 // Reset any replaced arguments. 927 for (BlockArgument replacedArg : 928 llvm::drop_begin(argReplacements, state.numArgReplacements)) 929 mapping.erase(replacedArg); 930 argReplacements.resize(state.numArgReplacements); 931 932 // Undo any block actions. 933 undoBlockActions(state.numBlockActions); 934 935 // Reset any replaced operations and undo any saved mappings. 936 for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) 937 for (auto result : repl.first->getResults()) 938 mapping.erase(result); 939 while (replacements.size() != state.numReplacements) 940 replacements.pop_back(); 941 942 // Pop all of the newly created operations. 943 while (createdOps.size() != state.numCreatedOps) { 944 detachNestedAndErase(createdOps.back()); 945 createdOps.pop_back(); 946 } 947 948 // Pop all of the recorded ignored operations that are no longer valid. 949 while (ignoredOps.size() != state.numIgnoredOperations) 950 ignoredOps.pop_back(); 951 952 // Reset operations with changed results. 953 while (!operationsWithChangedResults.empty() && 954 operationsWithChangedResults.back() >= state.numReplacements) 955 operationsWithChangedResults.pop_back(); 956 } 957 958 void ConversionPatternRewriterImpl::eraseDanglingBlocks() { 959 for (auto &action : blockActions) 960 if (action.kind == BlockActionKind::Erase) 961 delete action.block; 962 } 963 964 void ConversionPatternRewriterImpl::undoBlockActions( 965 unsigned numActionsToKeep) { 966 for (auto &action : 967 llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { 968 switch (action.kind) { 969 // Delete the created block. 970 case BlockActionKind::Create: { 971 // Unlink all of the operations within this block, they will be deleted 972 // separately. 973 auto &blockOps = action.block->getOperations(); 974 while (!blockOps.empty()) 975 blockOps.remove(blockOps.begin()); 976 action.block->dropAllDefinedValueUses(); 977 action.block->erase(); 978 break; 979 } 980 // Put the block (owned by action) back into its original position. 981 case BlockActionKind::Erase: { 982 auto &blockList = action.originalPosition.region->getBlocks(); 983 Block *insertAfterBlock = action.originalPosition.insertAfterBlock; 984 blockList.insert((insertAfterBlock 985 ? std::next(Region::iterator(insertAfterBlock)) 986 : blockList.begin()), 987 action.block); 988 break; 989 } 990 // Split the block at the position which was originally the end of the 991 // destination block (owned by action), and put the instructions back into 992 // the block used before the merge. 993 case BlockActionKind::Merge: { 994 Block *sourceBlock = action.mergeInfo.sourceBlock; 995 Block::iterator splitPoint = 996 (action.mergeInfo.destBlockLastInst 997 ? ++Block::iterator(action.mergeInfo.destBlockLastInst) 998 : action.block->begin()); 999 sourceBlock->getOperations().splice(sourceBlock->begin(), 1000 action.block->getOperations(), 1001 splitPoint, action.block->end()); 1002 break; 1003 } 1004 // Move the block back to its original position. 1005 case BlockActionKind::Move: { 1006 Region *originalRegion = action.originalPosition.region; 1007 Block *insertAfterBlock = action.originalPosition.insertAfterBlock; 1008 originalRegion->getBlocks().splice( 1009 (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock)) 1010 : originalRegion->end()), 1011 action.block->getParent()->getBlocks(), action.block); 1012 break; 1013 } 1014 // Merge back the block that was split out. 1015 case BlockActionKind::Split: { 1016 action.originalBlock->getOperations().splice( 1017 action.originalBlock->end(), action.block->getOperations()); 1018 action.block->dropAllDefinedValueUses(); 1019 action.block->erase(); 1020 break; 1021 } 1022 // Undo the type conversion. 1023 case BlockActionKind::TypeConversion: { 1024 argConverter.discardRewrites(action.block); 1025 break; 1026 } 1027 } 1028 } 1029 blockActions.resize(numActionsToKeep); 1030 } 1031 1032 LogicalResult ConversionPatternRewriterImpl::remapValues( 1033 Location loc, PatternRewriter &rewriter, TypeConverter *converter, 1034 Operation::operand_range operands, SmallVectorImpl<Value> &remapped) { 1035 remapped.reserve(llvm::size(operands)); 1036 1037 SmallVector<Type, 1> legalTypes; 1038 for (auto it : llvm::enumerate(operands)) { 1039 Value operand = it.value(); 1040 Type origType = operand.getType(); 1041 1042 // If a converter was provided, get the desired legal types for this 1043 // operand. 1044 Type desiredType; 1045 if (converter) { 1046 // If there is no legal conversion, fail to match this pattern. 1047 legalTypes.clear(); 1048 if (failed(converter->convertType(origType, legalTypes))) { 1049 return notifyMatchFailure(loc, [=](Diagnostic &diag) { 1050 diag << "unable to convert type for operand #" << it.index() 1051 << ", type was " << origType; 1052 }); 1053 } 1054 // TODO: There currently isn't any mechanism to do 1->N type conversion 1055 // via the PatternRewriter replacement API, so for now we just ignore it. 1056 if (legalTypes.size() == 1) 1057 desiredType = legalTypes.front(); 1058 } else { 1059 // TODO: What we should do here is just set `desiredType` to `origType` 1060 // and then handle the necessary type conversions after the conversion 1061 // process has finished. Unfortunately a lot of patterns currently rely on 1062 // receiving the new operands even if the types change, so we keep the 1063 // original behavior here for now until all of the patterns relying on 1064 // this get updated. 1065 } 1066 Value newOperand = mapping.lookupOrDefault(operand, desiredType); 1067 1068 // Handle the case where the conversion was 1->1 and the new operand type 1069 // isn't legal. 1070 Type newOperandType = newOperand.getType(); 1071 if (converter && desiredType && newOperandType != desiredType) { 1072 // Attempt to materialize a conversion for this new value. 1073 newOperand = converter->materializeTargetConversion( 1074 rewriter, loc, desiredType, newOperand); 1075 if (!newOperand) { 1076 return notifyMatchFailure(loc, [=](Diagnostic &diag) { 1077 diag << "unable to materialize a conversion for " 1078 "operand #" 1079 << it.index() << ", from " << newOperandType << " to " 1080 << desiredType; 1081 }); 1082 } 1083 } 1084 remapped.push_back(newOperand); 1085 } 1086 return success(); 1087 } 1088 1089 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { 1090 // Check to see if this operation was replaced or its parent ignored. 1091 return replacements.count(op) || ignoredOps.count(op->getParentOp()); 1092 } 1093 1094 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { 1095 // Walk this operation and collect nested operations that define non-empty 1096 // regions. We mark such operations as 'ignored' so that we know we don't have 1097 // to convert them, or their nested ops. 1098 if (op->getNumRegions() == 0) 1099 return; 1100 op->walk([&](Operation *op) { 1101 if (llvm::any_of(op->getRegions(), 1102 [](Region ®ion) { return !region.empty(); })) 1103 ignoredOps.insert(op); 1104 }); 1105 } 1106 1107 //===----------------------------------------------------------------------===// 1108 // Type Conversion 1109 1110 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( 1111 Block *block, TypeConverter &converter, 1112 TypeConverter::SignatureConversion *conversion) { 1113 FailureOr<Block *> result = 1114 conversion ? argConverter.applySignatureConversion(block, converter, 1115 *conversion, mapping) 1116 : argConverter.convertSignature(block, converter, mapping); 1117 if (Block *newBlock = result.getValue()) { 1118 if (newBlock != block) 1119 blockActions.push_back(BlockAction::getTypeConversion(newBlock)); 1120 } 1121 return result; 1122 } 1123 1124 Block *ConversionPatternRewriterImpl::applySignatureConversion( 1125 Region *region, TypeConverter::SignatureConversion &conversion) { 1126 if (!region->empty()) { 1127 return *convertBlockSignature(®ion->front(), defaultTypeConverter, 1128 &conversion); 1129 } 1130 return nullptr; 1131 } 1132 1133 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( 1134 Region *region, TypeConverter &converter, 1135 TypeConverter::SignatureConversion *entryConversion) { 1136 argConverter.setConverter(region, &converter); 1137 if (region->empty()) 1138 return nullptr; 1139 1140 // Convert the arguments of each block within the region. 1141 FailureOr<Block *> newEntry = 1142 convertBlockSignature(®ion->front(), converter, entryConversion); 1143 for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) 1144 if (failed(convertBlockSignature(&block, converter))) 1145 return failure(); 1146 return newEntry; 1147 } 1148 1149 //===----------------------------------------------------------------------===// 1150 // Rewriter Notification Hooks 1151 1152 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, 1153 ValueRange newValues) { 1154 assert(newValues.size() == op->getNumResults()); 1155 assert(!replacements.count(op) && "operation was already replaced"); 1156 1157 // Track if any of the results changed, e.g. erased and replaced with null. 1158 bool resultChanged = false; 1159 1160 // Create mappings for each of the new result values. 1161 Value newValue, result; 1162 for (auto it : llvm::zip(newValues, op->getResults())) { 1163 std::tie(newValue, result) = it; 1164 if (!newValue) { 1165 resultChanged = true; 1166 continue; 1167 } 1168 // Remap, and check for any result type changes. 1169 mapping.map(result, newValue); 1170 resultChanged |= (newValue.getType() != result.getType()); 1171 } 1172 if (resultChanged) 1173 operationsWithChangedResults.push_back(replacements.size()); 1174 1175 // Record the requested operation replacement. 1176 TypeConverter *converter = nullptr; 1177 if (currentConversionPattern) 1178 converter = currentConversionPattern->getTypeConverter(); 1179 replacements.insert(std::make_pair(op, OpReplacement(converter))); 1180 1181 // Mark this operation as recursively ignored so that we don't need to 1182 // convert any nested operations. 1183 markNestedOpsIgnored(op); 1184 } 1185 1186 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { 1187 Region *region = block->getParent(); 1188 Block *origPrevBlock = block->getPrevNode(); 1189 blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock})); 1190 } 1191 1192 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) { 1193 blockActions.push_back(BlockAction::getCreate(block)); 1194 } 1195 1196 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, 1197 Block *continuation) { 1198 blockActions.push_back(BlockAction::getSplit(continuation, block)); 1199 } 1200 1201 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block, 1202 Block *srcBlock) { 1203 blockActions.push_back(BlockAction::getMerge(block, srcBlock)); 1204 } 1205 1206 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( 1207 Region ®ion, Region &parent, Region::iterator before) { 1208 if (region.empty()) 1209 return; 1210 Block *laterBlock = ®ion.back(); 1211 for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) { 1212 blockActions.push_back( 1213 BlockAction::getMove(laterBlock, {®ion, &earlierBlock})); 1214 laterBlock = &earlierBlock; 1215 } 1216 blockActions.push_back(BlockAction::getMove(laterBlock, {®ion, nullptr})); 1217 } 1218 1219 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( 1220 iterator_range<Region::iterator> &blocks, Location origRegionLoc) { 1221 for (Block &block : blocks) 1222 blockActions.push_back(BlockAction::getCreate(&block)); 1223 1224 // Compute the conversion set for the inlined region. 1225 auto result = computeConversionSet(blocks, origRegionLoc, createdOps); 1226 1227 // This original region has already had its conversion set computed, so there 1228 // shouldn't be any new failures. 1229 (void)result; 1230 assert(succeeded(result) && "expected region to have no unreachable blocks"); 1231 } 1232 1233 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( 1234 Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 1235 LLVM_DEBUG({ 1236 Diagnostic diag(loc, DiagnosticSeverity::Remark); 1237 reasonCallback(diag); 1238 logger.startLine() << "** Failure : " << diag.str() << "\n"; 1239 }); 1240 return failure(); 1241 } 1242 1243 //===----------------------------------------------------------------------===// 1244 // ConversionPatternRewriter 1245 //===----------------------------------------------------------------------===// 1246 1247 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) 1248 : PatternRewriter(ctx), 1249 impl(new detail::ConversionPatternRewriterImpl(*this)) {} 1250 ConversionPatternRewriter::~ConversionPatternRewriter() {} 1251 1252 /// PatternRewriter hook for replacing the results of an operation. 1253 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { 1254 LLVM_DEBUG({ 1255 impl->logger.startLine() 1256 << "** Replace : '" << op->getName() << "'(" << op << ")\n"; 1257 }); 1258 impl->notifyOpReplaced(op, newValues); 1259 } 1260 1261 /// PatternRewriter hook for erasing a dead operation. The uses of this 1262 /// operation *must* be made dead by the end of the conversion process, 1263 /// otherwise an assert will be issued. 1264 void ConversionPatternRewriter::eraseOp(Operation *op) { 1265 LLVM_DEBUG({ 1266 impl->logger.startLine() 1267 << "** Erase : '" << op->getName() << "'(" << op << ")\n"; 1268 }); 1269 SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr); 1270 impl->notifyOpReplaced(op, nullRepls); 1271 } 1272 1273 void ConversionPatternRewriter::eraseBlock(Block *block) { 1274 impl->notifyBlockIsBeingErased(block); 1275 1276 // Mark all ops for erasure. 1277 for (Operation &op : *block) 1278 eraseOp(&op); 1279 1280 // Unlink the block from its parent region. The block is kept in the block 1281 // action and will be actually destroyed when rewrites are applied. This 1282 // allows us to keep the operations in the block live and undo the removal by 1283 // re-inserting the block. 1284 block->getParent()->getBlocks().remove(block); 1285 } 1286 1287 Block *ConversionPatternRewriter::applySignatureConversion( 1288 Region *region, TypeConverter::SignatureConversion &conversion) { 1289 return impl->applySignatureConversion(region, conversion); 1290 } 1291 1292 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( 1293 Region *region, TypeConverter &converter, 1294 TypeConverter::SignatureConversion *entryConversion) { 1295 return impl->convertRegionTypes(region, converter, entryConversion); 1296 } 1297 1298 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, 1299 Value to) { 1300 LLVM_DEBUG({ 1301 Operation *parentOp = from.getOwner()->getParentOp(); 1302 impl->logger.startLine() << "** Replace Argument : '" << from 1303 << "'(in region of '" << parentOp->getName() 1304 << "'(" << from.getOwner()->getParentOp() << ")\n"; 1305 }); 1306 impl->argReplacements.push_back(from); 1307 impl->mapping.map(impl->mapping.lookupOrDefault(from), to); 1308 } 1309 1310 /// Return the converted value that replaces 'key'. Return 'key' if there is 1311 /// no such a converted value. 1312 Value ConversionPatternRewriter::getRemappedValue(Value key) { 1313 return impl->mapping.lookupOrDefault(key); 1314 } 1315 1316 /// PatternRewriter hook for creating a new block with the given arguments. 1317 void ConversionPatternRewriter::notifyBlockCreated(Block *block) { 1318 impl->notifyCreatedBlock(block); 1319 } 1320 1321 /// PatternRewriter hook for splitting a block into two parts. 1322 Block *ConversionPatternRewriter::splitBlock(Block *block, 1323 Block::iterator before) { 1324 auto *continuation = PatternRewriter::splitBlock(block, before); 1325 impl->notifySplitBlock(block, continuation); 1326 return continuation; 1327 } 1328 1329 /// PatternRewriter hook for merging a block into another. 1330 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, 1331 ValueRange argValues) { 1332 impl->notifyBlocksBeingMerged(dest, source); 1333 assert(llvm::all_of(source->getPredecessors(), 1334 [dest](Block *succ) { return succ == dest; }) && 1335 "expected 'source' to have no predecessors or only 'dest'"); 1336 assert(argValues.size() == source->getNumArguments() && 1337 "incorrect # of argument replacement values"); 1338 for (auto it : llvm::zip(source->getArguments(), argValues)) 1339 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); 1340 dest->getOperations().splice(dest->end(), source->getOperations()); 1341 eraseBlock(source); 1342 } 1343 1344 /// PatternRewriter hook for moving blocks out of a region. 1345 void ConversionPatternRewriter::inlineRegionBefore(Region ®ion, 1346 Region &parent, 1347 Region::iterator before) { 1348 impl->notifyRegionIsBeingInlinedBefore(region, parent, before); 1349 PatternRewriter::inlineRegionBefore(region, parent, before); 1350 } 1351 1352 /// PatternRewriter hook for cloning blocks of one region into another. 1353 void ConversionPatternRewriter::cloneRegionBefore( 1354 Region ®ion, Region &parent, Region::iterator before, 1355 BlockAndValueMapping &mapping) { 1356 if (region.empty()) 1357 return; 1358 PatternRewriter::cloneRegionBefore(region, parent, before, mapping); 1359 1360 // Collect the range of the cloned blocks. 1361 auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); 1362 auto clonedBlocks = llvm::make_range(clonedBeginIt, before); 1363 impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc()); 1364 } 1365 1366 /// PatternRewriter hook for creating a new operation. 1367 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) { 1368 LLVM_DEBUG({ 1369 impl->logger.startLine() 1370 << "** Insert : '" << op->getName() << "'(" << op << ")\n"; 1371 }); 1372 impl->createdOps.push_back(op); 1373 } 1374 1375 /// PatternRewriter hook for updating the root operation in-place. 1376 void ConversionPatternRewriter::startRootUpdate(Operation *op) { 1377 #ifndef NDEBUG 1378 impl->pendingRootUpdates.insert(op); 1379 #endif 1380 impl->rootUpdates.emplace_back(op); 1381 } 1382 1383 /// PatternRewriter hook for updating the root operation in-place. 1384 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { 1385 // There is nothing to do here, we only need to track the operation at the 1386 // start of the update. 1387 #ifndef NDEBUG 1388 assert(impl->pendingRootUpdates.erase(op) && 1389 "operation did not have a pending in-place update"); 1390 #endif 1391 } 1392 1393 /// PatternRewriter hook for updating the root operation in-place. 1394 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { 1395 #ifndef NDEBUG 1396 assert(impl->pendingRootUpdates.erase(op) && 1397 "operation did not have a pending in-place update"); 1398 #endif 1399 // Erase the last update for this operation. 1400 auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; 1401 auto &rootUpdates = impl->rootUpdates; 1402 auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); 1403 rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it)); 1404 } 1405 1406 /// PatternRewriter hook for notifying match failure reasons. 1407 LogicalResult ConversionPatternRewriter::notifyMatchFailure( 1408 Operation *op, function_ref<void(Diagnostic &)> reasonCallback) { 1409 return impl->notifyMatchFailure(op->getLoc(), reasonCallback); 1410 } 1411 1412 /// Return a reference to the internal implementation. 1413 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { 1414 return *impl; 1415 } 1416 1417 //===----------------------------------------------------------------------===// 1418 // ConversionPattern 1419 //===----------------------------------------------------------------------===// 1420 1421 /// Attempt to match and rewrite the IR root at the specified operation. 1422 LogicalResult 1423 ConversionPattern::matchAndRewrite(Operation *op, 1424 PatternRewriter &rewriter) const { 1425 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); 1426 auto &rewriterImpl = dialectRewriter.getImpl(); 1427 1428 // Track the current conversion pattern in the rewriter. 1429 assert(!rewriterImpl.currentConversionPattern && 1430 "already inside of a pattern rewrite"); 1431 llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard( 1432 rewriterImpl.currentConversionPattern, this); 1433 1434 // Remap the operands of the operation. 1435 SmallVector<Value, 4> operands; 1436 if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter, 1437 getTypeConverter(), op->getOperands(), 1438 operands))) { 1439 return failure(); 1440 } 1441 return matchAndRewrite(op, operands, dialectRewriter); 1442 } 1443 1444 //===----------------------------------------------------------------------===// 1445 // OperationLegalizer 1446 //===----------------------------------------------------------------------===// 1447 1448 namespace { 1449 /// A set of rewrite patterns that can be used to legalize a given operation. 1450 using LegalizationPatterns = SmallVector<const Pattern *, 1>; 1451 1452 /// This class defines a recursive operation legalizer. 1453 class OperationLegalizer { 1454 public: 1455 using LegalizationAction = ConversionTarget::LegalizationAction; 1456 1457 OperationLegalizer(ConversionTarget &targetInfo, 1458 const FrozenRewritePatternList &patterns); 1459 1460 /// Returns true if the given operation is known to be illegal on the target. 1461 bool isIllegal(Operation *op) const; 1462 1463 /// Attempt to legalize the given operation. Returns success if the operation 1464 /// was legalized, failure otherwise. 1465 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); 1466 1467 /// Returns the conversion target in use by the legalizer. 1468 ConversionTarget &getTarget() { return target; } 1469 1470 private: 1471 /// Attempt to legalize the given operation by folding it. 1472 LogicalResult legalizeWithFold(Operation *op, 1473 ConversionPatternRewriter &rewriter); 1474 1475 /// Attempt to legalize the given operation by applying a pattern. Returns 1476 /// success if the operation was legalized, failure otherwise. 1477 LogicalResult legalizeWithPattern(Operation *op, 1478 ConversionPatternRewriter &rewriter); 1479 1480 /// Return true if the given pattern may be applied to the given operation, 1481 /// false otherwise. 1482 bool canApplyPattern(Operation *op, const Pattern &pattern, 1483 ConversionPatternRewriter &rewriter); 1484 1485 /// Legalize the resultant IR after successfully applying the given pattern. 1486 LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, 1487 ConversionPatternRewriter &rewriter, 1488 RewriterState &curState); 1489 1490 /// Legalizes the actions registered during the execution of a pattern. 1491 LogicalResult legalizePatternBlockActions(Operation *op, 1492 ConversionPatternRewriter &rewriter, 1493 ConversionPatternRewriterImpl &impl, 1494 RewriterState &state, 1495 RewriterState &newState); 1496 LogicalResult legalizePatternCreatedOperations( 1497 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 1498 RewriterState &state, RewriterState &newState); 1499 LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, 1500 ConversionPatternRewriterImpl &impl, 1501 RewriterState &state, 1502 RewriterState &newState); 1503 1504 //===--------------------------------------------------------------------===// 1505 // Cost Model 1506 //===--------------------------------------------------------------------===// 1507 1508 /// Build an optimistic legalization graph given the provided patterns. This 1509 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with 1510 /// patterns for operations that are not directly legal, but may be 1511 /// transitively legal for the current target given the provided patterns. 1512 void buildLegalizationGraph( 1513 LegalizationPatterns &anyOpLegalizerPatterns, 1514 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1515 1516 /// Compute the benefit of each node within the computed legalization graph. 1517 /// This orders the patterns within 'legalizerPatterns' based upon two 1518 /// criteria: 1519 /// 1) Prefer patterns that have the lowest legalization depth, i.e. 1520 /// represent the more direct mapping to the target. 1521 /// 2) When comparing patterns with the same legalization depth, prefer the 1522 /// pattern with the highest PatternBenefit. This allows for users to 1523 /// prefer specific legalizations over others. 1524 void computeLegalizationGraphBenefit( 1525 LegalizationPatterns &anyOpLegalizerPatterns, 1526 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1527 1528 /// Compute the legalization depth when legalizing an operation of the given 1529 /// type. 1530 unsigned computeOpLegalizationDepth( 1531 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 1532 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1533 1534 /// Apply the conversion cost model to the given set of patterns, and return 1535 /// the smallest legalization depth of any of the patterns. See 1536 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. 1537 unsigned applyCostModelToPatterns( 1538 LegalizationPatterns &patterns, 1539 DenseMap<OperationName, unsigned> &minOpPatternDepth, 1540 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1541 1542 /// The current set of patterns that have been applied. 1543 SmallPtrSet<const Pattern *, 8> appliedPatterns; 1544 1545 /// The legalization information provided by the target. 1546 ConversionTarget ⌖ 1547 1548 /// The pattern applicator to use for conversions. 1549 PatternApplicator applicator; 1550 }; 1551 } // namespace 1552 1553 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, 1554 const FrozenRewritePatternList &patterns) 1555 : target(targetInfo), applicator(patterns) { 1556 // The set of patterns that can be applied to illegal operations to transform 1557 // them into legal ones. 1558 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; 1559 LegalizationPatterns anyOpLegalizerPatterns; 1560 1561 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); 1562 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); 1563 } 1564 1565 bool OperationLegalizer::isIllegal(Operation *op) const { 1566 // Check if the target explicitly marked this operation as illegal. 1567 return target.getOpAction(op->getName()) == LegalizationAction::Illegal; 1568 } 1569 1570 LogicalResult 1571 OperationLegalizer::legalize(Operation *op, 1572 ConversionPatternRewriter &rewriter) { 1573 #ifndef NDEBUG 1574 const char *logLineComment = 1575 "//===-------------------------------------------===//\n"; 1576 1577 auto &rewriterImpl = rewriter.getImpl(); 1578 #endif 1579 LLVM_DEBUG({ 1580 auto &os = rewriterImpl.logger; 1581 os.getOStream() << "\n"; 1582 os.startLine() << logLineComment; 1583 os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op 1584 << ") {\n"; 1585 os.indent(); 1586 1587 // If the operation has no regions, just print it here. 1588 if (op->getNumRegions() == 0) { 1589 op->print(os.startLine(), OpPrintingFlags().printGenericOpForm()); 1590 os.getOStream() << "\n\n"; 1591 } 1592 }); 1593 1594 // Check if this operation is legal on the target. 1595 if (auto legalityInfo = target.isLegal(op)) { 1596 LLVM_DEBUG({ 1597 logSuccess( 1598 rewriterImpl.logger, "operation marked legal by the target{0}", 1599 legalityInfo->isRecursivelyLegal 1600 ? "; NOTE: operation is recursively legal; skipping internals" 1601 : ""); 1602 rewriterImpl.logger.startLine() << logLineComment; 1603 }); 1604 1605 // If this operation is recursively legal, mark its children as ignored so 1606 // that we don't consider them for legalization. 1607 if (legalityInfo->isRecursivelyLegal) 1608 rewriter.getImpl().markNestedOpsIgnored(op); 1609 return success(); 1610 } 1611 1612 // Check to see if the operation is ignored and doesn't need to be converted. 1613 if (rewriter.getImpl().isOpIgnored(op)) { 1614 LLVM_DEBUG({ 1615 logSuccess(rewriterImpl.logger, 1616 "operation marked 'ignored' during conversion"); 1617 rewriterImpl.logger.startLine() << logLineComment; 1618 }); 1619 return success(); 1620 } 1621 1622 // If the operation isn't legal, try to fold it in-place. 1623 // TODO: Should we always try to do this, even if the op is 1624 // already legal? 1625 if (succeeded(legalizeWithFold(op, rewriter))) { 1626 LLVM_DEBUG({ 1627 logSuccess(rewriterImpl.logger, "operation was folded"); 1628 rewriterImpl.logger.startLine() << logLineComment; 1629 }); 1630 return success(); 1631 } 1632 1633 // Otherwise, we need to apply a legalization pattern to this operation. 1634 if (succeeded(legalizeWithPattern(op, rewriter))) { 1635 LLVM_DEBUG({ 1636 logSuccess(rewriterImpl.logger, ""); 1637 rewriterImpl.logger.startLine() << logLineComment; 1638 }); 1639 return success(); 1640 } 1641 1642 LLVM_DEBUG({ 1643 logFailure(rewriterImpl.logger, "no matched legalization pattern"); 1644 rewriterImpl.logger.startLine() << logLineComment; 1645 }); 1646 return failure(); 1647 } 1648 1649 LogicalResult 1650 OperationLegalizer::legalizeWithFold(Operation *op, 1651 ConversionPatternRewriter &rewriter) { 1652 auto &rewriterImpl = rewriter.getImpl(); 1653 RewriterState curState = rewriterImpl.getCurrentState(); 1654 1655 LLVM_DEBUG({ 1656 rewriterImpl.logger.startLine() << "* Fold {\n"; 1657 rewriterImpl.logger.indent(); 1658 }); 1659 1660 // Try to fold the operation. 1661 SmallVector<Value, 2> replacementValues; 1662 rewriter.setInsertionPoint(op); 1663 if (failed(rewriter.tryFold(op, replacementValues))) { 1664 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); 1665 return failure(); 1666 } 1667 1668 // Insert a replacement for 'op' with the folded replacement values. 1669 rewriter.replaceOp(op, replacementValues); 1670 1671 // Recursively legalize any new constant operations. 1672 for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); 1673 i != e; ++i) { 1674 Operation *cstOp = rewriterImpl.createdOps[i]; 1675 if (failed(legalize(cstOp, rewriter))) { 1676 LLVM_DEBUG(logFailure(rewriterImpl.logger, 1677 "generated constant '{0}' was illegal", 1678 cstOp->getName())); 1679 rewriterImpl.resetState(curState); 1680 return failure(); 1681 } 1682 } 1683 1684 LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); 1685 return success(); 1686 } 1687 1688 LogicalResult 1689 OperationLegalizer::legalizeWithPattern(Operation *op, 1690 ConversionPatternRewriter &rewriter) { 1691 auto &rewriterImpl = rewriter.getImpl(); 1692 1693 // Functor that returns if the given pattern may be applied. 1694 auto canApply = [&](const Pattern &pattern) { 1695 return canApplyPattern(op, pattern, rewriter); 1696 }; 1697 1698 // Functor that cleans up the rewriter state after a pattern failed to match. 1699 RewriterState curState = rewriterImpl.getCurrentState(); 1700 auto onFailure = [&](const Pattern &pattern) { 1701 LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); 1702 rewriterImpl.resetState(curState); 1703 appliedPatterns.erase(&pattern); 1704 }; 1705 1706 // Functor that performs additional legalization when a pattern is 1707 // successfully applied. 1708 auto onSuccess = [&](const Pattern &pattern) { 1709 auto result = legalizePatternResult(op, pattern, rewriter, curState); 1710 appliedPatterns.erase(&pattern); 1711 if (failed(result)) 1712 rewriterImpl.resetState(curState); 1713 return result; 1714 }; 1715 1716 // Try to match and rewrite a pattern on this operation. 1717 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, 1718 onSuccess); 1719 } 1720 1721 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, 1722 ConversionPatternRewriter &rewriter) { 1723 LLVM_DEBUG({ 1724 auto &os = rewriter.getImpl().logger; 1725 os.getOStream() << "\n"; 1726 os.startLine() << "* Pattern : '" << op->getName() << " -> ("; 1727 llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs()); 1728 os.getOStream() << ")' {\n"; 1729 os.indent(); 1730 }); 1731 1732 // Ensure that we don't cycle by not allowing the same pattern to be 1733 // applied twice in the same recursion stack if it is not known to be safe. 1734 if (!pattern.hasBoundedRewriteRecursion() && 1735 !appliedPatterns.insert(&pattern).second) { 1736 LLVM_DEBUG( 1737 logFailure(rewriter.getImpl().logger, "pattern was already applied")); 1738 return false; 1739 } 1740 return true; 1741 } 1742 1743 LogicalResult 1744 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, 1745 ConversionPatternRewriter &rewriter, 1746 RewriterState &curState) { 1747 auto &impl = rewriter.getImpl(); 1748 1749 #ifndef NDEBUG 1750 assert(impl.pendingRootUpdates.empty() && "dangling root updates"); 1751 #endif 1752 1753 // Check that the root was either replaced or updated in place. 1754 auto replacedRoot = [&] { 1755 return llvm::any_of( 1756 llvm::drop_begin(impl.replacements, curState.numReplacements), 1757 [op](auto &it) { return it.first == op; }); 1758 }; 1759 auto updatedRootInPlace = [&] { 1760 return llvm::any_of( 1761 llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), 1762 [op](auto &state) { return state.getOperation() == op; }); 1763 }; 1764 (void)replacedRoot; 1765 (void)updatedRootInPlace; 1766 assert((replacedRoot() || updatedRootInPlace()) && 1767 "expected pattern to replace the root operation"); 1768 1769 // Legalize each of the actions registered during application. 1770 RewriterState newState = impl.getCurrentState(); 1771 if (failed(legalizePatternBlockActions(op, rewriter, impl, curState, 1772 newState)) || 1773 failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || 1774 failed(legalizePatternCreatedOperations(rewriter, impl, curState, 1775 newState))) { 1776 return failure(); 1777 } 1778 1779 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); 1780 return success(); 1781 } 1782 1783 LogicalResult OperationLegalizer::legalizePatternBlockActions( 1784 Operation *op, ConversionPatternRewriter &rewriter, 1785 ConversionPatternRewriterImpl &impl, RewriterState &state, 1786 RewriterState &newState) { 1787 SmallPtrSet<Operation *, 16> operationsToIgnore; 1788 1789 // If the pattern moved or created any blocks, make sure the types of block 1790 // arguments get legalized. 1791 for (int i = state.numBlockActions, e = newState.numBlockActions; i != e; 1792 ++i) { 1793 auto &action = impl.blockActions[i]; 1794 if (action.kind == BlockActionKind::TypeConversion || 1795 action.kind == BlockActionKind::Erase) 1796 continue; 1797 // Only check blocks outside of the current operation. 1798 Operation *parentOp = action.block->getParentOp(); 1799 if (!parentOp || parentOp == op || action.block->getNumArguments() == 0) 1800 continue; 1801 1802 // If the region of the block has a type converter, try to convert the block 1803 // directly. 1804 if (auto *converter = 1805 impl.argConverter.getConverter(action.block->getParent())) { 1806 if (failed(impl.convertBlockSignature(action.block, *converter))) { 1807 LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " 1808 "block")); 1809 return failure(); 1810 } 1811 continue; 1812 } 1813 1814 // Otherwise, check that this operation isn't one generated by this pattern. 1815 // This is because we will attempt to legalize the parent operation, and 1816 // blocks in regions created by this pattern will already be legalized later 1817 // on. If we haven't built the set yet, build it now. 1818 if (operationsToIgnore.empty()) { 1819 auto createdOps = ArrayRef<Operation *>(impl.createdOps) 1820 .drop_front(state.numCreatedOps); 1821 operationsToIgnore.insert(createdOps.begin(), createdOps.end()); 1822 } 1823 1824 // If this operation should be considered for re-legalization, try it. 1825 if (operationsToIgnore.insert(parentOp).second && 1826 failed(legalize(parentOp, rewriter))) { 1827 LLVM_DEBUG(logFailure( 1828 impl.logger, "operation '{0}'({1}) became illegal after block action", 1829 parentOp->getName(), parentOp)); 1830 return failure(); 1831 } 1832 } 1833 return success(); 1834 } 1835 LogicalResult OperationLegalizer::legalizePatternCreatedOperations( 1836 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 1837 RewriterState &state, RewriterState &newState) { 1838 for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) { 1839 Operation *op = impl.createdOps[i]; 1840 if (failed(legalize(op, rewriter))) { 1841 LLVM_DEBUG(logFailure(impl.logger, 1842 "generated operation '{0}'({1}) was illegal", 1843 op->getName(), op)); 1844 return failure(); 1845 } 1846 } 1847 return success(); 1848 } 1849 LogicalResult OperationLegalizer::legalizePatternRootUpdates( 1850 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 1851 RewriterState &state, RewriterState &newState) { 1852 for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { 1853 Operation *op = impl.rootUpdates[i].getOperation(); 1854 if (failed(legalize(op, rewriter))) { 1855 LLVM_DEBUG(logFailure(impl.logger, 1856 "operation updated in-place '{0}' was illegal", 1857 op->getName())); 1858 return failure(); 1859 } 1860 } 1861 return success(); 1862 } 1863 1864 //===----------------------------------------------------------------------===// 1865 // Cost Model 1866 1867 void OperationLegalizer::buildLegalizationGraph( 1868 LegalizationPatterns &anyOpLegalizerPatterns, 1869 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 1870 // A mapping between an operation and a set of operations that can be used to 1871 // generate it. 1872 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; 1873 // A mapping between an operation and any currently invalid patterns it has. 1874 DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns; 1875 // A worklist of patterns to consider for legality. 1876 llvm::SetVector<const Pattern *> patternWorklist; 1877 1878 // Build the mapping from operations to the parent ops that may generate them. 1879 applicator.walkAllPatterns([&](const Pattern &pattern) { 1880 Optional<OperationName> root = pattern.getRootKind(); 1881 1882 // If the pattern has no specific root, we can't analyze the relationship 1883 // between the root op and generated operations. Given that, add all such 1884 // patterns to the legalization set. 1885 if (!root) { 1886 anyOpLegalizerPatterns.push_back(&pattern); 1887 return; 1888 } 1889 1890 // Skip operations that are always known to be legal. 1891 if (target.getOpAction(*root) == LegalizationAction::Legal) 1892 return; 1893 1894 // Add this pattern to the invalid set for the root op and record this root 1895 // as a parent for any generated operations. 1896 invalidPatterns[*root].insert(&pattern); 1897 for (auto op : pattern.getGeneratedOps()) 1898 parentOps[op].insert(*root); 1899 1900 // Add this pattern to the worklist. 1901 patternWorklist.insert(&pattern); 1902 }); 1903 1904 // If there are any patterns that don't have a specific root kind, we can't 1905 // make direct assumptions about what operations will never be legalized. 1906 // Note: Technically we could, but it would require an analysis that may 1907 // recurse into itself. It would be better to perform this kind of filtering 1908 // at a higher level than here anyways. 1909 if (!anyOpLegalizerPatterns.empty()) { 1910 for (const Pattern *pattern : patternWorklist) 1911 legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 1912 return; 1913 } 1914 1915 while (!patternWorklist.empty()) { 1916 auto *pattern = patternWorklist.pop_back_val(); 1917 1918 // Check to see if any of the generated operations are invalid. 1919 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { 1920 Optional<LegalizationAction> action = target.getOpAction(op); 1921 return !legalizerPatterns.count(op) && 1922 (!action || action == LegalizationAction::Illegal); 1923 })) 1924 continue; 1925 1926 // Otherwise, if all of the generated operation are valid, this op is now 1927 // legal so add all of the child patterns to the worklist. 1928 legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 1929 invalidPatterns[*pattern->getRootKind()].erase(pattern); 1930 1931 // Add any invalid patterns of the parent operations to see if they have now 1932 // become legal. 1933 for (auto op : parentOps[*pattern->getRootKind()]) 1934 patternWorklist.set_union(invalidPatterns[op]); 1935 } 1936 } 1937 1938 void OperationLegalizer::computeLegalizationGraphBenefit( 1939 LegalizationPatterns &anyOpLegalizerPatterns, 1940 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 1941 // The smallest pattern depth, when legalizing an operation. 1942 DenseMap<OperationName, unsigned> minOpPatternDepth; 1943 1944 // For each operation that is transitively legal, compute a cost for it. 1945 for (auto &opIt : legalizerPatterns) 1946 if (!minOpPatternDepth.count(opIt.first)) 1947 computeOpLegalizationDepth(opIt.first, minOpPatternDepth, 1948 legalizerPatterns); 1949 1950 // Apply the cost model to the patterns that can match any operation. Those 1951 // with a specific operation type are already resolved when computing the op 1952 // legalization depth. 1953 if (!anyOpLegalizerPatterns.empty()) 1954 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth, 1955 legalizerPatterns); 1956 1957 // Apply a cost model to the pattern applicator. We order patterns first by 1958 // depth then benefit. `legalizerPatterns` contains per-op patterns by 1959 // decreasing benefit. 1960 applicator.applyCostModel([&](const Pattern &pattern) { 1961 ArrayRef<const Pattern *> orderedPatternList; 1962 if (Optional<OperationName> rootName = pattern.getRootKind()) 1963 orderedPatternList = legalizerPatterns[*rootName]; 1964 else 1965 orderedPatternList = anyOpLegalizerPatterns; 1966 1967 // If the pattern is not found, then it was removed and cannot be matched. 1968 auto it = llvm::find(orderedPatternList, &pattern); 1969 if (it == orderedPatternList.end()) 1970 return PatternBenefit::impossibleToMatch(); 1971 1972 // Patterns found earlier in the list have higher benefit. 1973 return PatternBenefit(std::distance(it, orderedPatternList.end())); 1974 }); 1975 } 1976 1977 unsigned OperationLegalizer::computeOpLegalizationDepth( 1978 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 1979 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 1980 // Check for existing depth. 1981 auto depthIt = minOpPatternDepth.find(op); 1982 if (depthIt != minOpPatternDepth.end()) 1983 return depthIt->second; 1984 1985 // If a mapping for this operation does not exist, then this operation 1986 // is always legal. Return 0 as the depth for a directly legal operation. 1987 auto opPatternsIt = legalizerPatterns.find(op); 1988 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) 1989 return 0u; 1990 1991 // Record this initial depth in case we encounter this op again when 1992 // recursively computing the depth. 1993 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max()); 1994 1995 // Apply the cost model to the operation patterns, and update the minimum 1996 // depth. 1997 unsigned minDepth = applyCostModelToPatterns( 1998 opPatternsIt->second, minOpPatternDepth, legalizerPatterns); 1999 minOpPatternDepth[op] = minDepth; 2000 return minDepth; 2001 } 2002 2003 unsigned OperationLegalizer::applyCostModelToPatterns( 2004 LegalizationPatterns &patterns, 2005 DenseMap<OperationName, unsigned> &minOpPatternDepth, 2006 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2007 unsigned minDepth = std::numeric_limits<unsigned>::max(); 2008 2009 // Compute the depth for each pattern within the set. 2010 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth; 2011 patternsByDepth.reserve(patterns.size()); 2012 for (const Pattern *pattern : patterns) { 2013 unsigned depth = 0; 2014 for (auto generatedOp : pattern->getGeneratedOps()) { 2015 unsigned generatedOpDepth = computeOpLegalizationDepth( 2016 generatedOp, minOpPatternDepth, legalizerPatterns); 2017 depth = std::max(depth, generatedOpDepth + 1); 2018 } 2019 patternsByDepth.emplace_back(pattern, depth); 2020 2021 // Update the minimum depth of the pattern list. 2022 minDepth = std::min(minDepth, depth); 2023 } 2024 2025 // If the operation only has one legalization pattern, there is no need to 2026 // sort them. 2027 if (patternsByDepth.size() == 1) 2028 return minDepth; 2029 2030 // Sort the patterns by those likely to be the most beneficial. 2031 llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(), 2032 [](const std::pair<const Pattern *, unsigned> *lhs, 2033 const std::pair<const Pattern *, unsigned> *rhs) { 2034 // First sort by the smaller pattern legalization 2035 // depth. 2036 if (lhs->second != rhs->second) 2037 return llvm::array_pod_sort_comparator<unsigned>( 2038 &lhs->second, &rhs->second); 2039 2040 // Then sort by the larger pattern benefit. 2041 auto lhsBenefit = lhs->first->getBenefit(); 2042 auto rhsBenefit = rhs->first->getBenefit(); 2043 return llvm::array_pod_sort_comparator<PatternBenefit>( 2044 &rhsBenefit, &lhsBenefit); 2045 }); 2046 2047 // Update the legalization pattern to use the new sorted list. 2048 patterns.clear(); 2049 for (auto &patternIt : patternsByDepth) 2050 patterns.push_back(patternIt.first); 2051 return minDepth; 2052 } 2053 2054 //===----------------------------------------------------------------------===// 2055 // OperationConverter 2056 //===----------------------------------------------------------------------===// 2057 namespace { 2058 enum OpConversionMode { 2059 // In this mode, the conversion will ignore failed conversions to allow 2060 // illegal operations to co-exist in the IR. 2061 Partial, 2062 2063 // In this mode, all operations must be legal for the given target for the 2064 // conversion to succeed. 2065 Full, 2066 2067 // In this mode, operations are analyzed for legality. No actual rewrites are 2068 // applied to the operations on success. 2069 Analysis, 2070 }; 2071 2072 // This class converts operations to a given conversion target via a set of 2073 // rewrite patterns. The conversion behaves differently depending on the 2074 // conversion mode. 2075 struct OperationConverter { 2076 explicit OperationConverter(ConversionTarget &target, 2077 const FrozenRewritePatternList &patterns, 2078 OpConversionMode mode, 2079 DenseSet<Operation *> *trackedOps = nullptr) 2080 : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} 2081 2082 /// Converts the given operations to the conversion target. 2083 LogicalResult convertOperations(ArrayRef<Operation *> ops); 2084 2085 private: 2086 /// Converts an operation with the given rewriter. 2087 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); 2088 2089 /// This method is called after the conversion process to legalize any 2090 /// remaining artifacts and complete the conversion. 2091 LogicalResult finalize(ConversionPatternRewriter &rewriter); 2092 2093 /// Legalize the types of converted block arguments. 2094 LogicalResult 2095 legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, 2096 ConversionPatternRewriterImpl &rewriterImpl); 2097 2098 /// Legalize an operation result that was marked as "erased". 2099 LogicalResult 2100 legalizeErasedResult(Operation *op, OpResult result, 2101 ConversionPatternRewriterImpl &rewriterImpl); 2102 2103 /// Legalize an operation result that was replaced with a value of a different 2104 /// type. 2105 LogicalResult 2106 legalizeChangedResultType(Operation *op, OpResult result, Value newValue, 2107 TypeConverter *replConverter, 2108 ConversionPatternRewriter &rewriter, 2109 ConversionPatternRewriterImpl &rewriterImpl); 2110 2111 /// The legalizer to use when converting operations. 2112 OperationLegalizer opLegalizer; 2113 2114 /// The conversion mode to use when legalizing operations. 2115 OpConversionMode mode; 2116 2117 /// A set of pre-existing operations. When mode == OpConversionMode::Analysis, 2118 /// this is populated with ops found to be legalizable to the target. 2119 /// When mode == OpConversionMode::Partial, this is populated with ops found 2120 /// *not* to be legalizable to the target. 2121 DenseSet<Operation *> *trackedOps; 2122 }; 2123 } // end anonymous namespace 2124 2125 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, 2126 Operation *op) { 2127 // Legalize the given operation. 2128 if (failed(opLegalizer.legalize(op, rewriter))) { 2129 // Handle the case of a failed conversion for each of the different modes. 2130 // Full conversions expect all operations to be converted. 2131 if (mode == OpConversionMode::Full) 2132 return op->emitError() 2133 << "failed to legalize operation '" << op->getName() << "'"; 2134 // Partial conversions allow conversions to fail iff the operation was not 2135 // explicitly marked as illegal. If the user provided a nonlegalizableOps 2136 // set, non-legalizable ops are included. 2137 if (mode == OpConversionMode::Partial) { 2138 if (opLegalizer.isIllegal(op)) 2139 return op->emitError() 2140 << "failed to legalize operation '" << op->getName() 2141 << "' that was explicitly marked illegal"; 2142 if (trackedOps) 2143 trackedOps->insert(op); 2144 } 2145 } else if (mode == OpConversionMode::Analysis) { 2146 // Analysis conversions don't fail if any operations fail to legalize, 2147 // they are only interested in the operations that were successfully 2148 // legalized. 2149 trackedOps->insert(op); 2150 } 2151 return success(); 2152 } 2153 2154 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { 2155 if (ops.empty()) 2156 return success(); 2157 ConversionTarget &target = opLegalizer.getTarget(); 2158 2159 // Compute the set of operations and blocks to convert. 2160 std::vector<Operation *> toConvert; 2161 for (auto *op : ops) { 2162 toConvert.emplace_back(op); 2163 for (auto ®ion : op->getRegions()) 2164 if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), 2165 toConvert, &target))) 2166 return failure(); 2167 } 2168 2169 // Convert each operation and discard rewrites on failure. 2170 ConversionPatternRewriter rewriter(ops.front()->getContext()); 2171 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2172 for (auto *op : toConvert) 2173 if (failed(convert(rewriter, op))) 2174 return rewriterImpl.discardRewrites(), failure(); 2175 2176 // Now that all of the operations have been converted, finalize the conversion 2177 // process to ensure any lingering conversion artifacts are cleaned up and 2178 // legalized. 2179 if (failed(finalize(rewriter))) 2180 return rewriterImpl.discardRewrites(), failure(); 2181 2182 // After a successful conversion, apply rewrites if this is not an analysis 2183 // conversion. 2184 if (mode == OpConversionMode::Analysis) 2185 rewriterImpl.discardRewrites(); 2186 else 2187 rewriterImpl.applyRewrites(); 2188 return success(); 2189 } 2190 2191 LogicalResult 2192 OperationConverter::finalize(ConversionPatternRewriter &rewriter) { 2193 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2194 2195 // Legalize converted block arguments. 2196 if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) 2197 return failure(); 2198 2199 // Process requested operation replacements. 2200 for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); 2201 i != e; ++i) { 2202 unsigned replIdx = rewriterImpl.operationsWithChangedResults[i]; 2203 auto &repl = *(rewriterImpl.replacements.begin() + replIdx); 2204 for (OpResult result : repl.first->getResults()) { 2205 Value newValue = rewriterImpl.mapping.lookupOrNull(result); 2206 2207 // If the operation result was replaced with null, all of the uses of this 2208 // value should be replaced. 2209 if (!newValue) { 2210 if (failed(legalizeErasedResult(repl.first, result, rewriterImpl))) 2211 return failure(); 2212 continue; 2213 } 2214 2215 // Otherwise, check to see if the type of the result changed. 2216 if (result.getType() == newValue.getType()) 2217 continue; 2218 2219 // Legalize this result. 2220 rewriter.setInsertionPoint(repl.first); 2221 if (failed(legalizeChangedResultType(repl.first, result, newValue, 2222 repl.second.converter, rewriter, 2223 rewriterImpl))) 2224 return failure(); 2225 2226 // Update the end iterator for this loop in the case it was updated 2227 // when legalizing generated conversion operations. 2228 e = rewriterImpl.operationsWithChangedResults.size(); 2229 } 2230 } 2231 return success(); 2232 } 2233 2234 LogicalResult OperationConverter::legalizeConvertedArgumentTypes( 2235 ConversionPatternRewriter &rewriter, 2236 ConversionPatternRewriterImpl &rewriterImpl) { 2237 // Functor used to check if all users of a value will be dead after 2238 // conversion. 2239 auto findLiveUser = [&](Value val) { 2240 auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { 2241 return rewriterImpl.isOpIgnored(user); 2242 }); 2243 return liveUserIt == val.user_end() ? nullptr : *liveUserIt; 2244 }; 2245 2246 // Materialize any necessary conversions for converted block arguments that 2247 // are still live. 2248 size_t numCreatedOps = rewriterImpl.createdOps.size(); 2249 if (failed(rewriterImpl.argConverter.materializeLiveConversions( 2250 rewriterImpl.mapping, rewriter, findLiveUser))) 2251 return failure(); 2252 2253 // Legalize any newly created operations during argument materialization. 2254 for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) { 2255 if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { 2256 return rewriterImpl.createdOps[i]->emitError() 2257 << "failed to legalize conversion operation generated for block " 2258 "argument that remained live after conversion"; 2259 } 2260 } 2261 return success(); 2262 } 2263 2264 LogicalResult OperationConverter::legalizeErasedResult( 2265 Operation *op, OpResult result, 2266 ConversionPatternRewriterImpl &rewriterImpl) { 2267 // If the operation result was replaced with null, all of the uses of this 2268 // value should be replaced. 2269 auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { 2270 return rewriterImpl.isOpIgnored(user); 2271 }); 2272 if (liveUserIt != result.user_end()) { 2273 InFlightDiagnostic diag = op->emitError("failed to legalize operation '") 2274 << op->getName() << "' marked as erased"; 2275 diag.attachNote(liveUserIt->getLoc()) 2276 << "found live user of result #" << result.getResultNumber() << ": " 2277 << *liveUserIt; 2278 return failure(); 2279 } 2280 return success(); 2281 } 2282 2283 LogicalResult OperationConverter::legalizeChangedResultType( 2284 Operation *op, OpResult result, Value newValue, 2285 TypeConverter *replConverter, ConversionPatternRewriter &rewriter, 2286 ConversionPatternRewriterImpl &rewriterImpl) { 2287 // Walk the users of this value to see if there are any live users that 2288 // weren't replaced during conversion. 2289 auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { 2290 return rewriterImpl.isOpIgnored(user); 2291 }); 2292 if (liveUserIt == result.user_end()) 2293 return success(); 2294 2295 // If the replacement has a type converter, attempt to materialize a 2296 // conversion back to the original type. 2297 if (!replConverter) { 2298 // TODO: We should emit an error here, similarly to the case where the 2299 // result is replaced with null. Unfortunately a lot of existing 2300 // patterns rely on this behavior, so until those patterns are updated 2301 // we keep the legacy behavior here of just forwarding the new value. 2302 return success(); 2303 } 2304 2305 // Track the number of created operations so that new ones can be legalized. 2306 size_t numCreatedOps = rewriterImpl.createdOps.size(); 2307 2308 // Materialize a conversion for this live result value. 2309 Type resultType = result.getType(); 2310 Value convertedValue = replConverter->materializeSourceConversion( 2311 rewriter, op->getLoc(), resultType, newValue); 2312 if (!convertedValue) { 2313 InFlightDiagnostic diag = op->emitError() 2314 << "failed to materialize conversion for result #" 2315 << result.getResultNumber() << " of operation '" 2316 << op->getName() 2317 << "' that remained live after conversion"; 2318 diag.attachNote(liveUserIt->getLoc()) 2319 << "see existing live user here: " << *liveUserIt; 2320 return failure(); 2321 } 2322 2323 // Legalize all of the newly created conversion operations. 2324 for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) { 2325 if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { 2326 return op->emitError("failed to legalize conversion operation generated ") 2327 << "for result #" << result.getResultNumber() << " of operation '" 2328 << op->getName() << "' that remained live after conversion"; 2329 } 2330 } 2331 2332 rewriterImpl.mapping.map(result, convertedValue); 2333 return success(); 2334 } 2335 2336 //===----------------------------------------------------------------------===// 2337 // Type Conversion 2338 //===----------------------------------------------------------------------===// 2339 2340 /// Remap an input of the original signature with a new set of types. The 2341 /// new types are appended to the new signature conversion. 2342 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, 2343 ArrayRef<Type> types) { 2344 assert(!types.empty() && "expected valid types"); 2345 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); 2346 addInputs(types); 2347 } 2348 2349 /// Append new input types to the signature conversion, this should only be 2350 /// used if the new types are not intended to remap an existing input. 2351 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { 2352 assert(!types.empty() && 2353 "1->0 type remappings don't need to be added explicitly"); 2354 argTypes.append(types.begin(), types.end()); 2355 } 2356 2357 /// Remap an input of the original signature with a range of types in the 2358 /// new signature. 2359 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2360 unsigned newInputNo, 2361 unsigned newInputCount) { 2362 assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2363 assert(newInputCount != 0 && "expected valid input count"); 2364 remappedInputs[origInputNo] = 2365 InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr}; 2366 } 2367 2368 /// Remap an input of the original signature to another `replacementValue` 2369 /// value. This would make the signature converter drop this argument. 2370 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2371 Value replacementValue) { 2372 assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2373 remappedInputs[origInputNo] = 2374 InputMapping{origInputNo, /*size=*/0, replacementValue}; 2375 } 2376 2377 /// This hooks allows for converting a type. 2378 LogicalResult TypeConverter::convertType(Type t, 2379 SmallVectorImpl<Type> &results) { 2380 auto existingIt = cachedDirectConversions.find(t); 2381 if (existingIt != cachedDirectConversions.end()) { 2382 if (existingIt->second) 2383 results.push_back(existingIt->second); 2384 return success(existingIt->second != nullptr); 2385 } 2386 auto multiIt = cachedMultiConversions.find(t); 2387 if (multiIt != cachedMultiConversions.end()) { 2388 results.append(multiIt->second.begin(), multiIt->second.end()); 2389 return success(); 2390 } 2391 2392 // Walk the added converters in reverse order to apply the most recently 2393 // registered first. 2394 size_t currentCount = results.size(); 2395 for (ConversionCallbackFn &converter : llvm::reverse(conversions)) { 2396 if (Optional<LogicalResult> result = converter(t, results)) { 2397 if (!succeeded(*result)) { 2398 cachedDirectConversions.try_emplace(t, nullptr); 2399 return failure(); 2400 } 2401 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); 2402 if (newTypes.size() == 1) 2403 cachedDirectConversions.try_emplace(t, newTypes.front()); 2404 else 2405 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); 2406 return success(); 2407 } 2408 } 2409 return failure(); 2410 } 2411 2412 /// This hook simplifies defining 1-1 type conversions. This function returns 2413 /// the type to convert to on success, and a null type on failure. 2414 Type TypeConverter::convertType(Type t) { 2415 // Use the multi-type result version to convert the type. 2416 SmallVector<Type, 1> results; 2417 if (failed(convertType(t, results))) 2418 return nullptr; 2419 2420 // Check to ensure that only one type was produced. 2421 return results.size() == 1 ? results.front() : nullptr; 2422 } 2423 2424 /// Convert the given set of types, filling 'results' as necessary. This 2425 /// returns failure if the conversion of any of the types fails, success 2426 /// otherwise. 2427 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types, 2428 SmallVectorImpl<Type> &results) { 2429 for (auto type : types) 2430 if (failed(convertType(type, results))) 2431 return failure(); 2432 return success(); 2433 } 2434 2435 /// Return true if the given type is legal for this type converter, i.e. the 2436 /// type converts to itself. 2437 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; } 2438 /// Return true if the given operation has legal operand and result types. 2439 bool TypeConverter::isLegal(Operation *op) { 2440 return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); 2441 } 2442 2443 /// Return true if the types of block arguments within the region are legal. 2444 bool TypeConverter::isLegal(Region *region) { 2445 return llvm::all_of(*region, [this](Block &block) { 2446 return isLegal(block.getArgumentTypes()); 2447 }); 2448 } 2449 2450 /// Return true if the inputs and outputs of the given function type are 2451 /// legal. 2452 bool TypeConverter::isSignatureLegal(FunctionType ty) { 2453 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); 2454 } 2455 2456 /// This hook allows for converting a specific argument of a signature. 2457 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, 2458 SignatureConversion &result) { 2459 // Try to convert the given input type. 2460 SmallVector<Type, 1> convertedTypes; 2461 if (failed(convertType(type, convertedTypes))) 2462 return failure(); 2463 2464 // If this argument is being dropped, there is nothing left to do. 2465 if (convertedTypes.empty()) 2466 return success(); 2467 2468 // Otherwise, add the new inputs. 2469 result.addInputs(inputNo, convertedTypes); 2470 return success(); 2471 } 2472 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types, 2473 SignatureConversion &result, 2474 unsigned origInputOffset) { 2475 for (unsigned i = 0, e = types.size(); i != e; ++i) 2476 if (failed(convertSignatureArg(origInputOffset + i, types[i], result))) 2477 return failure(); 2478 return success(); 2479 } 2480 2481 Value TypeConverter::materializeConversion( 2482 MutableArrayRef<MaterializationCallbackFn> materializations, 2483 OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) { 2484 for (MaterializationCallbackFn &fn : llvm::reverse(materializations)) 2485 if (Optional<Value> result = fn(builder, resultType, inputs, loc)) 2486 return result.getValue(); 2487 return nullptr; 2488 } 2489 2490 /// This function converts the type signature of the given block, by invoking 2491 /// 'convertSignatureArg' for each argument. This function should return a valid 2492 /// conversion for the signature on success, None otherwise. 2493 auto TypeConverter::convertBlockSignature(Block *block) 2494 -> Optional<SignatureConversion> { 2495 SignatureConversion conversion(block->getNumArguments()); 2496 if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) 2497 return llvm::None; 2498 return conversion; 2499 } 2500 2501 /// Create a default conversion pattern that rewrites the type signature of a 2502 /// FuncOp. 2503 namespace { 2504 struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> { 2505 FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 2506 : OpConversionPattern(converter, ctx) {} 2507 2508 /// Hook for derived classes to implement combined matching and rewriting. 2509 LogicalResult 2510 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 2511 ConversionPatternRewriter &rewriter) const override { 2512 FunctionType type = funcOp.getType(); 2513 2514 // Convert the original function types. 2515 TypeConverter::SignatureConversion result(type.getNumInputs()); 2516 SmallVector<Type, 1> newResults; 2517 if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || 2518 failed(typeConverter->convertTypes(type.getResults(), newResults)) || 2519 failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, 2520 &result))) 2521 return failure(); 2522 2523 // Update the function signature in-place. 2524 rewriter.updateRootInPlace(funcOp, [&] { 2525 funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults, 2526 funcOp.getContext())); 2527 }); 2528 return success(); 2529 } 2530 }; 2531 } // end anonymous namespace 2532 2533 void mlir::populateFuncOpTypeConversionPattern( 2534 OwningRewritePatternList &patterns, MLIRContext *ctx, 2535 TypeConverter &converter) { 2536 patterns.insert<FuncOpSignatureConversion>(ctx, converter); 2537 } 2538 2539 //===----------------------------------------------------------------------===// 2540 // ConversionTarget 2541 //===----------------------------------------------------------------------===// 2542 2543 /// Register a legality action for the given operation. 2544 void ConversionTarget::setOpAction(OperationName op, 2545 LegalizationAction action) { 2546 legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None}; 2547 } 2548 2549 /// Register a legality action for the given dialects. 2550 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, 2551 LegalizationAction action) { 2552 for (StringRef dialect : dialectNames) 2553 legalDialects[dialect] = action; 2554 } 2555 2556 /// Get the legality action for the given operation. 2557 auto ConversionTarget::getOpAction(OperationName op) const 2558 -> Optional<LegalizationAction> { 2559 Optional<LegalizationInfo> info = getOpInfo(op); 2560 return info ? info->action : Optional<LegalizationAction>(); 2561 } 2562 2563 /// If the given operation instance is legal on this target, a structure 2564 /// containing legality information is returned. If the operation is not legal, 2565 /// None is returned. 2566 auto ConversionTarget::isLegal(Operation *op) const 2567 -> Optional<LegalOpDetails> { 2568 Optional<LegalizationInfo> info = getOpInfo(op->getName()); 2569 if (!info) 2570 return llvm::None; 2571 2572 // Returns true if this operation instance is known to be legal. 2573 auto isOpLegal = [&] { 2574 // Handle dynamic legality either with the provided legality function, or 2575 // the default hook on the derived instance. 2576 if (info->action == LegalizationAction::Dynamic) 2577 return info->legalityFn ? (*info->legalityFn)(op) 2578 : isDynamicallyLegal(op); 2579 2580 // Otherwise, the operation is only legal if it was marked 'Legal'. 2581 return info->action == LegalizationAction::Legal; 2582 }; 2583 if (!isOpLegal()) 2584 return llvm::None; 2585 2586 // This operation is legal, compute any additional legality information. 2587 LegalOpDetails legalityDetails; 2588 if (info->isRecursivelyLegal) { 2589 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); 2590 if (legalityFnIt != opRecursiveLegalityFns.end()) 2591 legalityDetails.isRecursivelyLegal = legalityFnIt->second(op); 2592 else 2593 legalityDetails.isRecursivelyLegal = true; 2594 } 2595 return legalityDetails; 2596 } 2597 2598 /// Set the dynamic legality callback for the given operation. 2599 void ConversionTarget::setLegalityCallback( 2600 OperationName name, const DynamicLegalityCallbackFn &callback) { 2601 assert(callback && "expected valid legality callback"); 2602 auto infoIt = legalOperations.find(name); 2603 assert(infoIt != legalOperations.end() && 2604 infoIt->second.action == LegalizationAction::Dynamic && 2605 "expected operation to already be marked as dynamically legal"); 2606 infoIt->second.legalityFn = callback; 2607 } 2608 2609 /// Set the recursive legality callback for the given operation and mark the 2610 /// operation as recursively legal. 2611 void ConversionTarget::markOpRecursivelyLegal( 2612 OperationName name, const DynamicLegalityCallbackFn &callback) { 2613 auto infoIt = legalOperations.find(name); 2614 assert(infoIt != legalOperations.end() && 2615 infoIt->second.action != LegalizationAction::Illegal && 2616 "expected operation to already be marked as legal"); 2617 infoIt->second.isRecursivelyLegal = true; 2618 if (callback) 2619 opRecursiveLegalityFns[name] = callback; 2620 else 2621 opRecursiveLegalityFns.erase(name); 2622 } 2623 2624 /// Set the dynamic legality callback for the given dialects. 2625 void ConversionTarget::setLegalityCallback( 2626 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { 2627 assert(callback && "expected valid legality callback"); 2628 for (StringRef dialect : dialects) 2629 dialectLegalityFns[dialect] = callback; 2630 } 2631 2632 /// Get the legalization information for the given operation. 2633 auto ConversionTarget::getOpInfo(OperationName op) const 2634 -> Optional<LegalizationInfo> { 2635 // Check for info for this specific operation. 2636 auto it = legalOperations.find(op); 2637 if (it != legalOperations.end()) 2638 return it->second; 2639 // Check for info for the parent dialect. 2640 auto dialectIt = legalDialects.find(op.getDialect()); 2641 if (dialectIt != legalDialects.end()) { 2642 Optional<DynamicLegalityCallbackFn> callback; 2643 auto dialectFn = dialectLegalityFns.find(op.getDialect()); 2644 if (dialectFn != dialectLegalityFns.end()) 2645 callback = dialectFn->second; 2646 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, 2647 callback}; 2648 } 2649 // Otherwise, check if we mark unknown operations as dynamic. 2650 if (unknownOpsDynamicallyLegal) 2651 return LegalizationInfo{LegalizationAction::Dynamic, 2652 /*isRecursivelyLegal=*/false, unknownLegalityFn}; 2653 return llvm::None; 2654 } 2655 2656 //===----------------------------------------------------------------------===// 2657 // Op Conversion Entry Points 2658 //===----------------------------------------------------------------------===// 2659 2660 /// Apply a partial conversion on the given operations and all nested 2661 /// operations. This method converts as many operations to the target as 2662 /// possible, ignoring operations that failed to legalize. This method only 2663 /// returns failure if there ops explicitly marked as illegal. 2664 /// If an `unconvertedOps` set is provided, all operations that are found not 2665 /// to be legalizable to the given `target` are placed within that set. (Note 2666 /// that if there is an op explicitly marked as illegal, the conversion 2667 /// terminates and the `unconvertedOps` set will not necessarily be complete.) 2668 LogicalResult 2669 mlir::applyPartialConversion(ArrayRef<Operation *> ops, 2670 ConversionTarget &target, 2671 const FrozenRewritePatternList &patterns, 2672 DenseSet<Operation *> *unconvertedOps) { 2673 OperationConverter opConverter(target, patterns, OpConversionMode::Partial, 2674 unconvertedOps); 2675 return opConverter.convertOperations(ops); 2676 } 2677 LogicalResult 2678 mlir::applyPartialConversion(Operation *op, ConversionTarget &target, 2679 const FrozenRewritePatternList &patterns, 2680 DenseSet<Operation *> *unconvertedOps) { 2681 return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, 2682 unconvertedOps); 2683 } 2684 2685 /// Apply a complete conversion on the given operations, and all nested 2686 /// operations. This method will return failure if the conversion of any 2687 /// operation fails. 2688 LogicalResult 2689 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target, 2690 const FrozenRewritePatternList &patterns) { 2691 OperationConverter opConverter(target, patterns, OpConversionMode::Full); 2692 return opConverter.convertOperations(ops); 2693 } 2694 LogicalResult 2695 mlir::applyFullConversion(Operation *op, ConversionTarget &target, 2696 const FrozenRewritePatternList &patterns) { 2697 return applyFullConversion(llvm::makeArrayRef(op), target, patterns); 2698 } 2699 2700 /// Apply an analysis conversion on the given operations, and all nested 2701 /// operations. This method analyzes which operations would be successfully 2702 /// converted to the target if a conversion was applied. All operations that 2703 /// were found to be legalizable to the given 'target' are placed within the 2704 /// provided 'convertedOps' set; note that no actual rewrites are applied to the 2705 /// operations on success and only pre-existing operations are added to the set. 2706 LogicalResult 2707 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops, 2708 ConversionTarget &target, 2709 const FrozenRewritePatternList &patterns, 2710 DenseSet<Operation *> &convertedOps) { 2711 OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, 2712 &convertedOps); 2713 return opConverter.convertOperations(ops); 2714 } 2715 LogicalResult 2716 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, 2717 const FrozenRewritePatternList &patterns, 2718 DenseSet<Operation *> &convertedOps) { 2719 return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, 2720 convertedOps); 2721 } 2722