1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 static bool getFirstI32Result(Operation *op, Value &value) { 39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 40 return false; 41 value = op->getResult(0); 42 return true; 43 } 44 45 static Value bindNativeCodeCallResult(Value value) { return value; } 46 47 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 48 Value input2) { 49 return SmallVector<Value, 2>({input2, input1}); 50 } 51 52 // Test that natives calls are only called once during rewrites. 53 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 54 // This let us check the number of times OpM_Test was called by inspecting 55 // the returned value in the MLIR output. 56 static int64_t opMIncreasingValue = 314159265; 57 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 58 int64_t i = opMIncreasingValue++; 59 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 60 } 61 62 namespace { 63 #include "TestPatterns.inc" 64 } // end anonymous namespace 65 66 //===----------------------------------------------------------------------===// 67 // Test Reduce Pattern Interface 68 //===----------------------------------------------------------------------===// 69 70 void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) { 71 populateWithGenerated(patterns); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Canonicalizer Driver. 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 struct FoldingPattern : public RewritePattern { 80 public: 81 FoldingPattern(MLIRContext *context) 82 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 83 /*benefit=*/1, context) {} 84 85 LogicalResult matchAndRewrite(Operation *op, 86 PatternRewriter &rewriter) const override { 87 // Exercise OperationFolder API for a single-result operation that is folded 88 // upon construction. The operation being created through the folder has an 89 // in-place folder, and it should be still present in the output. 90 // Furthermore, the folder should not crash when attempting to recover the 91 // (unchanged) operation result. 92 OperationFolder folder(op->getContext()); 93 Value result = folder.create<TestOpInPlaceFold>( 94 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 95 rewriter.getI32IntegerAttr(0)); 96 assert(result); 97 rewriter.replaceOp(op, result); 98 return success(); 99 } 100 }; 101 102 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 103 StringRef getArgument() const final { return "test-patterns"; } 104 StringRef getDescription() const final { return "Run test dialect patterns"; } 105 void runOnFunction() override { 106 mlir::RewritePatternSet patterns(&getContext()); 107 populateWithGenerated(patterns); 108 109 // Verify named pattern is generated with expected name. 110 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext()); 111 112 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 113 } 114 }; 115 } // end anonymous namespace 116 117 //===----------------------------------------------------------------------===// 118 // ReturnType Driver. 119 //===----------------------------------------------------------------------===// 120 121 namespace { 122 // Generate ops for each instance where the type can be successfully inferred. 123 template <typename OpTy> 124 static void invokeCreateWithInferredReturnType(Operation *op) { 125 auto *context = op->getContext(); 126 auto fop = op->getParentOfType<FuncOp>(); 127 auto location = UnknownLoc::get(context); 128 OpBuilder b(op); 129 b.setInsertionPointAfter(op); 130 131 // Use permutations of 2 args as operands. 132 assert(fop.getNumArguments() >= 2); 133 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 134 for (int j = 0; j < e; ++j) { 135 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 136 SmallVector<Type, 2> inferredReturnTypes; 137 if (succeeded(OpTy::inferReturnTypes( 138 context, llvm::None, values, op->getAttrDictionary(), 139 op->getRegions(), inferredReturnTypes))) { 140 OperationState state(location, OpTy::getOperationName()); 141 // TODO: Expand to regions. 142 OpTy::build(b, state, values, op->getAttrs()); 143 (void)b.createOperation(state); 144 } 145 } 146 } 147 } 148 149 static void reifyReturnShape(Operation *op) { 150 OpBuilder b(op); 151 152 // Use permutations of 2 args as operands. 153 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 154 SmallVector<Value, 2> shapes; 155 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 156 !llvm::hasSingleElement(shapes)) 157 return; 158 for (auto it : llvm::enumerate(shapes)) { 159 op->emitRemark() << "value " << it.index() << ": " 160 << it.value().getDefiningOp(); 161 } 162 } 163 164 struct TestReturnTypeDriver 165 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 166 void getDependentDialects(DialectRegistry ®istry) const override { 167 registry.insert<tensor::TensorDialect>(); 168 } 169 StringRef getArgument() const final { return "test-return-type"; } 170 StringRef getDescription() const final { return "Run return type functions"; } 171 172 void runOnFunction() override { 173 if (getFunction().getName() == "testCreateFunctions") { 174 std::vector<Operation *> ops; 175 // Collect ops to avoid triggering on inserted ops. 176 for (auto &op : getFunction().getBody().front()) 177 ops.push_back(&op); 178 // Generate test patterns for each, but skip terminator. 179 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 180 // Test create method of each of the Op classes below. The resultant 181 // output would be in reverse order underneath `op` from which 182 // the attributes and regions are used. 183 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 184 invokeCreateWithInferredReturnType< 185 OpWithShapedTypeInferTypeInterfaceOp>(op); 186 }; 187 return; 188 } 189 if (getFunction().getName() == "testReifyFunctions") { 190 std::vector<Operation *> ops; 191 // Collect ops to avoid triggering on inserted ops. 192 for (auto &op : getFunction().getBody().front()) 193 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 194 ops.push_back(&op); 195 // Generate test patterns for each, but skip terminator. 196 for (auto *op : ops) 197 reifyReturnShape(op); 198 } 199 } 200 }; 201 } // end anonymous namespace 202 203 namespace { 204 struct TestDerivedAttributeDriver 205 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 206 StringRef getArgument() const final { return "test-derived-attr"; } 207 StringRef getDescription() const final { 208 return "Run test derived attributes"; 209 } 210 void runOnFunction() override; 211 }; 212 } // end anonymous namespace 213 214 void TestDerivedAttributeDriver::runOnFunction() { 215 getFunction().walk([](DerivedAttributeOpInterface dOp) { 216 auto dAttr = dOp.materializeDerivedAttributes(); 217 if (!dAttr) 218 return; 219 for (auto d : dAttr) 220 dOp.emitRemark() << d.first << " = " << d.second; 221 }); 222 } 223 224 //===----------------------------------------------------------------------===// 225 // Legalization Driver. 226 //===----------------------------------------------------------------------===// 227 228 namespace { 229 //===----------------------------------------------------------------------===// 230 // Region-Block Rewrite Testing 231 232 /// This pattern is a simple pattern that inlines the first region of a given 233 /// operation into the parent region. 234 struct TestRegionRewriteBlockMovement : public ConversionPattern { 235 TestRegionRewriteBlockMovement(MLIRContext *ctx) 236 : ConversionPattern("test.region", 1, ctx) {} 237 238 LogicalResult 239 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 240 ConversionPatternRewriter &rewriter) const final { 241 // Inline this region into the parent region. 242 auto &parentRegion = *op->getParentRegion(); 243 auto &opRegion = op->getRegion(0); 244 if (op->getAttr("legalizer.should_clone")) 245 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 246 else 247 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 248 249 if (op->getAttr("legalizer.erase_old_blocks")) { 250 while (!opRegion.empty()) 251 rewriter.eraseBlock(&opRegion.front()); 252 } 253 254 // Drop this operation. 255 rewriter.eraseOp(op); 256 return success(); 257 } 258 }; 259 /// This pattern is a simple pattern that generates a region containing an 260 /// illegal operation. 261 struct TestRegionRewriteUndo : public RewritePattern { 262 TestRegionRewriteUndo(MLIRContext *ctx) 263 : RewritePattern("test.region_builder", 1, ctx) {} 264 265 LogicalResult matchAndRewrite(Operation *op, 266 PatternRewriter &rewriter) const final { 267 // Create the region operation with an entry block containing arguments. 268 OperationState newRegion(op->getLoc(), "test.region"); 269 newRegion.addRegion(); 270 auto *regionOp = rewriter.createOperation(newRegion); 271 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 272 entryBlock->addArgument(rewriter.getIntegerType(64)); 273 274 // Add an explicitly illegal operation to ensure the conversion fails. 275 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 276 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 277 278 // Drop this operation. 279 rewriter.eraseOp(op); 280 return success(); 281 } 282 }; 283 /// A simple pattern that creates a block at the end of the parent region of the 284 /// matched operation. 285 struct TestCreateBlock : public RewritePattern { 286 TestCreateBlock(MLIRContext *ctx) 287 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 288 289 LogicalResult matchAndRewrite(Operation *op, 290 PatternRewriter &rewriter) const final { 291 Region ®ion = *op->getParentRegion(); 292 Type i32Type = rewriter.getIntegerType(32); 293 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 294 rewriter.create<TerminatorOp>(op->getLoc()); 295 rewriter.replaceOp(op, {}); 296 return success(); 297 } 298 }; 299 300 /// A simple pattern that creates a block containing an invalid operation in 301 /// order to trigger the block creation undo mechanism. 302 struct TestCreateIllegalBlock : public RewritePattern { 303 TestCreateIllegalBlock(MLIRContext *ctx) 304 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 305 306 LogicalResult matchAndRewrite(Operation *op, 307 PatternRewriter &rewriter) const final { 308 Region ®ion = *op->getParentRegion(); 309 Type i32Type = rewriter.getIntegerType(32); 310 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 311 // Create an illegal op to ensure the conversion fails. 312 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 313 rewriter.create<TerminatorOp>(op->getLoc()); 314 rewriter.replaceOp(op, {}); 315 return success(); 316 } 317 }; 318 319 /// A simple pattern that tests the undo mechanism when replacing the uses of a 320 /// block argument. 321 struct TestUndoBlockArgReplace : public ConversionPattern { 322 TestUndoBlockArgReplace(MLIRContext *ctx) 323 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 324 325 LogicalResult 326 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 327 ConversionPatternRewriter &rewriter) const final { 328 auto illegalOp = 329 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 330 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 331 illegalOp); 332 rewriter.updateRootInPlace(op, [] {}); 333 return success(); 334 } 335 }; 336 337 /// A rewrite pattern that tests the undo mechanism when erasing a block. 338 struct TestUndoBlockErase : public ConversionPattern { 339 TestUndoBlockErase(MLIRContext *ctx) 340 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 341 342 LogicalResult 343 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 344 ConversionPatternRewriter &rewriter) const final { 345 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 346 rewriter.setInsertionPointToStart(secondBlock); 347 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 348 rewriter.eraseBlock(secondBlock); 349 rewriter.updateRootInPlace(op, [] {}); 350 return success(); 351 } 352 }; 353 354 //===----------------------------------------------------------------------===// 355 // Type-Conversion Rewrite Testing 356 357 /// This patterns erases a region operation that has had a type conversion. 358 struct TestDropOpSignatureConversion : public ConversionPattern { 359 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 360 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 361 LogicalResult 362 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 363 ConversionPatternRewriter &rewriter) const override { 364 Region ®ion = op->getRegion(0); 365 Block *entry = ®ion.front(); 366 367 // Convert the original entry arguments. 368 TypeConverter &converter = *getTypeConverter(); 369 TypeConverter::SignatureConversion result(entry->getNumArguments()); 370 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 371 result)) || 372 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 373 return failure(); 374 375 // Convert the region signature and just drop the operation. 376 rewriter.eraseOp(op); 377 return success(); 378 } 379 }; 380 /// This pattern simply updates the operands of the given operation. 381 struct TestPassthroughInvalidOp : public ConversionPattern { 382 TestPassthroughInvalidOp(MLIRContext *ctx) 383 : ConversionPattern("test.invalid", 1, ctx) {} 384 LogicalResult 385 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 386 ConversionPatternRewriter &rewriter) const final { 387 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 388 llvm::None); 389 return success(); 390 } 391 }; 392 /// This pattern handles the case of a split return value. 393 struct TestSplitReturnType : public ConversionPattern { 394 TestSplitReturnType(MLIRContext *ctx) 395 : ConversionPattern("test.return", 1, ctx) {} 396 LogicalResult 397 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 398 ConversionPatternRewriter &rewriter) const final { 399 // Check for a return of F32. 400 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 401 return failure(); 402 403 // Check if the first operation is a cast operation, if it is we use the 404 // results directly. 405 auto *defOp = operands[0].getDefiningOp(); 406 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 407 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 408 return success(); 409 } 410 411 // Otherwise, fail to match. 412 return failure(); 413 } 414 }; 415 416 //===----------------------------------------------------------------------===// 417 // Multi-Level Type-Conversion Rewrite Testing 418 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 419 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 420 : ConversionPattern("test.type_producer", 1, ctx) {} 421 LogicalResult 422 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 423 ConversionPatternRewriter &rewriter) const final { 424 // If the type is I32, change the type to F32. 425 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 426 return failure(); 427 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 428 return success(); 429 } 430 }; 431 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 432 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 433 : ConversionPattern("test.type_producer", 1, ctx) {} 434 LogicalResult 435 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 436 ConversionPatternRewriter &rewriter) const final { 437 // If the type is F32, change the type to F64. 438 if (!Type(*op->result_type_begin()).isF32()) 439 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 440 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 441 return success(); 442 } 443 }; 444 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 445 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 446 : ConversionPattern("test.type_producer", 10, ctx) {} 447 LogicalResult 448 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 449 ConversionPatternRewriter &rewriter) const final { 450 // Always convert to B16, even though it is not a legal type. This tests 451 // that values are unmapped correctly. 452 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 453 return success(); 454 } 455 }; 456 struct TestUpdateConsumerType : public ConversionPattern { 457 TestUpdateConsumerType(MLIRContext *ctx) 458 : ConversionPattern("test.type_consumer", 1, ctx) {} 459 LogicalResult 460 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 461 ConversionPatternRewriter &rewriter) const final { 462 // Verify that the incoming operand has been successfully remapped to F64. 463 if (!operands[0].getType().isF64()) 464 return failure(); 465 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 466 return success(); 467 } 468 }; 469 470 //===----------------------------------------------------------------------===// 471 // Non-Root Replacement Rewrite Testing 472 /// This pattern generates an invalid operation, but replaces it before the 473 /// pattern is finished. This checks that we don't need to legalize the 474 /// temporary op. 475 struct TestNonRootReplacement : public RewritePattern { 476 TestNonRootReplacement(MLIRContext *ctx) 477 : RewritePattern("test.replace_non_root", 1, ctx) {} 478 479 LogicalResult matchAndRewrite(Operation *op, 480 PatternRewriter &rewriter) const final { 481 auto resultType = *op->result_type_begin(); 482 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 483 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 484 485 rewriter.replaceOp(illegalOp, {legalOp}); 486 rewriter.replaceOp(op, {illegalOp}); 487 return success(); 488 } 489 }; 490 491 //===----------------------------------------------------------------------===// 492 // Recursive Rewrite Testing 493 /// This pattern is applied to the same operation multiple times, but has a 494 /// bounded recursion. 495 struct TestBoundedRecursiveRewrite 496 : public OpRewritePattern<TestRecursiveRewriteOp> { 497 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 498 499 void initialize() { 500 // The conversion target handles bounding the recursion of this pattern. 501 setHasBoundedRewriteRecursion(); 502 } 503 504 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 505 PatternRewriter &rewriter) const final { 506 // Decrement the depth of the op in-place. 507 rewriter.updateRootInPlace(op, [&] { 508 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 509 }); 510 return success(); 511 } 512 }; 513 514 struct TestNestedOpCreationUndoRewrite 515 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 516 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 517 518 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 519 PatternRewriter &rewriter) const final { 520 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 521 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 522 return success(); 523 }; 524 }; 525 526 // This pattern matches `test.blackhole` and delete this op and its producer. 527 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 528 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 529 530 LogicalResult matchAndRewrite(BlackHoleOp op, 531 PatternRewriter &rewriter) const final { 532 Operation *producer = op.getOperand().getDefiningOp(); 533 // Always erase the user before the producer, the framework should handle 534 // this correctly. 535 rewriter.eraseOp(op); 536 rewriter.eraseOp(producer); 537 return success(); 538 }; 539 }; 540 } // namespace 541 542 namespace { 543 struct TestTypeConverter : public TypeConverter { 544 using TypeConverter::TypeConverter; 545 TestTypeConverter() { 546 addConversion(convertType); 547 addArgumentMaterialization(materializeCast); 548 addSourceMaterialization(materializeCast); 549 550 /// Materialize the cast for one-to-one conversion from i64 to f64. 551 const auto materializeOneToOneCast = 552 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 553 Location loc) -> Optional<Value> { 554 if (resultType.getWidth() == 42 && inputs.size() == 1) 555 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 556 return llvm::None; 557 }; 558 addArgumentMaterialization(materializeOneToOneCast); 559 } 560 561 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 562 // Drop I16 types. 563 if (t.isSignlessInteger(16)) 564 return success(); 565 566 // Convert I64 to F64. 567 if (t.isSignlessInteger(64)) { 568 results.push_back(FloatType::getF64(t.getContext())); 569 return success(); 570 } 571 572 // Convert I42 to I43. 573 if (t.isInteger(42)) { 574 results.push_back(IntegerType::get(t.getContext(), 43)); 575 return success(); 576 } 577 578 // Split F32 into F16,F16. 579 if (t.isF32()) { 580 results.assign(2, FloatType::getF16(t.getContext())); 581 return success(); 582 } 583 584 // Otherwise, convert the type directly. 585 results.push_back(t); 586 return success(); 587 } 588 589 /// Hook for materializing a conversion. This is necessary because we generate 590 /// 1->N type mappings. 591 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 592 ValueRange inputs, Location loc) { 593 if (inputs.size() == 1) 594 return inputs[0]; 595 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 596 } 597 }; 598 599 struct TestLegalizePatternDriver 600 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 601 StringRef getArgument() const final { return "test-legalize-patterns"; } 602 StringRef getDescription() const final { 603 return "Run test dialect legalization patterns"; 604 } 605 /// The mode of conversion to use with the driver. 606 enum class ConversionMode { Analysis, Full, Partial }; 607 608 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 609 610 void runOnOperation() override { 611 TestTypeConverter converter; 612 mlir::RewritePatternSet patterns(&getContext()); 613 populateWithGenerated(patterns); 614 patterns 615 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 616 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 617 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 618 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 619 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 620 TestNonRootReplacement, TestBoundedRecursiveRewrite, 621 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( 622 &getContext()); 623 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 624 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 625 mlir::populateCallOpTypeConversionPattern(patterns, converter); 626 627 // Define the conversion target used for the test. 628 ConversionTarget target(getContext()); 629 target.addLegalOp<ModuleOp>(); 630 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 631 TerminatorOp>(); 632 target 633 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 634 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 635 // Don't allow F32 operands. 636 return llvm::none_of(op.getOperandTypes(), 637 [](Type type) { return type.isF32(); }); 638 }); 639 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 640 return converter.isSignatureLegal(op.getType()) && 641 converter.isLegal(&op.getBody()); 642 }); 643 644 // Expect the type_producer/type_consumer operations to only operate on f64. 645 target.addDynamicallyLegalOp<TestTypeProducerOp>( 646 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 647 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 648 return op.getOperand().getType().isF64(); 649 }); 650 651 // Check support for marking certain operations as recursively legal. 652 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 653 return static_cast<bool>( 654 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 655 }); 656 657 // Mark the bound recursion operation as dynamically legal. 658 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 659 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 660 661 // Handle a partial conversion. 662 if (mode == ConversionMode::Partial) { 663 DenseSet<Operation *> unlegalizedOps; 664 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 665 &unlegalizedOps); 666 // Emit remarks for each legalizable operation. 667 for (auto *op : unlegalizedOps) 668 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 669 return; 670 } 671 672 // Handle a full conversion. 673 if (mode == ConversionMode::Full) { 674 // Check support for marking unknown operations as dynamically legal. 675 target.markUnknownOpDynamicallyLegal([](Operation *op) { 676 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 677 }); 678 679 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 680 return; 681 } 682 683 // Otherwise, handle an analysis conversion. 684 assert(mode == ConversionMode::Analysis); 685 686 // Analyze the convertible operations. 687 DenseSet<Operation *> legalizedOps; 688 if (failed(applyAnalysisConversion(getOperation(), target, 689 std::move(patterns), legalizedOps))) 690 return signalPassFailure(); 691 692 // Emit remarks for each legalizable operation. 693 for (auto *op : legalizedOps) 694 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 695 } 696 697 /// The mode of conversion to use. 698 ConversionMode mode; 699 }; 700 } // end anonymous namespace 701 702 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 703 legalizerConversionMode( 704 "test-legalize-mode", 705 llvm::cl::desc("The legalization mode to use with the test driver"), 706 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 707 llvm::cl::values( 708 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 709 "analysis", "Perform an analysis conversion"), 710 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 711 "Perform a full conversion"), 712 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 713 "partial", "Perform a partial conversion"))); 714 715 //===----------------------------------------------------------------------===// 716 // ConversionPatternRewriter::getRemappedValue testing. This method is used 717 // to get the remapped value of an original value that was replaced using 718 // ConversionPatternRewriter. 719 namespace { 720 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 721 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 722 /// operand twice. 723 /// 724 /// Example: 725 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 726 /// is replaced with: 727 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 728 struct OneVResOneVOperandOp1Converter 729 : public OpConversionPattern<OneVResOneVOperandOp1> { 730 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 731 732 LogicalResult 733 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 734 ConversionPatternRewriter &rewriter) const override { 735 auto origOps = op.getOperands(); 736 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 737 "One operand expected"); 738 Value origOp = *origOps.begin(); 739 SmallVector<Value, 2> remappedOperands; 740 // Replicate the remapped original operand twice. Note that we don't used 741 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 742 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 743 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 744 745 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 746 remappedOperands); 747 return success(); 748 } 749 }; 750 751 struct TestRemappedValue 752 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 753 StringRef getArgument() const final { return "test-remapped-value"; } 754 StringRef getDescription() const final { 755 return "Test public remapped value mechanism in ConversionPatternRewriter"; 756 } 757 void runOnFunction() override { 758 mlir::RewritePatternSet patterns(&getContext()); 759 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 760 761 mlir::ConversionTarget target(getContext()); 762 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 763 // We make OneVResOneVOperandOp1 legal only when it has more that one 764 // operand. This will trigger the conversion that will replace one-operand 765 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 766 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 767 [](Operation *op) -> bool { 768 return std::distance(op->operand_begin(), op->operand_end()) > 1; 769 }); 770 771 if (failed(mlir::applyFullConversion(getFunction(), target, 772 std::move(patterns)))) { 773 signalPassFailure(); 774 } 775 } 776 }; 777 } // end anonymous namespace 778 779 //===----------------------------------------------------------------------===// 780 // Test patterns without a specific root operation kind 781 //===----------------------------------------------------------------------===// 782 783 namespace { 784 /// This pattern matches and removes any operation in the test dialect. 785 struct RemoveTestDialectOps : public RewritePattern { 786 RemoveTestDialectOps(MLIRContext *context) 787 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 788 789 LogicalResult matchAndRewrite(Operation *op, 790 PatternRewriter &rewriter) const override { 791 if (!isa<TestDialect>(op->getDialect())) 792 return failure(); 793 rewriter.eraseOp(op); 794 return success(); 795 } 796 }; 797 798 struct TestUnknownRootOpDriver 799 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 800 StringRef getArgument() const final { 801 return "test-legalize-unknown-root-patterns"; 802 } 803 StringRef getDescription() const final { 804 return "Test public remapped value mechanism in ConversionPatternRewriter"; 805 } 806 void runOnFunction() override { 807 mlir::RewritePatternSet patterns(&getContext()); 808 patterns.add<RemoveTestDialectOps>(&getContext()); 809 810 mlir::ConversionTarget target(getContext()); 811 target.addIllegalDialect<TestDialect>(); 812 if (failed( 813 applyPartialConversion(getFunction(), target, std::move(patterns)))) 814 signalPassFailure(); 815 } 816 }; 817 } // end anonymous namespace 818 819 //===----------------------------------------------------------------------===// 820 // Test type conversions 821 //===----------------------------------------------------------------------===// 822 823 namespace { 824 struct TestTypeConversionProducer 825 : public OpConversionPattern<TestTypeProducerOp> { 826 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 827 LogicalResult 828 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 829 ConversionPatternRewriter &rewriter) const final { 830 Type resultType = op.getType(); 831 if (resultType.isa<FloatType>()) 832 resultType = rewriter.getF64Type(); 833 else if (resultType.isInteger(16)) 834 resultType = rewriter.getIntegerType(64); 835 else 836 return failure(); 837 838 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 839 return success(); 840 } 841 }; 842 843 /// Call signature conversion and then fail the rewrite to trigger the undo 844 /// mechanism. 845 struct TestSignatureConversionUndo 846 : public OpConversionPattern<TestSignatureConversionUndoOp> { 847 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 848 849 LogicalResult 850 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 851 ConversionPatternRewriter &rewriter) const final { 852 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 853 return failure(); 854 } 855 }; 856 857 /// Just forward the operands to the root op. This is essentially a no-op 858 /// pattern that is used to trigger target materialization. 859 struct TestTypeConsumerForward 860 : public OpConversionPattern<TestTypeConsumerOp> { 861 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 862 863 LogicalResult 864 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 865 ConversionPatternRewriter &rewriter) const final { 866 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 867 return success(); 868 } 869 }; 870 871 struct TestTypeConversionAnotherProducer 872 : public OpRewritePattern<TestAnotherTypeProducerOp> { 873 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 874 875 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 876 PatternRewriter &rewriter) const final { 877 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 878 return success(); 879 } 880 }; 881 882 struct TestTypeConversionDriver 883 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 884 void getDependentDialects(DialectRegistry ®istry) const override { 885 registry.insert<TestDialect>(); 886 } 887 StringRef getArgument() const final { 888 return "test-legalize-type-conversion"; 889 } 890 StringRef getDescription() const final { 891 return "Test various type conversion functionalities in DialectConversion"; 892 } 893 894 void runOnOperation() override { 895 // Initialize the type converter. 896 TypeConverter converter; 897 898 /// Add the legal set of type conversions. 899 converter.addConversion([](Type type) -> Type { 900 // Treat F64 as legal. 901 if (type.isF64()) 902 return type; 903 // Allow converting BF16/F16/F32 to F64. 904 if (type.isBF16() || type.isF16() || type.isF32()) 905 return FloatType::getF64(type.getContext()); 906 // Otherwise, the type is illegal. 907 return nullptr; 908 }); 909 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 910 // Drop all integer types. 911 return success(); 912 }); 913 914 /// Add the legal set of type materializations. 915 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 916 ValueRange inputs, 917 Location loc) -> Value { 918 // Allow casting from F64 back to F32. 919 if (!resultType.isF16() && inputs.size() == 1 && 920 inputs[0].getType().isF64()) 921 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 922 // Allow producing an i32 or i64 from nothing. 923 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 924 inputs.empty()) 925 return builder.create<TestTypeProducerOp>(loc, resultType); 926 // Allow producing an i64 from an integer. 927 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 928 inputs[0].getType().isa<IntegerType>()) 929 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 930 // Otherwise, fail. 931 return nullptr; 932 }); 933 934 // Initialize the conversion target. 935 mlir::ConversionTarget target(getContext()); 936 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 937 return op.getType().isF64() || op.getType().isInteger(64); 938 }); 939 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 940 return converter.isSignatureLegal(op.getType()) && 941 converter.isLegal(&op.getBody()); 942 }); 943 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 944 // Allow casts from F64 to F32. 945 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 946 }); 947 948 // Initialize the set of rewrite patterns. 949 RewritePatternSet patterns(&getContext()); 950 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 951 TestSignatureConversionUndo>(converter, &getContext()); 952 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 953 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 954 955 if (failed(applyPartialConversion(getOperation(), target, 956 std::move(patterns)))) 957 signalPassFailure(); 958 } 959 }; 960 } // end anonymous namespace 961 962 //===----------------------------------------------------------------------===// 963 // Test Block Merging 964 //===----------------------------------------------------------------------===// 965 966 namespace { 967 /// A rewriter pattern that tests that blocks can be merged. 968 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 969 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 970 971 LogicalResult 972 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 973 ConversionPatternRewriter &rewriter) const final { 974 Block &firstBlock = op.body().front(); 975 Operation *branchOp = firstBlock.getTerminator(); 976 Block *secondBlock = &*(std::next(op.body().begin())); 977 auto succOperands = branchOp->getOperands(); 978 SmallVector<Value, 2> replacements(succOperands); 979 rewriter.eraseOp(branchOp); 980 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 981 rewriter.updateRootInPlace(op, [] {}); 982 return success(); 983 } 984 }; 985 986 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 987 struct TestUndoBlocksMerge : public ConversionPattern { 988 TestUndoBlocksMerge(MLIRContext *ctx) 989 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 990 LogicalResult 991 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 992 ConversionPatternRewriter &rewriter) const final { 993 Block &firstBlock = op->getRegion(0).front(); 994 Operation *branchOp = firstBlock.getTerminator(); 995 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 996 rewriter.setInsertionPointToStart(secondBlock); 997 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 998 auto succOperands = branchOp->getOperands(); 999 SmallVector<Value, 2> replacements(succOperands); 1000 rewriter.eraseOp(branchOp); 1001 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1002 rewriter.updateRootInPlace(op, [] {}); 1003 return success(); 1004 } 1005 }; 1006 1007 /// A rewrite mechanism to inline the body of the op into its parent, when both 1008 /// ops can have a single block. 1009 struct TestMergeSingleBlockOps 1010 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1011 using OpConversionPattern< 1012 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1013 1014 LogicalResult 1015 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 1016 ConversionPatternRewriter &rewriter) const final { 1017 SingleBlockImplicitTerminatorOp parentOp = 1018 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1019 if (!parentOp) 1020 return failure(); 1021 Block &innerBlock = op.region().front(); 1022 TerminatorOp innerTerminator = 1023 cast<TerminatorOp>(innerBlock.getTerminator()); 1024 rewriter.mergeBlockBefore(&innerBlock, op); 1025 rewriter.eraseOp(innerTerminator); 1026 rewriter.eraseOp(op); 1027 rewriter.updateRootInPlace(op, [] {}); 1028 return success(); 1029 } 1030 }; 1031 1032 struct TestMergeBlocksPatternDriver 1033 : public PassWrapper<TestMergeBlocksPatternDriver, 1034 OperationPass<ModuleOp>> { 1035 StringRef getArgument() const final { return "test-merge-blocks"; } 1036 StringRef getDescription() const final { 1037 return "Test Merging operation in ConversionPatternRewriter"; 1038 } 1039 void runOnOperation() override { 1040 MLIRContext *context = &getContext(); 1041 mlir::RewritePatternSet patterns(context); 1042 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1043 context); 1044 ConversionTarget target(*context); 1045 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1046 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1047 target.addIllegalOp<ILLegalOpF>(); 1048 1049 /// Expect the op to have a single block after legalization. 1050 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1051 [&](TestMergeBlocksOp op) -> bool { 1052 return llvm::hasSingleElement(op.body()); 1053 }); 1054 1055 /// Only allow `test.br` within test.merge_blocks op. 1056 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1057 return op->getParentOfType<TestMergeBlocksOp>(); 1058 }); 1059 1060 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1061 /// inlined. 1062 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1063 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1064 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1065 }); 1066 1067 DenseSet<Operation *> unlegalizedOps; 1068 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1069 &unlegalizedOps); 1070 for (auto *op : unlegalizedOps) 1071 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1072 } 1073 }; 1074 } // namespace 1075 1076 //===----------------------------------------------------------------------===// 1077 // Test Selective Replacement 1078 //===----------------------------------------------------------------------===// 1079 1080 namespace { 1081 /// A rewrite mechanism to inline the body of the op into its parent, when both 1082 /// ops can have a single block. 1083 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1084 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1085 1086 LogicalResult matchAndRewrite(TestCastOp op, 1087 PatternRewriter &rewriter) const final { 1088 if (op.getNumOperands() != 2) 1089 return failure(); 1090 OperandRange operands = op.getOperands(); 1091 1092 // Replace non-terminator uses with the first operand. 1093 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1094 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1095 }); 1096 // Replace everything else with the second operand if the operation isn't 1097 // dead. 1098 rewriter.replaceOp(op, op.getOperand(1)); 1099 return success(); 1100 } 1101 }; 1102 1103 struct TestSelectiveReplacementPatternDriver 1104 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1105 OperationPass<>> { 1106 StringRef getArgument() const final { 1107 return "test-pattern-selective-replacement"; 1108 } 1109 StringRef getDescription() const final { 1110 return "Test selective replacement in the PatternRewriter"; 1111 } 1112 void runOnOperation() override { 1113 MLIRContext *context = &getContext(); 1114 mlir::RewritePatternSet patterns(context); 1115 patterns.add<TestSelectiveOpReplacementPattern>(context); 1116 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1117 std::move(patterns)); 1118 } 1119 }; 1120 } // namespace 1121 1122 //===----------------------------------------------------------------------===// 1123 // PassRegistration 1124 //===----------------------------------------------------------------------===// 1125 1126 namespace mlir { 1127 namespace test { 1128 void registerPatternsTestPass() { 1129 PassRegistration<TestReturnTypeDriver>(); 1130 1131 PassRegistration<TestDerivedAttributeDriver>(); 1132 1133 PassRegistration<TestPatternDriver>(); 1134 1135 PassRegistration<TestLegalizePatternDriver>([] { 1136 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1137 }); 1138 1139 PassRegistration<TestRemappedValue>(); 1140 1141 PassRegistration<TestUnknownRootOpDriver>(); 1142 1143 PassRegistration<TestTypeConversionDriver>(); 1144 1145 PassRegistration<TestMergeBlocksPatternDriver>(); 1146 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1147 } 1148 } // namespace test 1149 } // namespace mlir 1150