1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 static bool getFirstI32Result(Operation *op, Value &value) { 39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 40 return false; 41 value = op->getResult(0); 42 return true; 43 } 44 45 static Value bindNativeCodeCallResult(Value value) { return value; } 46 47 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 48 Value input2) { 49 return SmallVector<Value, 2>({input2, input1}); 50 } 51 52 // Test that natives calls are only called once during rewrites. 53 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 54 // This let us check the number of times OpM_Test was called by inspecting 55 // the returned value in the MLIR output. 56 static int64_t opMIncreasingValue = 314159265; 57 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 58 int64_t i = opMIncreasingValue++; 59 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 60 } 61 62 namespace { 63 #include "TestPatterns.inc" 64 } // end anonymous namespace 65 66 //===----------------------------------------------------------------------===// 67 // Test Reduce Pattern Interface 68 //===----------------------------------------------------------------------===// 69 70 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 71 populateWithGenerated(patterns); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Canonicalizer Driver. 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 struct FoldingPattern : public RewritePattern { 80 public: 81 FoldingPattern(MLIRContext *context) 82 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 83 /*benefit=*/1, context) {} 84 85 LogicalResult matchAndRewrite(Operation *op, 86 PatternRewriter &rewriter) const override { 87 // Exercise OperationFolder API for a single-result operation that is folded 88 // upon construction. The operation being created through the folder has an 89 // in-place folder, and it should be still present in the output. 90 // Furthermore, the folder should not crash when attempting to recover the 91 // (unchanged) operation result. 92 OperationFolder folder(op->getContext()); 93 Value result = folder.create<TestOpInPlaceFold>( 94 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 95 rewriter.getI32IntegerAttr(0)); 96 assert(result); 97 rewriter.replaceOp(op, result); 98 return success(); 99 } 100 }; 101 102 /// This pattern creates a foldable operation at the entry point of the block. 103 /// This tests the situation where the operation folder will need to replace an 104 /// operation with a previously created constant that does not initially 105 /// dominate the operation to replace. 106 struct FolderInsertBeforePreviouslyFoldedConstantPattern 107 : public OpRewritePattern<TestCastOp> { 108 public: 109 using OpRewritePattern<TestCastOp>::OpRewritePattern; 110 111 LogicalResult matchAndRewrite(TestCastOp op, 112 PatternRewriter &rewriter) const override { 113 if (!op->hasAttr("test_fold_before_previously_folded_op")) 114 return failure(); 115 rewriter.setInsertionPointToStart(op->getBlock()); 116 117 auto constOp = 118 rewriter.create<ConstantOp>(op.getLoc(), rewriter.getBoolAttr(true)); 119 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 120 Value(constOp)); 121 return success(); 122 } 123 }; 124 125 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 126 StringRef getArgument() const final { return "test-patterns"; } 127 StringRef getDescription() const final { return "Run test dialect patterns"; } 128 void runOnFunction() override { 129 mlir::RewritePatternSet patterns(&getContext()); 130 populateWithGenerated(patterns); 131 132 // Verify named pattern is generated with expected name. 133 patterns.add<FoldingPattern, TestNamedPatternRule, 134 FolderInsertBeforePreviouslyFoldedConstantPattern>( 135 &getContext()); 136 137 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 138 } 139 }; 140 } // end anonymous namespace 141 142 //===----------------------------------------------------------------------===// 143 // ReturnType Driver. 144 //===----------------------------------------------------------------------===// 145 146 namespace { 147 // Generate ops for each instance where the type can be successfully inferred. 148 template <typename OpTy> 149 static void invokeCreateWithInferredReturnType(Operation *op) { 150 auto *context = op->getContext(); 151 auto fop = op->getParentOfType<FuncOp>(); 152 auto location = UnknownLoc::get(context); 153 OpBuilder b(op); 154 b.setInsertionPointAfter(op); 155 156 // Use permutations of 2 args as operands. 157 assert(fop.getNumArguments() >= 2); 158 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 159 for (int j = 0; j < e; ++j) { 160 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 161 SmallVector<Type, 2> inferredReturnTypes; 162 if (succeeded(OpTy::inferReturnTypes( 163 context, llvm::None, values, op->getAttrDictionary(), 164 op->getRegions(), inferredReturnTypes))) { 165 OperationState state(location, OpTy::getOperationName()); 166 // TODO: Expand to regions. 167 OpTy::build(b, state, values, op->getAttrs()); 168 (void)b.createOperation(state); 169 } 170 } 171 } 172 } 173 174 static void reifyReturnShape(Operation *op) { 175 OpBuilder b(op); 176 177 // Use permutations of 2 args as operands. 178 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 179 SmallVector<Value, 2> shapes; 180 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 181 !llvm::hasSingleElement(shapes)) 182 return; 183 for (auto it : llvm::enumerate(shapes)) { 184 op->emitRemark() << "value " << it.index() << ": " 185 << it.value().getDefiningOp(); 186 } 187 } 188 189 struct TestReturnTypeDriver 190 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 191 void getDependentDialects(DialectRegistry ®istry) const override { 192 registry.insert<tensor::TensorDialect>(); 193 } 194 StringRef getArgument() const final { return "test-return-type"; } 195 StringRef getDescription() const final { return "Run return type functions"; } 196 197 void runOnFunction() override { 198 if (getFunction().getName() == "testCreateFunctions") { 199 std::vector<Operation *> ops; 200 // Collect ops to avoid triggering on inserted ops. 201 for (auto &op : getFunction().getBody().front()) 202 ops.push_back(&op); 203 // Generate test patterns for each, but skip terminator. 204 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 205 // Test create method of each of the Op classes below. The resultant 206 // output would be in reverse order underneath `op` from which 207 // the attributes and regions are used. 208 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 209 invokeCreateWithInferredReturnType< 210 OpWithShapedTypeInferTypeInterfaceOp>(op); 211 }; 212 return; 213 } 214 if (getFunction().getName() == "testReifyFunctions") { 215 std::vector<Operation *> ops; 216 // Collect ops to avoid triggering on inserted ops. 217 for (auto &op : getFunction().getBody().front()) 218 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 219 ops.push_back(&op); 220 // Generate test patterns for each, but skip terminator. 221 for (auto *op : ops) 222 reifyReturnShape(op); 223 } 224 } 225 }; 226 } // end anonymous namespace 227 228 namespace { 229 struct TestDerivedAttributeDriver 230 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 231 StringRef getArgument() const final { return "test-derived-attr"; } 232 StringRef getDescription() const final { 233 return "Run test derived attributes"; 234 } 235 void runOnFunction() override; 236 }; 237 } // end anonymous namespace 238 239 void TestDerivedAttributeDriver::runOnFunction() { 240 getFunction().walk([](DerivedAttributeOpInterface dOp) { 241 auto dAttr = dOp.materializeDerivedAttributes(); 242 if (!dAttr) 243 return; 244 for (auto d : dAttr) 245 dOp.emitRemark() << d.first << " = " << d.second; 246 }); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // Legalization Driver. 251 //===----------------------------------------------------------------------===// 252 253 namespace { 254 //===----------------------------------------------------------------------===// 255 // Region-Block Rewrite Testing 256 257 /// This pattern is a simple pattern that inlines the first region of a given 258 /// operation into the parent region. 259 struct TestRegionRewriteBlockMovement : public ConversionPattern { 260 TestRegionRewriteBlockMovement(MLIRContext *ctx) 261 : ConversionPattern("test.region", 1, ctx) {} 262 263 LogicalResult 264 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 265 ConversionPatternRewriter &rewriter) const final { 266 // Inline this region into the parent region. 267 auto &parentRegion = *op->getParentRegion(); 268 auto &opRegion = op->getRegion(0); 269 if (op->getAttr("legalizer.should_clone")) 270 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 271 else 272 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 273 274 if (op->getAttr("legalizer.erase_old_blocks")) { 275 while (!opRegion.empty()) 276 rewriter.eraseBlock(&opRegion.front()); 277 } 278 279 // Drop this operation. 280 rewriter.eraseOp(op); 281 return success(); 282 } 283 }; 284 /// This pattern is a simple pattern that generates a region containing an 285 /// illegal operation. 286 struct TestRegionRewriteUndo : public RewritePattern { 287 TestRegionRewriteUndo(MLIRContext *ctx) 288 : RewritePattern("test.region_builder", 1, ctx) {} 289 290 LogicalResult matchAndRewrite(Operation *op, 291 PatternRewriter &rewriter) const final { 292 // Create the region operation with an entry block containing arguments. 293 OperationState newRegion(op->getLoc(), "test.region"); 294 newRegion.addRegion(); 295 auto *regionOp = rewriter.createOperation(newRegion); 296 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 297 entryBlock->addArgument(rewriter.getIntegerType(64)); 298 299 // Add an explicitly illegal operation to ensure the conversion fails. 300 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 301 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 302 303 // Drop this operation. 304 rewriter.eraseOp(op); 305 return success(); 306 } 307 }; 308 /// A simple pattern that creates a block at the end of the parent region of the 309 /// matched operation. 310 struct TestCreateBlock : public RewritePattern { 311 TestCreateBlock(MLIRContext *ctx) 312 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 313 314 LogicalResult matchAndRewrite(Operation *op, 315 PatternRewriter &rewriter) const final { 316 Region ®ion = *op->getParentRegion(); 317 Type i32Type = rewriter.getIntegerType(32); 318 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 319 rewriter.create<TerminatorOp>(op->getLoc()); 320 rewriter.replaceOp(op, {}); 321 return success(); 322 } 323 }; 324 325 /// A simple pattern that creates a block containing an invalid operation in 326 /// order to trigger the block creation undo mechanism. 327 struct TestCreateIllegalBlock : public RewritePattern { 328 TestCreateIllegalBlock(MLIRContext *ctx) 329 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 330 331 LogicalResult matchAndRewrite(Operation *op, 332 PatternRewriter &rewriter) const final { 333 Region ®ion = *op->getParentRegion(); 334 Type i32Type = rewriter.getIntegerType(32); 335 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 336 // Create an illegal op to ensure the conversion fails. 337 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 338 rewriter.create<TerminatorOp>(op->getLoc()); 339 rewriter.replaceOp(op, {}); 340 return success(); 341 } 342 }; 343 344 /// A simple pattern that tests the undo mechanism when replacing the uses of a 345 /// block argument. 346 struct TestUndoBlockArgReplace : public ConversionPattern { 347 TestUndoBlockArgReplace(MLIRContext *ctx) 348 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 349 350 LogicalResult 351 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 352 ConversionPatternRewriter &rewriter) const final { 353 auto illegalOp = 354 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 355 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 356 illegalOp); 357 rewriter.updateRootInPlace(op, [] {}); 358 return success(); 359 } 360 }; 361 362 /// A rewrite pattern that tests the undo mechanism when erasing a block. 363 struct TestUndoBlockErase : public ConversionPattern { 364 TestUndoBlockErase(MLIRContext *ctx) 365 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 366 367 LogicalResult 368 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 369 ConversionPatternRewriter &rewriter) const final { 370 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 371 rewriter.setInsertionPointToStart(secondBlock); 372 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 373 rewriter.eraseBlock(secondBlock); 374 rewriter.updateRootInPlace(op, [] {}); 375 return success(); 376 } 377 }; 378 379 //===----------------------------------------------------------------------===// 380 // Type-Conversion Rewrite Testing 381 382 /// This patterns erases a region operation that has had a type conversion. 383 struct TestDropOpSignatureConversion : public ConversionPattern { 384 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 385 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 386 LogicalResult 387 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 388 ConversionPatternRewriter &rewriter) const override { 389 Region ®ion = op->getRegion(0); 390 Block *entry = ®ion.front(); 391 392 // Convert the original entry arguments. 393 TypeConverter &converter = *getTypeConverter(); 394 TypeConverter::SignatureConversion result(entry->getNumArguments()); 395 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 396 result)) || 397 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 398 return failure(); 399 400 // Convert the region signature and just drop the operation. 401 rewriter.eraseOp(op); 402 return success(); 403 } 404 }; 405 /// This pattern simply updates the operands of the given operation. 406 struct TestPassthroughInvalidOp : public ConversionPattern { 407 TestPassthroughInvalidOp(MLIRContext *ctx) 408 : ConversionPattern("test.invalid", 1, ctx) {} 409 LogicalResult 410 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 411 ConversionPatternRewriter &rewriter) const final { 412 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 413 llvm::None); 414 return success(); 415 } 416 }; 417 /// This pattern handles the case of a split return value. 418 struct TestSplitReturnType : public ConversionPattern { 419 TestSplitReturnType(MLIRContext *ctx) 420 : ConversionPattern("test.return", 1, ctx) {} 421 LogicalResult 422 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 423 ConversionPatternRewriter &rewriter) const final { 424 // Check for a return of F32. 425 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 426 return failure(); 427 428 // Check if the first operation is a cast operation, if it is we use the 429 // results directly. 430 auto *defOp = operands[0].getDefiningOp(); 431 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 432 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 433 return success(); 434 } 435 436 // Otherwise, fail to match. 437 return failure(); 438 } 439 }; 440 441 //===----------------------------------------------------------------------===// 442 // Multi-Level Type-Conversion Rewrite Testing 443 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 444 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 445 : ConversionPattern("test.type_producer", 1, ctx) {} 446 LogicalResult 447 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 448 ConversionPatternRewriter &rewriter) const final { 449 // If the type is I32, change the type to F32. 450 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 451 return failure(); 452 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 453 return success(); 454 } 455 }; 456 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 457 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 458 : ConversionPattern("test.type_producer", 1, ctx) {} 459 LogicalResult 460 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 461 ConversionPatternRewriter &rewriter) const final { 462 // If the type is F32, change the type to F64. 463 if (!Type(*op->result_type_begin()).isF32()) 464 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 465 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 466 return success(); 467 } 468 }; 469 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 470 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 471 : ConversionPattern("test.type_producer", 10, ctx) {} 472 LogicalResult 473 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 474 ConversionPatternRewriter &rewriter) const final { 475 // Always convert to B16, even though it is not a legal type. This tests 476 // that values are unmapped correctly. 477 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 478 return success(); 479 } 480 }; 481 struct TestUpdateConsumerType : public ConversionPattern { 482 TestUpdateConsumerType(MLIRContext *ctx) 483 : ConversionPattern("test.type_consumer", 1, ctx) {} 484 LogicalResult 485 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 486 ConversionPatternRewriter &rewriter) const final { 487 // Verify that the incoming operand has been successfully remapped to F64. 488 if (!operands[0].getType().isF64()) 489 return failure(); 490 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 491 return success(); 492 } 493 }; 494 495 //===----------------------------------------------------------------------===// 496 // Non-Root Replacement Rewrite Testing 497 /// This pattern generates an invalid operation, but replaces it before the 498 /// pattern is finished. This checks that we don't need to legalize the 499 /// temporary op. 500 struct TestNonRootReplacement : public RewritePattern { 501 TestNonRootReplacement(MLIRContext *ctx) 502 : RewritePattern("test.replace_non_root", 1, ctx) {} 503 504 LogicalResult matchAndRewrite(Operation *op, 505 PatternRewriter &rewriter) const final { 506 auto resultType = *op->result_type_begin(); 507 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 508 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 509 510 rewriter.replaceOp(illegalOp, {legalOp}); 511 rewriter.replaceOp(op, {illegalOp}); 512 return success(); 513 } 514 }; 515 516 //===----------------------------------------------------------------------===// 517 // Recursive Rewrite Testing 518 /// This pattern is applied to the same operation multiple times, but has a 519 /// bounded recursion. 520 struct TestBoundedRecursiveRewrite 521 : public OpRewritePattern<TestRecursiveRewriteOp> { 522 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 523 524 void initialize() { 525 // The conversion target handles bounding the recursion of this pattern. 526 setHasBoundedRewriteRecursion(); 527 } 528 529 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 530 PatternRewriter &rewriter) const final { 531 // Decrement the depth of the op in-place. 532 rewriter.updateRootInPlace(op, [&] { 533 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 534 }); 535 return success(); 536 } 537 }; 538 539 struct TestNestedOpCreationUndoRewrite 540 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 541 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 542 543 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 544 PatternRewriter &rewriter) const final { 545 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 546 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 547 return success(); 548 }; 549 }; 550 551 // This pattern matches `test.blackhole` and delete this op and its producer. 552 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 553 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 554 555 LogicalResult matchAndRewrite(BlackHoleOp op, 556 PatternRewriter &rewriter) const final { 557 Operation *producer = op.getOperand().getDefiningOp(); 558 // Always erase the user before the producer, the framework should handle 559 // this correctly. 560 rewriter.eraseOp(op); 561 rewriter.eraseOp(producer); 562 return success(); 563 }; 564 }; 565 } // namespace 566 567 namespace { 568 struct TestTypeConverter : public TypeConverter { 569 using TypeConverter::TypeConverter; 570 TestTypeConverter() { 571 addConversion(convertType); 572 addArgumentMaterialization(materializeCast); 573 addSourceMaterialization(materializeCast); 574 575 /// Materialize the cast for one-to-one conversion from i64 to f64. 576 const auto materializeOneToOneCast = 577 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 578 Location loc) -> Optional<Value> { 579 if (resultType.getWidth() == 42 && inputs.size() == 1) 580 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 581 return llvm::None; 582 }; 583 addArgumentMaterialization(materializeOneToOneCast); 584 } 585 586 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 587 // Drop I16 types. 588 if (t.isSignlessInteger(16)) 589 return success(); 590 591 // Convert I64 to F64. 592 if (t.isSignlessInteger(64)) { 593 results.push_back(FloatType::getF64(t.getContext())); 594 return success(); 595 } 596 597 // Convert I42 to I43. 598 if (t.isInteger(42)) { 599 results.push_back(IntegerType::get(t.getContext(), 43)); 600 return success(); 601 } 602 603 // Split F32 into F16,F16. 604 if (t.isF32()) { 605 results.assign(2, FloatType::getF16(t.getContext())); 606 return success(); 607 } 608 609 // Otherwise, convert the type directly. 610 results.push_back(t); 611 return success(); 612 } 613 614 /// Hook for materializing a conversion. This is necessary because we generate 615 /// 1->N type mappings. 616 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 617 ValueRange inputs, Location loc) { 618 if (inputs.size() == 1) 619 return inputs[0]; 620 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 621 } 622 }; 623 624 struct TestLegalizePatternDriver 625 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 626 StringRef getArgument() const final { return "test-legalize-patterns"; } 627 StringRef getDescription() const final { 628 return "Run test dialect legalization patterns"; 629 } 630 /// The mode of conversion to use with the driver. 631 enum class ConversionMode { Analysis, Full, Partial }; 632 633 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 634 635 void runOnOperation() override { 636 TestTypeConverter converter; 637 mlir::RewritePatternSet patterns(&getContext()); 638 populateWithGenerated(patterns); 639 patterns 640 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 641 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 642 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 643 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 644 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 645 TestNonRootReplacement, TestBoundedRecursiveRewrite, 646 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( 647 &getContext()); 648 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 649 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 650 mlir::populateCallOpTypeConversionPattern(patterns, converter); 651 652 // Define the conversion target used for the test. 653 ConversionTarget target(getContext()); 654 target.addLegalOp<ModuleOp>(); 655 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 656 TerminatorOp>(); 657 target 658 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 659 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 660 // Don't allow F32 operands. 661 return llvm::none_of(op.getOperandTypes(), 662 [](Type type) { return type.isF32(); }); 663 }); 664 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 665 return converter.isSignatureLegal(op.getType()) && 666 converter.isLegal(&op.getBody()); 667 }); 668 669 // Expect the type_producer/type_consumer operations to only operate on f64. 670 target.addDynamicallyLegalOp<TestTypeProducerOp>( 671 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 672 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 673 return op.getOperand().getType().isF64(); 674 }); 675 676 // Check support for marking certain operations as recursively legal. 677 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 678 return static_cast<bool>( 679 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 680 }); 681 682 // Mark the bound recursion operation as dynamically legal. 683 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 684 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 685 686 // Handle a partial conversion. 687 if (mode == ConversionMode::Partial) { 688 DenseSet<Operation *> unlegalizedOps; 689 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 690 &unlegalizedOps); 691 // Emit remarks for each legalizable operation. 692 for (auto *op : unlegalizedOps) 693 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 694 return; 695 } 696 697 // Handle a full conversion. 698 if (mode == ConversionMode::Full) { 699 // Check support for marking unknown operations as dynamically legal. 700 target.markUnknownOpDynamicallyLegal([](Operation *op) { 701 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 702 }); 703 704 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 705 return; 706 } 707 708 // Otherwise, handle an analysis conversion. 709 assert(mode == ConversionMode::Analysis); 710 711 // Analyze the convertible operations. 712 DenseSet<Operation *> legalizedOps; 713 if (failed(applyAnalysisConversion(getOperation(), target, 714 std::move(patterns), legalizedOps))) 715 return signalPassFailure(); 716 717 // Emit remarks for each legalizable operation. 718 for (auto *op : legalizedOps) 719 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 720 } 721 722 /// The mode of conversion to use. 723 ConversionMode mode; 724 }; 725 } // end anonymous namespace 726 727 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 728 legalizerConversionMode( 729 "test-legalize-mode", 730 llvm::cl::desc("The legalization mode to use with the test driver"), 731 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 732 llvm::cl::values( 733 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 734 "analysis", "Perform an analysis conversion"), 735 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 736 "Perform a full conversion"), 737 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 738 "partial", "Perform a partial conversion"))); 739 740 //===----------------------------------------------------------------------===// 741 // ConversionPatternRewriter::getRemappedValue testing. This method is used 742 // to get the remapped value of an original value that was replaced using 743 // ConversionPatternRewriter. 744 namespace { 745 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 746 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 747 /// operand twice. 748 /// 749 /// Example: 750 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 751 /// is replaced with: 752 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 753 struct OneVResOneVOperandOp1Converter 754 : public OpConversionPattern<OneVResOneVOperandOp1> { 755 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 756 757 LogicalResult 758 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 759 ConversionPatternRewriter &rewriter) const override { 760 auto origOps = op.getOperands(); 761 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 762 "One operand expected"); 763 Value origOp = *origOps.begin(); 764 SmallVector<Value, 2> remappedOperands; 765 // Replicate the remapped original operand twice. Note that we don't used 766 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 767 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 768 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 769 770 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 771 remappedOperands); 772 return success(); 773 } 774 }; 775 776 struct TestRemappedValue 777 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 778 StringRef getArgument() const final { return "test-remapped-value"; } 779 StringRef getDescription() const final { 780 return "Test public remapped value mechanism in ConversionPatternRewriter"; 781 } 782 void runOnFunction() override { 783 mlir::RewritePatternSet patterns(&getContext()); 784 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 785 786 mlir::ConversionTarget target(getContext()); 787 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 788 // We make OneVResOneVOperandOp1 legal only when it has more that one 789 // operand. This will trigger the conversion that will replace one-operand 790 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 791 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 792 [](Operation *op) -> bool { 793 return std::distance(op->operand_begin(), op->operand_end()) > 1; 794 }); 795 796 if (failed(mlir::applyFullConversion(getFunction(), target, 797 std::move(patterns)))) { 798 signalPassFailure(); 799 } 800 } 801 }; 802 } // end anonymous namespace 803 804 //===----------------------------------------------------------------------===// 805 // Test patterns without a specific root operation kind 806 //===----------------------------------------------------------------------===// 807 808 namespace { 809 /// This pattern matches and removes any operation in the test dialect. 810 struct RemoveTestDialectOps : public RewritePattern { 811 RemoveTestDialectOps(MLIRContext *context) 812 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 813 814 LogicalResult matchAndRewrite(Operation *op, 815 PatternRewriter &rewriter) const override { 816 if (!isa<TestDialect>(op->getDialect())) 817 return failure(); 818 rewriter.eraseOp(op); 819 return success(); 820 } 821 }; 822 823 struct TestUnknownRootOpDriver 824 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 825 StringRef getArgument() const final { 826 return "test-legalize-unknown-root-patterns"; 827 } 828 StringRef getDescription() const final { 829 return "Test public remapped value mechanism in ConversionPatternRewriter"; 830 } 831 void runOnFunction() override { 832 mlir::RewritePatternSet patterns(&getContext()); 833 patterns.add<RemoveTestDialectOps>(&getContext()); 834 835 mlir::ConversionTarget target(getContext()); 836 target.addIllegalDialect<TestDialect>(); 837 if (failed( 838 applyPartialConversion(getFunction(), target, std::move(patterns)))) 839 signalPassFailure(); 840 } 841 }; 842 } // end anonymous namespace 843 844 //===----------------------------------------------------------------------===// 845 // Test type conversions 846 //===----------------------------------------------------------------------===// 847 848 namespace { 849 struct TestTypeConversionProducer 850 : public OpConversionPattern<TestTypeProducerOp> { 851 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 852 LogicalResult 853 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 854 ConversionPatternRewriter &rewriter) const final { 855 Type resultType = op.getType(); 856 if (resultType.isa<FloatType>()) 857 resultType = rewriter.getF64Type(); 858 else if (resultType.isInteger(16)) 859 resultType = rewriter.getIntegerType(64); 860 else 861 return failure(); 862 863 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 864 return success(); 865 } 866 }; 867 868 /// Call signature conversion and then fail the rewrite to trigger the undo 869 /// mechanism. 870 struct TestSignatureConversionUndo 871 : public OpConversionPattern<TestSignatureConversionUndoOp> { 872 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 873 874 LogicalResult 875 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 876 ConversionPatternRewriter &rewriter) const final { 877 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 878 return failure(); 879 } 880 }; 881 882 /// Just forward the operands to the root op. This is essentially a no-op 883 /// pattern that is used to trigger target materialization. 884 struct TestTypeConsumerForward 885 : public OpConversionPattern<TestTypeConsumerOp> { 886 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 887 888 LogicalResult 889 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 890 ConversionPatternRewriter &rewriter) const final { 891 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 892 return success(); 893 } 894 }; 895 896 struct TestTypeConversionAnotherProducer 897 : public OpRewritePattern<TestAnotherTypeProducerOp> { 898 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 899 900 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 901 PatternRewriter &rewriter) const final { 902 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 903 return success(); 904 } 905 }; 906 907 struct TestTypeConversionDriver 908 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 909 void getDependentDialects(DialectRegistry ®istry) const override { 910 registry.insert<TestDialect>(); 911 } 912 StringRef getArgument() const final { 913 return "test-legalize-type-conversion"; 914 } 915 StringRef getDescription() const final { 916 return "Test various type conversion functionalities in DialectConversion"; 917 } 918 919 void runOnOperation() override { 920 // Initialize the type converter. 921 TypeConverter converter; 922 923 /// Add the legal set of type conversions. 924 converter.addConversion([](Type type) -> Type { 925 // Treat F64 as legal. 926 if (type.isF64()) 927 return type; 928 // Allow converting BF16/F16/F32 to F64. 929 if (type.isBF16() || type.isF16() || type.isF32()) 930 return FloatType::getF64(type.getContext()); 931 // Otherwise, the type is illegal. 932 return nullptr; 933 }); 934 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 935 // Drop all integer types. 936 return success(); 937 }); 938 939 /// Add the legal set of type materializations. 940 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 941 ValueRange inputs, 942 Location loc) -> Value { 943 // Allow casting from F64 back to F32. 944 if (!resultType.isF16() && inputs.size() == 1 && 945 inputs[0].getType().isF64()) 946 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 947 // Allow producing an i32 or i64 from nothing. 948 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 949 inputs.empty()) 950 return builder.create<TestTypeProducerOp>(loc, resultType); 951 // Allow producing an i64 from an integer. 952 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 953 inputs[0].getType().isa<IntegerType>()) 954 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 955 // Otherwise, fail. 956 return nullptr; 957 }); 958 959 // Initialize the conversion target. 960 mlir::ConversionTarget target(getContext()); 961 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 962 return op.getType().isF64() || op.getType().isInteger(64); 963 }); 964 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 965 return converter.isSignatureLegal(op.getType()) && 966 converter.isLegal(&op.getBody()); 967 }); 968 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 969 // Allow casts from F64 to F32. 970 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 971 }); 972 973 // Initialize the set of rewrite patterns. 974 RewritePatternSet patterns(&getContext()); 975 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 976 TestSignatureConversionUndo>(converter, &getContext()); 977 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 978 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 979 980 if (failed(applyPartialConversion(getOperation(), target, 981 std::move(patterns)))) 982 signalPassFailure(); 983 } 984 }; 985 } // end anonymous namespace 986 987 //===----------------------------------------------------------------------===// 988 // Test Block Merging 989 //===----------------------------------------------------------------------===// 990 991 namespace { 992 /// A rewriter pattern that tests that blocks can be merged. 993 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 994 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 995 996 LogicalResult 997 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 998 ConversionPatternRewriter &rewriter) const final { 999 Block &firstBlock = op.body().front(); 1000 Operation *branchOp = firstBlock.getTerminator(); 1001 Block *secondBlock = &*(std::next(op.body().begin())); 1002 auto succOperands = branchOp->getOperands(); 1003 SmallVector<Value, 2> replacements(succOperands); 1004 rewriter.eraseOp(branchOp); 1005 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1006 rewriter.updateRootInPlace(op, [] {}); 1007 return success(); 1008 } 1009 }; 1010 1011 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1012 struct TestUndoBlocksMerge : public ConversionPattern { 1013 TestUndoBlocksMerge(MLIRContext *ctx) 1014 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1015 LogicalResult 1016 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1017 ConversionPatternRewriter &rewriter) const final { 1018 Block &firstBlock = op->getRegion(0).front(); 1019 Operation *branchOp = firstBlock.getTerminator(); 1020 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1021 rewriter.setInsertionPointToStart(secondBlock); 1022 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1023 auto succOperands = branchOp->getOperands(); 1024 SmallVector<Value, 2> replacements(succOperands); 1025 rewriter.eraseOp(branchOp); 1026 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1027 rewriter.updateRootInPlace(op, [] {}); 1028 return success(); 1029 } 1030 }; 1031 1032 /// A rewrite mechanism to inline the body of the op into its parent, when both 1033 /// ops can have a single block. 1034 struct TestMergeSingleBlockOps 1035 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1036 using OpConversionPattern< 1037 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1038 1039 LogicalResult 1040 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 1041 ConversionPatternRewriter &rewriter) const final { 1042 SingleBlockImplicitTerminatorOp parentOp = 1043 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1044 if (!parentOp) 1045 return failure(); 1046 Block &innerBlock = op.region().front(); 1047 TerminatorOp innerTerminator = 1048 cast<TerminatorOp>(innerBlock.getTerminator()); 1049 rewriter.mergeBlockBefore(&innerBlock, op); 1050 rewriter.eraseOp(innerTerminator); 1051 rewriter.eraseOp(op); 1052 rewriter.updateRootInPlace(op, [] {}); 1053 return success(); 1054 } 1055 }; 1056 1057 struct TestMergeBlocksPatternDriver 1058 : public PassWrapper<TestMergeBlocksPatternDriver, 1059 OperationPass<ModuleOp>> { 1060 StringRef getArgument() const final { return "test-merge-blocks"; } 1061 StringRef getDescription() const final { 1062 return "Test Merging operation in ConversionPatternRewriter"; 1063 } 1064 void runOnOperation() override { 1065 MLIRContext *context = &getContext(); 1066 mlir::RewritePatternSet patterns(context); 1067 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1068 context); 1069 ConversionTarget target(*context); 1070 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1071 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1072 target.addIllegalOp<ILLegalOpF>(); 1073 1074 /// Expect the op to have a single block after legalization. 1075 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1076 [&](TestMergeBlocksOp op) -> bool { 1077 return llvm::hasSingleElement(op.body()); 1078 }); 1079 1080 /// Only allow `test.br` within test.merge_blocks op. 1081 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1082 return op->getParentOfType<TestMergeBlocksOp>(); 1083 }); 1084 1085 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1086 /// inlined. 1087 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1088 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1089 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1090 }); 1091 1092 DenseSet<Operation *> unlegalizedOps; 1093 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1094 &unlegalizedOps); 1095 for (auto *op : unlegalizedOps) 1096 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1097 } 1098 }; 1099 } // namespace 1100 1101 //===----------------------------------------------------------------------===// 1102 // Test Selective Replacement 1103 //===----------------------------------------------------------------------===// 1104 1105 namespace { 1106 /// A rewrite mechanism to inline the body of the op into its parent, when both 1107 /// ops can have a single block. 1108 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1109 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1110 1111 LogicalResult matchAndRewrite(TestCastOp op, 1112 PatternRewriter &rewriter) const final { 1113 if (op.getNumOperands() != 2) 1114 return failure(); 1115 OperandRange operands = op.getOperands(); 1116 1117 // Replace non-terminator uses with the first operand. 1118 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1119 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1120 }); 1121 // Replace everything else with the second operand if the operation isn't 1122 // dead. 1123 rewriter.replaceOp(op, op.getOperand(1)); 1124 return success(); 1125 } 1126 }; 1127 1128 struct TestSelectiveReplacementPatternDriver 1129 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1130 OperationPass<>> { 1131 StringRef getArgument() const final { 1132 return "test-pattern-selective-replacement"; 1133 } 1134 StringRef getDescription() const final { 1135 return "Test selective replacement in the PatternRewriter"; 1136 } 1137 void runOnOperation() override { 1138 MLIRContext *context = &getContext(); 1139 mlir::RewritePatternSet patterns(context); 1140 patterns.add<TestSelectiveOpReplacementPattern>(context); 1141 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1142 std::move(patterns)); 1143 } 1144 }; 1145 } // namespace 1146 1147 //===----------------------------------------------------------------------===// 1148 // PassRegistration 1149 //===----------------------------------------------------------------------===// 1150 1151 namespace mlir { 1152 namespace test { 1153 void registerPatternsTestPass() { 1154 PassRegistration<TestReturnTypeDriver>(); 1155 1156 PassRegistration<TestDerivedAttributeDriver>(); 1157 1158 PassRegistration<TestPatternDriver>(); 1159 1160 PassRegistration<TestLegalizePatternDriver>([] { 1161 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1162 }); 1163 1164 PassRegistration<TestRemappedValue>(); 1165 1166 PassRegistration<TestUnknownRootOpDriver>(); 1167 1168 PassRegistration<TestTypeConversionDriver>(); 1169 1170 PassRegistration<TestMergeBlocksPatternDriver>(); 1171 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1172 } 1173 } // namespace test 1174 } // namespace mlir 1175