1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 static bool getFirstI32Result(Operation *op, Value &value) { 39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 40 return false; 41 value = op->getResult(0); 42 return true; 43 } 44 45 static Value bindNativeCodeCallResult(Value value) { return value; } 46 47 // Test that natives calls are only called once during rewrites. 48 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 49 // This let us check the number of times OpM_Test was called by inspecting 50 // the returned value in the MLIR output. 51 static int64_t opMIncreasingValue = 314159265; 52 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 53 int64_t i = opMIncreasingValue++; 54 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 55 } 56 57 namespace { 58 #include "TestPatterns.inc" 59 } // end anonymous namespace 60 61 //===----------------------------------------------------------------------===// 62 // Canonicalizer Driver. 63 //===----------------------------------------------------------------------===// 64 65 namespace { 66 struct FoldingPattern : public RewritePattern { 67 public: 68 FoldingPattern(MLIRContext *context) 69 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 70 /*benefit=*/1, context) {} 71 72 LogicalResult matchAndRewrite(Operation *op, 73 PatternRewriter &rewriter) const override { 74 // Exercise OperationFolder API for a single-result operation that is folded 75 // upon construction. The operation being created through the folder has an 76 // in-place folder, and it should be still present in the output. 77 // Furthermore, the folder should not crash when attempting to recover the 78 // (unchanged) operation result. 79 OperationFolder folder(op->getContext()); 80 Value result = folder.create<TestOpInPlaceFold>( 81 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 82 rewriter.getI32IntegerAttr(0)); 83 assert(result); 84 rewriter.replaceOp(op, result); 85 return success(); 86 } 87 }; 88 89 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 90 void runOnFunction() override { 91 mlir::RewritePatternSet patterns(&getContext()); 92 populateWithGenerated(patterns); 93 94 // Verify named pattern is generated with expected name. 95 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext()); 96 97 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 98 } 99 }; 100 } // end anonymous namespace 101 102 //===----------------------------------------------------------------------===// 103 // ReturnType Driver. 104 //===----------------------------------------------------------------------===// 105 106 namespace { 107 // Generate ops for each instance where the type can be successfully inferred. 108 template <typename OpTy> 109 static void invokeCreateWithInferredReturnType(Operation *op) { 110 auto *context = op->getContext(); 111 auto fop = op->getParentOfType<FuncOp>(); 112 auto location = UnknownLoc::get(context); 113 OpBuilder b(op); 114 b.setInsertionPointAfter(op); 115 116 // Use permutations of 2 args as operands. 117 assert(fop.getNumArguments() >= 2); 118 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 119 for (int j = 0; j < e; ++j) { 120 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 121 SmallVector<Type, 2> inferredReturnTypes; 122 if (succeeded(OpTy::inferReturnTypes( 123 context, llvm::None, values, op->getAttrDictionary(), 124 op->getRegions(), inferredReturnTypes))) { 125 OperationState state(location, OpTy::getOperationName()); 126 // TODO: Expand to regions. 127 OpTy::build(b, state, values, op->getAttrs()); 128 (void)b.createOperation(state); 129 } 130 } 131 } 132 } 133 134 static void reifyReturnShape(Operation *op) { 135 OpBuilder b(op); 136 137 // Use permutations of 2 args as operands. 138 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 139 SmallVector<Value, 2> shapes; 140 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 141 !llvm::hasSingleElement(shapes)) 142 return; 143 for (auto it : llvm::enumerate(shapes)) { 144 op->emitRemark() << "value " << it.index() << ": " 145 << it.value().getDefiningOp(); 146 } 147 } 148 149 struct TestReturnTypeDriver 150 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 151 void getDependentDialects(DialectRegistry ®istry) const override { 152 registry.insert<memref::MemRefDialect>(); 153 } 154 155 void runOnFunction() override { 156 if (getFunction().getName() == "testCreateFunctions") { 157 std::vector<Operation *> ops; 158 // Collect ops to avoid triggering on inserted ops. 159 for (auto &op : getFunction().getBody().front()) 160 ops.push_back(&op); 161 // Generate test patterns for each, but skip terminator. 162 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 163 // Test create method of each of the Op classes below. The resultant 164 // output would be in reverse order underneath `op` from which 165 // the attributes and regions are used. 166 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 167 invokeCreateWithInferredReturnType< 168 OpWithShapedTypeInferTypeInterfaceOp>(op); 169 }; 170 return; 171 } 172 if (getFunction().getName() == "testReifyFunctions") { 173 std::vector<Operation *> ops; 174 // Collect ops to avoid triggering on inserted ops. 175 for (auto &op : getFunction().getBody().front()) 176 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 177 ops.push_back(&op); 178 // Generate test patterns for each, but skip terminator. 179 for (auto *op : ops) 180 reifyReturnShape(op); 181 } 182 } 183 }; 184 } // end anonymous namespace 185 186 namespace { 187 struct TestDerivedAttributeDriver 188 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 189 void runOnFunction() override; 190 }; 191 } // end anonymous namespace 192 193 void TestDerivedAttributeDriver::runOnFunction() { 194 getFunction().walk([](DerivedAttributeOpInterface dOp) { 195 auto dAttr = dOp.materializeDerivedAttributes(); 196 if (!dAttr) 197 return; 198 for (auto d : dAttr) 199 dOp.emitRemark() << d.first << " = " << d.second; 200 }); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // Legalization Driver. 205 //===----------------------------------------------------------------------===// 206 207 namespace { 208 //===----------------------------------------------------------------------===// 209 // Region-Block Rewrite Testing 210 211 /// This pattern is a simple pattern that inlines the first region of a given 212 /// operation into the parent region. 213 struct TestRegionRewriteBlockMovement : public ConversionPattern { 214 TestRegionRewriteBlockMovement(MLIRContext *ctx) 215 : ConversionPattern("test.region", 1, ctx) {} 216 217 LogicalResult 218 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 219 ConversionPatternRewriter &rewriter) const final { 220 // Inline this region into the parent region. 221 auto &parentRegion = *op->getParentRegion(); 222 auto &opRegion = op->getRegion(0); 223 if (op->getAttr("legalizer.should_clone")) 224 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 225 else 226 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 227 228 if (op->getAttr("legalizer.erase_old_blocks")) { 229 while (!opRegion.empty()) 230 rewriter.eraseBlock(&opRegion.front()); 231 } 232 233 // Drop this operation. 234 rewriter.eraseOp(op); 235 return success(); 236 } 237 }; 238 /// This pattern is a simple pattern that generates a region containing an 239 /// illegal operation. 240 struct TestRegionRewriteUndo : public RewritePattern { 241 TestRegionRewriteUndo(MLIRContext *ctx) 242 : RewritePattern("test.region_builder", 1, ctx) {} 243 244 LogicalResult matchAndRewrite(Operation *op, 245 PatternRewriter &rewriter) const final { 246 // Create the region operation with an entry block containing arguments. 247 OperationState newRegion(op->getLoc(), "test.region"); 248 newRegion.addRegion(); 249 auto *regionOp = rewriter.createOperation(newRegion); 250 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 251 entryBlock->addArgument(rewriter.getIntegerType(64)); 252 253 // Add an explicitly illegal operation to ensure the conversion fails. 254 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 255 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 256 257 // Drop this operation. 258 rewriter.eraseOp(op); 259 return success(); 260 } 261 }; 262 /// A simple pattern that creates a block at the end of the parent region of the 263 /// matched operation. 264 struct TestCreateBlock : public RewritePattern { 265 TestCreateBlock(MLIRContext *ctx) 266 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 267 268 LogicalResult matchAndRewrite(Operation *op, 269 PatternRewriter &rewriter) const final { 270 Region ®ion = *op->getParentRegion(); 271 Type i32Type = rewriter.getIntegerType(32); 272 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 273 rewriter.create<TerminatorOp>(op->getLoc()); 274 rewriter.replaceOp(op, {}); 275 return success(); 276 } 277 }; 278 279 /// A simple pattern that creates a block containing an invalid operation in 280 /// order to trigger the block creation undo mechanism. 281 struct TestCreateIllegalBlock : public RewritePattern { 282 TestCreateIllegalBlock(MLIRContext *ctx) 283 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 284 285 LogicalResult matchAndRewrite(Operation *op, 286 PatternRewriter &rewriter) const final { 287 Region ®ion = *op->getParentRegion(); 288 Type i32Type = rewriter.getIntegerType(32); 289 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 290 // Create an illegal op to ensure the conversion fails. 291 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 292 rewriter.create<TerminatorOp>(op->getLoc()); 293 rewriter.replaceOp(op, {}); 294 return success(); 295 } 296 }; 297 298 /// A simple pattern that tests the undo mechanism when replacing the uses of a 299 /// block argument. 300 struct TestUndoBlockArgReplace : public ConversionPattern { 301 TestUndoBlockArgReplace(MLIRContext *ctx) 302 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 303 304 LogicalResult 305 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 306 ConversionPatternRewriter &rewriter) const final { 307 auto illegalOp = 308 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 309 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 310 illegalOp); 311 rewriter.updateRootInPlace(op, [] {}); 312 return success(); 313 } 314 }; 315 316 /// A rewrite pattern that tests the undo mechanism when erasing a block. 317 struct TestUndoBlockErase : public ConversionPattern { 318 TestUndoBlockErase(MLIRContext *ctx) 319 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 320 321 LogicalResult 322 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 323 ConversionPatternRewriter &rewriter) const final { 324 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 325 rewriter.setInsertionPointToStart(secondBlock); 326 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 327 rewriter.eraseBlock(secondBlock); 328 rewriter.updateRootInPlace(op, [] {}); 329 return success(); 330 } 331 }; 332 333 //===----------------------------------------------------------------------===// 334 // Type-Conversion Rewrite Testing 335 336 /// This patterns erases a region operation that has had a type conversion. 337 struct TestDropOpSignatureConversion : public ConversionPattern { 338 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 339 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 340 LogicalResult 341 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 342 ConversionPatternRewriter &rewriter) const override { 343 Region ®ion = op->getRegion(0); 344 Block *entry = ®ion.front(); 345 346 // Convert the original entry arguments. 347 TypeConverter &converter = *getTypeConverter(); 348 TypeConverter::SignatureConversion result(entry->getNumArguments()); 349 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 350 result)) || 351 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 352 return failure(); 353 354 // Convert the region signature and just drop the operation. 355 rewriter.eraseOp(op); 356 return success(); 357 } 358 }; 359 /// This pattern simply updates the operands of the given operation. 360 struct TestPassthroughInvalidOp : public ConversionPattern { 361 TestPassthroughInvalidOp(MLIRContext *ctx) 362 : ConversionPattern("test.invalid", 1, ctx) {} 363 LogicalResult 364 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 365 ConversionPatternRewriter &rewriter) const final { 366 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 367 llvm::None); 368 return success(); 369 } 370 }; 371 /// This pattern handles the case of a split return value. 372 struct TestSplitReturnType : public ConversionPattern { 373 TestSplitReturnType(MLIRContext *ctx) 374 : ConversionPattern("test.return", 1, ctx) {} 375 LogicalResult 376 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 377 ConversionPatternRewriter &rewriter) const final { 378 // Check for a return of F32. 379 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 380 return failure(); 381 382 // Check if the first operation is a cast operation, if it is we use the 383 // results directly. 384 auto *defOp = operands[0].getDefiningOp(); 385 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 386 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 387 return success(); 388 } 389 390 // Otherwise, fail to match. 391 return failure(); 392 } 393 }; 394 395 //===----------------------------------------------------------------------===// 396 // Multi-Level Type-Conversion Rewrite Testing 397 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 398 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 399 : ConversionPattern("test.type_producer", 1, ctx) {} 400 LogicalResult 401 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 402 ConversionPatternRewriter &rewriter) const final { 403 // If the type is I32, change the type to F32. 404 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 405 return failure(); 406 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 407 return success(); 408 } 409 }; 410 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 411 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 412 : ConversionPattern("test.type_producer", 1, ctx) {} 413 LogicalResult 414 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 415 ConversionPatternRewriter &rewriter) const final { 416 // If the type is F32, change the type to F64. 417 if (!Type(*op->result_type_begin()).isF32()) 418 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 419 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 420 return success(); 421 } 422 }; 423 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 424 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 425 : ConversionPattern("test.type_producer", 10, ctx) {} 426 LogicalResult 427 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 428 ConversionPatternRewriter &rewriter) const final { 429 // Always convert to B16, even though it is not a legal type. This tests 430 // that values are unmapped correctly. 431 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 432 return success(); 433 } 434 }; 435 struct TestUpdateConsumerType : public ConversionPattern { 436 TestUpdateConsumerType(MLIRContext *ctx) 437 : ConversionPattern("test.type_consumer", 1, ctx) {} 438 LogicalResult 439 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 440 ConversionPatternRewriter &rewriter) const final { 441 // Verify that the incoming operand has been successfully remapped to F64. 442 if (!operands[0].getType().isF64()) 443 return failure(); 444 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 445 return success(); 446 } 447 }; 448 449 //===----------------------------------------------------------------------===// 450 // Non-Root Replacement Rewrite Testing 451 /// This pattern generates an invalid operation, but replaces it before the 452 /// pattern is finished. This checks that we don't need to legalize the 453 /// temporary op. 454 struct TestNonRootReplacement : public RewritePattern { 455 TestNonRootReplacement(MLIRContext *ctx) 456 : RewritePattern("test.replace_non_root", 1, ctx) {} 457 458 LogicalResult matchAndRewrite(Operation *op, 459 PatternRewriter &rewriter) const final { 460 auto resultType = *op->result_type_begin(); 461 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 462 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 463 464 rewriter.replaceOp(illegalOp, {legalOp}); 465 rewriter.replaceOp(op, {illegalOp}); 466 return success(); 467 } 468 }; 469 470 //===----------------------------------------------------------------------===// 471 // Recursive Rewrite Testing 472 /// This pattern is applied to the same operation multiple times, but has a 473 /// bounded recursion. 474 struct TestBoundedRecursiveRewrite 475 : public OpRewritePattern<TestRecursiveRewriteOp> { 476 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 477 478 void initialize() { 479 // The conversion target handles bounding the recursion of this pattern. 480 setHasBoundedRewriteRecursion(); 481 } 482 483 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 484 PatternRewriter &rewriter) const final { 485 // Decrement the depth of the op in-place. 486 rewriter.updateRootInPlace(op, [&] { 487 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 488 }); 489 return success(); 490 } 491 }; 492 493 struct TestNestedOpCreationUndoRewrite 494 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 495 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 496 497 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 498 PatternRewriter &rewriter) const final { 499 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 500 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 501 return success(); 502 }; 503 }; 504 505 // This pattern matches `test.blackhole` and delete this op and its producer. 506 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 507 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 508 509 LogicalResult matchAndRewrite(BlackHoleOp op, 510 PatternRewriter &rewriter) const final { 511 Operation *producer = op.getOperand().getDefiningOp(); 512 // Always erase the user before the producer, the framework should handle 513 // this correctly. 514 rewriter.eraseOp(op); 515 rewriter.eraseOp(producer); 516 return success(); 517 }; 518 }; 519 } // namespace 520 521 namespace { 522 struct TestTypeConverter : public TypeConverter { 523 using TypeConverter::TypeConverter; 524 TestTypeConverter() { 525 addConversion(convertType); 526 addArgumentMaterialization(materializeCast); 527 addSourceMaterialization(materializeCast); 528 529 /// Materialize the cast for one-to-one conversion from i64 to f64. 530 const auto materializeOneToOneCast = 531 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 532 Location loc) -> Optional<Value> { 533 if (resultType.getWidth() == 42 && inputs.size() == 1) 534 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 535 return llvm::None; 536 }; 537 addArgumentMaterialization(materializeOneToOneCast); 538 } 539 540 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 541 // Drop I16 types. 542 if (t.isSignlessInteger(16)) 543 return success(); 544 545 // Convert I64 to F64. 546 if (t.isSignlessInteger(64)) { 547 results.push_back(FloatType::getF64(t.getContext())); 548 return success(); 549 } 550 551 // Convert I42 to I43. 552 if (t.isInteger(42)) { 553 results.push_back(IntegerType::get(t.getContext(), 43)); 554 return success(); 555 } 556 557 // Split F32 into F16,F16. 558 if (t.isF32()) { 559 results.assign(2, FloatType::getF16(t.getContext())); 560 return success(); 561 } 562 563 // Otherwise, convert the type directly. 564 results.push_back(t); 565 return success(); 566 } 567 568 /// Hook for materializing a conversion. This is necessary because we generate 569 /// 1->N type mappings. 570 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 571 ValueRange inputs, Location loc) { 572 if (inputs.size() == 1) 573 return inputs[0]; 574 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 575 } 576 }; 577 578 struct TestLegalizePatternDriver 579 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 580 /// The mode of conversion to use with the driver. 581 enum class ConversionMode { Analysis, Full, Partial }; 582 583 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 584 585 void runOnOperation() override { 586 TestTypeConverter converter; 587 mlir::RewritePatternSet patterns(&getContext()); 588 populateWithGenerated(patterns); 589 patterns 590 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 591 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 592 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 593 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 594 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 595 TestNonRootReplacement, TestBoundedRecursiveRewrite, 596 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( 597 &getContext()); 598 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 599 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 600 mlir::populateCallOpTypeConversionPattern(patterns, converter); 601 602 // Define the conversion target used for the test. 603 ConversionTarget target(getContext()); 604 target.addLegalOp<ModuleOp>(); 605 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 606 TerminatorOp>(); 607 target 608 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 609 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 610 // Don't allow F32 operands. 611 return llvm::none_of(op.getOperandTypes(), 612 [](Type type) { return type.isF32(); }); 613 }); 614 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 615 return converter.isSignatureLegal(op.getType()) && 616 converter.isLegal(&op.getBody()); 617 }); 618 619 // Expect the type_producer/type_consumer operations to only operate on f64. 620 target.addDynamicallyLegalOp<TestTypeProducerOp>( 621 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 622 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 623 return op.getOperand().getType().isF64(); 624 }); 625 626 // Check support for marking certain operations as recursively legal. 627 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 628 return static_cast<bool>( 629 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 630 }); 631 632 // Mark the bound recursion operation as dynamically legal. 633 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 634 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 635 636 // Handle a partial conversion. 637 if (mode == ConversionMode::Partial) { 638 DenseSet<Operation *> unlegalizedOps; 639 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 640 &unlegalizedOps); 641 // Emit remarks for each legalizable operation. 642 for (auto *op : unlegalizedOps) 643 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 644 return; 645 } 646 647 // Handle a full conversion. 648 if (mode == ConversionMode::Full) { 649 // Check support for marking unknown operations as dynamically legal. 650 target.markUnknownOpDynamicallyLegal([](Operation *op) { 651 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 652 }); 653 654 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 655 return; 656 } 657 658 // Otherwise, handle an analysis conversion. 659 assert(mode == ConversionMode::Analysis); 660 661 // Analyze the convertible operations. 662 DenseSet<Operation *> legalizedOps; 663 if (failed(applyAnalysisConversion(getOperation(), target, 664 std::move(patterns), legalizedOps))) 665 return signalPassFailure(); 666 667 // Emit remarks for each legalizable operation. 668 for (auto *op : legalizedOps) 669 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 670 } 671 672 /// The mode of conversion to use. 673 ConversionMode mode; 674 }; 675 } // end anonymous namespace 676 677 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 678 legalizerConversionMode( 679 "test-legalize-mode", 680 llvm::cl::desc("The legalization mode to use with the test driver"), 681 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 682 llvm::cl::values( 683 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 684 "analysis", "Perform an analysis conversion"), 685 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 686 "Perform a full conversion"), 687 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 688 "partial", "Perform a partial conversion"))); 689 690 //===----------------------------------------------------------------------===// 691 // ConversionPatternRewriter::getRemappedValue testing. This method is used 692 // to get the remapped value of an original value that was replaced using 693 // ConversionPatternRewriter. 694 namespace { 695 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 696 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 697 /// operand twice. 698 /// 699 /// Example: 700 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 701 /// is replaced with: 702 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 703 struct OneVResOneVOperandOp1Converter 704 : public OpConversionPattern<OneVResOneVOperandOp1> { 705 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 706 707 LogicalResult 708 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 709 ConversionPatternRewriter &rewriter) const override { 710 auto origOps = op.getOperands(); 711 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 712 "One operand expected"); 713 Value origOp = *origOps.begin(); 714 SmallVector<Value, 2> remappedOperands; 715 // Replicate the remapped original operand twice. Note that we don't used 716 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 717 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 718 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 719 720 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 721 remappedOperands); 722 return success(); 723 } 724 }; 725 726 struct TestRemappedValue 727 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 728 void runOnFunction() override { 729 mlir::RewritePatternSet patterns(&getContext()); 730 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 731 732 mlir::ConversionTarget target(getContext()); 733 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 734 // We make OneVResOneVOperandOp1 legal only when it has more that one 735 // operand. This will trigger the conversion that will replace one-operand 736 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 737 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 738 [](Operation *op) -> bool { 739 return std::distance(op->operand_begin(), op->operand_end()) > 1; 740 }); 741 742 if (failed(mlir::applyFullConversion(getFunction(), target, 743 std::move(patterns)))) { 744 signalPassFailure(); 745 } 746 } 747 }; 748 } // end anonymous namespace 749 750 //===----------------------------------------------------------------------===// 751 // Test patterns without a specific root operation kind 752 //===----------------------------------------------------------------------===// 753 754 namespace { 755 /// This pattern matches and removes any operation in the test dialect. 756 struct RemoveTestDialectOps : public RewritePattern { 757 RemoveTestDialectOps(MLIRContext *context) 758 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 759 760 LogicalResult matchAndRewrite(Operation *op, 761 PatternRewriter &rewriter) const override { 762 if (!isa<TestDialect>(op->getDialect())) 763 return failure(); 764 rewriter.eraseOp(op); 765 return success(); 766 } 767 }; 768 769 struct TestUnknownRootOpDriver 770 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 771 void runOnFunction() override { 772 mlir::RewritePatternSet patterns(&getContext()); 773 patterns.add<RemoveTestDialectOps>(&getContext()); 774 775 mlir::ConversionTarget target(getContext()); 776 target.addIllegalDialect<TestDialect>(); 777 if (failed( 778 applyPartialConversion(getFunction(), target, std::move(patterns)))) 779 signalPassFailure(); 780 } 781 }; 782 } // end anonymous namespace 783 784 //===----------------------------------------------------------------------===// 785 // Test type conversions 786 //===----------------------------------------------------------------------===// 787 788 namespace { 789 struct TestTypeConversionProducer 790 : public OpConversionPattern<TestTypeProducerOp> { 791 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 792 LogicalResult 793 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 794 ConversionPatternRewriter &rewriter) const final { 795 Type resultType = op.getType(); 796 if (resultType.isa<FloatType>()) 797 resultType = rewriter.getF64Type(); 798 else if (resultType.isInteger(16)) 799 resultType = rewriter.getIntegerType(64); 800 else 801 return failure(); 802 803 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 804 return success(); 805 } 806 }; 807 808 /// Call signature conversion and then fail the rewrite to trigger the undo 809 /// mechanism. 810 struct TestSignatureConversionUndo 811 : public OpConversionPattern<TestSignatureConversionUndoOp> { 812 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 813 814 LogicalResult 815 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 816 ConversionPatternRewriter &rewriter) const final { 817 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 818 return failure(); 819 } 820 }; 821 822 /// Just forward the operands to the root op. This is essentially a no-op 823 /// pattern that is used to trigger target materialization. 824 struct TestTypeConsumerForward 825 : public OpConversionPattern<TestTypeConsumerOp> { 826 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 827 828 LogicalResult 829 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 830 ConversionPatternRewriter &rewriter) const final { 831 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 832 return success(); 833 } 834 }; 835 836 struct TestTypeConversionAnotherProducer 837 : public OpRewritePattern<TestAnotherTypeProducerOp> { 838 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 839 840 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 841 PatternRewriter &rewriter) const final { 842 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 843 return success(); 844 } 845 }; 846 847 struct TestTypeConversionDriver 848 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 849 void getDependentDialects(DialectRegistry ®istry) const override { 850 registry.insert<TestDialect>(); 851 } 852 853 void runOnOperation() override { 854 // Initialize the type converter. 855 TypeConverter converter; 856 857 /// Add the legal set of type conversions. 858 converter.addConversion([](Type type) -> Type { 859 // Treat F64 as legal. 860 if (type.isF64()) 861 return type; 862 // Allow converting BF16/F16/F32 to F64. 863 if (type.isBF16() || type.isF16() || type.isF32()) 864 return FloatType::getF64(type.getContext()); 865 // Otherwise, the type is illegal. 866 return nullptr; 867 }); 868 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 869 // Drop all integer types. 870 return success(); 871 }); 872 873 /// Add the legal set of type materializations. 874 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 875 ValueRange inputs, 876 Location loc) -> Value { 877 // Allow casting from F64 back to F32. 878 if (!resultType.isF16() && inputs.size() == 1 && 879 inputs[0].getType().isF64()) 880 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 881 // Allow producing an i32 or i64 from nothing. 882 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 883 inputs.empty()) 884 return builder.create<TestTypeProducerOp>(loc, resultType); 885 // Allow producing an i64 from an integer. 886 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 887 inputs[0].getType().isa<IntegerType>()) 888 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 889 // Otherwise, fail. 890 return nullptr; 891 }); 892 893 // Initialize the conversion target. 894 mlir::ConversionTarget target(getContext()); 895 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 896 return op.getType().isF64() || op.getType().isInteger(64); 897 }); 898 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 899 return converter.isSignatureLegal(op.getType()) && 900 converter.isLegal(&op.getBody()); 901 }); 902 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 903 // Allow casts from F64 to F32. 904 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 905 }); 906 907 // Initialize the set of rewrite patterns. 908 RewritePatternSet patterns(&getContext()); 909 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 910 TestSignatureConversionUndo>(converter, &getContext()); 911 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 912 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 913 914 if (failed(applyPartialConversion(getOperation(), target, 915 std::move(patterns)))) 916 signalPassFailure(); 917 } 918 }; 919 } // end anonymous namespace 920 921 //===----------------------------------------------------------------------===// 922 // Test Block Merging 923 //===----------------------------------------------------------------------===// 924 925 namespace { 926 /// A rewriter pattern that tests that blocks can be merged. 927 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 928 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 929 930 LogicalResult 931 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 932 ConversionPatternRewriter &rewriter) const final { 933 Block &firstBlock = op.body().front(); 934 Operation *branchOp = firstBlock.getTerminator(); 935 Block *secondBlock = &*(std::next(op.body().begin())); 936 auto succOperands = branchOp->getOperands(); 937 SmallVector<Value, 2> replacements(succOperands); 938 rewriter.eraseOp(branchOp); 939 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 940 rewriter.updateRootInPlace(op, [] {}); 941 return success(); 942 } 943 }; 944 945 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 946 struct TestUndoBlocksMerge : public ConversionPattern { 947 TestUndoBlocksMerge(MLIRContext *ctx) 948 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 949 LogicalResult 950 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 951 ConversionPatternRewriter &rewriter) const final { 952 Block &firstBlock = op->getRegion(0).front(); 953 Operation *branchOp = firstBlock.getTerminator(); 954 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 955 rewriter.setInsertionPointToStart(secondBlock); 956 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 957 auto succOperands = branchOp->getOperands(); 958 SmallVector<Value, 2> replacements(succOperands); 959 rewriter.eraseOp(branchOp); 960 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 961 rewriter.updateRootInPlace(op, [] {}); 962 return success(); 963 } 964 }; 965 966 /// A rewrite mechanism to inline the body of the op into its parent, when both 967 /// ops can have a single block. 968 struct TestMergeSingleBlockOps 969 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 970 using OpConversionPattern< 971 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 972 973 LogicalResult 974 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 975 ConversionPatternRewriter &rewriter) const final { 976 SingleBlockImplicitTerminatorOp parentOp = 977 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 978 if (!parentOp) 979 return failure(); 980 Block &innerBlock = op.region().front(); 981 TerminatorOp innerTerminator = 982 cast<TerminatorOp>(innerBlock.getTerminator()); 983 rewriter.mergeBlockBefore(&innerBlock, op); 984 rewriter.eraseOp(innerTerminator); 985 rewriter.eraseOp(op); 986 rewriter.updateRootInPlace(op, [] {}); 987 return success(); 988 } 989 }; 990 991 struct TestMergeBlocksPatternDriver 992 : public PassWrapper<TestMergeBlocksPatternDriver, 993 OperationPass<ModuleOp>> { 994 void runOnOperation() override { 995 MLIRContext *context = &getContext(); 996 mlir::RewritePatternSet patterns(context); 997 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 998 context); 999 ConversionTarget target(*context); 1000 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1001 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1002 target.addIllegalOp<ILLegalOpF>(); 1003 1004 /// Expect the op to have a single block after legalization. 1005 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1006 [&](TestMergeBlocksOp op) -> bool { 1007 return llvm::hasSingleElement(op.body()); 1008 }); 1009 1010 /// Only allow `test.br` within test.merge_blocks op. 1011 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1012 return op->getParentOfType<TestMergeBlocksOp>(); 1013 }); 1014 1015 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1016 /// inlined. 1017 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1018 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1019 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1020 }); 1021 1022 DenseSet<Operation *> unlegalizedOps; 1023 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1024 &unlegalizedOps); 1025 for (auto *op : unlegalizedOps) 1026 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1027 } 1028 }; 1029 } // namespace 1030 1031 //===----------------------------------------------------------------------===// 1032 // Test Selective Replacement 1033 //===----------------------------------------------------------------------===// 1034 1035 namespace { 1036 /// A rewrite mechanism to inline the body of the op into its parent, when both 1037 /// ops can have a single block. 1038 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1039 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1040 1041 LogicalResult matchAndRewrite(TestCastOp op, 1042 PatternRewriter &rewriter) const final { 1043 if (op.getNumOperands() != 2) 1044 return failure(); 1045 OperandRange operands = op.getOperands(); 1046 1047 // Replace non-terminator uses with the first operand. 1048 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1049 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1050 }); 1051 // Replace everything else with the second operand if the operation isn't 1052 // dead. 1053 rewriter.replaceOp(op, op.getOperand(1)); 1054 return success(); 1055 } 1056 }; 1057 1058 struct TestSelectiveReplacementPatternDriver 1059 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1060 OperationPass<>> { 1061 void runOnOperation() override { 1062 MLIRContext *context = &getContext(); 1063 mlir::RewritePatternSet patterns(context); 1064 patterns.add<TestSelectiveOpReplacementPattern>(context); 1065 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1066 std::move(patterns)); 1067 } 1068 }; 1069 } // namespace 1070 1071 //===----------------------------------------------------------------------===// 1072 // PassRegistration 1073 //===----------------------------------------------------------------------===// 1074 1075 namespace mlir { 1076 namespace test { 1077 void registerPatternsTestPass() { 1078 PassRegistration<TestReturnTypeDriver>("test-return-type", 1079 "Run return type functions"); 1080 1081 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 1082 "Run test derived attributes"); 1083 1084 PassRegistration<TestPatternDriver>("test-patterns", 1085 "Run test dialect patterns"); 1086 1087 PassRegistration<TestLegalizePatternDriver>( 1088 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 1089 return std::make_unique<TestLegalizePatternDriver>( 1090 legalizerConversionMode); 1091 }); 1092 1093 PassRegistration<TestRemappedValue>( 1094 "test-remapped-value", 1095 "Test public remapped value mechanism in ConversionPatternRewriter"); 1096 1097 PassRegistration<TestUnknownRootOpDriver>( 1098 "test-legalize-unknown-root-patterns", 1099 "Test public remapped value mechanism in ConversionPatternRewriter"); 1100 1101 PassRegistration<TestTypeConversionDriver>( 1102 "test-legalize-type-conversion", 1103 "Test various type conversion functionalities in DialectConversion"); 1104 1105 PassRegistration<TestMergeBlocksPatternDriver>{ 1106 "test-merge-blocks", 1107 "Test Merging operation in ConversionPatternRewriter"}; 1108 PassRegistration<TestSelectiveReplacementPatternDriver>{ 1109 "test-pattern-selective-replacement", 1110 "Test selective replacement in the PatternRewriter"}; 1111 } 1112 } // namespace test 1113 } // namespace mlir 1114