1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/DialectConversion.h" 10 #include "mlir/IR/Block.h" 11 #include "mlir/IR/BlockAndValueMapping.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/FunctionInterfaces.h" 15 #include "mlir/Rewrite/PatternApplicator.h" 16 #include "llvm/ADT/ScopeExit.h" 17 #include "llvm/ADT/SetVector.h" 18 #include "llvm/ADT/SmallPtrSet.h" 19 #include "llvm/Support/Debug.h" 20 #include "llvm/Support/FormatVariadic.h" 21 #include "llvm/Support/SaveAndRestore.h" 22 #include "llvm/Support/ScopedPrinter.h" 23 24 using namespace mlir; 25 using namespace mlir::detail; 26 27 #define DEBUG_TYPE "dialect-conversion" 28 29 /// Recursively collect all of the operations to convert from within 'region'. 30 /// If 'target' is nonnull, operations that are recursively legal have their 31 /// regions pre-filtered to avoid considering them for legalization. 32 static LogicalResult 33 computeConversionSet(iterator_range<Region::iterator> region, 34 Location regionLoc, 35 SmallVectorImpl<Operation *> &toConvert, 36 ConversionTarget *target = nullptr) { 37 if (llvm::empty(region)) 38 return success(); 39 40 // Traverse starting from the entry block. 41 SmallVector<Block *, 16> worklist(1, &*region.begin()); 42 DenseSet<Block *> visitedBlocks; 43 visitedBlocks.insert(worklist.front()); 44 while (!worklist.empty()) { 45 Block *block = worklist.pop_back_val(); 46 47 // Compute the conversion set of each of the nested operations. 48 for (Operation &op : *block) { 49 toConvert.emplace_back(&op); 50 51 // Don't check this operation's children for conversion if the operation 52 // is recursively legal. 53 auto legalityInfo = target ? target->isLegal(&op) 54 : Optional<ConversionTarget::LegalOpDetails>(); 55 if (legalityInfo && legalityInfo->isRecursivelyLegal) 56 continue; 57 for (auto ®ion : op.getRegions()) { 58 if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), 59 toConvert, target))) 60 return failure(); 61 } 62 } 63 64 // Recurse to children that haven't been visited. 65 for (Block *succ : block->getSuccessors()) 66 if (visitedBlocks.insert(succ).second) 67 worklist.push_back(succ); 68 } 69 70 // Check that all blocks in the region were visited. 71 if (llvm::any_of(llvm::drop_begin(region, 1), 72 [&](Block &block) { return !visitedBlocks.count(&block); })) 73 return emitError(regionLoc, "unreachable blocks were not converted"); 74 return success(); 75 } 76 77 /// A utility function to log a successful result for the given reason. 78 template <typename... Args> 79 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 80 LLVM_DEBUG({ 81 os.unindent(); 82 os.startLine() << "} -> SUCCESS"; 83 if (!fmt.empty()) 84 os.getOStream() << " : " 85 << llvm::formatv(fmt.data(), std::forward<Args>(args)...); 86 os.getOStream() << "\n"; 87 }); 88 } 89 90 /// A utility function to log a failure result for the given reason. 91 template <typename... Args> 92 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 93 LLVM_DEBUG({ 94 os.unindent(); 95 os.startLine() << "} -> FAILURE : " 96 << llvm::formatv(fmt.data(), std::forward<Args>(args)...) 97 << "\n"; 98 }); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // ConversionValueMapping 103 //===----------------------------------------------------------------------===// 104 105 namespace { 106 /// This class wraps a BlockAndValueMapping to provide recursive lookup 107 /// functionality, i.e. we will traverse if the mapped value also has a mapping. 108 struct ConversionValueMapping { 109 /// Lookup a mapped value within the map. If a mapping for the provided value 110 /// does not exist then return the provided value. If `desiredType` is 111 /// non-null, returns the most recently mapped value with that type. If an 112 /// operand of that type does not exist, defaults to normal behavior. 113 Value lookupOrDefault(Value from, Type desiredType = nullptr) const; 114 115 /// Lookup a mapped value within the map, or return null if a mapping does not 116 /// exist. If a mapping exists, this follows the same behavior of 117 /// `lookupOrDefault`. 118 Value lookupOrNull(Value from, Type desiredType = nullptr) const; 119 120 /// Map a value to the one provided. 121 void map(Value oldVal, Value newVal) { 122 LLVM_DEBUG({ 123 for (Value it = newVal; it; it = mapping.lookupOrNull(it)) 124 assert(it != oldVal && "inserting cyclic mapping"); 125 }); 126 mapping.map(oldVal, newVal); 127 } 128 129 /// Try to map a value to the one provided. Returns false if a transitive 130 /// mapping from the new value to the old value already exists, true if the 131 /// map was updated. 132 bool tryMap(Value oldVal, Value newVal); 133 134 /// Drop the last mapping for the given value. 135 void erase(Value value) { mapping.erase(value); } 136 137 /// Returns the inverse raw value mapping (without recursive query support). 138 DenseMap<Value, SmallVector<Value>> getInverse() const { 139 DenseMap<Value, SmallVector<Value>> inverse; 140 for (auto &it : mapping.getValueMap()) 141 inverse[it.second].push_back(it.first); 142 return inverse; 143 } 144 145 private: 146 /// Current value mappings. 147 BlockAndValueMapping mapping; 148 }; 149 } // namespace 150 151 Value ConversionValueMapping::lookupOrDefault(Value from, 152 Type desiredType) const { 153 // If there was no desired type, simply find the leaf value. 154 if (!desiredType) { 155 // If this value had a valid mapping, unmap that value as well in the case 156 // that it was also replaced. 157 while (auto mappedValue = mapping.lookupOrNull(from)) 158 from = mappedValue; 159 return from; 160 } 161 162 // Otherwise, try to find the deepest value that has the desired type. 163 Value desiredValue; 164 do { 165 if (from.getType() == desiredType) 166 desiredValue = from; 167 168 Value mappedValue = mapping.lookupOrNull(from); 169 if (!mappedValue) 170 break; 171 from = mappedValue; 172 } while (true); 173 174 // If the desired value was found use it, otherwise default to the leaf value. 175 return desiredValue ? desiredValue : from; 176 } 177 178 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { 179 Value result = lookupOrDefault(from, desiredType); 180 if (result == from || (desiredType && result.getType() != desiredType)) 181 return nullptr; 182 return result; 183 } 184 185 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) { 186 for (Value it = newVal; it; it = mapping.lookupOrNull(it)) 187 if (it == oldVal) 188 return false; 189 map(oldVal, newVal); 190 return true; 191 } 192 193 //===----------------------------------------------------------------------===// 194 // Rewriter and Translation State 195 //===----------------------------------------------------------------------===// 196 namespace { 197 /// This class contains a snapshot of the current conversion rewriter state. 198 /// This is useful when saving and undoing a set of rewrites. 199 struct RewriterState { 200 RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, 201 unsigned numReplacements, unsigned numArgReplacements, 202 unsigned numBlockActions, unsigned numIgnoredOperations, 203 unsigned numRootUpdates) 204 : numCreatedOps(numCreatedOps), 205 numUnresolvedMaterializations(numUnresolvedMaterializations), 206 numReplacements(numReplacements), 207 numArgReplacements(numArgReplacements), 208 numBlockActions(numBlockActions), 209 numIgnoredOperations(numIgnoredOperations), 210 numRootUpdates(numRootUpdates) {} 211 212 /// The current number of created operations. 213 unsigned numCreatedOps; 214 215 /// The current number of unresolved materializations. 216 unsigned numUnresolvedMaterializations; 217 218 /// The current number of replacements queued. 219 unsigned numReplacements; 220 221 /// The current number of argument replacements queued. 222 unsigned numArgReplacements; 223 224 /// The current number of block actions performed. 225 unsigned numBlockActions; 226 227 /// The current number of ignored operations. 228 unsigned numIgnoredOperations; 229 230 /// The current number of operations that were updated in place. 231 unsigned numRootUpdates; 232 }; 233 234 //===----------------------------------------------------------------------===// 235 // OperationTransactionState 236 237 /// The state of an operation that was updated by a pattern in-place. This 238 /// contains all of the necessary information to reconstruct an operation that 239 /// was updated in place. 240 class OperationTransactionState { 241 public: 242 OperationTransactionState() = default; 243 OperationTransactionState(Operation *op) 244 : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()), 245 operands(op->operand_begin(), op->operand_end()), 246 successors(op->successor_begin(), op->successor_end()) {} 247 248 /// Discard the transaction state and reset the state of the original 249 /// operation. 250 void resetOperation() const { 251 op->setLoc(loc); 252 op->setAttrs(attrs); 253 op->setOperands(operands); 254 for (const auto &it : llvm::enumerate(successors)) 255 op->setSuccessor(it.value(), it.index()); 256 } 257 258 /// Return the original operation of this state. 259 Operation *getOperation() const { return op; } 260 261 private: 262 Operation *op; 263 LocationAttr loc; 264 DictionaryAttr attrs; 265 SmallVector<Value, 8> operands; 266 SmallVector<Block *, 2> successors; 267 }; 268 269 //===----------------------------------------------------------------------===// 270 // OpReplacement 271 272 /// This class represents one requested operation replacement via 'replaceOp' or 273 /// 'eraseOp`. 274 struct OpReplacement { 275 OpReplacement(TypeConverter *converter = nullptr) : converter(converter) {} 276 277 /// An optional type converter that can be used to materialize conversions 278 /// between the new and old values if necessary. 279 TypeConverter *converter; 280 }; 281 282 //===----------------------------------------------------------------------===// 283 // BlockAction 284 285 /// The kind of the block action performed during the rewrite. Actions can be 286 /// undone if the conversion fails. 287 enum class BlockActionKind { 288 Create, 289 Erase, 290 Merge, 291 Move, 292 Split, 293 TypeConversion 294 }; 295 296 /// Original position of the given block in its parent region. During undo 297 /// actions, the block needs to be placed after `insertAfterBlock`. 298 struct BlockPosition { 299 Region *region; 300 Block *insertAfterBlock; 301 }; 302 303 /// Information needed to undo the merge actions. 304 /// - the source block, and 305 /// - the Operation that was the last operation in the dest block before the 306 /// merge (could be null if the dest block was empty). 307 struct MergeInfo { 308 Block *sourceBlock; 309 Operation *destBlockLastInst; 310 }; 311 312 /// The storage class for an undoable block action (one of BlockActionKind), 313 /// contains the information necessary to undo this action. 314 struct BlockAction { 315 static BlockAction getCreate(Block *block) { 316 return {BlockActionKind::Create, block, {}}; 317 } 318 static BlockAction getErase(Block *block, BlockPosition originalPosition) { 319 return {BlockActionKind::Erase, block, {originalPosition}}; 320 } 321 static BlockAction getMerge(Block *block, Block *sourceBlock) { 322 BlockAction action{BlockActionKind::Merge, block, {}}; 323 action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()}; 324 return action; 325 } 326 static BlockAction getMove(Block *block, BlockPosition originalPosition) { 327 return {BlockActionKind::Move, block, {originalPosition}}; 328 } 329 static BlockAction getSplit(Block *block, Block *originalBlock) { 330 BlockAction action{BlockActionKind::Split, block, {}}; 331 action.originalBlock = originalBlock; 332 return action; 333 } 334 static BlockAction getTypeConversion(Block *block) { 335 return BlockAction{BlockActionKind::TypeConversion, block, {}}; 336 } 337 338 // The action kind. 339 BlockActionKind kind; 340 341 // A pointer to the block that was created by the action. 342 Block *block; 343 344 union { 345 // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and 346 // contains a pointer to the region that originally contained the block as 347 // well as the position of the block in that region. 348 BlockPosition originalPosition; 349 // In use if kind == BlockActionKind::Split and contains a pointer to the 350 // block that was split into two parts. 351 Block *originalBlock; 352 // In use if kind == BlockActionKind::Merge, and contains the information 353 // needed to undo the merge. 354 MergeInfo mergeInfo; 355 }; 356 }; 357 358 //===----------------------------------------------------------------------===// 359 // UnresolvedMaterialization 360 361 /// This class represents an unresolved materialization, i.e. a materialization 362 /// that was inserted during conversion that needs to be legalized at the end of 363 /// the conversion process. 364 class UnresolvedMaterialization { 365 public: 366 /// The type of materialization. 367 enum Kind { 368 /// This materialization materializes a conversion for an illegal block 369 /// argument type, to a legal one. 370 Argument, 371 372 /// This materialization materializes a conversion from an illegal type to a 373 /// legal one. 374 Target 375 }; 376 377 UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr, 378 TypeConverter *converter = nullptr, 379 Kind kind = Target, Type origOutputType = nullptr) 380 : op(op), converterAndKind(converter, kind), 381 origOutputType(origOutputType) {} 382 383 /// Return the temporary conversion operation inserted for this 384 /// materialization. 385 UnrealizedConversionCastOp getOp() const { return op; } 386 387 /// Return the type converter of this materialization (which may be null). 388 TypeConverter *getConverter() const { return converterAndKind.getPointer(); } 389 390 /// Return the kind of this materialization. 391 Kind getKind() const { return converterAndKind.getInt(); } 392 393 /// Set the kind of this materialization. 394 void setKind(Kind kind) { converterAndKind.setInt(kind); } 395 396 /// Return the original illegal output type of the input values. 397 Type getOrigOutputType() const { return origOutputType; } 398 399 private: 400 /// The unresolved materialization operation created during conversion. 401 UnrealizedConversionCastOp op; 402 403 /// The corresponding type converter to use when resolving this 404 /// materialization, and the kind of this materialization. 405 llvm::PointerIntPair<TypeConverter *, 1, Kind> converterAndKind; 406 407 /// The original output type. This is only used for argument conversions. 408 Type origOutputType; 409 }; 410 } // namespace 411 412 /// Build an unresolved materialization operation given an output type and set 413 /// of input operands. 414 static Value buildUnresolvedMaterialization( 415 UnresolvedMaterialization::Kind kind, Block *insertBlock, 416 Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, 417 Type origOutputType, TypeConverter *converter, 418 SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { 419 // Avoid materializing an unnecessary cast. 420 if (inputs.size() == 1 && inputs.front().getType() == outputType) 421 return inputs.front(); 422 423 // Create an unresolved materialization. We use a new OpBuilder to avoid 424 // tracking the materialization like we do for other operations. 425 OpBuilder builder(insertBlock, insertPt); 426 auto convertOp = 427 builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); 428 unresolvedMaterializations.emplace_back(convertOp, converter, kind, 429 origOutputType); 430 return convertOp.getResult(0); 431 } 432 static Value buildUnresolvedArgumentMaterialization( 433 PatternRewriter &rewriter, Location loc, ValueRange inputs, 434 Type origOutputType, Type outputType, TypeConverter *converter, 435 SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { 436 return buildUnresolvedMaterialization( 437 UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(), 438 rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType, 439 converter, unresolvedMaterializations); 440 } 441 static Value buildUnresolvedTargetMaterialization( 442 Location loc, Value input, Type outputType, TypeConverter *converter, 443 SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { 444 Block *insertBlock = input.getParentBlock(); 445 Block::iterator insertPt = insertBlock->begin(); 446 if (OpResult inputRes = input.dyn_cast<OpResult>()) 447 insertPt = ++inputRes.getOwner()->getIterator(); 448 449 return buildUnresolvedMaterialization( 450 UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input, 451 outputType, outputType, converter, unresolvedMaterializations); 452 } 453 454 //===----------------------------------------------------------------------===// 455 // ArgConverter 456 //===----------------------------------------------------------------------===// 457 namespace { 458 /// This class provides a simple interface for converting the types of block 459 /// arguments. This is done by creating a new block that contains the new legal 460 /// types and extracting the block that contains the old illegal types to allow 461 /// for undoing pending rewrites in the case of failure. 462 struct ArgConverter { 463 ArgConverter( 464 PatternRewriter &rewriter, 465 SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) 466 : rewriter(rewriter), 467 unresolvedMaterializations(unresolvedMaterializations) {} 468 469 /// This structure contains the information pertaining to an argument that has 470 /// been converted. 471 struct ConvertedArgInfo { 472 ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, 473 Value castValue = nullptr) 474 : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} 475 476 /// The start index of in the new argument list that contains arguments that 477 /// replace the original. 478 unsigned newArgIdx; 479 480 /// The number of arguments that replaced the original argument. 481 unsigned newArgSize; 482 483 /// The cast value that was created to cast from the new arguments to the 484 /// old. This only used if 'newArgSize' > 1. 485 Value castValue; 486 }; 487 488 /// This structure contains information pertaining to a block that has had its 489 /// signature converted. 490 struct ConvertedBlockInfo { 491 ConvertedBlockInfo(Block *origBlock, TypeConverter *converter) 492 : origBlock(origBlock), converter(converter) {} 493 494 /// The original block that was requested to have its signature converted. 495 Block *origBlock; 496 497 /// The conversion information for each of the arguments. The information is 498 /// None if the argument was dropped during conversion. 499 SmallVector<Optional<ConvertedArgInfo>, 1> argInfo; 500 501 /// The type converter used to convert the arguments. 502 TypeConverter *converter; 503 }; 504 505 /// Return if the signature of the given block has already been converted. 506 bool hasBeenConverted(Block *block) const { 507 return conversionInfo.count(block) || convertedBlocks.count(block); 508 } 509 510 /// Set the type converter to use for the given region. 511 void setConverter(Region *region, TypeConverter *typeConverter) { 512 assert(typeConverter && "expected valid type converter"); 513 regionToConverter[region] = typeConverter; 514 } 515 516 /// Return the type converter to use for the given region, or null if there 517 /// isn't one. 518 TypeConverter *getConverter(Region *region) { 519 return regionToConverter.lookup(region); 520 } 521 522 //===--------------------------------------------------------------------===// 523 // Rewrite Application 524 //===--------------------------------------------------------------------===// 525 526 /// Erase any rewrites registered for the blocks within the given operation 527 /// which is about to be removed. This merely drops the rewrites without 528 /// undoing them. 529 void notifyOpRemoved(Operation *op); 530 531 /// Cleanup and undo any generated conversions for the arguments of block. 532 /// This method replaces the new block with the original, reverting the IR to 533 /// its original state. 534 void discardRewrites(Block *block); 535 536 /// Fully replace uses of the old arguments with the new. 537 void applyRewrites(ConversionValueMapping &mapping); 538 539 /// Materialize any necessary conversions for converted arguments that have 540 /// live users, using the provided `findLiveUser` to search for a user that 541 /// survives the conversion process. 542 LogicalResult 543 materializeLiveConversions(ConversionValueMapping &mapping, 544 OpBuilder &builder, 545 function_ref<Operation *(Value)> findLiveUser); 546 547 //===--------------------------------------------------------------------===// 548 // Conversion 549 //===--------------------------------------------------------------------===// 550 551 /// Attempt to convert the signature of the given block, if successful a new 552 /// block is returned containing the new arguments. Returns `block` if it did 553 /// not require conversion. 554 FailureOr<Block *> 555 convertSignature(Block *block, TypeConverter *converter, 556 ConversionValueMapping &mapping, 557 SmallVectorImpl<BlockArgument> &argReplacements); 558 559 /// Apply the given signature conversion on the given block. The new block 560 /// containing the updated signature is returned. If no conversions were 561 /// necessary, e.g. if the block has no arguments, `block` is returned. 562 /// `converter` is used to generate any necessary cast operations that 563 /// translate between the origin argument types and those specified in the 564 /// signature conversion. 565 Block *applySignatureConversion( 566 Block *block, TypeConverter *converter, 567 TypeConverter::SignatureConversion &signatureConversion, 568 ConversionValueMapping &mapping, 569 SmallVectorImpl<BlockArgument> &argReplacements); 570 571 /// Insert a new conversion into the cache. 572 void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); 573 574 /// A collection of blocks that have had their arguments converted. This is a 575 /// map from the new replacement block, back to the original block. 576 llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo; 577 578 /// The set of original blocks that were converted. 579 DenseSet<Block *> convertedBlocks; 580 581 /// A mapping from valid regions, to those containing the original blocks of a 582 /// conversion. 583 DenseMap<Region *, std::unique_ptr<Region>> regionMapping; 584 585 /// A mapping of regions to type converters that should be used when 586 /// converting the arguments of blocks within that region. 587 DenseMap<Region *, TypeConverter *> regionToConverter; 588 589 /// The pattern rewriter to use when materializing conversions. 590 PatternRewriter &rewriter; 591 592 /// An ordered set of unresolved materializations during conversion. 593 SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations; 594 }; 595 } // namespace 596 597 //===----------------------------------------------------------------------===// 598 // Rewrite Application 599 600 void ArgConverter::notifyOpRemoved(Operation *op) { 601 if (conversionInfo.empty()) 602 return; 603 604 for (Region ®ion : op->getRegions()) { 605 for (Block &block : region) { 606 // Drop any rewrites from within. 607 for (Operation &nestedOp : block) 608 if (nestedOp.getNumRegions()) 609 notifyOpRemoved(&nestedOp); 610 611 // Check if this block was converted. 612 auto it = conversionInfo.find(&block); 613 if (it == conversionInfo.end()) 614 continue; 615 616 // Drop all uses of the original arguments and delete the original block. 617 Block *origBlock = it->second.origBlock; 618 for (BlockArgument arg : origBlock->getArguments()) 619 arg.dropAllUses(); 620 conversionInfo.erase(it); 621 } 622 } 623 } 624 625 void ArgConverter::discardRewrites(Block *block) { 626 auto it = conversionInfo.find(block); 627 if (it == conversionInfo.end()) 628 return; 629 Block *origBlock = it->second.origBlock; 630 631 // Drop all uses of the new block arguments and replace uses of the new block. 632 for (int i = block->getNumArguments() - 1; i >= 0; --i) 633 block->getArgument(i).dropAllUses(); 634 block->replaceAllUsesWith(origBlock); 635 636 // Move the operations back the original block and the delete the new block. 637 origBlock->getOperations().splice(origBlock->end(), block->getOperations()); 638 origBlock->moveBefore(block); 639 block->erase(); 640 641 convertedBlocks.erase(origBlock); 642 conversionInfo.erase(it); 643 } 644 645 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { 646 for (auto &info : conversionInfo) { 647 ConvertedBlockInfo &blockInfo = info.second; 648 Block *origBlock = blockInfo.origBlock; 649 650 // Process the remapping for each of the original arguments. 651 for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { 652 Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i]; 653 BlockArgument origArg = origBlock->getArgument(i); 654 655 // Handle the case of a 1->0 value mapping. 656 if (!argInfo) { 657 if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType())) 658 origArg.replaceAllUsesWith(newArg); 659 continue; 660 } 661 662 // Otherwise this is a 1->1+ value mapping. 663 Value castValue = argInfo->castValue; 664 assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); 665 666 // If the argument is still used, replace it with the generated cast. 667 if (!origArg.use_empty()) { 668 origArg.replaceAllUsesWith( 669 mapping.lookupOrDefault(castValue, origArg.getType())); 670 } 671 } 672 } 673 } 674 675 LogicalResult ArgConverter::materializeLiveConversions( 676 ConversionValueMapping &mapping, OpBuilder &builder, 677 function_ref<Operation *(Value)> findLiveUser) { 678 for (auto &info : conversionInfo) { 679 Block *newBlock = info.first; 680 ConvertedBlockInfo &blockInfo = info.second; 681 Block *origBlock = blockInfo.origBlock; 682 683 // Process the remapping for each of the original arguments. 684 for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { 685 // If the type of this argument changed and the argument is still live, we 686 // need to materialize a conversion. 687 BlockArgument origArg = origBlock->getArgument(i); 688 if (mapping.lookupOrNull(origArg, origArg.getType())) 689 continue; 690 Operation *liveUser = findLiveUser(origArg); 691 if (!liveUser) 692 continue; 693 694 Value replacementValue = mapping.lookupOrDefault(origArg); 695 bool isDroppedArg = replacementValue == origArg; 696 if (isDroppedArg) 697 rewriter.setInsertionPointToStart(newBlock); 698 else 699 rewriter.setInsertionPointAfterValue(replacementValue); 700 Value newArg; 701 if (blockInfo.converter) { 702 newArg = blockInfo.converter->materializeSourceConversion( 703 rewriter, origArg.getLoc(), origArg.getType(), 704 isDroppedArg ? ValueRange() : ValueRange(replacementValue)); 705 assert((!newArg || newArg.getType() == origArg.getType()) && 706 "materialization hook did not provide a value of the expected " 707 "type"); 708 } 709 if (!newArg) { 710 InFlightDiagnostic diag = 711 emitError(origArg.getLoc()) 712 << "failed to materialize conversion for block argument #" << i 713 << " that remained live after conversion, type was " 714 << origArg.getType(); 715 if (!isDroppedArg) 716 diag << ", with target type " << replacementValue.getType(); 717 diag.attachNote(liveUser->getLoc()) 718 << "see existing live user here: " << *liveUser; 719 return failure(); 720 } 721 mapping.map(origArg, newArg); 722 } 723 } 724 return success(); 725 } 726 727 //===----------------------------------------------------------------------===// 728 // Conversion 729 730 FailureOr<Block *> ArgConverter::convertSignature( 731 Block *block, TypeConverter *converter, ConversionValueMapping &mapping, 732 SmallVectorImpl<BlockArgument> &argReplacements) { 733 // Check if the block was already converted. If the block is detached, 734 // conservatively assume it is going to be deleted. 735 if (hasBeenConverted(block) || !block->getParent()) 736 return block; 737 // If a converter wasn't provided, and the block wasn't already converted, 738 // there is nothing we can do. 739 if (!converter) 740 return failure(); 741 742 // Try to convert the signature for the block with the provided converter. 743 if (auto conversion = converter->convertBlockSignature(block)) 744 return applySignatureConversion(block, converter, *conversion, mapping, 745 argReplacements); 746 return failure(); 747 } 748 749 Block *ArgConverter::applySignatureConversion( 750 Block *block, TypeConverter *converter, 751 TypeConverter::SignatureConversion &signatureConversion, 752 ConversionValueMapping &mapping, 753 SmallVectorImpl<BlockArgument> &argReplacements) { 754 // If no arguments are being changed or added, there is nothing to do. 755 unsigned origArgCount = block->getNumArguments(); 756 auto convertedTypes = signatureConversion.getConvertedTypes(); 757 if (origArgCount == 0 && convertedTypes.empty()) 758 return block; 759 760 // Split the block at the beginning to get a new block to use for the updated 761 // signature. 762 Block *newBlock = block->splitBlock(block->begin()); 763 block->replaceAllUsesWith(newBlock); 764 765 // FIXME: We should map the new arguments to proper locations. 766 SmallVector<Location> newLocs(convertedTypes.size(), 767 rewriter.getUnknownLoc()); 768 SmallVector<Value, 4> newArgRange( 769 newBlock->addArguments(convertedTypes, newLocs)); 770 ArrayRef<Value> newArgs(newArgRange); 771 772 // Remap each of the original arguments as determined by the signature 773 // conversion. 774 ConvertedBlockInfo info(block, converter); 775 info.argInfo.resize(origArgCount); 776 777 OpBuilder::InsertionGuard guard(rewriter); 778 rewriter.setInsertionPointToStart(newBlock); 779 for (unsigned i = 0; i != origArgCount; ++i) { 780 auto inputMap = signatureConversion.getInputMapping(i); 781 if (!inputMap) 782 continue; 783 BlockArgument origArg = block->getArgument(i); 784 785 // If inputMap->replacementValue is not nullptr, then the argument is 786 // dropped and a replacement value is provided to be the remappedValue. 787 if (inputMap->replacementValue) { 788 assert(inputMap->size == 0 && 789 "invalid to provide a replacement value when the argument isn't " 790 "dropped"); 791 mapping.map(origArg, inputMap->replacementValue); 792 argReplacements.push_back(origArg); 793 continue; 794 } 795 796 // Otherwise, this is a 1->1+ mapping. 797 auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); 798 Value newArg; 799 800 // If this is a 1->1 mapping and the types of new and replacement arguments 801 // match (i.e. it's an identity map), then the argument is mapped to its 802 // original type. 803 // FIXME: We simply pass through the replacement argument if there wasn't a 804 // converter, which isn't great as it allows implicit type conversions to 805 // appear. We should properly restructure this code to handle cases where a 806 // converter isn't provided and also to properly handle the case where an 807 // argument materialization is actually a temporary source materialization 808 // (e.g. in the case of 1->N). 809 if (replArgs.size() == 1 && 810 (!converter || replArgs[0].getType() == origArg.getType())) { 811 newArg = replArgs.front(); 812 } else { 813 Type origOutputType = origArg.getType(); 814 815 // Legalize the argument output type. 816 Type outputType = origOutputType; 817 if (Type legalOutputType = converter->convertType(outputType)) 818 outputType = legalOutputType; 819 820 newArg = buildUnresolvedArgumentMaterialization( 821 rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, 822 converter, unresolvedMaterializations); 823 } 824 825 mapping.map(origArg, newArg); 826 argReplacements.push_back(origArg); 827 info.argInfo[i] = 828 ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); 829 } 830 831 // Remove the original block from the region and return the new one. 832 insertConversion(newBlock, std::move(info)); 833 return newBlock; 834 } 835 836 void ArgConverter::insertConversion(Block *newBlock, 837 ConvertedBlockInfo &&info) { 838 // Get a region to insert the old block. 839 Region *region = newBlock->getParent(); 840 std::unique_ptr<Region> &mappedRegion = regionMapping[region]; 841 if (!mappedRegion) 842 mappedRegion = std::make_unique<Region>(region->getParentOp()); 843 844 // Move the original block to the mapped region and emplace the conversion. 845 mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), 846 info.origBlock->getIterator()); 847 convertedBlocks.insert(info.origBlock); 848 conversionInfo.insert({newBlock, std::move(info)}); 849 } 850 851 //===----------------------------------------------------------------------===// 852 // ConversionPatternRewriterImpl 853 //===----------------------------------------------------------------------===// 854 namespace mlir { 855 namespace detail { 856 struct ConversionPatternRewriterImpl { 857 explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) 858 : argConverter(rewriter, unresolvedMaterializations), 859 notifyCallback(nullptr) {} 860 861 /// Cleanup and destroy any generated rewrite operations. This method is 862 /// invoked when the conversion process fails. 863 void discardRewrites(); 864 865 /// Apply all requested operation rewrites. This method is invoked when the 866 /// conversion process succeeds. 867 void applyRewrites(); 868 869 //===--------------------------------------------------------------------===// 870 // State Management 871 //===--------------------------------------------------------------------===// 872 873 /// Return the current state of the rewriter. 874 RewriterState getCurrentState(); 875 876 /// Reset the state of the rewriter to a previously saved point. 877 void resetState(RewriterState state); 878 879 /// Erase any blocks that were unlinked from their regions and stored in block 880 /// actions. 881 void eraseDanglingBlocks(); 882 883 /// Undo the block actions (motions, splits) one by one in reverse order until 884 /// "numActionsToKeep" actions remains. 885 void undoBlockActions(unsigned numActionsToKeep = 0); 886 887 /// Remap the given values to those with potentially different types. Returns 888 /// success if the values could be remapped, failure otherwise. `valueDiagTag` 889 /// is the tag used when describing a value within a diagnostic, e.g. 890 /// "operand". 891 LogicalResult remapValues(StringRef valueDiagTag, Optional<Location> inputLoc, 892 PatternRewriter &rewriter, ValueRange values, 893 SmallVectorImpl<Value> &remapped); 894 895 /// Returns true if the given operation is ignored, and does not need to be 896 /// converted. 897 bool isOpIgnored(Operation *op) const; 898 899 /// Recursively marks the nested operations under 'op' as ignored. This 900 /// removes them from being considered for legalization. 901 void markNestedOpsIgnored(Operation *op); 902 903 //===--------------------------------------------------------------------===// 904 // Type Conversion 905 //===--------------------------------------------------------------------===// 906 907 /// Convert the signature of the given block. 908 FailureOr<Block *> convertBlockSignature( 909 Block *block, TypeConverter *converter, 910 TypeConverter::SignatureConversion *conversion = nullptr); 911 912 /// Apply a signature conversion on the given region, using `converter` for 913 /// materializations if not null. 914 Block * 915 applySignatureConversion(Region *region, 916 TypeConverter::SignatureConversion &conversion, 917 TypeConverter *converter); 918 919 /// Convert the types of block arguments within the given region. 920 FailureOr<Block *> 921 convertRegionTypes(Region *region, TypeConverter &converter, 922 TypeConverter::SignatureConversion *entryConversion); 923 924 /// Convert the types of non-entry block arguments within the given region. 925 LogicalResult convertNonEntryRegionTypes( 926 Region *region, TypeConverter &converter, 927 ArrayRef<TypeConverter::SignatureConversion> blockConversions = {}); 928 929 //===--------------------------------------------------------------------===// 930 // Rewriter Notification Hooks 931 //===--------------------------------------------------------------------===// 932 933 /// PatternRewriter hook for replacing the results of an operation. 934 void notifyOpReplaced(Operation *op, ValueRange newValues); 935 936 /// Notifies that a block is about to be erased. 937 void notifyBlockIsBeingErased(Block *block); 938 939 /// Notifies that a block was created. 940 void notifyCreatedBlock(Block *block); 941 942 /// Notifies that a block was split. 943 void notifySplitBlock(Block *block, Block *continuation); 944 945 /// Notifies that `block` is being merged with `srcBlock`. 946 void notifyBlocksBeingMerged(Block *block, Block *srcBlock); 947 948 /// Notifies that the blocks of a region are about to be moved. 949 void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, 950 Region::iterator before); 951 952 /// Notifies that the blocks of a region were cloned into another. 953 void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks, 954 Location origRegionLoc); 955 956 /// Notifies that a pattern match failed for the given reason. 957 LogicalResult 958 notifyMatchFailure(Location loc, 959 function_ref<void(Diagnostic &)> reasonCallback); 960 961 //===--------------------------------------------------------------------===// 962 // State 963 //===--------------------------------------------------------------------===// 964 965 // Mapping between replaced values that differ in type. This happens when 966 // replacing a value with one of a different type. 967 ConversionValueMapping mapping; 968 969 /// Utility used to convert block arguments. 970 ArgConverter argConverter; 971 972 /// Ordered vector of all of the newly created operations during conversion. 973 SmallVector<Operation *> createdOps; 974 975 /// Ordered vector of all unresolved type conversion materializations during 976 /// conversion. 977 SmallVector<UnresolvedMaterialization> unresolvedMaterializations; 978 979 /// Ordered map of requested operation replacements. 980 llvm::MapVector<Operation *, OpReplacement> replacements; 981 982 /// Ordered vector of any requested block argument replacements. 983 SmallVector<BlockArgument, 4> argReplacements; 984 985 /// Ordered list of block operations (creations, splits, motions). 986 SmallVector<BlockAction, 4> blockActions; 987 988 /// A set of operations that should no longer be considered for legalization, 989 /// but were not directly replace/erased/etc. by a pattern. These are 990 /// generally child operations of other operations who were 991 /// replaced/erased/etc. This is not meant to be an exhaustive list of all 992 /// operations, but the minimal set that can be used to detect if a given 993 /// operation should be `ignored`. For example, we may add the operations that 994 /// define non-empty regions to the set, but not any of the others. This 995 /// simplifies the amount of memory needed as we can query if the parent 996 /// operation was ignored. 997 SetVector<Operation *> ignoredOps; 998 999 /// A transaction state for each of operations that were updated in-place. 1000 SmallVector<OperationTransactionState, 4> rootUpdates; 1001 1002 /// A vector of indices into `replacements` of operations that were replaced 1003 /// with values with different result types than the original operation, e.g. 1004 /// 1->N conversion of some kind. 1005 SmallVector<unsigned, 4> operationsWithChangedResults; 1006 1007 /// The current type converter, or nullptr if no type converter is currently 1008 /// active. 1009 TypeConverter *currentTypeConverter = nullptr; 1010 1011 /// This allows the user to collect the match failure message. 1012 function_ref<void(Diagnostic &)> notifyCallback; 1013 1014 #ifndef NDEBUG 1015 /// A set of operations that have pending updates. This tracking isn't 1016 /// strictly necessary, and is thus only active during debug builds for extra 1017 /// verification. 1018 SmallPtrSet<Operation *, 1> pendingRootUpdates; 1019 1020 /// A logger used to emit diagnostics during the conversion process. 1021 llvm::ScopedPrinter logger{llvm::dbgs()}; 1022 #endif 1023 }; 1024 } // namespace detail 1025 } // namespace mlir 1026 1027 /// Detach any operations nested in the given operation from their parent 1028 /// blocks, and erase the given operation. This can be used when the nested 1029 /// operations are scheduled for erasure themselves, so deleting the regions of 1030 /// the given operation together with their content would result in double-free. 1031 /// This happens, for example, when rolling back op creation in the reverse 1032 /// order and if the nested ops were created before the parent op. This function 1033 /// does not need to collect nested ops recursively because it is expected to 1034 /// also be called for each nested op when it is about to be deleted. 1035 static void detachNestedAndErase(Operation *op) { 1036 for (Region ®ion : op->getRegions()) { 1037 for (Block &block : region.getBlocks()) { 1038 while (!block.getOperations().empty()) 1039 block.getOperations().remove(block.getOperations().begin()); 1040 block.dropAllDefinedValueUses(); 1041 } 1042 } 1043 op->dropAllUses(); 1044 op->erase(); 1045 } 1046 1047 void ConversionPatternRewriterImpl::discardRewrites() { 1048 // Reset any operations that were updated in place. 1049 for (auto &state : rootUpdates) 1050 state.resetOperation(); 1051 1052 undoBlockActions(); 1053 1054 // Remove any newly created ops. 1055 for (UnresolvedMaterialization &materialization : unresolvedMaterializations) 1056 detachNestedAndErase(materialization.getOp()); 1057 for (auto *op : llvm::reverse(createdOps)) 1058 detachNestedAndErase(op); 1059 } 1060 1061 void ConversionPatternRewriterImpl::applyRewrites() { 1062 // Apply all of the rewrites replacements requested during conversion. 1063 for (auto &repl : replacements) { 1064 for (OpResult result : repl.first->getResults()) 1065 if (Value newValue = mapping.lookupOrNull(result, result.getType())) 1066 result.replaceAllUsesWith(newValue); 1067 1068 // If this operation defines any regions, drop any pending argument 1069 // rewrites. 1070 if (repl.first->getNumRegions()) 1071 argConverter.notifyOpRemoved(repl.first); 1072 } 1073 1074 // Apply all of the requested argument replacements. 1075 for (BlockArgument arg : argReplacements) { 1076 Value repl = mapping.lookupOrNull(arg, arg.getType()); 1077 if (!repl) 1078 continue; 1079 1080 if (repl.isa<BlockArgument>()) { 1081 arg.replaceAllUsesWith(repl); 1082 continue; 1083 } 1084 1085 // If the replacement value is an operation, we check to make sure that we 1086 // don't replace uses that are within the parent operation of the 1087 // replacement value. 1088 Operation *replOp = repl.cast<OpResult>().getOwner(); 1089 Block *replBlock = replOp->getBlock(); 1090 arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { 1091 Operation *user = operand.getOwner(); 1092 return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); 1093 }); 1094 } 1095 1096 // Drop all of the unresolved materialization operations created during 1097 // conversion. 1098 for (auto &mat : unresolvedMaterializations) { 1099 mat.getOp()->dropAllUses(); 1100 mat.getOp()->erase(); 1101 } 1102 1103 // In a second pass, erase all of the replaced operations in reverse. This 1104 // allows processing nested operations before their parent region is 1105 // destroyed. Because we process in reverse order, producers may be deleted 1106 // before their users (a pattern deleting a producer and then the consumer) 1107 // so we first drop all uses explicitly. 1108 for (auto &repl : llvm::reverse(replacements)) { 1109 repl.first->dropAllUses(); 1110 repl.first->erase(); 1111 } 1112 1113 argConverter.applyRewrites(mapping); 1114 1115 // Now that the ops have been erased, also erase dangling blocks. 1116 eraseDanglingBlocks(); 1117 } 1118 1119 //===----------------------------------------------------------------------===// 1120 // State Management 1121 1122 RewriterState ConversionPatternRewriterImpl::getCurrentState() { 1123 return RewriterState(createdOps.size(), unresolvedMaterializations.size(), 1124 replacements.size(), argReplacements.size(), 1125 blockActions.size(), ignoredOps.size(), 1126 rootUpdates.size()); 1127 } 1128 1129 void ConversionPatternRewriterImpl::resetState(RewriterState state) { 1130 // Reset any operations that were updated in place. 1131 for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) 1132 rootUpdates[i].resetOperation(); 1133 rootUpdates.resize(state.numRootUpdates); 1134 1135 // Reset any replaced arguments. 1136 for (BlockArgument replacedArg : 1137 llvm::drop_begin(argReplacements, state.numArgReplacements)) 1138 mapping.erase(replacedArg); 1139 argReplacements.resize(state.numArgReplacements); 1140 1141 // Undo any block actions. 1142 undoBlockActions(state.numBlockActions); 1143 1144 // Reset any replaced operations and undo any saved mappings. 1145 for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) 1146 for (auto result : repl.first->getResults()) 1147 mapping.erase(result); 1148 while (replacements.size() != state.numReplacements) 1149 replacements.pop_back(); 1150 1151 // Pop all of the newly inserted materializations. 1152 while (unresolvedMaterializations.size() != 1153 state.numUnresolvedMaterializations) { 1154 UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val(); 1155 UnrealizedConversionCastOp op = mat.getOp(); 1156 1157 // If this was a target materialization, drop the mapping that was inserted. 1158 if (mat.getKind() == UnresolvedMaterialization::Target) { 1159 for (Value input : op->getOperands()) 1160 mapping.erase(input); 1161 } 1162 detachNestedAndErase(op); 1163 } 1164 1165 // Pop all of the newly created operations. 1166 while (createdOps.size() != state.numCreatedOps) { 1167 detachNestedAndErase(createdOps.back()); 1168 createdOps.pop_back(); 1169 } 1170 1171 // Pop all of the recorded ignored operations that are no longer valid. 1172 while (ignoredOps.size() != state.numIgnoredOperations) 1173 ignoredOps.pop_back(); 1174 1175 // Reset operations with changed results. 1176 while (!operationsWithChangedResults.empty() && 1177 operationsWithChangedResults.back() >= state.numReplacements) 1178 operationsWithChangedResults.pop_back(); 1179 } 1180 1181 void ConversionPatternRewriterImpl::eraseDanglingBlocks() { 1182 for (auto &action : blockActions) 1183 if (action.kind == BlockActionKind::Erase) 1184 delete action.block; 1185 } 1186 1187 void ConversionPatternRewriterImpl::undoBlockActions( 1188 unsigned numActionsToKeep) { 1189 for (auto &action : 1190 llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { 1191 switch (action.kind) { 1192 // Delete the created block. 1193 case BlockActionKind::Create: { 1194 // Unlink all of the operations within this block, they will be deleted 1195 // separately. 1196 auto &blockOps = action.block->getOperations(); 1197 while (!blockOps.empty()) 1198 blockOps.remove(blockOps.begin()); 1199 action.block->dropAllDefinedValueUses(); 1200 action.block->erase(); 1201 break; 1202 } 1203 // Put the block (owned by action) back into its original position. 1204 case BlockActionKind::Erase: { 1205 auto &blockList = action.originalPosition.region->getBlocks(); 1206 Block *insertAfterBlock = action.originalPosition.insertAfterBlock; 1207 blockList.insert((insertAfterBlock 1208 ? std::next(Region::iterator(insertAfterBlock)) 1209 : blockList.begin()), 1210 action.block); 1211 break; 1212 } 1213 // Split the block at the position which was originally the end of the 1214 // destination block (owned by action), and put the instructions back into 1215 // the block used before the merge. 1216 case BlockActionKind::Merge: { 1217 Block *sourceBlock = action.mergeInfo.sourceBlock; 1218 Block::iterator splitPoint = 1219 (action.mergeInfo.destBlockLastInst 1220 ? ++Block::iterator(action.mergeInfo.destBlockLastInst) 1221 : action.block->begin()); 1222 sourceBlock->getOperations().splice(sourceBlock->begin(), 1223 action.block->getOperations(), 1224 splitPoint, action.block->end()); 1225 break; 1226 } 1227 // Move the block back to its original position. 1228 case BlockActionKind::Move: { 1229 Region *originalRegion = action.originalPosition.region; 1230 Block *insertAfterBlock = action.originalPosition.insertAfterBlock; 1231 originalRegion->getBlocks().splice( 1232 (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock)) 1233 : originalRegion->end()), 1234 action.block->getParent()->getBlocks(), action.block); 1235 break; 1236 } 1237 // Merge back the block that was split out. 1238 case BlockActionKind::Split: { 1239 action.originalBlock->getOperations().splice( 1240 action.originalBlock->end(), action.block->getOperations()); 1241 action.block->dropAllDefinedValueUses(); 1242 action.block->erase(); 1243 break; 1244 } 1245 // Undo the type conversion. 1246 case BlockActionKind::TypeConversion: { 1247 argConverter.discardRewrites(action.block); 1248 break; 1249 } 1250 } 1251 } 1252 blockActions.resize(numActionsToKeep); 1253 } 1254 1255 LogicalResult ConversionPatternRewriterImpl::remapValues( 1256 StringRef valueDiagTag, Optional<Location> inputLoc, 1257 PatternRewriter &rewriter, ValueRange values, 1258 SmallVectorImpl<Value> &remapped) { 1259 remapped.reserve(llvm::size(values)); 1260 1261 SmallVector<Type, 1> legalTypes; 1262 for (const auto &it : llvm::enumerate(values)) { 1263 Value operand = it.value(); 1264 Type origType = operand.getType(); 1265 1266 // If a converter was provided, get the desired legal types for this 1267 // operand. 1268 Type desiredType; 1269 if (currentTypeConverter) { 1270 // If there is no legal conversion, fail to match this pattern. 1271 legalTypes.clear(); 1272 if (failed(currentTypeConverter->convertType(origType, legalTypes))) { 1273 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); 1274 return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { 1275 diag << "unable to convert type for " << valueDiagTag << " #" 1276 << it.index() << ", type was " << origType; 1277 }); 1278 } 1279 // TODO: There currently isn't any mechanism to do 1->N type conversion 1280 // via the PatternRewriter replacement API, so for now we just ignore it. 1281 if (legalTypes.size() == 1) 1282 desiredType = legalTypes.front(); 1283 } else { 1284 // TODO: What we should do here is just set `desiredType` to `origType` 1285 // and then handle the necessary type conversions after the conversion 1286 // process has finished. Unfortunately a lot of patterns currently rely on 1287 // receiving the new operands even if the types change, so we keep the 1288 // original behavior here for now until all of the patterns relying on 1289 // this get updated. 1290 } 1291 Value newOperand = mapping.lookupOrDefault(operand, desiredType); 1292 1293 // Handle the case where the conversion was 1->1 and the new operand type 1294 // isn't legal. 1295 Type newOperandType = newOperand.getType(); 1296 if (currentTypeConverter && desiredType && newOperandType != desiredType) { 1297 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); 1298 Value castValue = buildUnresolvedTargetMaterialization( 1299 operandLoc, newOperand, desiredType, currentTypeConverter, 1300 unresolvedMaterializations); 1301 mapping.map(mapping.lookupOrDefault(newOperand), castValue); 1302 newOperand = castValue; 1303 } 1304 remapped.push_back(newOperand); 1305 } 1306 return success(); 1307 } 1308 1309 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { 1310 // Check to see if this operation was replaced or its parent ignored. 1311 return replacements.count(op) || ignoredOps.count(op->getParentOp()); 1312 } 1313 1314 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { 1315 // Walk this operation and collect nested operations that define non-empty 1316 // regions. We mark such operations as 'ignored' so that we know we don't have 1317 // to convert them, or their nested ops. 1318 if (op->getNumRegions() == 0) 1319 return; 1320 op->walk([&](Operation *op) { 1321 if (llvm::any_of(op->getRegions(), 1322 [](Region ®ion) { return !region.empty(); })) 1323 ignoredOps.insert(op); 1324 }); 1325 } 1326 1327 //===----------------------------------------------------------------------===// 1328 // Type Conversion 1329 1330 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( 1331 Block *block, TypeConverter *converter, 1332 TypeConverter::SignatureConversion *conversion) { 1333 FailureOr<Block *> result = 1334 conversion ? argConverter.applySignatureConversion( 1335 block, converter, *conversion, mapping, argReplacements) 1336 : argConverter.convertSignature(block, converter, mapping, 1337 argReplacements); 1338 if (failed(result)) 1339 return failure(); 1340 if (Block *newBlock = *result) { 1341 if (newBlock != block) 1342 blockActions.push_back(BlockAction::getTypeConversion(newBlock)); 1343 } 1344 return result; 1345 } 1346 1347 Block *ConversionPatternRewriterImpl::applySignatureConversion( 1348 Region *region, TypeConverter::SignatureConversion &conversion, 1349 TypeConverter *converter) { 1350 if (!region->empty()) 1351 return *convertBlockSignature(®ion->front(), converter, &conversion); 1352 return nullptr; 1353 } 1354 1355 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( 1356 Region *region, TypeConverter &converter, 1357 TypeConverter::SignatureConversion *entryConversion) { 1358 argConverter.setConverter(region, &converter); 1359 if (region->empty()) 1360 return nullptr; 1361 1362 if (failed(convertNonEntryRegionTypes(region, converter))) 1363 return failure(); 1364 1365 FailureOr<Block *> newEntry = 1366 convertBlockSignature(®ion->front(), &converter, entryConversion); 1367 return newEntry; 1368 } 1369 1370 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( 1371 Region *region, TypeConverter &converter, 1372 ArrayRef<TypeConverter::SignatureConversion> blockConversions) { 1373 argConverter.setConverter(region, &converter); 1374 if (region->empty()) 1375 return success(); 1376 1377 // Convert the arguments of each block within the region. 1378 int blockIdx = 0; 1379 assert((blockConversions.empty() || 1380 blockConversions.size() == region->getBlocks().size() - 1) && 1381 "expected either to provide no SignatureConversions at all or to " 1382 "provide a SignatureConversion for each non-entry block"); 1383 1384 for (Block &block : 1385 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { 1386 TypeConverter::SignatureConversion *blockConversion = 1387 blockConversions.empty() 1388 ? nullptr 1389 : const_cast<TypeConverter::SignatureConversion *>( 1390 &blockConversions[blockIdx++]); 1391 1392 if (failed(convertBlockSignature(&block, &converter, blockConversion))) 1393 return failure(); 1394 } 1395 return success(); 1396 } 1397 1398 //===----------------------------------------------------------------------===// 1399 // Rewriter Notification Hooks 1400 1401 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, 1402 ValueRange newValues) { 1403 assert(newValues.size() == op->getNumResults()); 1404 assert(!replacements.count(op) && "operation was already replaced"); 1405 1406 // Track if any of the results changed, e.g. erased and replaced with null. 1407 bool resultChanged = false; 1408 1409 // Create mappings for each of the new result values. 1410 Value newValue, result; 1411 for (auto it : llvm::zip(newValues, op->getResults())) { 1412 std::tie(newValue, result) = it; 1413 if (!newValue) { 1414 resultChanged = true; 1415 continue; 1416 } 1417 // Remap, and check for any result type changes. 1418 mapping.map(result, newValue); 1419 resultChanged |= (newValue.getType() != result.getType()); 1420 } 1421 if (resultChanged) 1422 operationsWithChangedResults.push_back(replacements.size()); 1423 1424 // Record the requested operation replacement. 1425 replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter))); 1426 1427 // Mark this operation as recursively ignored so that we don't need to 1428 // convert any nested operations. 1429 markNestedOpsIgnored(op); 1430 } 1431 1432 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { 1433 Region *region = block->getParent(); 1434 Block *origPrevBlock = block->getPrevNode(); 1435 blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock})); 1436 } 1437 1438 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) { 1439 blockActions.push_back(BlockAction::getCreate(block)); 1440 } 1441 1442 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, 1443 Block *continuation) { 1444 blockActions.push_back(BlockAction::getSplit(continuation, block)); 1445 } 1446 1447 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block, 1448 Block *srcBlock) { 1449 blockActions.push_back(BlockAction::getMerge(block, srcBlock)); 1450 } 1451 1452 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( 1453 Region ®ion, Region &parent, Region::iterator before) { 1454 if (region.empty()) 1455 return; 1456 Block *laterBlock = ®ion.back(); 1457 for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) { 1458 blockActions.push_back( 1459 BlockAction::getMove(laterBlock, {®ion, &earlierBlock})); 1460 laterBlock = &earlierBlock; 1461 } 1462 blockActions.push_back(BlockAction::getMove(laterBlock, {®ion, nullptr})); 1463 } 1464 1465 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( 1466 iterator_range<Region::iterator> &blocks, Location origRegionLoc) { 1467 for (Block &block : blocks) 1468 blockActions.push_back(BlockAction::getCreate(&block)); 1469 1470 // Compute the conversion set for the inlined region. 1471 auto result = computeConversionSet(blocks, origRegionLoc, createdOps); 1472 1473 // This original region has already had its conversion set computed, so there 1474 // shouldn't be any new failures. 1475 (void)result; 1476 assert(succeeded(result) && "expected region to have no unreachable blocks"); 1477 } 1478 1479 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( 1480 Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 1481 LLVM_DEBUG({ 1482 Diagnostic diag(loc, DiagnosticSeverity::Remark); 1483 reasonCallback(diag); 1484 logger.startLine() << "** Failure : " << diag.str() << "\n"; 1485 if (notifyCallback) 1486 notifyCallback(diag); 1487 }); 1488 return failure(); 1489 } 1490 1491 //===----------------------------------------------------------------------===// 1492 // ConversionPatternRewriter 1493 //===----------------------------------------------------------------------===// 1494 1495 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) 1496 : PatternRewriter(ctx), 1497 impl(new detail::ConversionPatternRewriterImpl(*this)) {} 1498 ConversionPatternRewriter::~ConversionPatternRewriter() = default; 1499 1500 void ConversionPatternRewriter::replaceOpWithIf( 1501 Operation *op, ValueRange newValues, bool *allUsesReplaced, 1502 llvm::unique_function<bool(OpOperand &) const> functor) { 1503 // TODO: To support this we will need to rework a bit of how replacements are 1504 // tracked, given that this isn't guranteed to replace all of the uses of an 1505 // operation. The main change is that now an operation can be replaced 1506 // multiple times, in parts. The current "set" based tracking is mainly useful 1507 // for tracking if a replaced operation should be ignored, i.e. if all of the 1508 // uses will be replaced. 1509 llvm_unreachable( 1510 "replaceOpWithIf is currently not supported by DialectConversion"); 1511 } 1512 1513 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { 1514 LLVM_DEBUG({ 1515 impl->logger.startLine() 1516 << "** Replace : '" << op->getName() << "'(" << op << ")\n"; 1517 }); 1518 impl->notifyOpReplaced(op, newValues); 1519 } 1520 1521 void ConversionPatternRewriter::eraseOp(Operation *op) { 1522 LLVM_DEBUG({ 1523 impl->logger.startLine() 1524 << "** Erase : '" << op->getName() << "'(" << op << ")\n"; 1525 }); 1526 SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr); 1527 impl->notifyOpReplaced(op, nullRepls); 1528 } 1529 1530 void ConversionPatternRewriter::eraseBlock(Block *block) { 1531 impl->notifyBlockIsBeingErased(block); 1532 1533 // Mark all ops for erasure. 1534 for (Operation &op : *block) 1535 eraseOp(&op); 1536 1537 // Unlink the block from its parent region. The block is kept in the block 1538 // action and will be actually destroyed when rewrites are applied. This 1539 // allows us to keep the operations in the block live and undo the removal by 1540 // re-inserting the block. 1541 block->getParent()->getBlocks().remove(block); 1542 } 1543 1544 Block *ConversionPatternRewriter::applySignatureConversion( 1545 Region *region, TypeConverter::SignatureConversion &conversion, 1546 TypeConverter *converter) { 1547 return impl->applySignatureConversion(region, conversion, converter); 1548 } 1549 1550 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( 1551 Region *region, TypeConverter &converter, 1552 TypeConverter::SignatureConversion *entryConversion) { 1553 return impl->convertRegionTypes(region, converter, entryConversion); 1554 } 1555 1556 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( 1557 Region *region, TypeConverter &converter, 1558 ArrayRef<TypeConverter::SignatureConversion> blockConversions) { 1559 return impl->convertNonEntryRegionTypes(region, converter, blockConversions); 1560 } 1561 1562 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, 1563 Value to) { 1564 LLVM_DEBUG({ 1565 Operation *parentOp = from.getOwner()->getParentOp(); 1566 impl->logger.startLine() << "** Replace Argument : '" << from 1567 << "'(in region of '" << parentOp->getName() 1568 << "'(" << from.getOwner()->getParentOp() << ")\n"; 1569 }); 1570 impl->argReplacements.push_back(from); 1571 impl->mapping.map(impl->mapping.lookupOrDefault(from), to); 1572 } 1573 1574 Value ConversionPatternRewriter::getRemappedValue(Value key) { 1575 SmallVector<Value> remappedValues; 1576 if (failed(impl->remapValues("value", /*inputLoc=*/llvm::None, *this, key, 1577 remappedValues))) 1578 return nullptr; 1579 return remappedValues.front(); 1580 } 1581 1582 LogicalResult 1583 ConversionPatternRewriter::getRemappedValues(ValueRange keys, 1584 SmallVectorImpl<Value> &results) { 1585 if (keys.empty()) 1586 return success(); 1587 return impl->remapValues("value", /*inputLoc=*/llvm::None, *this, keys, 1588 results); 1589 } 1590 1591 void ConversionPatternRewriter::notifyBlockCreated(Block *block) { 1592 impl->notifyCreatedBlock(block); 1593 } 1594 1595 Block *ConversionPatternRewriter::splitBlock(Block *block, 1596 Block::iterator before) { 1597 auto *continuation = PatternRewriter::splitBlock(block, before); 1598 impl->notifySplitBlock(block, continuation); 1599 return continuation; 1600 } 1601 1602 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, 1603 ValueRange argValues) { 1604 impl->notifyBlocksBeingMerged(dest, source); 1605 assert(llvm::all_of(source->getPredecessors(), 1606 [dest](Block *succ) { return succ == dest; }) && 1607 "expected 'source' to have no predecessors or only 'dest'"); 1608 assert(argValues.size() == source->getNumArguments() && 1609 "incorrect # of argument replacement values"); 1610 for (auto it : llvm::zip(source->getArguments(), argValues)) 1611 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); 1612 dest->getOperations().splice(dest->end(), source->getOperations()); 1613 eraseBlock(source); 1614 } 1615 1616 void ConversionPatternRewriter::inlineRegionBefore(Region ®ion, 1617 Region &parent, 1618 Region::iterator before) { 1619 impl->notifyRegionIsBeingInlinedBefore(region, parent, before); 1620 PatternRewriter::inlineRegionBefore(region, parent, before); 1621 } 1622 1623 void ConversionPatternRewriter::cloneRegionBefore( 1624 Region ®ion, Region &parent, Region::iterator before, 1625 BlockAndValueMapping &mapping) { 1626 if (region.empty()) 1627 return; 1628 PatternRewriter::cloneRegionBefore(region, parent, before, mapping); 1629 1630 // Collect the range of the cloned blocks. 1631 auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); 1632 auto clonedBlocks = llvm::make_range(clonedBeginIt, before); 1633 impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc()); 1634 } 1635 1636 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) { 1637 LLVM_DEBUG({ 1638 impl->logger.startLine() 1639 << "** Insert : '" << op->getName() << "'(" << op << ")\n"; 1640 }); 1641 impl->createdOps.push_back(op); 1642 } 1643 1644 void ConversionPatternRewriter::startRootUpdate(Operation *op) { 1645 #ifndef NDEBUG 1646 impl->pendingRootUpdates.insert(op); 1647 #endif 1648 impl->rootUpdates.emplace_back(op); 1649 } 1650 1651 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { 1652 // There is nothing to do here, we only need to track the operation at the 1653 // start of the update. 1654 #ifndef NDEBUG 1655 assert(impl->pendingRootUpdates.erase(op) && 1656 "operation did not have a pending in-place update"); 1657 #endif 1658 } 1659 1660 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { 1661 #ifndef NDEBUG 1662 assert(impl->pendingRootUpdates.erase(op) && 1663 "operation did not have a pending in-place update"); 1664 #endif 1665 // Erase the last update for this operation. 1666 auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; 1667 auto &rootUpdates = impl->rootUpdates; 1668 auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); 1669 assert(it != rootUpdates.rend() && "no root update started on op"); 1670 (*it).resetOperation(); 1671 int updateIdx = std::prev(rootUpdates.rend()) - it; 1672 rootUpdates.erase(rootUpdates.begin() + updateIdx); 1673 } 1674 1675 LogicalResult ConversionPatternRewriter::notifyMatchFailure( 1676 Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 1677 return impl->notifyMatchFailure(loc, reasonCallback); 1678 } 1679 1680 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { 1681 return *impl; 1682 } 1683 1684 //===----------------------------------------------------------------------===// 1685 // ConversionPattern 1686 //===----------------------------------------------------------------------===// 1687 1688 LogicalResult 1689 ConversionPattern::matchAndRewrite(Operation *op, 1690 PatternRewriter &rewriter) const { 1691 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); 1692 auto &rewriterImpl = dialectRewriter.getImpl(); 1693 1694 // Track the current conversion pattern type converter in the rewriter. 1695 llvm::SaveAndRestore<TypeConverter *> currentConverterGuard( 1696 rewriterImpl.currentTypeConverter, getTypeConverter()); 1697 1698 // Remap the operands of the operation. 1699 SmallVector<Value, 4> operands; 1700 if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, 1701 op->getOperands(), operands))) { 1702 return failure(); 1703 } 1704 return matchAndRewrite(op, operands, dialectRewriter); 1705 } 1706 1707 //===----------------------------------------------------------------------===// 1708 // OperationLegalizer 1709 //===----------------------------------------------------------------------===// 1710 1711 namespace { 1712 /// A set of rewrite patterns that can be used to legalize a given operation. 1713 using LegalizationPatterns = SmallVector<const Pattern *, 1>; 1714 1715 /// This class defines a recursive operation legalizer. 1716 class OperationLegalizer { 1717 public: 1718 using LegalizationAction = ConversionTarget::LegalizationAction; 1719 1720 OperationLegalizer(ConversionTarget &targetInfo, 1721 const FrozenRewritePatternSet &patterns); 1722 1723 /// Returns true if the given operation is known to be illegal on the target. 1724 bool isIllegal(Operation *op) const; 1725 1726 /// Attempt to legalize the given operation. Returns success if the operation 1727 /// was legalized, failure otherwise. 1728 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); 1729 1730 /// Returns the conversion target in use by the legalizer. 1731 ConversionTarget &getTarget() { return target; } 1732 1733 private: 1734 /// Attempt to legalize the given operation by folding it. 1735 LogicalResult legalizeWithFold(Operation *op, 1736 ConversionPatternRewriter &rewriter); 1737 1738 /// Attempt to legalize the given operation by applying a pattern. Returns 1739 /// success if the operation was legalized, failure otherwise. 1740 LogicalResult legalizeWithPattern(Operation *op, 1741 ConversionPatternRewriter &rewriter); 1742 1743 /// Return true if the given pattern may be applied to the given operation, 1744 /// false otherwise. 1745 bool canApplyPattern(Operation *op, const Pattern &pattern, 1746 ConversionPatternRewriter &rewriter); 1747 1748 /// Legalize the resultant IR after successfully applying the given pattern. 1749 LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, 1750 ConversionPatternRewriter &rewriter, 1751 RewriterState &curState); 1752 1753 /// Legalizes the actions registered during the execution of a pattern. 1754 LogicalResult legalizePatternBlockActions(Operation *op, 1755 ConversionPatternRewriter &rewriter, 1756 ConversionPatternRewriterImpl &impl, 1757 RewriterState &state, 1758 RewriterState &newState); 1759 LogicalResult legalizePatternCreatedOperations( 1760 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 1761 RewriterState &state, RewriterState &newState); 1762 LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, 1763 ConversionPatternRewriterImpl &impl, 1764 RewriterState &state, 1765 RewriterState &newState); 1766 1767 //===--------------------------------------------------------------------===// 1768 // Cost Model 1769 //===--------------------------------------------------------------------===// 1770 1771 /// Build an optimistic legalization graph given the provided patterns. This 1772 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with 1773 /// patterns for operations that are not directly legal, but may be 1774 /// transitively legal for the current target given the provided patterns. 1775 void buildLegalizationGraph( 1776 LegalizationPatterns &anyOpLegalizerPatterns, 1777 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1778 1779 /// Compute the benefit of each node within the computed legalization graph. 1780 /// This orders the patterns within 'legalizerPatterns' based upon two 1781 /// criteria: 1782 /// 1) Prefer patterns that have the lowest legalization depth, i.e. 1783 /// represent the more direct mapping to the target. 1784 /// 2) When comparing patterns with the same legalization depth, prefer the 1785 /// pattern with the highest PatternBenefit. This allows for users to 1786 /// prefer specific legalizations over others. 1787 void computeLegalizationGraphBenefit( 1788 LegalizationPatterns &anyOpLegalizerPatterns, 1789 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1790 1791 /// Compute the legalization depth when legalizing an operation of the given 1792 /// type. 1793 unsigned computeOpLegalizationDepth( 1794 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 1795 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1796 1797 /// Apply the conversion cost model to the given set of patterns, and return 1798 /// the smallest legalization depth of any of the patterns. See 1799 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. 1800 unsigned applyCostModelToPatterns( 1801 LegalizationPatterns &patterns, 1802 DenseMap<OperationName, unsigned> &minOpPatternDepth, 1803 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1804 1805 /// The current set of patterns that have been applied. 1806 SmallPtrSet<const Pattern *, 8> appliedPatterns; 1807 1808 /// The legalization information provided by the target. 1809 ConversionTarget ⌖ 1810 1811 /// The pattern applicator to use for conversions. 1812 PatternApplicator applicator; 1813 }; 1814 } // namespace 1815 1816 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, 1817 const FrozenRewritePatternSet &patterns) 1818 : target(targetInfo), applicator(patterns) { 1819 // The set of patterns that can be applied to illegal operations to transform 1820 // them into legal ones. 1821 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; 1822 LegalizationPatterns anyOpLegalizerPatterns; 1823 1824 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); 1825 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); 1826 } 1827 1828 bool OperationLegalizer::isIllegal(Operation *op) const { 1829 return target.isIllegal(op); 1830 } 1831 1832 LogicalResult 1833 OperationLegalizer::legalize(Operation *op, 1834 ConversionPatternRewriter &rewriter) { 1835 #ifndef NDEBUG 1836 const char *logLineComment = 1837 "//===-------------------------------------------===//\n"; 1838 1839 auto &logger = rewriter.getImpl().logger; 1840 #endif 1841 LLVM_DEBUG({ 1842 logger.getOStream() << "\n"; 1843 logger.startLine() << logLineComment; 1844 logger.startLine() << "Legalizing operation : '" << op->getName() << "'(" 1845 << op << ") {\n"; 1846 logger.indent(); 1847 1848 // If the operation has no regions, just print it here. 1849 if (op->getNumRegions() == 0) { 1850 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); 1851 logger.getOStream() << "\n\n"; 1852 } 1853 }); 1854 1855 // Check if this operation is legal on the target. 1856 if (auto legalityInfo = target.isLegal(op)) { 1857 LLVM_DEBUG({ 1858 logSuccess( 1859 logger, "operation marked legal by the target{0}", 1860 legalityInfo->isRecursivelyLegal 1861 ? "; NOTE: operation is recursively legal; skipping internals" 1862 : ""); 1863 logger.startLine() << logLineComment; 1864 }); 1865 1866 // If this operation is recursively legal, mark its children as ignored so 1867 // that we don't consider them for legalization. 1868 if (legalityInfo->isRecursivelyLegal) 1869 rewriter.getImpl().markNestedOpsIgnored(op); 1870 return success(); 1871 } 1872 1873 // Check to see if the operation is ignored and doesn't need to be converted. 1874 if (rewriter.getImpl().isOpIgnored(op)) { 1875 LLVM_DEBUG({ 1876 logSuccess(logger, "operation marked 'ignored' during conversion"); 1877 logger.startLine() << logLineComment; 1878 }); 1879 return success(); 1880 } 1881 1882 // If the operation isn't legal, try to fold it in-place. 1883 // TODO: Should we always try to do this, even if the op is 1884 // already legal? 1885 if (succeeded(legalizeWithFold(op, rewriter))) { 1886 LLVM_DEBUG({ 1887 logSuccess(logger, "operation was folded"); 1888 logger.startLine() << logLineComment; 1889 }); 1890 return success(); 1891 } 1892 1893 // Otherwise, we need to apply a legalization pattern to this operation. 1894 if (succeeded(legalizeWithPattern(op, rewriter))) { 1895 LLVM_DEBUG({ 1896 logSuccess(logger, ""); 1897 logger.startLine() << logLineComment; 1898 }); 1899 return success(); 1900 } 1901 1902 LLVM_DEBUG({ 1903 logFailure(logger, "no matched legalization pattern"); 1904 logger.startLine() << logLineComment; 1905 }); 1906 return failure(); 1907 } 1908 1909 LogicalResult 1910 OperationLegalizer::legalizeWithFold(Operation *op, 1911 ConversionPatternRewriter &rewriter) { 1912 auto &rewriterImpl = rewriter.getImpl(); 1913 RewriterState curState = rewriterImpl.getCurrentState(); 1914 1915 LLVM_DEBUG({ 1916 rewriterImpl.logger.startLine() << "* Fold {\n"; 1917 rewriterImpl.logger.indent(); 1918 }); 1919 1920 // Try to fold the operation. 1921 SmallVector<Value, 2> replacementValues; 1922 rewriter.setInsertionPoint(op); 1923 if (failed(rewriter.tryFold(op, replacementValues))) { 1924 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); 1925 return failure(); 1926 } 1927 1928 // Insert a replacement for 'op' with the folded replacement values. 1929 rewriter.replaceOp(op, replacementValues); 1930 1931 // Recursively legalize any new constant operations. 1932 for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); 1933 i != e; ++i) { 1934 Operation *cstOp = rewriterImpl.createdOps[i]; 1935 if (failed(legalize(cstOp, rewriter))) { 1936 LLVM_DEBUG(logFailure(rewriterImpl.logger, 1937 "failed to legalize generated constant '{0}'", 1938 cstOp->getName())); 1939 rewriterImpl.resetState(curState); 1940 return failure(); 1941 } 1942 } 1943 1944 LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); 1945 return success(); 1946 } 1947 1948 LogicalResult 1949 OperationLegalizer::legalizeWithPattern(Operation *op, 1950 ConversionPatternRewriter &rewriter) { 1951 auto &rewriterImpl = rewriter.getImpl(); 1952 1953 // Functor that returns if the given pattern may be applied. 1954 auto canApply = [&](const Pattern &pattern) { 1955 return canApplyPattern(op, pattern, rewriter); 1956 }; 1957 1958 // Functor that cleans up the rewriter state after a pattern failed to match. 1959 RewriterState curState = rewriterImpl.getCurrentState(); 1960 auto onFailure = [&](const Pattern &pattern) { 1961 LLVM_DEBUG({ 1962 logFailure(rewriterImpl.logger, "pattern failed to match"); 1963 if (rewriterImpl.notifyCallback) { 1964 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); 1965 diag << "Failed to apply pattern \"" << pattern.getDebugName() 1966 << "\" on op:\n" 1967 << *op; 1968 rewriterImpl.notifyCallback(diag); 1969 } 1970 }); 1971 rewriterImpl.resetState(curState); 1972 appliedPatterns.erase(&pattern); 1973 }; 1974 1975 // Functor that performs additional legalization when a pattern is 1976 // successfully applied. 1977 auto onSuccess = [&](const Pattern &pattern) { 1978 auto result = legalizePatternResult(op, pattern, rewriter, curState); 1979 appliedPatterns.erase(&pattern); 1980 if (failed(result)) 1981 rewriterImpl.resetState(curState); 1982 return result; 1983 }; 1984 1985 // Try to match and rewrite a pattern on this operation. 1986 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, 1987 onSuccess); 1988 } 1989 1990 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, 1991 ConversionPatternRewriter &rewriter) { 1992 LLVM_DEBUG({ 1993 auto &os = rewriter.getImpl().logger; 1994 os.getOStream() << "\n"; 1995 os.startLine() << "* Pattern : '" << op->getName() << " -> ("; 1996 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream()); 1997 os.getOStream() << ")' {\n"; 1998 os.indent(); 1999 }); 2000 2001 // Ensure that we don't cycle by not allowing the same pattern to be 2002 // applied twice in the same recursion stack if it is not known to be safe. 2003 if (!pattern.hasBoundedRewriteRecursion() && 2004 !appliedPatterns.insert(&pattern).second) { 2005 LLVM_DEBUG( 2006 logFailure(rewriter.getImpl().logger, "pattern was already applied")); 2007 return false; 2008 } 2009 return true; 2010 } 2011 2012 LogicalResult 2013 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, 2014 ConversionPatternRewriter &rewriter, 2015 RewriterState &curState) { 2016 auto &impl = rewriter.getImpl(); 2017 2018 #ifndef NDEBUG 2019 assert(impl.pendingRootUpdates.empty() && "dangling root updates"); 2020 #endif 2021 2022 // Check that the root was either replaced or updated in place. 2023 auto replacedRoot = [&] { 2024 return llvm::any_of( 2025 llvm::drop_begin(impl.replacements, curState.numReplacements), 2026 [op](auto &it) { return it.first == op; }); 2027 }; 2028 auto updatedRootInPlace = [&] { 2029 return llvm::any_of( 2030 llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), 2031 [op](auto &state) { return state.getOperation() == op; }); 2032 }; 2033 (void)replacedRoot; 2034 (void)updatedRootInPlace; 2035 assert((replacedRoot() || updatedRootInPlace()) && 2036 "expected pattern to replace the root operation"); 2037 2038 // Legalize each of the actions registered during application. 2039 RewriterState newState = impl.getCurrentState(); 2040 if (failed(legalizePatternBlockActions(op, rewriter, impl, curState, 2041 newState)) || 2042 failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || 2043 failed(legalizePatternCreatedOperations(rewriter, impl, curState, 2044 newState))) { 2045 return failure(); 2046 } 2047 2048 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); 2049 return success(); 2050 } 2051 2052 LogicalResult OperationLegalizer::legalizePatternBlockActions( 2053 Operation *op, ConversionPatternRewriter &rewriter, 2054 ConversionPatternRewriterImpl &impl, RewriterState &state, 2055 RewriterState &newState) { 2056 SmallPtrSet<Operation *, 16> operationsToIgnore; 2057 2058 // If the pattern moved or created any blocks, make sure the types of block 2059 // arguments get legalized. 2060 for (int i = state.numBlockActions, e = newState.numBlockActions; i != e; 2061 ++i) { 2062 auto &action = impl.blockActions[i]; 2063 if (action.kind == BlockActionKind::TypeConversion || 2064 action.kind == BlockActionKind::Erase) 2065 continue; 2066 // Only check blocks outside of the current operation. 2067 Operation *parentOp = action.block->getParentOp(); 2068 if (!parentOp || parentOp == op || action.block->getNumArguments() == 0) 2069 continue; 2070 2071 // If the region of the block has a type converter, try to convert the block 2072 // directly. 2073 if (auto *converter = 2074 impl.argConverter.getConverter(action.block->getParent())) { 2075 if (failed(impl.convertBlockSignature(action.block, converter))) { 2076 LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " 2077 "block")); 2078 return failure(); 2079 } 2080 continue; 2081 } 2082 2083 // Otherwise, check that this operation isn't one generated by this pattern. 2084 // This is because we will attempt to legalize the parent operation, and 2085 // blocks in regions created by this pattern will already be legalized later 2086 // on. If we haven't built the set yet, build it now. 2087 if (operationsToIgnore.empty()) { 2088 auto createdOps = ArrayRef<Operation *>(impl.createdOps) 2089 .drop_front(state.numCreatedOps); 2090 operationsToIgnore.insert(createdOps.begin(), createdOps.end()); 2091 } 2092 2093 // If this operation should be considered for re-legalization, try it. 2094 if (operationsToIgnore.insert(parentOp).second && 2095 failed(legalize(parentOp, rewriter))) { 2096 LLVM_DEBUG(logFailure( 2097 impl.logger, "operation '{0}'({1}) became illegal after block action", 2098 parentOp->getName(), parentOp)); 2099 return failure(); 2100 } 2101 } 2102 return success(); 2103 } 2104 2105 LogicalResult OperationLegalizer::legalizePatternCreatedOperations( 2106 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2107 RewriterState &state, RewriterState &newState) { 2108 for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) { 2109 Operation *op = impl.createdOps[i]; 2110 if (failed(legalize(op, rewriter))) { 2111 LLVM_DEBUG(logFailure(impl.logger, 2112 "failed to legalize generated operation '{0}'({1})", 2113 op->getName(), op)); 2114 return failure(); 2115 } 2116 } 2117 return success(); 2118 } 2119 2120 LogicalResult OperationLegalizer::legalizePatternRootUpdates( 2121 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2122 RewriterState &state, RewriterState &newState) { 2123 for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { 2124 Operation *op = impl.rootUpdates[i].getOperation(); 2125 if (failed(legalize(op, rewriter))) { 2126 LLVM_DEBUG(logFailure( 2127 impl.logger, "failed to legalize operation updated in-place '{0}'", 2128 op->getName())); 2129 return failure(); 2130 } 2131 } 2132 return success(); 2133 } 2134 2135 //===----------------------------------------------------------------------===// 2136 // Cost Model 2137 2138 void OperationLegalizer::buildLegalizationGraph( 2139 LegalizationPatterns &anyOpLegalizerPatterns, 2140 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2141 // A mapping between an operation and a set of operations that can be used to 2142 // generate it. 2143 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; 2144 // A mapping between an operation and any currently invalid patterns it has. 2145 DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns; 2146 // A worklist of patterns to consider for legality. 2147 SetVector<const Pattern *> patternWorklist; 2148 2149 // Build the mapping from operations to the parent ops that may generate them. 2150 applicator.walkAllPatterns([&](const Pattern &pattern) { 2151 Optional<OperationName> root = pattern.getRootKind(); 2152 2153 // If the pattern has no specific root, we can't analyze the relationship 2154 // between the root op and generated operations. Given that, add all such 2155 // patterns to the legalization set. 2156 if (!root) { 2157 anyOpLegalizerPatterns.push_back(&pattern); 2158 return; 2159 } 2160 2161 // Skip operations that are always known to be legal. 2162 if (target.getOpAction(*root) == LegalizationAction::Legal) 2163 return; 2164 2165 // Add this pattern to the invalid set for the root op and record this root 2166 // as a parent for any generated operations. 2167 invalidPatterns[*root].insert(&pattern); 2168 for (auto op : pattern.getGeneratedOps()) 2169 parentOps[op].insert(*root); 2170 2171 // Add this pattern to the worklist. 2172 patternWorklist.insert(&pattern); 2173 }); 2174 2175 // If there are any patterns that don't have a specific root kind, we can't 2176 // make direct assumptions about what operations will never be legalized. 2177 // Note: Technically we could, but it would require an analysis that may 2178 // recurse into itself. It would be better to perform this kind of filtering 2179 // at a higher level than here anyways. 2180 if (!anyOpLegalizerPatterns.empty()) { 2181 for (const Pattern *pattern : patternWorklist) 2182 legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2183 return; 2184 } 2185 2186 while (!patternWorklist.empty()) { 2187 auto *pattern = patternWorklist.pop_back_val(); 2188 2189 // Check to see if any of the generated operations are invalid. 2190 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { 2191 Optional<LegalizationAction> action = target.getOpAction(op); 2192 return !legalizerPatterns.count(op) && 2193 (!action || action == LegalizationAction::Illegal); 2194 })) 2195 continue; 2196 2197 // Otherwise, if all of the generated operation are valid, this op is now 2198 // legal so add all of the child patterns to the worklist. 2199 legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2200 invalidPatterns[*pattern->getRootKind()].erase(pattern); 2201 2202 // Add any invalid patterns of the parent operations to see if they have now 2203 // become legal. 2204 for (auto op : parentOps[*pattern->getRootKind()]) 2205 patternWorklist.set_union(invalidPatterns[op]); 2206 } 2207 } 2208 2209 void OperationLegalizer::computeLegalizationGraphBenefit( 2210 LegalizationPatterns &anyOpLegalizerPatterns, 2211 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2212 // The smallest pattern depth, when legalizing an operation. 2213 DenseMap<OperationName, unsigned> minOpPatternDepth; 2214 2215 // For each operation that is transitively legal, compute a cost for it. 2216 for (auto &opIt : legalizerPatterns) 2217 if (!minOpPatternDepth.count(opIt.first)) 2218 computeOpLegalizationDepth(opIt.first, minOpPatternDepth, 2219 legalizerPatterns); 2220 2221 // Apply the cost model to the patterns that can match any operation. Those 2222 // with a specific operation type are already resolved when computing the op 2223 // legalization depth. 2224 if (!anyOpLegalizerPatterns.empty()) 2225 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth, 2226 legalizerPatterns); 2227 2228 // Apply a cost model to the pattern applicator. We order patterns first by 2229 // depth then benefit. `legalizerPatterns` contains per-op patterns by 2230 // decreasing benefit. 2231 applicator.applyCostModel([&](const Pattern &pattern) { 2232 ArrayRef<const Pattern *> orderedPatternList; 2233 if (Optional<OperationName> rootName = pattern.getRootKind()) 2234 orderedPatternList = legalizerPatterns[*rootName]; 2235 else 2236 orderedPatternList = anyOpLegalizerPatterns; 2237 2238 // If the pattern is not found, then it was removed and cannot be matched. 2239 auto *it = llvm::find(orderedPatternList, &pattern); 2240 if (it == orderedPatternList.end()) 2241 return PatternBenefit::impossibleToMatch(); 2242 2243 // Patterns found earlier in the list have higher benefit. 2244 return PatternBenefit(std::distance(it, orderedPatternList.end())); 2245 }); 2246 } 2247 2248 unsigned OperationLegalizer::computeOpLegalizationDepth( 2249 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 2250 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2251 // Check for existing depth. 2252 auto depthIt = minOpPatternDepth.find(op); 2253 if (depthIt != minOpPatternDepth.end()) 2254 return depthIt->second; 2255 2256 // If a mapping for this operation does not exist, then this operation 2257 // is always legal. Return 0 as the depth for a directly legal operation. 2258 auto opPatternsIt = legalizerPatterns.find(op); 2259 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) 2260 return 0u; 2261 2262 // Record this initial depth in case we encounter this op again when 2263 // recursively computing the depth. 2264 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max()); 2265 2266 // Apply the cost model to the operation patterns, and update the minimum 2267 // depth. 2268 unsigned minDepth = applyCostModelToPatterns( 2269 opPatternsIt->second, minOpPatternDepth, legalizerPatterns); 2270 minOpPatternDepth[op] = minDepth; 2271 return minDepth; 2272 } 2273 2274 unsigned OperationLegalizer::applyCostModelToPatterns( 2275 LegalizationPatterns &patterns, 2276 DenseMap<OperationName, unsigned> &minOpPatternDepth, 2277 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2278 unsigned minDepth = std::numeric_limits<unsigned>::max(); 2279 2280 // Compute the depth for each pattern within the set. 2281 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth; 2282 patternsByDepth.reserve(patterns.size()); 2283 for (const Pattern *pattern : patterns) { 2284 unsigned depth = 1; 2285 for (auto generatedOp : pattern->getGeneratedOps()) { 2286 unsigned generatedOpDepth = computeOpLegalizationDepth( 2287 generatedOp, minOpPatternDepth, legalizerPatterns); 2288 depth = std::max(depth, generatedOpDepth + 1); 2289 } 2290 patternsByDepth.emplace_back(pattern, depth); 2291 2292 // Update the minimum depth of the pattern list. 2293 minDepth = std::min(minDepth, depth); 2294 } 2295 2296 // If the operation only has one legalization pattern, there is no need to 2297 // sort them. 2298 if (patternsByDepth.size() == 1) 2299 return minDepth; 2300 2301 // Sort the patterns by those likely to be the most beneficial. 2302 llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(), 2303 [](const std::pair<const Pattern *, unsigned> *lhs, 2304 const std::pair<const Pattern *, unsigned> *rhs) { 2305 // First sort by the smaller pattern legalization 2306 // depth. 2307 if (lhs->second != rhs->second) 2308 return llvm::array_pod_sort_comparator<unsigned>( 2309 &lhs->second, &rhs->second); 2310 2311 // Then sort by the larger pattern benefit. 2312 auto lhsBenefit = lhs->first->getBenefit(); 2313 auto rhsBenefit = rhs->first->getBenefit(); 2314 return llvm::array_pod_sort_comparator<PatternBenefit>( 2315 &rhsBenefit, &lhsBenefit); 2316 }); 2317 2318 // Update the legalization pattern to use the new sorted list. 2319 patterns.clear(); 2320 for (auto &patternIt : patternsByDepth) 2321 patterns.push_back(patternIt.first); 2322 return minDepth; 2323 } 2324 2325 //===----------------------------------------------------------------------===// 2326 // OperationConverter 2327 //===----------------------------------------------------------------------===// 2328 namespace { 2329 enum OpConversionMode { 2330 /// In this mode, the conversion will ignore failed conversions to allow 2331 /// illegal operations to co-exist in the IR. 2332 Partial, 2333 2334 /// In this mode, all operations must be legal for the given target for the 2335 /// conversion to succeed. 2336 Full, 2337 2338 /// In this mode, operations are analyzed for legality. No actual rewrites are 2339 /// applied to the operations on success. 2340 Analysis, 2341 }; 2342 2343 // This class converts operations to a given conversion target via a set of 2344 // rewrite patterns. The conversion behaves differently depending on the 2345 // conversion mode. 2346 struct OperationConverter { 2347 explicit OperationConverter(ConversionTarget &target, 2348 const FrozenRewritePatternSet &patterns, 2349 OpConversionMode mode, 2350 DenseSet<Operation *> *trackedOps = nullptr) 2351 : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} 2352 2353 /// Converts the given operations to the conversion target. 2354 LogicalResult 2355 convertOperations(ArrayRef<Operation *> ops, 2356 function_ref<void(Diagnostic &)> notifyCallback = nullptr); 2357 2358 private: 2359 /// Converts an operation with the given rewriter. 2360 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); 2361 2362 /// This method is called after the conversion process to legalize any 2363 /// remaining artifacts and complete the conversion. 2364 LogicalResult finalize(ConversionPatternRewriter &rewriter); 2365 2366 /// Legalize the types of converted block arguments. 2367 LogicalResult 2368 legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, 2369 ConversionPatternRewriterImpl &rewriterImpl); 2370 2371 /// Legalize any unresolved type materializations. 2372 LogicalResult legalizeUnresolvedMaterializations( 2373 ConversionPatternRewriter &rewriter, 2374 ConversionPatternRewriterImpl &rewriterImpl, 2375 Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping); 2376 2377 /// Legalize an operation result that was marked as "erased". 2378 LogicalResult 2379 legalizeErasedResult(Operation *op, OpResult result, 2380 ConversionPatternRewriterImpl &rewriterImpl); 2381 2382 /// Legalize an operation result that was replaced with a value of a different 2383 /// type. 2384 LogicalResult legalizeChangedResultType( 2385 Operation *op, OpResult result, Value newValue, 2386 TypeConverter *replConverter, ConversionPatternRewriter &rewriter, 2387 ConversionPatternRewriterImpl &rewriterImpl, 2388 const DenseMap<Value, SmallVector<Value>> &inverseMapping); 2389 2390 /// The legalizer to use when converting operations. 2391 OperationLegalizer opLegalizer; 2392 2393 /// The conversion mode to use when legalizing operations. 2394 OpConversionMode mode; 2395 2396 /// A set of pre-existing operations. When mode == OpConversionMode::Analysis, 2397 /// this is populated with ops found to be legalizable to the target. 2398 /// When mode == OpConversionMode::Partial, this is populated with ops found 2399 /// *not* to be legalizable to the target. 2400 DenseSet<Operation *> *trackedOps; 2401 }; 2402 } // namespace 2403 2404 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, 2405 Operation *op) { 2406 // Legalize the given operation. 2407 if (failed(opLegalizer.legalize(op, rewriter))) { 2408 // Handle the case of a failed conversion for each of the different modes. 2409 // Full conversions expect all operations to be converted. 2410 if (mode == OpConversionMode::Full) 2411 return op->emitError() 2412 << "failed to legalize operation '" << op->getName() << "'"; 2413 // Partial conversions allow conversions to fail iff the operation was not 2414 // explicitly marked as illegal. If the user provided a nonlegalizableOps 2415 // set, non-legalizable ops are included. 2416 if (mode == OpConversionMode::Partial) { 2417 if (opLegalizer.isIllegal(op)) 2418 return op->emitError() 2419 << "failed to legalize operation '" << op->getName() 2420 << "' that was explicitly marked illegal"; 2421 if (trackedOps) 2422 trackedOps->insert(op); 2423 } 2424 } else if (mode == OpConversionMode::Analysis) { 2425 // Analysis conversions don't fail if any operations fail to legalize, 2426 // they are only interested in the operations that were successfully 2427 // legalized. 2428 trackedOps->insert(op); 2429 } 2430 return success(); 2431 } 2432 2433 LogicalResult OperationConverter::convertOperations( 2434 ArrayRef<Operation *> ops, 2435 function_ref<void(Diagnostic &)> notifyCallback) { 2436 if (ops.empty()) 2437 return success(); 2438 ConversionTarget &target = opLegalizer.getTarget(); 2439 2440 // Compute the set of operations and blocks to convert. 2441 SmallVector<Operation *> toConvert; 2442 for (auto *op : ops) { 2443 toConvert.emplace_back(op); 2444 for (auto ®ion : op->getRegions()) 2445 if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), 2446 toConvert, &target))) 2447 return failure(); 2448 } 2449 2450 // Convert each operation and discard rewrites on failure. 2451 ConversionPatternRewriter rewriter(ops.front()->getContext()); 2452 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2453 rewriterImpl.notifyCallback = notifyCallback; 2454 2455 for (auto *op : toConvert) 2456 if (failed(convert(rewriter, op))) 2457 return rewriterImpl.discardRewrites(), failure(); 2458 2459 // Now that all of the operations have been converted, finalize the conversion 2460 // process to ensure any lingering conversion artifacts are cleaned up and 2461 // legalized. 2462 if (failed(finalize(rewriter))) 2463 return rewriterImpl.discardRewrites(), failure(); 2464 2465 // After a successful conversion, apply rewrites if this is not an analysis 2466 // conversion. 2467 if (mode == OpConversionMode::Analysis) { 2468 rewriterImpl.discardRewrites(); 2469 } else { 2470 rewriterImpl.applyRewrites(); 2471 2472 // It is possible for a later pattern to erase an op that was originally 2473 // identified as illegal and added to the trackedOps, remove it now after 2474 // replacements have been computed. 2475 if (trackedOps) 2476 for (auto &repl : rewriterImpl.replacements) 2477 trackedOps->erase(repl.first); 2478 } 2479 return success(); 2480 } 2481 2482 LogicalResult 2483 OperationConverter::finalize(ConversionPatternRewriter &rewriter) { 2484 Optional<DenseMap<Value, SmallVector<Value>>> inverseMapping; 2485 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2486 if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, 2487 inverseMapping)) || 2488 failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) 2489 return failure(); 2490 2491 if (rewriterImpl.operationsWithChangedResults.empty()) 2492 return success(); 2493 2494 // Process requested operation replacements. 2495 for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); 2496 i != e; ++i) { 2497 unsigned replIdx = rewriterImpl.operationsWithChangedResults[i]; 2498 auto &repl = *(rewriterImpl.replacements.begin() + replIdx); 2499 for (OpResult result : repl.first->getResults()) { 2500 Value newValue = rewriterImpl.mapping.lookupOrNull(result); 2501 2502 // If the operation result was replaced with null, all of the uses of this 2503 // value should be replaced. 2504 if (!newValue) { 2505 if (failed(legalizeErasedResult(repl.first, result, rewriterImpl))) 2506 return failure(); 2507 continue; 2508 } 2509 2510 // Otherwise, check to see if the type of the result changed. 2511 if (result.getType() == newValue.getType()) 2512 continue; 2513 2514 // Compute the inverse mapping only if it is really needed. 2515 if (!inverseMapping) 2516 inverseMapping = rewriterImpl.mapping.getInverse(); 2517 2518 // Legalize this result. 2519 rewriter.setInsertionPoint(repl.first); 2520 if (failed(legalizeChangedResultType(repl.first, result, newValue, 2521 repl.second.converter, rewriter, 2522 rewriterImpl, *inverseMapping))) 2523 return failure(); 2524 2525 // Update the end iterator for this loop in the case it was updated 2526 // when legalizing generated conversion operations. 2527 e = rewriterImpl.operationsWithChangedResults.size(); 2528 } 2529 } 2530 return success(); 2531 } 2532 2533 LogicalResult OperationConverter::legalizeConvertedArgumentTypes( 2534 ConversionPatternRewriter &rewriter, 2535 ConversionPatternRewriterImpl &rewriterImpl) { 2536 // Functor used to check if all users of a value will be dead after 2537 // conversion. 2538 auto findLiveUser = [&](Value val) { 2539 auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { 2540 return rewriterImpl.isOpIgnored(user); 2541 }); 2542 return liveUserIt == val.user_end() ? nullptr : *liveUserIt; 2543 }; 2544 return rewriterImpl.argConverter.materializeLiveConversions( 2545 rewriterImpl.mapping, rewriter, findLiveUser); 2546 } 2547 2548 /// Replace the results of a materialization operation with the given values. 2549 static void 2550 replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, 2551 ResultRange matResults, ValueRange values, 2552 DenseMap<Value, SmallVector<Value>> &inverseMapping) { 2553 matResults.replaceAllUsesWith(values); 2554 2555 // For each of the materialization results, update the inverse mappings to 2556 // point to the replacement values. 2557 for (auto it : llvm::zip(matResults, values)) { 2558 Value matResult, newValue; 2559 std::tie(matResult, newValue) = it; 2560 auto inverseMapIt = inverseMapping.find(matResult); 2561 if (inverseMapIt == inverseMapping.end()) 2562 continue; 2563 2564 // Update the reverse mapping, or remove the mapping if we couldn't update 2565 // it. Not being able to update signals that the mapping would have become 2566 // circular (i.e. %foo -> newValue -> %foo), which may occur as values are 2567 // propagated through temporary materializations. We simply drop the 2568 // mapping, and let the post-conversion replacement logic handle updating 2569 // uses. 2570 for (Value inverseMapVal : inverseMapIt->second) 2571 if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue)) 2572 rewriterImpl.mapping.erase(inverseMapVal); 2573 } 2574 } 2575 2576 /// Compute all of the unresolved materializations that will persist beyond the 2577 /// conversion process, and require inserting a proper user materialization for. 2578 static void computeNecessaryMaterializations( 2579 DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps, 2580 ConversionPatternRewriter &rewriter, 2581 ConversionPatternRewriterImpl &rewriterImpl, 2582 DenseMap<Value, SmallVector<Value>> &inverseMapping, 2583 SetVector<UnresolvedMaterialization *> &necessaryMaterializations) { 2584 auto isLive = [&](Value value) { 2585 auto findFn = [&](Operation *user) { 2586 auto matIt = materializationOps.find(user); 2587 if (matIt != materializationOps.end()) 2588 return !necessaryMaterializations.count(matIt->second); 2589 return rewriterImpl.isOpIgnored(user); 2590 }; 2591 // This value may be replacing another value that has a live user. 2592 for (Value inv : inverseMapping.lookup(value)) 2593 if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end()) 2594 return true; 2595 // Or have live users itself. 2596 return llvm::find_if_not(value.getUsers(), findFn) != value.user_end(); 2597 }; 2598 2599 llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue = 2600 [&](Value invalidRoot, Value value, Type type) { 2601 // Check to see if the input operation was remapped to a variant of the 2602 // output. 2603 Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); 2604 if (remappedValue.getType() == type && remappedValue != invalidRoot) 2605 return remappedValue; 2606 2607 // Check to see if the input is a materialization operation that 2608 // provides an inverse conversion. We just check blindly for 2609 // UnrealizedConversionCastOp here, but it has no effect on correctness. 2610 auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>(); 2611 if (inputCastOp && inputCastOp->getNumOperands() == 1) 2612 return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0), 2613 type); 2614 2615 return Value(); 2616 }; 2617 2618 SetVector<UnresolvedMaterialization *> worklist; 2619 for (auto &mat : rewriterImpl.unresolvedMaterializations) { 2620 materializationOps.try_emplace(mat.getOp(), &mat); 2621 worklist.insert(&mat); 2622 } 2623 while (!worklist.empty()) { 2624 UnresolvedMaterialization *mat = worklist.pop_back_val(); 2625 UnrealizedConversionCastOp op = mat->getOp(); 2626 2627 // We currently only handle target materializations here. 2628 assert(op->getNumResults() == 1 && "unexpected materialization type"); 2629 OpResult opResult = op->getOpResult(0); 2630 Type outputType = opResult.getType(); 2631 Operation::operand_range inputOperands = op.getOperands(); 2632 2633 // Try to forward propagate operands for user conversion casts that result 2634 // in the input types of the current cast. 2635 for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) { 2636 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user); 2637 if (!castOp) 2638 continue; 2639 if (castOp->getResultTypes() == inputOperands.getTypes()) { 2640 replaceMaterialization(rewriterImpl, opResult, inputOperands, 2641 inverseMapping); 2642 necessaryMaterializations.remove(materializationOps.lookup(user)); 2643 } 2644 } 2645 2646 // Try to avoid materializing a resolved materialization if possible. 2647 // Handle the case of a 1-1 materialization. 2648 if (inputOperands.size() == 1) { 2649 // Check to see if the input operation was remapped to a variant of the 2650 // output. 2651 Value remappedValue = 2652 lookupRemappedValue(opResult, inputOperands[0], outputType); 2653 if (remappedValue && remappedValue != opResult) { 2654 replaceMaterialization(rewriterImpl, opResult, remappedValue, 2655 inverseMapping); 2656 necessaryMaterializations.remove(mat); 2657 continue; 2658 } 2659 } else { 2660 // TODO: Avoid materializing other types of conversions here. 2661 } 2662 2663 // Check to see if this is an argument materialization. 2664 auto isBlockArg = [](Value v) { return v.isa<BlockArgument>(); }; 2665 if (llvm::any_of(op->getOperands(), isBlockArg) || 2666 llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { 2667 mat->setKind(UnresolvedMaterialization::Argument); 2668 } 2669 2670 // If the materialization does not have any live users, we don't need to 2671 // generate a user materialization for it. 2672 // FIXME: For argument materializations, we currently need to check if any 2673 // of the inverse mapped values are used because some patterns expect blind 2674 // value replacement even if the types differ in some cases. When those 2675 // patterns are fixed, we can drop the argument special case here. 2676 bool isMaterializationLive = isLive(opResult); 2677 if (mat->getKind() == UnresolvedMaterialization::Argument) 2678 isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive); 2679 if (!isMaterializationLive) 2680 continue; 2681 if (!necessaryMaterializations.insert(mat)) 2682 continue; 2683 2684 // Reprocess input materializations to see if they have an updated status. 2685 for (Value input : inputOperands) { 2686 if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) { 2687 if (auto *mat = materializationOps.lookup(parentOp)) 2688 worklist.insert(mat); 2689 } 2690 } 2691 } 2692 } 2693 2694 /// Legalize the given unresolved materialization. Returns success if the 2695 /// materialization was legalized, failure otherise. 2696 static LogicalResult legalizeUnresolvedMaterialization( 2697 UnresolvedMaterialization &mat, 2698 DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps, 2699 ConversionPatternRewriter &rewriter, 2700 ConversionPatternRewriterImpl &rewriterImpl, 2701 DenseMap<Value, SmallVector<Value>> &inverseMapping) { 2702 auto findLiveUser = [&](auto &&users) { 2703 auto liveUserIt = llvm::find_if_not( 2704 users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); }); 2705 return liveUserIt == users.end() ? nullptr : *liveUserIt; 2706 }; 2707 2708 llvm::unique_function<Value(Value, Type)> lookupRemappedValue = 2709 [&](Value value, Type type) { 2710 // Check to see if the input operation was remapped to a variant of the 2711 // output. 2712 Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); 2713 if (remappedValue.getType() == type) 2714 return remappedValue; 2715 return Value(); 2716 }; 2717 2718 UnrealizedConversionCastOp op = mat.getOp(); 2719 if (!rewriterImpl.ignoredOps.insert(op)) 2720 return success(); 2721 2722 // We currently only handle target materializations here. 2723 OpResult opResult = op->getOpResult(0); 2724 Operation::operand_range inputOperands = op.getOperands(); 2725 Type outputType = opResult.getType(); 2726 2727 // If any input to this materialization is another materialization, resolve 2728 // the input first. 2729 for (Value value : op->getOperands()) { 2730 auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>(); 2731 if (!valueCast) 2732 continue; 2733 2734 auto matIt = materializationOps.find(valueCast); 2735 if (matIt != materializationOps.end()) 2736 if (failed(legalizeUnresolvedMaterialization( 2737 *matIt->second, materializationOps, rewriter, rewriterImpl, 2738 inverseMapping))) 2739 return failure(); 2740 } 2741 2742 // Perform a last ditch attempt to avoid materializing a resolved 2743 // materialization if possible. 2744 // Handle the case of a 1-1 materialization. 2745 if (inputOperands.size() == 1) { 2746 // Check to see if the input operation was remapped to a variant of the 2747 // output. 2748 Value remappedValue = lookupRemappedValue(inputOperands[0], outputType); 2749 if (remappedValue && remappedValue != opResult) { 2750 replaceMaterialization(rewriterImpl, opResult, remappedValue, 2751 inverseMapping); 2752 return success(); 2753 } 2754 } else { 2755 // TODO: Avoid materializing other types of conversions here. 2756 } 2757 2758 // Try to materialize the conversion. 2759 if (TypeConverter *converter = mat.getConverter()) { 2760 // FIXME: Determine a suitable insertion location when there are multiple 2761 // inputs. 2762 if (inputOperands.size() == 1) 2763 rewriter.setInsertionPointAfterValue(inputOperands.front()); 2764 else 2765 rewriter.setInsertionPoint(op); 2766 2767 Value newMaterialization; 2768 switch (mat.getKind()) { 2769 case UnresolvedMaterialization::Argument: 2770 // Try to materialize an argument conversion. 2771 // FIXME: The current argument materialization hook expects the original 2772 // output type, even though it doesn't use that as the actual output type 2773 // of the generated IR. The output type is just used as an indicator of 2774 // the type of materialization to do. This behavior is really awkward in 2775 // that it diverges from the behavior of the other hooks, and can be 2776 // easily misunderstood. We should clean up the argument hooks to better 2777 // represent the desired invariants we actually care about. 2778 newMaterialization = converter->materializeArgumentConversion( 2779 rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands); 2780 if (newMaterialization) 2781 break; 2782 2783 // If an argument materialization failed, fallback to trying a target 2784 // materialization. 2785 LLVM_FALLTHROUGH; 2786 case UnresolvedMaterialization::Target: 2787 newMaterialization = converter->materializeTargetConversion( 2788 rewriter, op->getLoc(), outputType, inputOperands); 2789 break; 2790 } 2791 if (newMaterialization) { 2792 replaceMaterialization(rewriterImpl, opResult, newMaterialization, 2793 inverseMapping); 2794 return success(); 2795 } 2796 } 2797 2798 InFlightDiagnostic diag = op->emitError() 2799 << "failed to legalize unresolved materialization " 2800 "from " 2801 << inputOperands.getTypes() << " to " << outputType 2802 << " that remained live after conversion"; 2803 if (Operation *liveUser = findLiveUser(op->getUsers())) { 2804 diag.attachNote(liveUser->getLoc()) 2805 << "see existing live user here: " << *liveUser; 2806 } 2807 return failure(); 2808 } 2809 2810 LogicalResult OperationConverter::legalizeUnresolvedMaterializations( 2811 ConversionPatternRewriter &rewriter, 2812 ConversionPatternRewriterImpl &rewriterImpl, 2813 Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) { 2814 if (rewriterImpl.unresolvedMaterializations.empty()) 2815 return success(); 2816 inverseMapping = rewriterImpl.mapping.getInverse(); 2817 2818 // As an initial step, compute all of the inserted materializations that we 2819 // expect to persist beyond the conversion process. 2820 DenseMap<Operation *, UnresolvedMaterialization *> materializationOps; 2821 SetVector<UnresolvedMaterialization *> necessaryMaterializations; 2822 computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, 2823 *inverseMapping, necessaryMaterializations); 2824 2825 // Once computed, legalize any necessary materializations. 2826 for (auto *mat : necessaryMaterializations) { 2827 if (failed(legalizeUnresolvedMaterialization( 2828 *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping))) 2829 return failure(); 2830 } 2831 return success(); 2832 } 2833 2834 LogicalResult OperationConverter::legalizeErasedResult( 2835 Operation *op, OpResult result, 2836 ConversionPatternRewriterImpl &rewriterImpl) { 2837 // If the operation result was replaced with null, all of the uses of this 2838 // value should be replaced. 2839 auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { 2840 return rewriterImpl.isOpIgnored(user); 2841 }); 2842 if (liveUserIt != result.user_end()) { 2843 InFlightDiagnostic diag = op->emitError("failed to legalize operation '") 2844 << op->getName() << "' marked as erased"; 2845 diag.attachNote(liveUserIt->getLoc()) 2846 << "found live user of result #" << result.getResultNumber() << ": " 2847 << *liveUserIt; 2848 return failure(); 2849 } 2850 return success(); 2851 } 2852 2853 /// Finds a user of the given value, or of any other value that the given value 2854 /// replaced, that was not replaced in the conversion process. 2855 static Operation *findLiveUserOfReplaced( 2856 Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, 2857 const DenseMap<Value, SmallVector<Value>> &inverseMapping) { 2858 SmallVector<Value> worklist(1, initialValue); 2859 while (!worklist.empty()) { 2860 Value value = worklist.pop_back_val(); 2861 2862 // Walk the users of this value to see if there are any live users that 2863 // weren't replaced during conversion. 2864 auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { 2865 return rewriterImpl.isOpIgnored(user); 2866 }); 2867 if (liveUserIt != value.user_end()) 2868 return *liveUserIt; 2869 auto mapIt = inverseMapping.find(value); 2870 if (mapIt != inverseMapping.end()) 2871 worklist.append(mapIt->second); 2872 } 2873 return nullptr; 2874 } 2875 2876 LogicalResult OperationConverter::legalizeChangedResultType( 2877 Operation *op, OpResult result, Value newValue, 2878 TypeConverter *replConverter, ConversionPatternRewriter &rewriter, 2879 ConversionPatternRewriterImpl &rewriterImpl, 2880 const DenseMap<Value, SmallVector<Value>> &inverseMapping) { 2881 Operation *liveUser = 2882 findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); 2883 if (!liveUser) 2884 return success(); 2885 2886 // Functor used to emit a conversion error for a failed materialization. 2887 auto emitConversionError = [&] { 2888 InFlightDiagnostic diag = op->emitError() 2889 << "failed to materialize conversion for result #" 2890 << result.getResultNumber() << " of operation '" 2891 << op->getName() 2892 << "' that remained live after conversion"; 2893 diag.attachNote(liveUser->getLoc()) 2894 << "see existing live user here: " << *liveUser; 2895 return failure(); 2896 }; 2897 2898 // If the replacement has a type converter, attempt to materialize a 2899 // conversion back to the original type. 2900 if (!replConverter) 2901 return emitConversionError(); 2902 2903 // Materialize a conversion for this live result value. 2904 Type resultType = result.getType(); 2905 Value convertedValue = replConverter->materializeSourceConversion( 2906 rewriter, op->getLoc(), resultType, newValue); 2907 if (!convertedValue) 2908 return emitConversionError(); 2909 2910 rewriterImpl.mapping.map(result, convertedValue); 2911 return success(); 2912 } 2913 2914 //===----------------------------------------------------------------------===// 2915 // Type Conversion 2916 //===----------------------------------------------------------------------===// 2917 2918 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, 2919 ArrayRef<Type> types) { 2920 assert(!types.empty() && "expected valid types"); 2921 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); 2922 addInputs(types); 2923 } 2924 2925 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { 2926 assert(!types.empty() && 2927 "1->0 type remappings don't need to be added explicitly"); 2928 argTypes.append(types.begin(), types.end()); 2929 } 2930 2931 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2932 unsigned newInputNo, 2933 unsigned newInputCount) { 2934 assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2935 assert(newInputCount != 0 && "expected valid input count"); 2936 remappedInputs[origInputNo] = 2937 InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr}; 2938 } 2939 2940 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2941 Value replacementValue) { 2942 assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2943 remappedInputs[origInputNo] = 2944 InputMapping{origInputNo, /*size=*/0, replacementValue}; 2945 } 2946 2947 LogicalResult TypeConverter::convertType(Type t, 2948 SmallVectorImpl<Type> &results) { 2949 auto existingIt = cachedDirectConversions.find(t); 2950 if (existingIt != cachedDirectConversions.end()) { 2951 if (existingIt->second) 2952 results.push_back(existingIt->second); 2953 return success(existingIt->second != nullptr); 2954 } 2955 auto multiIt = cachedMultiConversions.find(t); 2956 if (multiIt != cachedMultiConversions.end()) { 2957 results.append(multiIt->second.begin(), multiIt->second.end()); 2958 return success(); 2959 } 2960 2961 // Walk the added converters in reverse order to apply the most recently 2962 // registered first. 2963 size_t currentCount = results.size(); 2964 conversionCallStack.push_back(t); 2965 auto popConversionCallStack = 2966 llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); }); 2967 for (ConversionCallbackFn &converter : llvm::reverse(conversions)) { 2968 if (Optional<LogicalResult> result = 2969 converter(t, results, conversionCallStack)) { 2970 if (!succeeded(*result)) { 2971 cachedDirectConversions.try_emplace(t, nullptr); 2972 return failure(); 2973 } 2974 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); 2975 if (newTypes.size() == 1) 2976 cachedDirectConversions.try_emplace(t, newTypes.front()); 2977 else 2978 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); 2979 return success(); 2980 } 2981 } 2982 return failure(); 2983 } 2984 2985 Type TypeConverter::convertType(Type t) { 2986 // Use the multi-type result version to convert the type. 2987 SmallVector<Type, 1> results; 2988 if (failed(convertType(t, results))) 2989 return nullptr; 2990 2991 // Check to ensure that only one type was produced. 2992 return results.size() == 1 ? results.front() : nullptr; 2993 } 2994 2995 LogicalResult TypeConverter::convertTypes(TypeRange types, 2996 SmallVectorImpl<Type> &results) { 2997 for (Type type : types) 2998 if (failed(convertType(type, results))) 2999 return failure(); 3000 return success(); 3001 } 3002 3003 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; } 3004 bool TypeConverter::isLegal(Operation *op) { 3005 return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); 3006 } 3007 3008 bool TypeConverter::isLegal(Region *region) { 3009 return llvm::all_of(*region, [this](Block &block) { 3010 return isLegal(block.getArgumentTypes()); 3011 }); 3012 } 3013 3014 bool TypeConverter::isSignatureLegal(FunctionType ty) { 3015 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); 3016 } 3017 3018 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, 3019 SignatureConversion &result) { 3020 // Try to convert the given input type. 3021 SmallVector<Type, 1> convertedTypes; 3022 if (failed(convertType(type, convertedTypes))) 3023 return failure(); 3024 3025 // If this argument is being dropped, there is nothing left to do. 3026 if (convertedTypes.empty()) 3027 return success(); 3028 3029 // Otherwise, add the new inputs. 3030 result.addInputs(inputNo, convertedTypes); 3031 return success(); 3032 } 3033 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types, 3034 SignatureConversion &result, 3035 unsigned origInputOffset) { 3036 for (unsigned i = 0, e = types.size(); i != e; ++i) 3037 if (failed(convertSignatureArg(origInputOffset + i, types[i], result))) 3038 return failure(); 3039 return success(); 3040 } 3041 3042 Value TypeConverter::materializeConversion( 3043 MutableArrayRef<MaterializationCallbackFn> materializations, 3044 OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) { 3045 for (MaterializationCallbackFn &fn : llvm::reverse(materializations)) 3046 if (Optional<Value> result = fn(builder, resultType, inputs, loc)) 3047 return *result; 3048 return nullptr; 3049 } 3050 3051 auto TypeConverter::convertBlockSignature(Block *block) 3052 -> Optional<SignatureConversion> { 3053 SignatureConversion conversion(block->getNumArguments()); 3054 if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) 3055 return llvm::None; 3056 return conversion; 3057 } 3058 3059 //===----------------------------------------------------------------------===// 3060 // FunctionOpInterfaceSignatureConversion 3061 //===----------------------------------------------------------------------===// 3062 3063 /// Create a default conversion pattern that rewrites the type signature of a 3064 /// FunctionOpInterface op. This only supports ops which use FunctionType to 3065 /// represent their type. 3066 namespace { 3067 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { 3068 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, 3069 MLIRContext *ctx, 3070 TypeConverter &converter) 3071 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} 3072 3073 LogicalResult 3074 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 3075 ConversionPatternRewriter &rewriter) const override { 3076 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); 3077 FunctionType type = funcOp.getFunctionType().cast<FunctionType>(); 3078 3079 // Convert the original function types. 3080 TypeConverter::SignatureConversion result(type.getNumInputs()); 3081 SmallVector<Type, 1> newResults; 3082 if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || 3083 failed(typeConverter->convertTypes(type.getResults(), newResults)) || 3084 failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, 3085 &result))) 3086 return failure(); 3087 3088 // Update the function signature in-place. 3089 auto newType = FunctionType::get(rewriter.getContext(), 3090 result.getConvertedTypes(), newResults); 3091 3092 rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); }); 3093 3094 return success(); 3095 } 3096 }; 3097 } // namespace 3098 3099 void mlir::populateFunctionOpInterfaceTypeConversionPattern( 3100 StringRef functionLikeOpName, RewritePatternSet &patterns, 3101 TypeConverter &converter) { 3102 patterns.add<FunctionOpInterfaceSignatureConversion>( 3103 functionLikeOpName, patterns.getContext(), converter); 3104 } 3105 3106 //===----------------------------------------------------------------------===// 3107 // ConversionTarget 3108 //===----------------------------------------------------------------------===// 3109 3110 void ConversionTarget::setOpAction(OperationName op, 3111 LegalizationAction action) { 3112 legalOperations[op].action = action; 3113 } 3114 3115 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, 3116 LegalizationAction action) { 3117 for (StringRef dialect : dialectNames) 3118 legalDialects[dialect] = action; 3119 } 3120 3121 auto ConversionTarget::getOpAction(OperationName op) const 3122 -> Optional<LegalizationAction> { 3123 Optional<LegalizationInfo> info = getOpInfo(op); 3124 return info ? info->action : Optional<LegalizationAction>(); 3125 } 3126 3127 auto ConversionTarget::isLegal(Operation *op) const 3128 -> Optional<LegalOpDetails> { 3129 Optional<LegalizationInfo> info = getOpInfo(op->getName()); 3130 if (!info) 3131 return llvm::None; 3132 3133 // Returns true if this operation instance is known to be legal. 3134 auto isOpLegal = [&] { 3135 // Handle dynamic legality either with the provided legality function. 3136 if (info->action == LegalizationAction::Dynamic) { 3137 Optional<bool> result = info->legalityFn(op); 3138 if (result) 3139 return *result; 3140 } 3141 3142 // Otherwise, the operation is only legal if it was marked 'Legal'. 3143 return info->action == LegalizationAction::Legal; 3144 }; 3145 if (!isOpLegal()) 3146 return llvm::None; 3147 3148 // This operation is legal, compute any additional legality information. 3149 LegalOpDetails legalityDetails; 3150 if (info->isRecursivelyLegal) { 3151 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); 3152 if (legalityFnIt != opRecursiveLegalityFns.end()) { 3153 legalityDetails.isRecursivelyLegal = 3154 legalityFnIt->second(op).value_or(true); 3155 } else { 3156 legalityDetails.isRecursivelyLegal = true; 3157 } 3158 } 3159 return legalityDetails; 3160 } 3161 3162 bool ConversionTarget::isIllegal(Operation *op) const { 3163 Optional<LegalizationInfo> info = getOpInfo(op->getName()); 3164 if (!info) 3165 return false; 3166 3167 if (info->action == LegalizationAction::Dynamic) { 3168 Optional<bool> result = info->legalityFn(op); 3169 if (!result) 3170 return false; 3171 3172 return !(*result); 3173 } 3174 3175 return info->action == LegalizationAction::Illegal; 3176 } 3177 3178 static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks( 3179 ConversionTarget::DynamicLegalityCallbackFn oldCallback, 3180 ConversionTarget::DynamicLegalityCallbackFn newCallback) { 3181 if (!oldCallback) 3182 return newCallback; 3183 3184 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)]( 3185 Operation *op) -> Optional<bool> { 3186 if (Optional<bool> result = newCl(op)) 3187 return *result; 3188 3189 return oldCl(op); 3190 }; 3191 return chain; 3192 } 3193 3194 void ConversionTarget::setLegalityCallback( 3195 OperationName name, const DynamicLegalityCallbackFn &callback) { 3196 assert(callback && "expected valid legality callback"); 3197 auto infoIt = legalOperations.find(name); 3198 assert(infoIt != legalOperations.end() && 3199 infoIt->second.action == LegalizationAction::Dynamic && 3200 "expected operation to already be marked as dynamically legal"); 3201 infoIt->second.legalityFn = 3202 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback); 3203 } 3204 3205 void ConversionTarget::markOpRecursivelyLegal( 3206 OperationName name, const DynamicLegalityCallbackFn &callback) { 3207 auto infoIt = legalOperations.find(name); 3208 assert(infoIt != legalOperations.end() && 3209 infoIt->second.action != LegalizationAction::Illegal && 3210 "expected operation to already be marked as legal"); 3211 infoIt->second.isRecursivelyLegal = true; 3212 if (callback) 3213 opRecursiveLegalityFns[name] = composeLegalityCallbacks( 3214 std::move(opRecursiveLegalityFns[name]), callback); 3215 else 3216 opRecursiveLegalityFns.erase(name); 3217 } 3218 3219 void ConversionTarget::setLegalityCallback( 3220 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { 3221 assert(callback && "expected valid legality callback"); 3222 for (StringRef dialect : dialects) 3223 dialectLegalityFns[dialect] = composeLegalityCallbacks( 3224 std::move(dialectLegalityFns[dialect]), callback); 3225 } 3226 3227 void ConversionTarget::setLegalityCallback( 3228 const DynamicLegalityCallbackFn &callback) { 3229 assert(callback && "expected valid legality callback"); 3230 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback); 3231 } 3232 3233 auto ConversionTarget::getOpInfo(OperationName op) const 3234 -> Optional<LegalizationInfo> { 3235 // Check for info for this specific operation. 3236 auto it = legalOperations.find(op); 3237 if (it != legalOperations.end()) 3238 return it->second; 3239 // Check for info for the parent dialect. 3240 auto dialectIt = legalDialects.find(op.getDialectNamespace()); 3241 if (dialectIt != legalDialects.end()) { 3242 DynamicLegalityCallbackFn callback; 3243 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace()); 3244 if (dialectFn != dialectLegalityFns.end()) 3245 callback = dialectFn->second; 3246 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, 3247 callback}; 3248 } 3249 // Otherwise, check if we mark unknown operations as dynamic. 3250 if (unknownLegalityFn) 3251 return LegalizationInfo{LegalizationAction::Dynamic, 3252 /*isRecursivelyLegal=*/false, unknownLegalityFn}; 3253 return llvm::None; 3254 } 3255 3256 //===----------------------------------------------------------------------===// 3257 // Op Conversion Entry Points 3258 //===----------------------------------------------------------------------===// 3259 3260 //===----------------------------------------------------------------------===// 3261 // Partial Conversion 3262 3263 LogicalResult 3264 mlir::applyPartialConversion(ArrayRef<Operation *> ops, 3265 ConversionTarget &target, 3266 const FrozenRewritePatternSet &patterns, 3267 DenseSet<Operation *> *unconvertedOps) { 3268 OperationConverter opConverter(target, patterns, OpConversionMode::Partial, 3269 unconvertedOps); 3270 return opConverter.convertOperations(ops); 3271 } 3272 LogicalResult 3273 mlir::applyPartialConversion(Operation *op, ConversionTarget &target, 3274 const FrozenRewritePatternSet &patterns, 3275 DenseSet<Operation *> *unconvertedOps) { 3276 return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, 3277 unconvertedOps); 3278 } 3279 3280 //===----------------------------------------------------------------------===// 3281 // Full Conversion 3282 3283 LogicalResult 3284 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target, 3285 const FrozenRewritePatternSet &patterns) { 3286 OperationConverter opConverter(target, patterns, OpConversionMode::Full); 3287 return opConverter.convertOperations(ops); 3288 } 3289 LogicalResult 3290 mlir::applyFullConversion(Operation *op, ConversionTarget &target, 3291 const FrozenRewritePatternSet &patterns) { 3292 return applyFullConversion(llvm::makeArrayRef(op), target, patterns); 3293 } 3294 3295 //===----------------------------------------------------------------------===// 3296 // Analysis Conversion 3297 3298 LogicalResult 3299 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops, 3300 ConversionTarget &target, 3301 const FrozenRewritePatternSet &patterns, 3302 DenseSet<Operation *> &convertedOps, 3303 function_ref<void(Diagnostic &)> notifyCallback) { 3304 OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, 3305 &convertedOps); 3306 return opConverter.convertOperations(ops, notifyCallback); 3307 } 3308 LogicalResult 3309 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, 3310 const FrozenRewritePatternSet &patterns, 3311 DenseSet<Operation *> &convertedOps, 3312 function_ref<void(Diagnostic &)> notifyCallback) { 3313 return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, 3314 convertedOps, notifyCallback); 3315 } 3316