1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 static bool getFirstI32Result(Operation *op, Value &value) { 39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 40 return false; 41 value = op->getResult(0); 42 return true; 43 } 44 45 static Value bindNativeCodeCallResult(Value value) { return value; } 46 47 // Test that natives calls are only called once during rewrites. 48 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 49 // This let us check the number of times OpM_Test was called by inspecting 50 // the returned value in the MLIR output. 51 static int64_t opMIncreasingValue = 314159265; 52 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 53 int64_t i = opMIncreasingValue++; 54 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 55 } 56 57 namespace { 58 #include "TestPatterns.inc" 59 } // end anonymous namespace 60 61 //===----------------------------------------------------------------------===// 62 // Test Reduce Pattern Interface 63 //===----------------------------------------------------------------------===// 64 65 void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) { 66 populateWithGenerated(patterns); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Canonicalizer Driver. 71 //===----------------------------------------------------------------------===// 72 73 namespace { 74 struct FoldingPattern : public RewritePattern { 75 public: 76 FoldingPattern(MLIRContext *context) 77 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 78 /*benefit=*/1, context) {} 79 80 LogicalResult matchAndRewrite(Operation *op, 81 PatternRewriter &rewriter) const override { 82 // Exercise OperationFolder API for a single-result operation that is folded 83 // upon construction. The operation being created through the folder has an 84 // in-place folder, and it should be still present in the output. 85 // Furthermore, the folder should not crash when attempting to recover the 86 // (unchanged) operation result. 87 OperationFolder folder(op->getContext()); 88 Value result = folder.create<TestOpInPlaceFold>( 89 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 90 rewriter.getI32IntegerAttr(0)); 91 assert(result); 92 rewriter.replaceOp(op, result); 93 return success(); 94 } 95 }; 96 97 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 98 void runOnFunction() override { 99 mlir::RewritePatternSet patterns(&getContext()); 100 populateWithGenerated(patterns); 101 102 // Verify named pattern is generated with expected name. 103 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext()); 104 105 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 106 } 107 }; 108 } // end anonymous namespace 109 110 //===----------------------------------------------------------------------===// 111 // ReturnType Driver. 112 //===----------------------------------------------------------------------===// 113 114 namespace { 115 // Generate ops for each instance where the type can be successfully inferred. 116 template <typename OpTy> 117 static void invokeCreateWithInferredReturnType(Operation *op) { 118 auto *context = op->getContext(); 119 auto fop = op->getParentOfType<FuncOp>(); 120 auto location = UnknownLoc::get(context); 121 OpBuilder b(op); 122 b.setInsertionPointAfter(op); 123 124 // Use permutations of 2 args as operands. 125 assert(fop.getNumArguments() >= 2); 126 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 127 for (int j = 0; j < e; ++j) { 128 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 129 SmallVector<Type, 2> inferredReturnTypes; 130 if (succeeded(OpTy::inferReturnTypes( 131 context, llvm::None, values, op->getAttrDictionary(), 132 op->getRegions(), inferredReturnTypes))) { 133 OperationState state(location, OpTy::getOperationName()); 134 // TODO: Expand to regions. 135 OpTy::build(b, state, values, op->getAttrs()); 136 (void)b.createOperation(state); 137 } 138 } 139 } 140 } 141 142 static void reifyReturnShape(Operation *op) { 143 OpBuilder b(op); 144 145 // Use permutations of 2 args as operands. 146 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 147 SmallVector<Value, 2> shapes; 148 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 149 !llvm::hasSingleElement(shapes)) 150 return; 151 for (auto it : llvm::enumerate(shapes)) { 152 op->emitRemark() << "value " << it.index() << ": " 153 << it.value().getDefiningOp(); 154 } 155 } 156 157 struct TestReturnTypeDriver 158 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 159 void getDependentDialects(DialectRegistry ®istry) const override { 160 registry.insert<memref::MemRefDialect>(); 161 } 162 163 void runOnFunction() override { 164 if (getFunction().getName() == "testCreateFunctions") { 165 std::vector<Operation *> ops; 166 // Collect ops to avoid triggering on inserted ops. 167 for (auto &op : getFunction().getBody().front()) 168 ops.push_back(&op); 169 // Generate test patterns for each, but skip terminator. 170 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 171 // Test create method of each of the Op classes below. The resultant 172 // output would be in reverse order underneath `op` from which 173 // the attributes and regions are used. 174 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 175 invokeCreateWithInferredReturnType< 176 OpWithShapedTypeInferTypeInterfaceOp>(op); 177 }; 178 return; 179 } 180 if (getFunction().getName() == "testReifyFunctions") { 181 std::vector<Operation *> ops; 182 // Collect ops to avoid triggering on inserted ops. 183 for (auto &op : getFunction().getBody().front()) 184 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 185 ops.push_back(&op); 186 // Generate test patterns for each, but skip terminator. 187 for (auto *op : ops) 188 reifyReturnShape(op); 189 } 190 } 191 }; 192 } // end anonymous namespace 193 194 namespace { 195 struct TestDerivedAttributeDriver 196 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 197 void runOnFunction() override; 198 }; 199 } // end anonymous namespace 200 201 void TestDerivedAttributeDriver::runOnFunction() { 202 getFunction().walk([](DerivedAttributeOpInterface dOp) { 203 auto dAttr = dOp.materializeDerivedAttributes(); 204 if (!dAttr) 205 return; 206 for (auto d : dAttr) 207 dOp.emitRemark() << d.first << " = " << d.second; 208 }); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // Legalization Driver. 213 //===----------------------------------------------------------------------===// 214 215 namespace { 216 //===----------------------------------------------------------------------===// 217 // Region-Block Rewrite Testing 218 219 /// This pattern is a simple pattern that inlines the first region of a given 220 /// operation into the parent region. 221 struct TestRegionRewriteBlockMovement : public ConversionPattern { 222 TestRegionRewriteBlockMovement(MLIRContext *ctx) 223 : ConversionPattern("test.region", 1, ctx) {} 224 225 LogicalResult 226 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 227 ConversionPatternRewriter &rewriter) const final { 228 // Inline this region into the parent region. 229 auto &parentRegion = *op->getParentRegion(); 230 auto &opRegion = op->getRegion(0); 231 if (op->getAttr("legalizer.should_clone")) 232 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 233 else 234 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 235 236 if (op->getAttr("legalizer.erase_old_blocks")) { 237 while (!opRegion.empty()) 238 rewriter.eraseBlock(&opRegion.front()); 239 } 240 241 // Drop this operation. 242 rewriter.eraseOp(op); 243 return success(); 244 } 245 }; 246 /// This pattern is a simple pattern that generates a region containing an 247 /// illegal operation. 248 struct TestRegionRewriteUndo : public RewritePattern { 249 TestRegionRewriteUndo(MLIRContext *ctx) 250 : RewritePattern("test.region_builder", 1, ctx) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const final { 254 // Create the region operation with an entry block containing arguments. 255 OperationState newRegion(op->getLoc(), "test.region"); 256 newRegion.addRegion(); 257 auto *regionOp = rewriter.createOperation(newRegion); 258 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 259 entryBlock->addArgument(rewriter.getIntegerType(64)); 260 261 // Add an explicitly illegal operation to ensure the conversion fails. 262 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 263 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 264 265 // Drop this operation. 266 rewriter.eraseOp(op); 267 return success(); 268 } 269 }; 270 /// A simple pattern that creates a block at the end of the parent region of the 271 /// matched operation. 272 struct TestCreateBlock : public RewritePattern { 273 TestCreateBlock(MLIRContext *ctx) 274 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 275 276 LogicalResult matchAndRewrite(Operation *op, 277 PatternRewriter &rewriter) const final { 278 Region ®ion = *op->getParentRegion(); 279 Type i32Type = rewriter.getIntegerType(32); 280 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 281 rewriter.create<TerminatorOp>(op->getLoc()); 282 rewriter.replaceOp(op, {}); 283 return success(); 284 } 285 }; 286 287 /// A simple pattern that creates a block containing an invalid operation in 288 /// order to trigger the block creation undo mechanism. 289 struct TestCreateIllegalBlock : public RewritePattern { 290 TestCreateIllegalBlock(MLIRContext *ctx) 291 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 292 293 LogicalResult matchAndRewrite(Operation *op, 294 PatternRewriter &rewriter) const final { 295 Region ®ion = *op->getParentRegion(); 296 Type i32Type = rewriter.getIntegerType(32); 297 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 298 // Create an illegal op to ensure the conversion fails. 299 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 300 rewriter.create<TerminatorOp>(op->getLoc()); 301 rewriter.replaceOp(op, {}); 302 return success(); 303 } 304 }; 305 306 /// A simple pattern that tests the undo mechanism when replacing the uses of a 307 /// block argument. 308 struct TestUndoBlockArgReplace : public ConversionPattern { 309 TestUndoBlockArgReplace(MLIRContext *ctx) 310 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 311 312 LogicalResult 313 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 314 ConversionPatternRewriter &rewriter) const final { 315 auto illegalOp = 316 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 317 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 318 illegalOp); 319 rewriter.updateRootInPlace(op, [] {}); 320 return success(); 321 } 322 }; 323 324 /// A rewrite pattern that tests the undo mechanism when erasing a block. 325 struct TestUndoBlockErase : public ConversionPattern { 326 TestUndoBlockErase(MLIRContext *ctx) 327 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 328 329 LogicalResult 330 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 331 ConversionPatternRewriter &rewriter) const final { 332 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 333 rewriter.setInsertionPointToStart(secondBlock); 334 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 335 rewriter.eraseBlock(secondBlock); 336 rewriter.updateRootInPlace(op, [] {}); 337 return success(); 338 } 339 }; 340 341 //===----------------------------------------------------------------------===// 342 // Type-Conversion Rewrite Testing 343 344 /// This patterns erases a region operation that has had a type conversion. 345 struct TestDropOpSignatureConversion : public ConversionPattern { 346 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 347 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 348 LogicalResult 349 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 350 ConversionPatternRewriter &rewriter) const override { 351 Region ®ion = op->getRegion(0); 352 Block *entry = ®ion.front(); 353 354 // Convert the original entry arguments. 355 TypeConverter &converter = *getTypeConverter(); 356 TypeConverter::SignatureConversion result(entry->getNumArguments()); 357 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 358 result)) || 359 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 360 return failure(); 361 362 // Convert the region signature and just drop the operation. 363 rewriter.eraseOp(op); 364 return success(); 365 } 366 }; 367 /// This pattern simply updates the operands of the given operation. 368 struct TestPassthroughInvalidOp : public ConversionPattern { 369 TestPassthroughInvalidOp(MLIRContext *ctx) 370 : ConversionPattern("test.invalid", 1, ctx) {} 371 LogicalResult 372 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 373 ConversionPatternRewriter &rewriter) const final { 374 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 375 llvm::None); 376 return success(); 377 } 378 }; 379 /// This pattern handles the case of a split return value. 380 struct TestSplitReturnType : public ConversionPattern { 381 TestSplitReturnType(MLIRContext *ctx) 382 : ConversionPattern("test.return", 1, ctx) {} 383 LogicalResult 384 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 385 ConversionPatternRewriter &rewriter) const final { 386 // Check for a return of F32. 387 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 388 return failure(); 389 390 // Check if the first operation is a cast operation, if it is we use the 391 // results directly. 392 auto *defOp = operands[0].getDefiningOp(); 393 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 394 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 395 return success(); 396 } 397 398 // Otherwise, fail to match. 399 return failure(); 400 } 401 }; 402 403 //===----------------------------------------------------------------------===// 404 // Multi-Level Type-Conversion Rewrite Testing 405 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 406 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 407 : ConversionPattern("test.type_producer", 1, ctx) {} 408 LogicalResult 409 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 410 ConversionPatternRewriter &rewriter) const final { 411 // If the type is I32, change the type to F32. 412 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 413 return failure(); 414 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 415 return success(); 416 } 417 }; 418 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 419 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 420 : ConversionPattern("test.type_producer", 1, ctx) {} 421 LogicalResult 422 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 423 ConversionPatternRewriter &rewriter) const final { 424 // If the type is F32, change the type to F64. 425 if (!Type(*op->result_type_begin()).isF32()) 426 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 427 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 428 return success(); 429 } 430 }; 431 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 432 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 433 : ConversionPattern("test.type_producer", 10, ctx) {} 434 LogicalResult 435 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 436 ConversionPatternRewriter &rewriter) const final { 437 // Always convert to B16, even though it is not a legal type. This tests 438 // that values are unmapped correctly. 439 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 440 return success(); 441 } 442 }; 443 struct TestUpdateConsumerType : public ConversionPattern { 444 TestUpdateConsumerType(MLIRContext *ctx) 445 : ConversionPattern("test.type_consumer", 1, ctx) {} 446 LogicalResult 447 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 448 ConversionPatternRewriter &rewriter) const final { 449 // Verify that the incoming operand has been successfully remapped to F64. 450 if (!operands[0].getType().isF64()) 451 return failure(); 452 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 453 return success(); 454 } 455 }; 456 457 //===----------------------------------------------------------------------===// 458 // Non-Root Replacement Rewrite Testing 459 /// This pattern generates an invalid operation, but replaces it before the 460 /// pattern is finished. This checks that we don't need to legalize the 461 /// temporary op. 462 struct TestNonRootReplacement : public RewritePattern { 463 TestNonRootReplacement(MLIRContext *ctx) 464 : RewritePattern("test.replace_non_root", 1, ctx) {} 465 466 LogicalResult matchAndRewrite(Operation *op, 467 PatternRewriter &rewriter) const final { 468 auto resultType = *op->result_type_begin(); 469 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 470 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 471 472 rewriter.replaceOp(illegalOp, {legalOp}); 473 rewriter.replaceOp(op, {illegalOp}); 474 return success(); 475 } 476 }; 477 478 //===----------------------------------------------------------------------===// 479 // Recursive Rewrite Testing 480 /// This pattern is applied to the same operation multiple times, but has a 481 /// bounded recursion. 482 struct TestBoundedRecursiveRewrite 483 : public OpRewritePattern<TestRecursiveRewriteOp> { 484 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 485 486 void initialize() { 487 // The conversion target handles bounding the recursion of this pattern. 488 setHasBoundedRewriteRecursion(); 489 } 490 491 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 492 PatternRewriter &rewriter) const final { 493 // Decrement the depth of the op in-place. 494 rewriter.updateRootInPlace(op, [&] { 495 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 496 }); 497 return success(); 498 } 499 }; 500 501 struct TestNestedOpCreationUndoRewrite 502 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 503 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 504 505 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 506 PatternRewriter &rewriter) const final { 507 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 508 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 509 return success(); 510 }; 511 }; 512 513 // This pattern matches `test.blackhole` and delete this op and its producer. 514 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 515 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 516 517 LogicalResult matchAndRewrite(BlackHoleOp op, 518 PatternRewriter &rewriter) const final { 519 Operation *producer = op.getOperand().getDefiningOp(); 520 // Always erase the user before the producer, the framework should handle 521 // this correctly. 522 rewriter.eraseOp(op); 523 rewriter.eraseOp(producer); 524 return success(); 525 }; 526 }; 527 } // namespace 528 529 namespace { 530 struct TestTypeConverter : public TypeConverter { 531 using TypeConverter::TypeConverter; 532 TestTypeConverter() { 533 addConversion(convertType); 534 addArgumentMaterialization(materializeCast); 535 addSourceMaterialization(materializeCast); 536 537 /// Materialize the cast for one-to-one conversion from i64 to f64. 538 const auto materializeOneToOneCast = 539 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 540 Location loc) -> Optional<Value> { 541 if (resultType.getWidth() == 42 && inputs.size() == 1) 542 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 543 return llvm::None; 544 }; 545 addArgumentMaterialization(materializeOneToOneCast); 546 } 547 548 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 549 // Drop I16 types. 550 if (t.isSignlessInteger(16)) 551 return success(); 552 553 // Convert I64 to F64. 554 if (t.isSignlessInteger(64)) { 555 results.push_back(FloatType::getF64(t.getContext())); 556 return success(); 557 } 558 559 // Convert I42 to I43. 560 if (t.isInteger(42)) { 561 results.push_back(IntegerType::get(t.getContext(), 43)); 562 return success(); 563 } 564 565 // Split F32 into F16,F16. 566 if (t.isF32()) { 567 results.assign(2, FloatType::getF16(t.getContext())); 568 return success(); 569 } 570 571 // Otherwise, convert the type directly. 572 results.push_back(t); 573 return success(); 574 } 575 576 /// Hook for materializing a conversion. This is necessary because we generate 577 /// 1->N type mappings. 578 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 579 ValueRange inputs, Location loc) { 580 if (inputs.size() == 1) 581 return inputs[0]; 582 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 583 } 584 }; 585 586 struct TestLegalizePatternDriver 587 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 588 /// The mode of conversion to use with the driver. 589 enum class ConversionMode { Analysis, Full, Partial }; 590 591 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 592 593 void runOnOperation() override { 594 TestTypeConverter converter; 595 mlir::RewritePatternSet patterns(&getContext()); 596 populateWithGenerated(patterns); 597 patterns 598 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 599 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 600 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 601 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 602 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 603 TestNonRootReplacement, TestBoundedRecursiveRewrite, 604 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( 605 &getContext()); 606 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 607 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 608 mlir::populateCallOpTypeConversionPattern(patterns, converter); 609 610 // Define the conversion target used for the test. 611 ConversionTarget target(getContext()); 612 target.addLegalOp<ModuleOp>(); 613 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 614 TerminatorOp>(); 615 target 616 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 617 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 618 // Don't allow F32 operands. 619 return llvm::none_of(op.getOperandTypes(), 620 [](Type type) { return type.isF32(); }); 621 }); 622 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 623 return converter.isSignatureLegal(op.getType()) && 624 converter.isLegal(&op.getBody()); 625 }); 626 627 // Expect the type_producer/type_consumer operations to only operate on f64. 628 target.addDynamicallyLegalOp<TestTypeProducerOp>( 629 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 630 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 631 return op.getOperand().getType().isF64(); 632 }); 633 634 // Check support for marking certain operations as recursively legal. 635 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 636 return static_cast<bool>( 637 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 638 }); 639 640 // Mark the bound recursion operation as dynamically legal. 641 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 642 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 643 644 // Handle a partial conversion. 645 if (mode == ConversionMode::Partial) { 646 DenseSet<Operation *> unlegalizedOps; 647 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 648 &unlegalizedOps); 649 // Emit remarks for each legalizable operation. 650 for (auto *op : unlegalizedOps) 651 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 652 return; 653 } 654 655 // Handle a full conversion. 656 if (mode == ConversionMode::Full) { 657 // Check support for marking unknown operations as dynamically legal. 658 target.markUnknownOpDynamicallyLegal([](Operation *op) { 659 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 660 }); 661 662 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 663 return; 664 } 665 666 // Otherwise, handle an analysis conversion. 667 assert(mode == ConversionMode::Analysis); 668 669 // Analyze the convertible operations. 670 DenseSet<Operation *> legalizedOps; 671 if (failed(applyAnalysisConversion(getOperation(), target, 672 std::move(patterns), legalizedOps))) 673 return signalPassFailure(); 674 675 // Emit remarks for each legalizable operation. 676 for (auto *op : legalizedOps) 677 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 678 } 679 680 /// The mode of conversion to use. 681 ConversionMode mode; 682 }; 683 } // end anonymous namespace 684 685 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 686 legalizerConversionMode( 687 "test-legalize-mode", 688 llvm::cl::desc("The legalization mode to use with the test driver"), 689 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 690 llvm::cl::values( 691 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 692 "analysis", "Perform an analysis conversion"), 693 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 694 "Perform a full conversion"), 695 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 696 "partial", "Perform a partial conversion"))); 697 698 //===----------------------------------------------------------------------===// 699 // ConversionPatternRewriter::getRemappedValue testing. This method is used 700 // to get the remapped value of an original value that was replaced using 701 // ConversionPatternRewriter. 702 namespace { 703 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 704 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 705 /// operand twice. 706 /// 707 /// Example: 708 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 709 /// is replaced with: 710 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 711 struct OneVResOneVOperandOp1Converter 712 : public OpConversionPattern<OneVResOneVOperandOp1> { 713 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 714 715 LogicalResult 716 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 717 ConversionPatternRewriter &rewriter) const override { 718 auto origOps = op.getOperands(); 719 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 720 "One operand expected"); 721 Value origOp = *origOps.begin(); 722 SmallVector<Value, 2> remappedOperands; 723 // Replicate the remapped original operand twice. Note that we don't used 724 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 725 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 726 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 727 728 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 729 remappedOperands); 730 return success(); 731 } 732 }; 733 734 struct TestRemappedValue 735 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 736 void runOnFunction() override { 737 mlir::RewritePatternSet patterns(&getContext()); 738 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 739 740 mlir::ConversionTarget target(getContext()); 741 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 742 // We make OneVResOneVOperandOp1 legal only when it has more that one 743 // operand. This will trigger the conversion that will replace one-operand 744 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 745 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 746 [](Operation *op) -> bool { 747 return std::distance(op->operand_begin(), op->operand_end()) > 1; 748 }); 749 750 if (failed(mlir::applyFullConversion(getFunction(), target, 751 std::move(patterns)))) { 752 signalPassFailure(); 753 } 754 } 755 }; 756 } // end anonymous namespace 757 758 //===----------------------------------------------------------------------===// 759 // Test patterns without a specific root operation kind 760 //===----------------------------------------------------------------------===// 761 762 namespace { 763 /// This pattern matches and removes any operation in the test dialect. 764 struct RemoveTestDialectOps : public RewritePattern { 765 RemoveTestDialectOps(MLIRContext *context) 766 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 767 768 LogicalResult matchAndRewrite(Operation *op, 769 PatternRewriter &rewriter) const override { 770 if (!isa<TestDialect>(op->getDialect())) 771 return failure(); 772 rewriter.eraseOp(op); 773 return success(); 774 } 775 }; 776 777 struct TestUnknownRootOpDriver 778 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 779 void runOnFunction() override { 780 mlir::RewritePatternSet patterns(&getContext()); 781 patterns.add<RemoveTestDialectOps>(&getContext()); 782 783 mlir::ConversionTarget target(getContext()); 784 target.addIllegalDialect<TestDialect>(); 785 if (failed( 786 applyPartialConversion(getFunction(), target, std::move(patterns)))) 787 signalPassFailure(); 788 } 789 }; 790 } // end anonymous namespace 791 792 //===----------------------------------------------------------------------===// 793 // Test type conversions 794 //===----------------------------------------------------------------------===// 795 796 namespace { 797 struct TestTypeConversionProducer 798 : public OpConversionPattern<TestTypeProducerOp> { 799 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 800 LogicalResult 801 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 802 ConversionPatternRewriter &rewriter) const final { 803 Type resultType = op.getType(); 804 if (resultType.isa<FloatType>()) 805 resultType = rewriter.getF64Type(); 806 else if (resultType.isInteger(16)) 807 resultType = rewriter.getIntegerType(64); 808 else 809 return failure(); 810 811 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 812 return success(); 813 } 814 }; 815 816 /// Call signature conversion and then fail the rewrite to trigger the undo 817 /// mechanism. 818 struct TestSignatureConversionUndo 819 : public OpConversionPattern<TestSignatureConversionUndoOp> { 820 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 821 822 LogicalResult 823 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 824 ConversionPatternRewriter &rewriter) const final { 825 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 826 return failure(); 827 } 828 }; 829 830 /// Just forward the operands to the root op. This is essentially a no-op 831 /// pattern that is used to trigger target materialization. 832 struct TestTypeConsumerForward 833 : public OpConversionPattern<TestTypeConsumerOp> { 834 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 835 836 LogicalResult 837 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 838 ConversionPatternRewriter &rewriter) const final { 839 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 840 return success(); 841 } 842 }; 843 844 struct TestTypeConversionAnotherProducer 845 : public OpRewritePattern<TestAnotherTypeProducerOp> { 846 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 847 848 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 849 PatternRewriter &rewriter) const final { 850 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 851 return success(); 852 } 853 }; 854 855 struct TestTypeConversionDriver 856 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 857 void getDependentDialects(DialectRegistry ®istry) const override { 858 registry.insert<TestDialect>(); 859 } 860 861 void runOnOperation() override { 862 // Initialize the type converter. 863 TypeConverter converter; 864 865 /// Add the legal set of type conversions. 866 converter.addConversion([](Type type) -> Type { 867 // Treat F64 as legal. 868 if (type.isF64()) 869 return type; 870 // Allow converting BF16/F16/F32 to F64. 871 if (type.isBF16() || type.isF16() || type.isF32()) 872 return FloatType::getF64(type.getContext()); 873 // Otherwise, the type is illegal. 874 return nullptr; 875 }); 876 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 877 // Drop all integer types. 878 return success(); 879 }); 880 881 /// Add the legal set of type materializations. 882 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 883 ValueRange inputs, 884 Location loc) -> Value { 885 // Allow casting from F64 back to F32. 886 if (!resultType.isF16() && inputs.size() == 1 && 887 inputs[0].getType().isF64()) 888 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 889 // Allow producing an i32 or i64 from nothing. 890 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 891 inputs.empty()) 892 return builder.create<TestTypeProducerOp>(loc, resultType); 893 // Allow producing an i64 from an integer. 894 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 895 inputs[0].getType().isa<IntegerType>()) 896 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 897 // Otherwise, fail. 898 return nullptr; 899 }); 900 901 // Initialize the conversion target. 902 mlir::ConversionTarget target(getContext()); 903 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 904 return op.getType().isF64() || op.getType().isInteger(64); 905 }); 906 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 907 return converter.isSignatureLegal(op.getType()) && 908 converter.isLegal(&op.getBody()); 909 }); 910 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 911 // Allow casts from F64 to F32. 912 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 913 }); 914 915 // Initialize the set of rewrite patterns. 916 RewritePatternSet patterns(&getContext()); 917 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 918 TestSignatureConversionUndo>(converter, &getContext()); 919 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 920 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 921 922 if (failed(applyPartialConversion(getOperation(), target, 923 std::move(patterns)))) 924 signalPassFailure(); 925 } 926 }; 927 } // end anonymous namespace 928 929 //===----------------------------------------------------------------------===// 930 // Test Block Merging 931 //===----------------------------------------------------------------------===// 932 933 namespace { 934 /// A rewriter pattern that tests that blocks can be merged. 935 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 936 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 937 938 LogicalResult 939 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 940 ConversionPatternRewriter &rewriter) const final { 941 Block &firstBlock = op.body().front(); 942 Operation *branchOp = firstBlock.getTerminator(); 943 Block *secondBlock = &*(std::next(op.body().begin())); 944 auto succOperands = branchOp->getOperands(); 945 SmallVector<Value, 2> replacements(succOperands); 946 rewriter.eraseOp(branchOp); 947 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 948 rewriter.updateRootInPlace(op, [] {}); 949 return success(); 950 } 951 }; 952 953 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 954 struct TestUndoBlocksMerge : public ConversionPattern { 955 TestUndoBlocksMerge(MLIRContext *ctx) 956 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 957 LogicalResult 958 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 959 ConversionPatternRewriter &rewriter) const final { 960 Block &firstBlock = op->getRegion(0).front(); 961 Operation *branchOp = firstBlock.getTerminator(); 962 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 963 rewriter.setInsertionPointToStart(secondBlock); 964 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 965 auto succOperands = branchOp->getOperands(); 966 SmallVector<Value, 2> replacements(succOperands); 967 rewriter.eraseOp(branchOp); 968 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 969 rewriter.updateRootInPlace(op, [] {}); 970 return success(); 971 } 972 }; 973 974 /// A rewrite mechanism to inline the body of the op into its parent, when both 975 /// ops can have a single block. 976 struct TestMergeSingleBlockOps 977 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 978 using OpConversionPattern< 979 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 980 981 LogicalResult 982 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 983 ConversionPatternRewriter &rewriter) const final { 984 SingleBlockImplicitTerminatorOp parentOp = 985 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 986 if (!parentOp) 987 return failure(); 988 Block &innerBlock = op.region().front(); 989 TerminatorOp innerTerminator = 990 cast<TerminatorOp>(innerBlock.getTerminator()); 991 rewriter.mergeBlockBefore(&innerBlock, op); 992 rewriter.eraseOp(innerTerminator); 993 rewriter.eraseOp(op); 994 rewriter.updateRootInPlace(op, [] {}); 995 return success(); 996 } 997 }; 998 999 struct TestMergeBlocksPatternDriver 1000 : public PassWrapper<TestMergeBlocksPatternDriver, 1001 OperationPass<ModuleOp>> { 1002 void runOnOperation() override { 1003 MLIRContext *context = &getContext(); 1004 mlir::RewritePatternSet patterns(context); 1005 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1006 context); 1007 ConversionTarget target(*context); 1008 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1009 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1010 target.addIllegalOp<ILLegalOpF>(); 1011 1012 /// Expect the op to have a single block after legalization. 1013 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1014 [&](TestMergeBlocksOp op) -> bool { 1015 return llvm::hasSingleElement(op.body()); 1016 }); 1017 1018 /// Only allow `test.br` within test.merge_blocks op. 1019 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1020 return op->getParentOfType<TestMergeBlocksOp>(); 1021 }); 1022 1023 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1024 /// inlined. 1025 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1026 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1027 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1028 }); 1029 1030 DenseSet<Operation *> unlegalizedOps; 1031 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1032 &unlegalizedOps); 1033 for (auto *op : unlegalizedOps) 1034 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1035 } 1036 }; 1037 } // namespace 1038 1039 //===----------------------------------------------------------------------===// 1040 // Test Selective Replacement 1041 //===----------------------------------------------------------------------===// 1042 1043 namespace { 1044 /// A rewrite mechanism to inline the body of the op into its parent, when both 1045 /// ops can have a single block. 1046 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1047 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1048 1049 LogicalResult matchAndRewrite(TestCastOp op, 1050 PatternRewriter &rewriter) const final { 1051 if (op.getNumOperands() != 2) 1052 return failure(); 1053 OperandRange operands = op.getOperands(); 1054 1055 // Replace non-terminator uses with the first operand. 1056 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1057 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1058 }); 1059 // Replace everything else with the second operand if the operation isn't 1060 // dead. 1061 rewriter.replaceOp(op, op.getOperand(1)); 1062 return success(); 1063 } 1064 }; 1065 1066 struct TestSelectiveReplacementPatternDriver 1067 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1068 OperationPass<>> { 1069 void runOnOperation() override { 1070 MLIRContext *context = &getContext(); 1071 mlir::RewritePatternSet patterns(context); 1072 patterns.add<TestSelectiveOpReplacementPattern>(context); 1073 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1074 std::move(patterns)); 1075 } 1076 }; 1077 } // namespace 1078 1079 //===----------------------------------------------------------------------===// 1080 // PassRegistration 1081 //===----------------------------------------------------------------------===// 1082 1083 namespace mlir { 1084 namespace test { 1085 void registerPatternsTestPass() { 1086 PassRegistration<TestReturnTypeDriver>("test-return-type", 1087 "Run return type functions"); 1088 1089 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 1090 "Run test derived attributes"); 1091 1092 PassRegistration<TestPatternDriver>("test-patterns", 1093 "Run test dialect patterns"); 1094 1095 PassRegistration<TestLegalizePatternDriver>( 1096 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 1097 return std::make_unique<TestLegalizePatternDriver>( 1098 legalizerConversionMode); 1099 }); 1100 1101 PassRegistration<TestRemappedValue>( 1102 "test-remapped-value", 1103 "Test public remapped value mechanism in ConversionPatternRewriter"); 1104 1105 PassRegistration<TestUnknownRootOpDriver>( 1106 "test-legalize-unknown-root-patterns", 1107 "Test public remapped value mechanism in ConversionPatternRewriter"); 1108 1109 PassRegistration<TestTypeConversionDriver>( 1110 "test-legalize-type-conversion", 1111 "Test various type conversion functionalities in DialectConversion"); 1112 1113 PassRegistration<TestMergeBlocksPatternDriver>{ 1114 "test-merge-blocks", 1115 "Test Merging operation in ConversionPatternRewriter"}; 1116 PassRegistration<TestSelectiveReplacementPatternDriver>{ 1117 "test-pattern-selective-replacement", 1118 "Test selective replacement in the PatternRewriter"}; 1119 } 1120 } // namespace test 1121 } // namespace mlir 1122