1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 static bool getFirstI32Result(Operation *op, Value &value) { 39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 40 return false; 41 value = op->getResult(0); 42 return true; 43 } 44 45 static Value bindNativeCodeCallResult(Value value) { return value; } 46 47 // Test that natives calls are only called once during rewrites. 48 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 49 // This let us check the number of times OpM_Test was called by inspecting 50 // the returned value in the MLIR output. 51 static int64_t opMIncreasingValue = 314159265; 52 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 53 int64_t i = opMIncreasingValue++; 54 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 55 } 56 57 namespace { 58 #include "TestPatterns.inc" 59 } // end anonymous namespace 60 61 //===----------------------------------------------------------------------===// 62 // Canonicalizer Driver. 63 //===----------------------------------------------------------------------===// 64 65 namespace { 66 struct FoldingPattern : public RewritePattern { 67 public: 68 FoldingPattern(MLIRContext *context) 69 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 70 /*benefit=*/1, context) {} 71 72 LogicalResult matchAndRewrite(Operation *op, 73 PatternRewriter &rewriter) const override { 74 // Exercise OperationFolder API for a single-result operation that is folded 75 // upon construction. The operation being created through the folder has an 76 // in-place folder, and it should be still present in the output. 77 // Furthermore, the folder should not crash when attempting to recover the 78 // (unchanged) operation result. 79 OperationFolder folder(op->getContext()); 80 Value result = folder.create<TestOpInPlaceFold>( 81 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 82 rewriter.getI32IntegerAttr(0)); 83 assert(result); 84 rewriter.replaceOp(op, result); 85 return success(); 86 } 87 }; 88 89 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 90 void runOnFunction() override { 91 mlir::RewritePatternSet patterns(&getContext()); 92 populateWithGenerated(patterns); 93 94 // Verify named pattern is generated with expected name. 95 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext()); 96 97 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 98 } 99 }; 100 } // end anonymous namespace 101 102 //===----------------------------------------------------------------------===// 103 // ReturnType Driver. 104 //===----------------------------------------------------------------------===// 105 106 namespace { 107 // Generate ops for each instance where the type can be successfully inferred. 108 template <typename OpTy> 109 static void invokeCreateWithInferredReturnType(Operation *op) { 110 auto *context = op->getContext(); 111 auto fop = op->getParentOfType<FuncOp>(); 112 auto location = UnknownLoc::get(context); 113 OpBuilder b(op); 114 b.setInsertionPointAfter(op); 115 116 // Use permutations of 2 args as operands. 117 assert(fop.getNumArguments() >= 2); 118 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 119 for (int j = 0; j < e; ++j) { 120 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 121 SmallVector<Type, 2> inferredReturnTypes; 122 if (succeeded(OpTy::inferReturnTypes( 123 context, llvm::None, values, op->getAttrDictionary(), 124 op->getRegions(), inferredReturnTypes))) { 125 OperationState state(location, OpTy::getOperationName()); 126 // TODO: Expand to regions. 127 OpTy::build(b, state, values, op->getAttrs()); 128 (void)b.createOperation(state); 129 } 130 } 131 } 132 } 133 134 static void reifyReturnShape(Operation *op) { 135 OpBuilder b(op); 136 137 // Use permutations of 2 args as operands. 138 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 139 SmallVector<Value, 2> shapes; 140 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) || 141 !llvm::hasSingleElement(shapes)) 142 return; 143 for (auto it : llvm::enumerate(shapes)) { 144 op->emitRemark() << "value " << it.index() << ": " 145 << it.value().getDefiningOp(); 146 } 147 } 148 149 struct TestReturnTypeDriver 150 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 151 void getDependentDialects(DialectRegistry ®istry) const override { 152 registry.insert<memref::MemRefDialect>(); 153 } 154 155 void runOnFunction() override { 156 if (getFunction().getName() == "testCreateFunctions") { 157 std::vector<Operation *> ops; 158 // Collect ops to avoid triggering on inserted ops. 159 for (auto &op : getFunction().getBody().front()) 160 ops.push_back(&op); 161 // Generate test patterns for each, but skip terminator. 162 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 163 // Test create method of each of the Op classes below. The resultant 164 // output would be in reverse order underneath `op` from which 165 // the attributes and regions are used. 166 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 167 invokeCreateWithInferredReturnType< 168 OpWithShapedTypeInferTypeInterfaceOp>(op); 169 }; 170 return; 171 } 172 if (getFunction().getName() == "testReifyFunctions") { 173 std::vector<Operation *> ops; 174 // Collect ops to avoid triggering on inserted ops. 175 for (auto &op : getFunction().getBody().front()) 176 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 177 ops.push_back(&op); 178 // Generate test patterns for each, but skip terminator. 179 for (auto *op : ops) 180 reifyReturnShape(op); 181 } 182 } 183 }; 184 } // end anonymous namespace 185 186 namespace { 187 struct TestDerivedAttributeDriver 188 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 189 void runOnFunction() override; 190 }; 191 } // end anonymous namespace 192 193 void TestDerivedAttributeDriver::runOnFunction() { 194 getFunction().walk([](DerivedAttributeOpInterface dOp) { 195 auto dAttr = dOp.materializeDerivedAttributes(); 196 if (!dAttr) 197 return; 198 for (auto d : dAttr) 199 dOp.emitRemark() << d.first << " = " << d.second; 200 }); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // Legalization Driver. 205 //===----------------------------------------------------------------------===// 206 207 namespace { 208 //===----------------------------------------------------------------------===// 209 // Region-Block Rewrite Testing 210 211 /// This pattern is a simple pattern that inlines the first region of a given 212 /// operation into the parent region. 213 struct TestRegionRewriteBlockMovement : public ConversionPattern { 214 TestRegionRewriteBlockMovement(MLIRContext *ctx) 215 : ConversionPattern("test.region", 1, ctx) {} 216 217 LogicalResult 218 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 219 ConversionPatternRewriter &rewriter) const final { 220 // Inline this region into the parent region. 221 auto &parentRegion = *op->getParentRegion(); 222 auto &opRegion = op->getRegion(0); 223 if (op->getAttr("legalizer.should_clone")) 224 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 225 else 226 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 227 228 if (op->getAttr("legalizer.erase_old_blocks")) { 229 while (!opRegion.empty()) 230 rewriter.eraseBlock(&opRegion.front()); 231 } 232 233 // Drop this operation. 234 rewriter.eraseOp(op); 235 return success(); 236 } 237 }; 238 /// This pattern is a simple pattern that generates a region containing an 239 /// illegal operation. 240 struct TestRegionRewriteUndo : public RewritePattern { 241 TestRegionRewriteUndo(MLIRContext *ctx) 242 : RewritePattern("test.region_builder", 1, ctx) {} 243 244 LogicalResult matchAndRewrite(Operation *op, 245 PatternRewriter &rewriter) const final { 246 // Create the region operation with an entry block containing arguments. 247 OperationState newRegion(op->getLoc(), "test.region"); 248 newRegion.addRegion(); 249 auto *regionOp = rewriter.createOperation(newRegion); 250 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 251 entryBlock->addArgument(rewriter.getIntegerType(64)); 252 253 // Add an explicitly illegal operation to ensure the conversion fails. 254 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 255 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 256 257 // Drop this operation. 258 rewriter.eraseOp(op); 259 return success(); 260 } 261 }; 262 /// A simple pattern that creates a block at the end of the parent region of the 263 /// matched operation. 264 struct TestCreateBlock : public RewritePattern { 265 TestCreateBlock(MLIRContext *ctx) 266 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 267 268 LogicalResult matchAndRewrite(Operation *op, 269 PatternRewriter &rewriter) const final { 270 Region ®ion = *op->getParentRegion(); 271 Type i32Type = rewriter.getIntegerType(32); 272 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 273 rewriter.create<TerminatorOp>(op->getLoc()); 274 rewriter.replaceOp(op, {}); 275 return success(); 276 } 277 }; 278 279 /// A simple pattern that creates a block containing an invalid operation in 280 /// order to trigger the block creation undo mechanism. 281 struct TestCreateIllegalBlock : public RewritePattern { 282 TestCreateIllegalBlock(MLIRContext *ctx) 283 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 284 285 LogicalResult matchAndRewrite(Operation *op, 286 PatternRewriter &rewriter) const final { 287 Region ®ion = *op->getParentRegion(); 288 Type i32Type = rewriter.getIntegerType(32); 289 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 290 // Create an illegal op to ensure the conversion fails. 291 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 292 rewriter.create<TerminatorOp>(op->getLoc()); 293 rewriter.replaceOp(op, {}); 294 return success(); 295 } 296 }; 297 298 /// A simple pattern that tests the undo mechanism when replacing the uses of a 299 /// block argument. 300 struct TestUndoBlockArgReplace : public ConversionPattern { 301 TestUndoBlockArgReplace(MLIRContext *ctx) 302 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 303 304 LogicalResult 305 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 306 ConversionPatternRewriter &rewriter) const final { 307 auto illegalOp = 308 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 309 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 310 illegalOp); 311 rewriter.updateRootInPlace(op, [] {}); 312 return success(); 313 } 314 }; 315 316 /// A rewrite pattern that tests the undo mechanism when erasing a block. 317 struct TestUndoBlockErase : public ConversionPattern { 318 TestUndoBlockErase(MLIRContext *ctx) 319 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 320 321 LogicalResult 322 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 323 ConversionPatternRewriter &rewriter) const final { 324 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 325 rewriter.setInsertionPointToStart(secondBlock); 326 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 327 rewriter.eraseBlock(secondBlock); 328 rewriter.updateRootInPlace(op, [] {}); 329 return success(); 330 } 331 }; 332 333 //===----------------------------------------------------------------------===// 334 // Type-Conversion Rewrite Testing 335 336 /// This patterns erases a region operation that has had a type conversion. 337 struct TestDropOpSignatureConversion : public ConversionPattern { 338 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 339 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 340 LogicalResult 341 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 342 ConversionPatternRewriter &rewriter) const override { 343 Region ®ion = op->getRegion(0); 344 Block *entry = ®ion.front(); 345 346 // Convert the original entry arguments. 347 TypeConverter &converter = *getTypeConverter(); 348 TypeConverter::SignatureConversion result(entry->getNumArguments()); 349 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 350 result)) || 351 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 352 return failure(); 353 354 // Convert the region signature and just drop the operation. 355 rewriter.eraseOp(op); 356 return success(); 357 } 358 }; 359 /// This pattern simply updates the operands of the given operation. 360 struct TestPassthroughInvalidOp : public ConversionPattern { 361 TestPassthroughInvalidOp(MLIRContext *ctx) 362 : ConversionPattern("test.invalid", 1, ctx) {} 363 LogicalResult 364 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 365 ConversionPatternRewriter &rewriter) const final { 366 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 367 llvm::None); 368 return success(); 369 } 370 }; 371 /// This pattern handles the case of a split return value. 372 struct TestSplitReturnType : public ConversionPattern { 373 TestSplitReturnType(MLIRContext *ctx) 374 : ConversionPattern("test.return", 1, ctx) {} 375 LogicalResult 376 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 377 ConversionPatternRewriter &rewriter) const final { 378 // Check for a return of F32. 379 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 380 return failure(); 381 382 // Check if the first operation is a cast operation, if it is we use the 383 // results directly. 384 auto *defOp = operands[0].getDefiningOp(); 385 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 386 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 387 return success(); 388 } 389 390 // Otherwise, fail to match. 391 return failure(); 392 } 393 }; 394 395 //===----------------------------------------------------------------------===// 396 // Multi-Level Type-Conversion Rewrite Testing 397 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 398 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 399 : ConversionPattern("test.type_producer", 1, ctx) {} 400 LogicalResult 401 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 402 ConversionPatternRewriter &rewriter) const final { 403 // If the type is I32, change the type to F32. 404 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 405 return failure(); 406 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 407 return success(); 408 } 409 }; 410 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 411 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 412 : ConversionPattern("test.type_producer", 1, ctx) {} 413 LogicalResult 414 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 415 ConversionPatternRewriter &rewriter) const final { 416 // If the type is F32, change the type to F64. 417 if (!Type(*op->result_type_begin()).isF32()) 418 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 419 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 420 return success(); 421 } 422 }; 423 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 424 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 425 : ConversionPattern("test.type_producer", 10, ctx) {} 426 LogicalResult 427 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 428 ConversionPatternRewriter &rewriter) const final { 429 // Always convert to B16, even though it is not a legal type. This tests 430 // that values are unmapped correctly. 431 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 432 return success(); 433 } 434 }; 435 struct TestUpdateConsumerType : public ConversionPattern { 436 TestUpdateConsumerType(MLIRContext *ctx) 437 : ConversionPattern("test.type_consumer", 1, ctx) {} 438 LogicalResult 439 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 440 ConversionPatternRewriter &rewriter) const final { 441 // Verify that the incoming operand has been successfully remapped to F64. 442 if (!operands[0].getType().isF64()) 443 return failure(); 444 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 445 return success(); 446 } 447 }; 448 449 //===----------------------------------------------------------------------===// 450 // Non-Root Replacement Rewrite Testing 451 /// This pattern generates an invalid operation, but replaces it before the 452 /// pattern is finished. This checks that we don't need to legalize the 453 /// temporary op. 454 struct TestNonRootReplacement : public RewritePattern { 455 TestNonRootReplacement(MLIRContext *ctx) 456 : RewritePattern("test.replace_non_root", 1, ctx) {} 457 458 LogicalResult matchAndRewrite(Operation *op, 459 PatternRewriter &rewriter) const final { 460 auto resultType = *op->result_type_begin(); 461 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 462 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 463 464 rewriter.replaceOp(illegalOp, {legalOp}); 465 rewriter.replaceOp(op, {illegalOp}); 466 return success(); 467 } 468 }; 469 470 //===----------------------------------------------------------------------===// 471 // Recursive Rewrite Testing 472 /// This pattern is applied to the same operation multiple times, but has a 473 /// bounded recursion. 474 struct TestBoundedRecursiveRewrite 475 : public OpRewritePattern<TestRecursiveRewriteOp> { 476 TestBoundedRecursiveRewrite(MLIRContext *ctx) 477 : OpRewritePattern<TestRecursiveRewriteOp>(ctx) { 478 // The conversion target handles bounding the recursion of this pattern. 479 setHasBoundedRewriteRecursion(); 480 } 481 482 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 483 PatternRewriter &rewriter) const final { 484 // Decrement the depth of the op in-place. 485 rewriter.updateRootInPlace(op, [&] { 486 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 487 }); 488 return success(); 489 } 490 }; 491 492 struct TestNestedOpCreationUndoRewrite 493 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 494 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 495 496 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 497 PatternRewriter &rewriter) const final { 498 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 499 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 500 return success(); 501 }; 502 }; 503 504 // This pattern matches `test.blackhole` and delete this op and its producer. 505 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 506 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 507 508 LogicalResult matchAndRewrite(BlackHoleOp op, 509 PatternRewriter &rewriter) const final { 510 Operation *producer = op.getOperand().getDefiningOp(); 511 // Always erase the user before the producer, the framework should handle 512 // this correctly. 513 rewriter.eraseOp(op); 514 rewriter.eraseOp(producer); 515 return success(); 516 }; 517 }; 518 } // namespace 519 520 namespace { 521 struct TestTypeConverter : public TypeConverter { 522 using TypeConverter::TypeConverter; 523 TestTypeConverter() { 524 addConversion(convertType); 525 addArgumentMaterialization(materializeCast); 526 addSourceMaterialization(materializeCast); 527 528 /// Materialize the cast for one-to-one conversion from i64 to f64. 529 const auto materializeOneToOneCast = 530 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 531 Location loc) -> Optional<Value> { 532 if (resultType.getWidth() == 42 && inputs.size() == 1) 533 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 534 return llvm::None; 535 }; 536 addArgumentMaterialization(materializeOneToOneCast); 537 } 538 539 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 540 // Drop I16 types. 541 if (t.isSignlessInteger(16)) 542 return success(); 543 544 // Convert I64 to F64. 545 if (t.isSignlessInteger(64)) { 546 results.push_back(FloatType::getF64(t.getContext())); 547 return success(); 548 } 549 550 // Convert I42 to I43. 551 if (t.isInteger(42)) { 552 results.push_back(IntegerType::get(t.getContext(), 43)); 553 return success(); 554 } 555 556 // Split F32 into F16,F16. 557 if (t.isF32()) { 558 results.assign(2, FloatType::getF16(t.getContext())); 559 return success(); 560 } 561 562 // Otherwise, convert the type directly. 563 results.push_back(t); 564 return success(); 565 } 566 567 /// Hook for materializing a conversion. This is necessary because we generate 568 /// 1->N type mappings. 569 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 570 ValueRange inputs, Location loc) { 571 if (inputs.size() == 1) 572 return inputs[0]; 573 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 574 } 575 }; 576 577 struct TestLegalizePatternDriver 578 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 579 /// The mode of conversion to use with the driver. 580 enum class ConversionMode { Analysis, Full, Partial }; 581 582 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 583 584 void runOnOperation() override { 585 TestTypeConverter converter; 586 mlir::RewritePatternSet patterns(&getContext()); 587 populateWithGenerated(patterns); 588 patterns 589 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 590 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 591 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 592 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 593 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 594 TestNonRootReplacement, TestBoundedRecursiveRewrite, 595 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( 596 &getContext()); 597 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 598 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 599 mlir::populateCallOpTypeConversionPattern(patterns, converter); 600 601 // Define the conversion target used for the test. 602 ConversionTarget target(getContext()); 603 target.addLegalOp<ModuleOp>(); 604 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 605 TerminatorOp>(); 606 target 607 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 608 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 609 // Don't allow F32 operands. 610 return llvm::none_of(op.getOperandTypes(), 611 [](Type type) { return type.isF32(); }); 612 }); 613 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 614 return converter.isSignatureLegal(op.getType()) && 615 converter.isLegal(&op.getBody()); 616 }); 617 618 // Expect the type_producer/type_consumer operations to only operate on f64. 619 target.addDynamicallyLegalOp<TestTypeProducerOp>( 620 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 621 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 622 return op.getOperand().getType().isF64(); 623 }); 624 625 // Check support for marking certain operations as recursively legal. 626 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 627 return static_cast<bool>( 628 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 629 }); 630 631 // Mark the bound recursion operation as dynamically legal. 632 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 633 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 634 635 // Handle a partial conversion. 636 if (mode == ConversionMode::Partial) { 637 DenseSet<Operation *> unlegalizedOps; 638 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 639 &unlegalizedOps); 640 // Emit remarks for each legalizable operation. 641 for (auto *op : unlegalizedOps) 642 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 643 return; 644 } 645 646 // Handle a full conversion. 647 if (mode == ConversionMode::Full) { 648 // Check support for marking unknown operations as dynamically legal. 649 target.markUnknownOpDynamicallyLegal([](Operation *op) { 650 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 651 }); 652 653 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 654 return; 655 } 656 657 // Otherwise, handle an analysis conversion. 658 assert(mode == ConversionMode::Analysis); 659 660 // Analyze the convertible operations. 661 DenseSet<Operation *> legalizedOps; 662 if (failed(applyAnalysisConversion(getOperation(), target, 663 std::move(patterns), legalizedOps))) 664 return signalPassFailure(); 665 666 // Emit remarks for each legalizable operation. 667 for (auto *op : legalizedOps) 668 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 669 } 670 671 /// The mode of conversion to use. 672 ConversionMode mode; 673 }; 674 } // end anonymous namespace 675 676 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 677 legalizerConversionMode( 678 "test-legalize-mode", 679 llvm::cl::desc("The legalization mode to use with the test driver"), 680 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 681 llvm::cl::values( 682 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 683 "analysis", "Perform an analysis conversion"), 684 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 685 "Perform a full conversion"), 686 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 687 "partial", "Perform a partial conversion"))); 688 689 //===----------------------------------------------------------------------===// 690 // ConversionPatternRewriter::getRemappedValue testing. This method is used 691 // to get the remapped value of an original value that was replaced using 692 // ConversionPatternRewriter. 693 namespace { 694 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 695 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 696 /// operand twice. 697 /// 698 /// Example: 699 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 700 /// is replaced with: 701 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 702 struct OneVResOneVOperandOp1Converter 703 : public OpConversionPattern<OneVResOneVOperandOp1> { 704 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 705 706 LogicalResult 707 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 708 ConversionPatternRewriter &rewriter) const override { 709 auto origOps = op.getOperands(); 710 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 711 "One operand expected"); 712 Value origOp = *origOps.begin(); 713 SmallVector<Value, 2> remappedOperands; 714 // Replicate the remapped original operand twice. Note that we don't used 715 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 716 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 717 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 718 719 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 720 remappedOperands); 721 return success(); 722 } 723 }; 724 725 struct TestRemappedValue 726 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 727 void runOnFunction() override { 728 mlir::RewritePatternSet patterns(&getContext()); 729 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 730 731 mlir::ConversionTarget target(getContext()); 732 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 733 // We make OneVResOneVOperandOp1 legal only when it has more that one 734 // operand. This will trigger the conversion that will replace one-operand 735 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 736 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 737 [](Operation *op) -> bool { 738 return std::distance(op->operand_begin(), op->operand_end()) > 1; 739 }); 740 741 if (failed(mlir::applyFullConversion(getFunction(), target, 742 std::move(patterns)))) { 743 signalPassFailure(); 744 } 745 } 746 }; 747 } // end anonymous namespace 748 749 //===----------------------------------------------------------------------===// 750 // Test patterns without a specific root operation kind 751 //===----------------------------------------------------------------------===// 752 753 namespace { 754 /// This pattern matches and removes any operation in the test dialect. 755 struct RemoveTestDialectOps : public RewritePattern { 756 RemoveTestDialectOps(MLIRContext *context) 757 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 758 759 LogicalResult matchAndRewrite(Operation *op, 760 PatternRewriter &rewriter) const override { 761 if (!isa<TestDialect>(op->getDialect())) 762 return failure(); 763 rewriter.eraseOp(op); 764 return success(); 765 } 766 }; 767 768 struct TestUnknownRootOpDriver 769 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 770 void runOnFunction() override { 771 mlir::RewritePatternSet patterns(&getContext()); 772 patterns.add<RemoveTestDialectOps>(&getContext()); 773 774 mlir::ConversionTarget target(getContext()); 775 target.addIllegalDialect<TestDialect>(); 776 if (failed( 777 applyPartialConversion(getFunction(), target, std::move(patterns)))) 778 signalPassFailure(); 779 } 780 }; 781 } // end anonymous namespace 782 783 //===----------------------------------------------------------------------===// 784 // Test type conversions 785 //===----------------------------------------------------------------------===// 786 787 namespace { 788 struct TestTypeConversionProducer 789 : public OpConversionPattern<TestTypeProducerOp> { 790 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 791 LogicalResult 792 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 793 ConversionPatternRewriter &rewriter) const final { 794 Type resultType = op.getType(); 795 if (resultType.isa<FloatType>()) 796 resultType = rewriter.getF64Type(); 797 else if (resultType.isInteger(16)) 798 resultType = rewriter.getIntegerType(64); 799 else 800 return failure(); 801 802 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 803 return success(); 804 } 805 }; 806 807 /// Call signature conversion and then fail the rewrite to trigger the undo 808 /// mechanism. 809 struct TestSignatureConversionUndo 810 : public OpConversionPattern<TestSignatureConversionUndoOp> { 811 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 812 813 LogicalResult 814 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 815 ConversionPatternRewriter &rewriter) const final { 816 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 817 return failure(); 818 } 819 }; 820 821 /// Just forward the operands to the root op. This is essentially a no-op 822 /// pattern that is used to trigger target materialization. 823 struct TestTypeConsumerForward 824 : public OpConversionPattern<TestTypeConsumerOp> { 825 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 826 827 LogicalResult 828 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 829 ConversionPatternRewriter &rewriter) const final { 830 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 831 return success(); 832 } 833 }; 834 835 struct TestTypeConversionAnotherProducer 836 : public OpRewritePattern<TestAnotherTypeProducerOp> { 837 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 838 839 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 840 PatternRewriter &rewriter) const final { 841 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 842 return success(); 843 } 844 }; 845 846 struct TestTypeConversionDriver 847 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 848 void getDependentDialects(DialectRegistry ®istry) const override { 849 registry.insert<TestDialect>(); 850 } 851 852 void runOnOperation() override { 853 // Initialize the type converter. 854 TypeConverter converter; 855 856 /// Add the legal set of type conversions. 857 converter.addConversion([](Type type) -> Type { 858 // Treat F64 as legal. 859 if (type.isF64()) 860 return type; 861 // Allow converting BF16/F16/F32 to F64. 862 if (type.isBF16() || type.isF16() || type.isF32()) 863 return FloatType::getF64(type.getContext()); 864 // Otherwise, the type is illegal. 865 return nullptr; 866 }); 867 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 868 // Drop all integer types. 869 return success(); 870 }); 871 872 /// Add the legal set of type materializations. 873 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 874 ValueRange inputs, 875 Location loc) -> Value { 876 // Allow casting from F64 back to F32. 877 if (!resultType.isF16() && inputs.size() == 1 && 878 inputs[0].getType().isF64()) 879 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 880 // Allow producing an i32 or i64 from nothing. 881 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 882 inputs.empty()) 883 return builder.create<TestTypeProducerOp>(loc, resultType); 884 // Allow producing an i64 from an integer. 885 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 886 inputs[0].getType().isa<IntegerType>()) 887 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 888 // Otherwise, fail. 889 return nullptr; 890 }); 891 892 // Initialize the conversion target. 893 mlir::ConversionTarget target(getContext()); 894 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 895 return op.getType().isF64() || op.getType().isInteger(64); 896 }); 897 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 898 return converter.isSignatureLegal(op.getType()) && 899 converter.isLegal(&op.getBody()); 900 }); 901 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 902 // Allow casts from F64 to F32. 903 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 904 }); 905 906 // Initialize the set of rewrite patterns. 907 RewritePatternSet patterns(&getContext()); 908 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 909 TestSignatureConversionUndo>(converter, &getContext()); 910 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 911 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 912 913 if (failed(applyPartialConversion(getOperation(), target, 914 std::move(patterns)))) 915 signalPassFailure(); 916 } 917 }; 918 } // end anonymous namespace 919 920 //===----------------------------------------------------------------------===// 921 // Test Block Merging 922 //===----------------------------------------------------------------------===// 923 924 namespace { 925 /// A rewriter pattern that tests that blocks can be merged. 926 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 927 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 928 929 LogicalResult 930 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 931 ConversionPatternRewriter &rewriter) const final { 932 Block &firstBlock = op.body().front(); 933 Operation *branchOp = firstBlock.getTerminator(); 934 Block *secondBlock = &*(std::next(op.body().begin())); 935 auto succOperands = branchOp->getOperands(); 936 SmallVector<Value, 2> replacements(succOperands); 937 rewriter.eraseOp(branchOp); 938 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 939 rewriter.updateRootInPlace(op, [] {}); 940 return success(); 941 } 942 }; 943 944 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 945 struct TestUndoBlocksMerge : public ConversionPattern { 946 TestUndoBlocksMerge(MLIRContext *ctx) 947 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 948 LogicalResult 949 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 950 ConversionPatternRewriter &rewriter) const final { 951 Block &firstBlock = op->getRegion(0).front(); 952 Operation *branchOp = firstBlock.getTerminator(); 953 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 954 rewriter.setInsertionPointToStart(secondBlock); 955 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 956 auto succOperands = branchOp->getOperands(); 957 SmallVector<Value, 2> replacements(succOperands); 958 rewriter.eraseOp(branchOp); 959 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 960 rewriter.updateRootInPlace(op, [] {}); 961 return success(); 962 } 963 }; 964 965 /// A rewrite mechanism to inline the body of the op into its parent, when both 966 /// ops can have a single block. 967 struct TestMergeSingleBlockOps 968 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 969 using OpConversionPattern< 970 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 971 972 LogicalResult 973 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 974 ConversionPatternRewriter &rewriter) const final { 975 SingleBlockImplicitTerminatorOp parentOp = 976 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 977 if (!parentOp) 978 return failure(); 979 Block &innerBlock = op.region().front(); 980 TerminatorOp innerTerminator = 981 cast<TerminatorOp>(innerBlock.getTerminator()); 982 rewriter.mergeBlockBefore(&innerBlock, op); 983 rewriter.eraseOp(innerTerminator); 984 rewriter.eraseOp(op); 985 rewriter.updateRootInPlace(op, [] {}); 986 return success(); 987 } 988 }; 989 990 struct TestMergeBlocksPatternDriver 991 : public PassWrapper<TestMergeBlocksPatternDriver, 992 OperationPass<ModuleOp>> { 993 void runOnOperation() override { 994 MLIRContext *context = &getContext(); 995 mlir::RewritePatternSet patterns(context); 996 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 997 context); 998 ConversionTarget target(*context); 999 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1000 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1001 target.addIllegalOp<ILLegalOpF>(); 1002 1003 /// Expect the op to have a single block after legalization. 1004 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1005 [&](TestMergeBlocksOp op) -> bool { 1006 return llvm::hasSingleElement(op.body()); 1007 }); 1008 1009 /// Only allow `test.br` within test.merge_blocks op. 1010 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1011 return op->getParentOfType<TestMergeBlocksOp>(); 1012 }); 1013 1014 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1015 /// inlined. 1016 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1017 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1018 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1019 }); 1020 1021 DenseSet<Operation *> unlegalizedOps; 1022 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1023 &unlegalizedOps); 1024 for (auto *op : unlegalizedOps) 1025 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1026 } 1027 }; 1028 } // namespace 1029 1030 //===----------------------------------------------------------------------===// 1031 // Test Selective Replacement 1032 //===----------------------------------------------------------------------===// 1033 1034 namespace { 1035 /// A rewrite mechanism to inline the body of the op into its parent, when both 1036 /// ops can have a single block. 1037 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1038 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1039 1040 LogicalResult matchAndRewrite(TestCastOp op, 1041 PatternRewriter &rewriter) const final { 1042 if (op.getNumOperands() != 2) 1043 return failure(); 1044 OperandRange operands = op.getOperands(); 1045 1046 // Replace non-terminator uses with the first operand. 1047 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1048 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1049 }); 1050 // Replace everything else with the second operand if the operation isn't 1051 // dead. 1052 rewriter.replaceOp(op, op.getOperand(1)); 1053 return success(); 1054 } 1055 }; 1056 1057 struct TestSelectiveReplacementPatternDriver 1058 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1059 OperationPass<>> { 1060 void runOnOperation() override { 1061 MLIRContext *context = &getContext(); 1062 mlir::RewritePatternSet patterns(context); 1063 patterns.add<TestSelectiveOpReplacementPattern>(context); 1064 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1065 std::move(patterns)); 1066 } 1067 }; 1068 } // namespace 1069 1070 //===----------------------------------------------------------------------===// 1071 // PassRegistration 1072 //===----------------------------------------------------------------------===// 1073 1074 namespace mlir { 1075 namespace test { 1076 void registerPatternsTestPass() { 1077 PassRegistration<TestReturnTypeDriver>("test-return-type", 1078 "Run return type functions"); 1079 1080 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 1081 "Run test derived attributes"); 1082 1083 PassRegistration<TestPatternDriver>("test-patterns", 1084 "Run test dialect patterns"); 1085 1086 PassRegistration<TestLegalizePatternDriver>( 1087 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 1088 return std::make_unique<TestLegalizePatternDriver>( 1089 legalizerConversionMode); 1090 }); 1091 1092 PassRegistration<TestRemappedValue>( 1093 "test-remapped-value", 1094 "Test public remapped value mechanism in ConversionPatternRewriter"); 1095 1096 PassRegistration<TestUnknownRootOpDriver>( 1097 "test-legalize-unknown-root-patterns", 1098 "Test public remapped value mechanism in ConversionPatternRewriter"); 1099 1100 PassRegistration<TestTypeConversionDriver>( 1101 "test-legalize-type-conversion", 1102 "Test various type conversion functionalities in DialectConversion"); 1103 1104 PassRegistration<TestMergeBlocksPatternDriver>{ 1105 "test-merge-blocks", 1106 "Test Merging operation in ConversionPatternRewriter"}; 1107 PassRegistration<TestSelectiveReplacementPatternDriver>{ 1108 "test-pattern-selective-replacement", 1109 "Test selective replacement in the PatternRewriter"}; 1110 } 1111 } // namespace test 1112 } // namespace mlir 1113