1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 static bool getFirstI32Result(Operation *op, Value &value) { 39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 40 return false; 41 value = op->getResult(0); 42 return true; 43 } 44 45 static Value bindNativeCodeCallResult(Value value) { return value; } 46 47 // Test that natives calls are only called once during rewrites. 48 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 49 // This let us check the number of times OpM_Test was called by inspecting 50 // the returned value in the MLIR output. 51 static int64_t opMIncreasingValue = 314159265; 52 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 53 int64_t i = opMIncreasingValue++; 54 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 55 } 56 57 namespace { 58 #include "TestPatterns.inc" 59 } // end anonymous namespace 60 61 //===----------------------------------------------------------------------===// 62 // Test Reduce Pattern Interface 63 //===----------------------------------------------------------------------===// 64 65 void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) { 66 populateWithGenerated(patterns); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Canonicalizer Driver. 71 //===----------------------------------------------------------------------===// 72 73 namespace { 74 struct FoldingPattern : public RewritePattern { 75 public: 76 FoldingPattern(MLIRContext *context) 77 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 78 /*benefit=*/1, context) {} 79 80 LogicalResult matchAndRewrite(Operation *op, 81 PatternRewriter &rewriter) const override { 82 // Exercise OperationFolder API for a single-result operation that is folded 83 // upon construction. The operation being created through the folder has an 84 // in-place folder, and it should be still present in the output. 85 // Furthermore, the folder should not crash when attempting to recover the 86 // (unchanged) operation result. 87 OperationFolder folder(op->getContext()); 88 Value result = folder.create<TestOpInPlaceFold>( 89 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 90 rewriter.getI32IntegerAttr(0)); 91 assert(result); 92 rewriter.replaceOp(op, result); 93 return success(); 94 } 95 }; 96 97 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 98 StringRef getArgument() const final { return "test-patterns"; } 99 StringRef getDescription() const final { return "Run test dialect patterns"; } 100 void runOnFunction() override { 101 mlir::RewritePatternSet patterns(&getContext()); 102 populateWithGenerated(patterns); 103 104 // Verify named pattern is generated with expected name. 105 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext()); 106 107 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 108 } 109 }; 110 } // end anonymous namespace 111 112 //===----------------------------------------------------------------------===// 113 // ReturnType Driver. 114 //===----------------------------------------------------------------------===// 115 116 namespace { 117 // Generate ops for each instance where the type can be successfully inferred. 118 template <typename OpTy> 119 static void invokeCreateWithInferredReturnType(Operation *op) { 120 auto *context = op->getContext(); 121 auto fop = op->getParentOfType<FuncOp>(); 122 auto location = UnknownLoc::get(context); 123 OpBuilder b(op); 124 b.setInsertionPointAfter(op); 125 126 // Use permutations of 2 args as operands. 127 assert(fop.getNumArguments() >= 2); 128 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 129 for (int j = 0; j < e; ++j) { 130 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 131 SmallVector<Type, 2> inferredReturnTypes; 132 if (succeeded(OpTy::inferReturnTypes( 133 context, llvm::None, values, op->getAttrDictionary(), 134 op->getRegions(), inferredReturnTypes))) { 135 OperationState state(location, OpTy::getOperationName()); 136 // TODO: Expand to regions. 137 OpTy::build(b, state, values, op->getAttrs()); 138 (void)b.createOperation(state); 139 } 140 } 141 } 142 } 143 144 static void reifyReturnShape(Operation *op) { 145 OpBuilder b(op); 146 147 // Use permutations of 2 args as operands. 148 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 149 SmallVector<Value, 2> shapes; 150 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 151 !llvm::hasSingleElement(shapes)) 152 return; 153 for (auto it : llvm::enumerate(shapes)) { 154 op->emitRemark() << "value " << it.index() << ": " 155 << it.value().getDefiningOp(); 156 } 157 } 158 159 struct TestReturnTypeDriver 160 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 161 void getDependentDialects(DialectRegistry ®istry) const override { 162 registry.insert<memref::MemRefDialect>(); 163 } 164 StringRef getArgument() const final { return "test-return-type"; } 165 StringRef getDescription() const final { return "Run return type functions"; } 166 167 void runOnFunction() override { 168 if (getFunction().getName() == "testCreateFunctions") { 169 std::vector<Operation *> ops; 170 // Collect ops to avoid triggering on inserted ops. 171 for (auto &op : getFunction().getBody().front()) 172 ops.push_back(&op); 173 // Generate test patterns for each, but skip terminator. 174 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 175 // Test create method of each of the Op classes below. The resultant 176 // output would be in reverse order underneath `op` from which 177 // the attributes and regions are used. 178 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 179 invokeCreateWithInferredReturnType< 180 OpWithShapedTypeInferTypeInterfaceOp>(op); 181 }; 182 return; 183 } 184 if (getFunction().getName() == "testReifyFunctions") { 185 std::vector<Operation *> ops; 186 // Collect ops to avoid triggering on inserted ops. 187 for (auto &op : getFunction().getBody().front()) 188 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 189 ops.push_back(&op); 190 // Generate test patterns for each, but skip terminator. 191 for (auto *op : ops) 192 reifyReturnShape(op); 193 } 194 } 195 }; 196 } // end anonymous namespace 197 198 namespace { 199 struct TestDerivedAttributeDriver 200 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 201 StringRef getArgument() const final { return "test-derived-attr"; } 202 StringRef getDescription() const final { 203 return "Run test derived attributes"; 204 } 205 void runOnFunction() override; 206 }; 207 } // end anonymous namespace 208 209 void TestDerivedAttributeDriver::runOnFunction() { 210 getFunction().walk([](DerivedAttributeOpInterface dOp) { 211 auto dAttr = dOp.materializeDerivedAttributes(); 212 if (!dAttr) 213 return; 214 for (auto d : dAttr) 215 dOp.emitRemark() << d.first << " = " << d.second; 216 }); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // Legalization Driver. 221 //===----------------------------------------------------------------------===// 222 223 namespace { 224 //===----------------------------------------------------------------------===// 225 // Region-Block Rewrite Testing 226 227 /// This pattern is a simple pattern that inlines the first region of a given 228 /// operation into the parent region. 229 struct TestRegionRewriteBlockMovement : public ConversionPattern { 230 TestRegionRewriteBlockMovement(MLIRContext *ctx) 231 : ConversionPattern("test.region", 1, ctx) {} 232 233 LogicalResult 234 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 235 ConversionPatternRewriter &rewriter) const final { 236 // Inline this region into the parent region. 237 auto &parentRegion = *op->getParentRegion(); 238 auto &opRegion = op->getRegion(0); 239 if (op->getAttr("legalizer.should_clone")) 240 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 241 else 242 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 243 244 if (op->getAttr("legalizer.erase_old_blocks")) { 245 while (!opRegion.empty()) 246 rewriter.eraseBlock(&opRegion.front()); 247 } 248 249 // Drop this operation. 250 rewriter.eraseOp(op); 251 return success(); 252 } 253 }; 254 /// This pattern is a simple pattern that generates a region containing an 255 /// illegal operation. 256 struct TestRegionRewriteUndo : public RewritePattern { 257 TestRegionRewriteUndo(MLIRContext *ctx) 258 : RewritePattern("test.region_builder", 1, ctx) {} 259 260 LogicalResult matchAndRewrite(Operation *op, 261 PatternRewriter &rewriter) const final { 262 // Create the region operation with an entry block containing arguments. 263 OperationState newRegion(op->getLoc(), "test.region"); 264 newRegion.addRegion(); 265 auto *regionOp = rewriter.createOperation(newRegion); 266 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 267 entryBlock->addArgument(rewriter.getIntegerType(64)); 268 269 // Add an explicitly illegal operation to ensure the conversion fails. 270 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 271 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 272 273 // Drop this operation. 274 rewriter.eraseOp(op); 275 return success(); 276 } 277 }; 278 /// A simple pattern that creates a block at the end of the parent region of the 279 /// matched operation. 280 struct TestCreateBlock : public RewritePattern { 281 TestCreateBlock(MLIRContext *ctx) 282 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 283 284 LogicalResult matchAndRewrite(Operation *op, 285 PatternRewriter &rewriter) const final { 286 Region ®ion = *op->getParentRegion(); 287 Type i32Type = rewriter.getIntegerType(32); 288 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 289 rewriter.create<TerminatorOp>(op->getLoc()); 290 rewriter.replaceOp(op, {}); 291 return success(); 292 } 293 }; 294 295 /// A simple pattern that creates a block containing an invalid operation in 296 /// order to trigger the block creation undo mechanism. 297 struct TestCreateIllegalBlock : public RewritePattern { 298 TestCreateIllegalBlock(MLIRContext *ctx) 299 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 300 301 LogicalResult matchAndRewrite(Operation *op, 302 PatternRewriter &rewriter) const final { 303 Region ®ion = *op->getParentRegion(); 304 Type i32Type = rewriter.getIntegerType(32); 305 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 306 // Create an illegal op to ensure the conversion fails. 307 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 308 rewriter.create<TerminatorOp>(op->getLoc()); 309 rewriter.replaceOp(op, {}); 310 return success(); 311 } 312 }; 313 314 /// A simple pattern that tests the undo mechanism when replacing the uses of a 315 /// block argument. 316 struct TestUndoBlockArgReplace : public ConversionPattern { 317 TestUndoBlockArgReplace(MLIRContext *ctx) 318 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 319 320 LogicalResult 321 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 322 ConversionPatternRewriter &rewriter) const final { 323 auto illegalOp = 324 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 325 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 326 illegalOp); 327 rewriter.updateRootInPlace(op, [] {}); 328 return success(); 329 } 330 }; 331 332 /// A rewrite pattern that tests the undo mechanism when erasing a block. 333 struct TestUndoBlockErase : public ConversionPattern { 334 TestUndoBlockErase(MLIRContext *ctx) 335 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 336 337 LogicalResult 338 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 339 ConversionPatternRewriter &rewriter) const final { 340 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 341 rewriter.setInsertionPointToStart(secondBlock); 342 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 343 rewriter.eraseBlock(secondBlock); 344 rewriter.updateRootInPlace(op, [] {}); 345 return success(); 346 } 347 }; 348 349 //===----------------------------------------------------------------------===// 350 // Type-Conversion Rewrite Testing 351 352 /// This patterns erases a region operation that has had a type conversion. 353 struct TestDropOpSignatureConversion : public ConversionPattern { 354 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 355 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 356 LogicalResult 357 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 358 ConversionPatternRewriter &rewriter) const override { 359 Region ®ion = op->getRegion(0); 360 Block *entry = ®ion.front(); 361 362 // Convert the original entry arguments. 363 TypeConverter &converter = *getTypeConverter(); 364 TypeConverter::SignatureConversion result(entry->getNumArguments()); 365 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 366 result)) || 367 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 368 return failure(); 369 370 // Convert the region signature and just drop the operation. 371 rewriter.eraseOp(op); 372 return success(); 373 } 374 }; 375 /// This pattern simply updates the operands of the given operation. 376 struct TestPassthroughInvalidOp : public ConversionPattern { 377 TestPassthroughInvalidOp(MLIRContext *ctx) 378 : ConversionPattern("test.invalid", 1, ctx) {} 379 LogicalResult 380 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 381 ConversionPatternRewriter &rewriter) const final { 382 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 383 llvm::None); 384 return success(); 385 } 386 }; 387 /// This pattern handles the case of a split return value. 388 struct TestSplitReturnType : public ConversionPattern { 389 TestSplitReturnType(MLIRContext *ctx) 390 : ConversionPattern("test.return", 1, ctx) {} 391 LogicalResult 392 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 393 ConversionPatternRewriter &rewriter) const final { 394 // Check for a return of F32. 395 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 396 return failure(); 397 398 // Check if the first operation is a cast operation, if it is we use the 399 // results directly. 400 auto *defOp = operands[0].getDefiningOp(); 401 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 402 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 403 return success(); 404 } 405 406 // Otherwise, fail to match. 407 return failure(); 408 } 409 }; 410 411 //===----------------------------------------------------------------------===// 412 // Multi-Level Type-Conversion Rewrite Testing 413 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 414 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 415 : ConversionPattern("test.type_producer", 1, ctx) {} 416 LogicalResult 417 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 418 ConversionPatternRewriter &rewriter) const final { 419 // If the type is I32, change the type to F32. 420 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 421 return failure(); 422 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 423 return success(); 424 } 425 }; 426 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 427 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 428 : ConversionPattern("test.type_producer", 1, ctx) {} 429 LogicalResult 430 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 431 ConversionPatternRewriter &rewriter) const final { 432 // If the type is F32, change the type to F64. 433 if (!Type(*op->result_type_begin()).isF32()) 434 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 435 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 436 return success(); 437 } 438 }; 439 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 440 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 441 : ConversionPattern("test.type_producer", 10, ctx) {} 442 LogicalResult 443 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 444 ConversionPatternRewriter &rewriter) const final { 445 // Always convert to B16, even though it is not a legal type. This tests 446 // that values are unmapped correctly. 447 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 448 return success(); 449 } 450 }; 451 struct TestUpdateConsumerType : public ConversionPattern { 452 TestUpdateConsumerType(MLIRContext *ctx) 453 : ConversionPattern("test.type_consumer", 1, ctx) {} 454 LogicalResult 455 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 456 ConversionPatternRewriter &rewriter) const final { 457 // Verify that the incoming operand has been successfully remapped to F64. 458 if (!operands[0].getType().isF64()) 459 return failure(); 460 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 461 return success(); 462 } 463 }; 464 465 //===----------------------------------------------------------------------===// 466 // Non-Root Replacement Rewrite Testing 467 /// This pattern generates an invalid operation, but replaces it before the 468 /// pattern is finished. This checks that we don't need to legalize the 469 /// temporary op. 470 struct TestNonRootReplacement : public RewritePattern { 471 TestNonRootReplacement(MLIRContext *ctx) 472 : RewritePattern("test.replace_non_root", 1, ctx) {} 473 474 LogicalResult matchAndRewrite(Operation *op, 475 PatternRewriter &rewriter) const final { 476 auto resultType = *op->result_type_begin(); 477 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 478 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 479 480 rewriter.replaceOp(illegalOp, {legalOp}); 481 rewriter.replaceOp(op, {illegalOp}); 482 return success(); 483 } 484 }; 485 486 //===----------------------------------------------------------------------===// 487 // Recursive Rewrite Testing 488 /// This pattern is applied to the same operation multiple times, but has a 489 /// bounded recursion. 490 struct TestBoundedRecursiveRewrite 491 : public OpRewritePattern<TestRecursiveRewriteOp> { 492 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 493 494 void initialize() { 495 // The conversion target handles bounding the recursion of this pattern. 496 setHasBoundedRewriteRecursion(); 497 } 498 499 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 500 PatternRewriter &rewriter) const final { 501 // Decrement the depth of the op in-place. 502 rewriter.updateRootInPlace(op, [&] { 503 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 504 }); 505 return success(); 506 } 507 }; 508 509 struct TestNestedOpCreationUndoRewrite 510 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 511 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 512 513 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 514 PatternRewriter &rewriter) const final { 515 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 516 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 517 return success(); 518 }; 519 }; 520 521 // This pattern matches `test.blackhole` and delete this op and its producer. 522 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 523 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 524 525 LogicalResult matchAndRewrite(BlackHoleOp op, 526 PatternRewriter &rewriter) const final { 527 Operation *producer = op.getOperand().getDefiningOp(); 528 // Always erase the user before the producer, the framework should handle 529 // this correctly. 530 rewriter.eraseOp(op); 531 rewriter.eraseOp(producer); 532 return success(); 533 }; 534 }; 535 } // namespace 536 537 namespace { 538 struct TestTypeConverter : public TypeConverter { 539 using TypeConverter::TypeConverter; 540 TestTypeConverter() { 541 addConversion(convertType); 542 addArgumentMaterialization(materializeCast); 543 addSourceMaterialization(materializeCast); 544 545 /// Materialize the cast for one-to-one conversion from i64 to f64. 546 const auto materializeOneToOneCast = 547 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 548 Location loc) -> Optional<Value> { 549 if (resultType.getWidth() == 42 && inputs.size() == 1) 550 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 551 return llvm::None; 552 }; 553 addArgumentMaterialization(materializeOneToOneCast); 554 } 555 556 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 557 // Drop I16 types. 558 if (t.isSignlessInteger(16)) 559 return success(); 560 561 // Convert I64 to F64. 562 if (t.isSignlessInteger(64)) { 563 results.push_back(FloatType::getF64(t.getContext())); 564 return success(); 565 } 566 567 // Convert I42 to I43. 568 if (t.isInteger(42)) { 569 results.push_back(IntegerType::get(t.getContext(), 43)); 570 return success(); 571 } 572 573 // Split F32 into F16,F16. 574 if (t.isF32()) { 575 results.assign(2, FloatType::getF16(t.getContext())); 576 return success(); 577 } 578 579 // Otherwise, convert the type directly. 580 results.push_back(t); 581 return success(); 582 } 583 584 /// Hook for materializing a conversion. This is necessary because we generate 585 /// 1->N type mappings. 586 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 587 ValueRange inputs, Location loc) { 588 if (inputs.size() == 1) 589 return inputs[0]; 590 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 591 } 592 }; 593 594 struct TestLegalizePatternDriver 595 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 596 StringRef getArgument() const final { return "test-legalize-patterns"; } 597 StringRef getDescription() const final { 598 return "Run test dialect legalization patterns"; 599 } 600 /// The mode of conversion to use with the driver. 601 enum class ConversionMode { Analysis, Full, Partial }; 602 603 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 604 605 void runOnOperation() override { 606 TestTypeConverter converter; 607 mlir::RewritePatternSet patterns(&getContext()); 608 populateWithGenerated(patterns); 609 patterns 610 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 611 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 612 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 613 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 614 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 615 TestNonRootReplacement, TestBoundedRecursiveRewrite, 616 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( 617 &getContext()); 618 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 619 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 620 mlir::populateCallOpTypeConversionPattern(patterns, converter); 621 622 // Define the conversion target used for the test. 623 ConversionTarget target(getContext()); 624 target.addLegalOp<ModuleOp>(); 625 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 626 TerminatorOp>(); 627 target 628 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 629 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 630 // Don't allow F32 operands. 631 return llvm::none_of(op.getOperandTypes(), 632 [](Type type) { return type.isF32(); }); 633 }); 634 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 635 return converter.isSignatureLegal(op.getType()) && 636 converter.isLegal(&op.getBody()); 637 }); 638 639 // Expect the type_producer/type_consumer operations to only operate on f64. 640 target.addDynamicallyLegalOp<TestTypeProducerOp>( 641 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 642 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 643 return op.getOperand().getType().isF64(); 644 }); 645 646 // Check support for marking certain operations as recursively legal. 647 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 648 return static_cast<bool>( 649 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 650 }); 651 652 // Mark the bound recursion operation as dynamically legal. 653 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 654 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 655 656 // Handle a partial conversion. 657 if (mode == ConversionMode::Partial) { 658 DenseSet<Operation *> unlegalizedOps; 659 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 660 &unlegalizedOps); 661 // Emit remarks for each legalizable operation. 662 for (auto *op : unlegalizedOps) 663 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 664 return; 665 } 666 667 // Handle a full conversion. 668 if (mode == ConversionMode::Full) { 669 // Check support for marking unknown operations as dynamically legal. 670 target.markUnknownOpDynamicallyLegal([](Operation *op) { 671 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 672 }); 673 674 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 675 return; 676 } 677 678 // Otherwise, handle an analysis conversion. 679 assert(mode == ConversionMode::Analysis); 680 681 // Analyze the convertible operations. 682 DenseSet<Operation *> legalizedOps; 683 if (failed(applyAnalysisConversion(getOperation(), target, 684 std::move(patterns), legalizedOps))) 685 return signalPassFailure(); 686 687 // Emit remarks for each legalizable operation. 688 for (auto *op : legalizedOps) 689 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 690 } 691 692 /// The mode of conversion to use. 693 ConversionMode mode; 694 }; 695 } // end anonymous namespace 696 697 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 698 legalizerConversionMode( 699 "test-legalize-mode", 700 llvm::cl::desc("The legalization mode to use with the test driver"), 701 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 702 llvm::cl::values( 703 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 704 "analysis", "Perform an analysis conversion"), 705 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 706 "Perform a full conversion"), 707 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 708 "partial", "Perform a partial conversion"))); 709 710 //===----------------------------------------------------------------------===// 711 // ConversionPatternRewriter::getRemappedValue testing. This method is used 712 // to get the remapped value of an original value that was replaced using 713 // ConversionPatternRewriter. 714 namespace { 715 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 716 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 717 /// operand twice. 718 /// 719 /// Example: 720 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 721 /// is replaced with: 722 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 723 struct OneVResOneVOperandOp1Converter 724 : public OpConversionPattern<OneVResOneVOperandOp1> { 725 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 726 727 LogicalResult 728 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 729 ConversionPatternRewriter &rewriter) const override { 730 auto origOps = op.getOperands(); 731 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 732 "One operand expected"); 733 Value origOp = *origOps.begin(); 734 SmallVector<Value, 2> remappedOperands; 735 // Replicate the remapped original operand twice. Note that we don't used 736 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 737 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 738 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 739 740 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 741 remappedOperands); 742 return success(); 743 } 744 }; 745 746 struct TestRemappedValue 747 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 748 StringRef getArgument() const final { return "test-remapped-value"; } 749 StringRef getDescription() const final { 750 return "Test public remapped value mechanism in ConversionPatternRewriter"; 751 } 752 void runOnFunction() override { 753 mlir::RewritePatternSet patterns(&getContext()); 754 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 755 756 mlir::ConversionTarget target(getContext()); 757 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 758 // We make OneVResOneVOperandOp1 legal only when it has more that one 759 // operand. This will trigger the conversion that will replace one-operand 760 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 761 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 762 [](Operation *op) -> bool { 763 return std::distance(op->operand_begin(), op->operand_end()) > 1; 764 }); 765 766 if (failed(mlir::applyFullConversion(getFunction(), target, 767 std::move(patterns)))) { 768 signalPassFailure(); 769 } 770 } 771 }; 772 } // end anonymous namespace 773 774 //===----------------------------------------------------------------------===// 775 // Test patterns without a specific root operation kind 776 //===----------------------------------------------------------------------===// 777 778 namespace { 779 /// This pattern matches and removes any operation in the test dialect. 780 struct RemoveTestDialectOps : public RewritePattern { 781 RemoveTestDialectOps(MLIRContext *context) 782 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 783 784 LogicalResult matchAndRewrite(Operation *op, 785 PatternRewriter &rewriter) const override { 786 if (!isa<TestDialect>(op->getDialect())) 787 return failure(); 788 rewriter.eraseOp(op); 789 return success(); 790 } 791 }; 792 793 struct TestUnknownRootOpDriver 794 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 795 StringRef getArgument() const final { 796 return "test-legalize-unknown-root-patterns"; 797 } 798 StringRef getDescription() const final { 799 return "Test public remapped value mechanism in ConversionPatternRewriter"; 800 } 801 void runOnFunction() override { 802 mlir::RewritePatternSet patterns(&getContext()); 803 patterns.add<RemoveTestDialectOps>(&getContext()); 804 805 mlir::ConversionTarget target(getContext()); 806 target.addIllegalDialect<TestDialect>(); 807 if (failed( 808 applyPartialConversion(getFunction(), target, std::move(patterns)))) 809 signalPassFailure(); 810 } 811 }; 812 } // end anonymous namespace 813 814 //===----------------------------------------------------------------------===// 815 // Test type conversions 816 //===----------------------------------------------------------------------===// 817 818 namespace { 819 struct TestTypeConversionProducer 820 : public OpConversionPattern<TestTypeProducerOp> { 821 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 822 LogicalResult 823 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 824 ConversionPatternRewriter &rewriter) const final { 825 Type resultType = op.getType(); 826 if (resultType.isa<FloatType>()) 827 resultType = rewriter.getF64Type(); 828 else if (resultType.isInteger(16)) 829 resultType = rewriter.getIntegerType(64); 830 else 831 return failure(); 832 833 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 834 return success(); 835 } 836 }; 837 838 /// Call signature conversion and then fail the rewrite to trigger the undo 839 /// mechanism. 840 struct TestSignatureConversionUndo 841 : public OpConversionPattern<TestSignatureConversionUndoOp> { 842 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 843 844 LogicalResult 845 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 846 ConversionPatternRewriter &rewriter) const final { 847 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 848 return failure(); 849 } 850 }; 851 852 /// Just forward the operands to the root op. This is essentially a no-op 853 /// pattern that is used to trigger target materialization. 854 struct TestTypeConsumerForward 855 : public OpConversionPattern<TestTypeConsumerOp> { 856 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 857 858 LogicalResult 859 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 860 ConversionPatternRewriter &rewriter) const final { 861 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 862 return success(); 863 } 864 }; 865 866 struct TestTypeConversionAnotherProducer 867 : public OpRewritePattern<TestAnotherTypeProducerOp> { 868 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 869 870 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 871 PatternRewriter &rewriter) const final { 872 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 873 return success(); 874 } 875 }; 876 877 struct TestTypeConversionDriver 878 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 879 void getDependentDialects(DialectRegistry ®istry) const override { 880 registry.insert<TestDialect>(); 881 } 882 StringRef getArgument() const final { 883 return "test-legalize-type-conversion"; 884 } 885 StringRef getDescription() const final { 886 return "Test various type conversion functionalities in DialectConversion"; 887 } 888 889 void runOnOperation() override { 890 // Initialize the type converter. 891 TypeConverter converter; 892 893 /// Add the legal set of type conversions. 894 converter.addConversion([](Type type) -> Type { 895 // Treat F64 as legal. 896 if (type.isF64()) 897 return type; 898 // Allow converting BF16/F16/F32 to F64. 899 if (type.isBF16() || type.isF16() || type.isF32()) 900 return FloatType::getF64(type.getContext()); 901 // Otherwise, the type is illegal. 902 return nullptr; 903 }); 904 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 905 // Drop all integer types. 906 return success(); 907 }); 908 909 /// Add the legal set of type materializations. 910 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 911 ValueRange inputs, 912 Location loc) -> Value { 913 // Allow casting from F64 back to F32. 914 if (!resultType.isF16() && inputs.size() == 1 && 915 inputs[0].getType().isF64()) 916 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 917 // Allow producing an i32 or i64 from nothing. 918 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 919 inputs.empty()) 920 return builder.create<TestTypeProducerOp>(loc, resultType); 921 // Allow producing an i64 from an integer. 922 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 923 inputs[0].getType().isa<IntegerType>()) 924 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 925 // Otherwise, fail. 926 return nullptr; 927 }); 928 929 // Initialize the conversion target. 930 mlir::ConversionTarget target(getContext()); 931 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 932 return op.getType().isF64() || op.getType().isInteger(64); 933 }); 934 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 935 return converter.isSignatureLegal(op.getType()) && 936 converter.isLegal(&op.getBody()); 937 }); 938 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 939 // Allow casts from F64 to F32. 940 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 941 }); 942 943 // Initialize the set of rewrite patterns. 944 RewritePatternSet patterns(&getContext()); 945 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 946 TestSignatureConversionUndo>(converter, &getContext()); 947 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 948 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 949 950 if (failed(applyPartialConversion(getOperation(), target, 951 std::move(patterns)))) 952 signalPassFailure(); 953 } 954 }; 955 } // end anonymous namespace 956 957 //===----------------------------------------------------------------------===// 958 // Test Block Merging 959 //===----------------------------------------------------------------------===// 960 961 namespace { 962 /// A rewriter pattern that tests that blocks can be merged. 963 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 964 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 965 966 LogicalResult 967 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 968 ConversionPatternRewriter &rewriter) const final { 969 Block &firstBlock = op.body().front(); 970 Operation *branchOp = firstBlock.getTerminator(); 971 Block *secondBlock = &*(std::next(op.body().begin())); 972 auto succOperands = branchOp->getOperands(); 973 SmallVector<Value, 2> replacements(succOperands); 974 rewriter.eraseOp(branchOp); 975 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 976 rewriter.updateRootInPlace(op, [] {}); 977 return success(); 978 } 979 }; 980 981 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 982 struct TestUndoBlocksMerge : public ConversionPattern { 983 TestUndoBlocksMerge(MLIRContext *ctx) 984 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 985 LogicalResult 986 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 987 ConversionPatternRewriter &rewriter) const final { 988 Block &firstBlock = op->getRegion(0).front(); 989 Operation *branchOp = firstBlock.getTerminator(); 990 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 991 rewriter.setInsertionPointToStart(secondBlock); 992 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 993 auto succOperands = branchOp->getOperands(); 994 SmallVector<Value, 2> replacements(succOperands); 995 rewriter.eraseOp(branchOp); 996 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 997 rewriter.updateRootInPlace(op, [] {}); 998 return success(); 999 } 1000 }; 1001 1002 /// A rewrite mechanism to inline the body of the op into its parent, when both 1003 /// ops can have a single block. 1004 struct TestMergeSingleBlockOps 1005 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1006 using OpConversionPattern< 1007 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1008 1009 LogicalResult 1010 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 1011 ConversionPatternRewriter &rewriter) const final { 1012 SingleBlockImplicitTerminatorOp parentOp = 1013 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1014 if (!parentOp) 1015 return failure(); 1016 Block &innerBlock = op.region().front(); 1017 TerminatorOp innerTerminator = 1018 cast<TerminatorOp>(innerBlock.getTerminator()); 1019 rewriter.mergeBlockBefore(&innerBlock, op); 1020 rewriter.eraseOp(innerTerminator); 1021 rewriter.eraseOp(op); 1022 rewriter.updateRootInPlace(op, [] {}); 1023 return success(); 1024 } 1025 }; 1026 1027 struct TestMergeBlocksPatternDriver 1028 : public PassWrapper<TestMergeBlocksPatternDriver, 1029 OperationPass<ModuleOp>> { 1030 StringRef getArgument() const final { return "test-merge-blocks"; } 1031 StringRef getDescription() const final { 1032 return "Test Merging operation in ConversionPatternRewriter"; 1033 } 1034 void runOnOperation() override { 1035 MLIRContext *context = &getContext(); 1036 mlir::RewritePatternSet patterns(context); 1037 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1038 context); 1039 ConversionTarget target(*context); 1040 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1041 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1042 target.addIllegalOp<ILLegalOpF>(); 1043 1044 /// Expect the op to have a single block after legalization. 1045 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1046 [&](TestMergeBlocksOp op) -> bool { 1047 return llvm::hasSingleElement(op.body()); 1048 }); 1049 1050 /// Only allow `test.br` within test.merge_blocks op. 1051 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1052 return op->getParentOfType<TestMergeBlocksOp>(); 1053 }); 1054 1055 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1056 /// inlined. 1057 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1058 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1059 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1060 }); 1061 1062 DenseSet<Operation *> unlegalizedOps; 1063 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1064 &unlegalizedOps); 1065 for (auto *op : unlegalizedOps) 1066 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1067 } 1068 }; 1069 } // namespace 1070 1071 //===----------------------------------------------------------------------===// 1072 // Test Selective Replacement 1073 //===----------------------------------------------------------------------===// 1074 1075 namespace { 1076 /// A rewrite mechanism to inline the body of the op into its parent, when both 1077 /// ops can have a single block. 1078 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1079 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1080 1081 LogicalResult matchAndRewrite(TestCastOp op, 1082 PatternRewriter &rewriter) const final { 1083 if (op.getNumOperands() != 2) 1084 return failure(); 1085 OperandRange operands = op.getOperands(); 1086 1087 // Replace non-terminator uses with the first operand. 1088 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1089 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1090 }); 1091 // Replace everything else with the second operand if the operation isn't 1092 // dead. 1093 rewriter.replaceOp(op, op.getOperand(1)); 1094 return success(); 1095 } 1096 }; 1097 1098 struct TestSelectiveReplacementPatternDriver 1099 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1100 OperationPass<>> { 1101 StringRef getArgument() const final { 1102 return "test-pattern-selective-replacement"; 1103 } 1104 StringRef getDescription() const final { 1105 return "Test selective replacement in the PatternRewriter"; 1106 } 1107 void runOnOperation() override { 1108 MLIRContext *context = &getContext(); 1109 mlir::RewritePatternSet patterns(context); 1110 patterns.add<TestSelectiveOpReplacementPattern>(context); 1111 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1112 std::move(patterns)); 1113 } 1114 }; 1115 } // namespace 1116 1117 //===----------------------------------------------------------------------===// 1118 // PassRegistration 1119 //===----------------------------------------------------------------------===// 1120 1121 namespace mlir { 1122 namespace test { 1123 void registerPatternsTestPass() { 1124 PassRegistration<TestReturnTypeDriver>(); 1125 1126 PassRegistration<TestDerivedAttributeDriver>(); 1127 1128 PassRegistration<TestPatternDriver>(); 1129 1130 PassRegistration<TestLegalizePatternDriver>([] { 1131 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1132 }); 1133 1134 PassRegistration<TestRemappedValue>(); 1135 1136 PassRegistration<TestUnknownRootOpDriver>(); 1137 1138 PassRegistration<TestTypeConversionDriver>(); 1139 1140 PassRegistration<TestMergeBlocksPatternDriver>(); 1141 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1142 } 1143 } // namespace test 1144 } // namespace mlir 1145