1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 // Test that natives calls are only called once during rewrites. 39 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 40 // This let us check the number of times OpM_Test was called by inspecting 41 // the returned value in the MLIR output. 42 static int64_t opMIncreasingValue = 314159265; 43 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 44 int64_t i = opMIncreasingValue++; 45 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 46 } 47 48 namespace { 49 #include "TestPatterns.inc" 50 } // end anonymous namespace 51 52 //===----------------------------------------------------------------------===// 53 // Canonicalizer Driver. 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 struct FoldingPattern : public RewritePattern { 58 public: 59 FoldingPattern(MLIRContext *context) 60 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 61 /*benefit=*/1, context) {} 62 63 LogicalResult matchAndRewrite(Operation *op, 64 PatternRewriter &rewriter) const override { 65 // Exercise OperationFolder API for a single-result operation that is folded 66 // upon construction. The operation being created through the folder has an 67 // in-place folder, and it should be still present in the output. 68 // Furthermore, the folder should not crash when attempting to recover the 69 // (unchanged) operation result. 70 OperationFolder folder(op->getContext()); 71 Value result = folder.create<TestOpInPlaceFold>( 72 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 73 rewriter.getI32IntegerAttr(0)); 74 assert(result); 75 rewriter.replaceOp(op, result); 76 return success(); 77 } 78 }; 79 80 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 81 void runOnFunction() override { 82 mlir::RewritePatternSet patterns(&getContext()); 83 populateWithGenerated(patterns); 84 85 // Verify named pattern is generated with expected name. 86 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext()); 87 88 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 89 } 90 }; 91 } // end anonymous namespace 92 93 //===----------------------------------------------------------------------===// 94 // ReturnType Driver. 95 //===----------------------------------------------------------------------===// 96 97 namespace { 98 // Generate ops for each instance where the type can be successfully inferred. 99 template <typename OpTy> 100 static void invokeCreateWithInferredReturnType(Operation *op) { 101 auto *context = op->getContext(); 102 auto fop = op->getParentOfType<FuncOp>(); 103 auto location = UnknownLoc::get(context); 104 OpBuilder b(op); 105 b.setInsertionPointAfter(op); 106 107 // Use permutations of 2 args as operands. 108 assert(fop.getNumArguments() >= 2); 109 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 110 for (int j = 0; j < e; ++j) { 111 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 112 SmallVector<Type, 2> inferredReturnTypes; 113 if (succeeded(OpTy::inferReturnTypes( 114 context, llvm::None, values, op->getAttrDictionary(), 115 op->getRegions(), inferredReturnTypes))) { 116 OperationState state(location, OpTy::getOperationName()); 117 // TODO: Expand to regions. 118 OpTy::build(b, state, values, op->getAttrs()); 119 (void)b.createOperation(state); 120 } 121 } 122 } 123 } 124 125 static void reifyReturnShape(Operation *op) { 126 OpBuilder b(op); 127 128 // Use permutations of 2 args as operands. 129 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 130 SmallVector<Value, 2> shapes; 131 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) || 132 !llvm::hasSingleElement(shapes)) 133 return; 134 for (auto it : llvm::enumerate(shapes)) { 135 op->emitRemark() << "value " << it.index() << ": " 136 << it.value().getDefiningOp(); 137 } 138 } 139 140 struct TestReturnTypeDriver 141 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 142 void getDependentDialects(DialectRegistry ®istry) const override { 143 registry.insert<memref::MemRefDialect>(); 144 } 145 146 void runOnFunction() override { 147 if (getFunction().getName() == "testCreateFunctions") { 148 std::vector<Operation *> ops; 149 // Collect ops to avoid triggering on inserted ops. 150 for (auto &op : getFunction().getBody().front()) 151 ops.push_back(&op); 152 // Generate test patterns for each, but skip terminator. 153 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 154 // Test create method of each of the Op classes below. The resultant 155 // output would be in reverse order underneath `op` from which 156 // the attributes and regions are used. 157 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 158 invokeCreateWithInferredReturnType< 159 OpWithShapedTypeInferTypeInterfaceOp>(op); 160 }; 161 return; 162 } 163 if (getFunction().getName() == "testReifyFunctions") { 164 std::vector<Operation *> ops; 165 // Collect ops to avoid triggering on inserted ops. 166 for (auto &op : getFunction().getBody().front()) 167 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 168 ops.push_back(&op); 169 // Generate test patterns for each, but skip terminator. 170 for (auto *op : ops) 171 reifyReturnShape(op); 172 } 173 } 174 }; 175 } // end anonymous namespace 176 177 namespace { 178 struct TestDerivedAttributeDriver 179 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 180 void runOnFunction() override; 181 }; 182 } // end anonymous namespace 183 184 void TestDerivedAttributeDriver::runOnFunction() { 185 getFunction().walk([](DerivedAttributeOpInterface dOp) { 186 auto dAttr = dOp.materializeDerivedAttributes(); 187 if (!dAttr) 188 return; 189 for (auto d : dAttr) 190 dOp.emitRemark() << d.first << " = " << d.second; 191 }); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // Legalization Driver. 196 //===----------------------------------------------------------------------===// 197 198 namespace { 199 //===----------------------------------------------------------------------===// 200 // Region-Block Rewrite Testing 201 202 /// This pattern is a simple pattern that inlines the first region of a given 203 /// operation into the parent region. 204 struct TestRegionRewriteBlockMovement : public ConversionPattern { 205 TestRegionRewriteBlockMovement(MLIRContext *ctx) 206 : ConversionPattern("test.region", 1, ctx) {} 207 208 LogicalResult 209 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 210 ConversionPatternRewriter &rewriter) const final { 211 // Inline this region into the parent region. 212 auto &parentRegion = *op->getParentRegion(); 213 auto &opRegion = op->getRegion(0); 214 if (op->getAttr("legalizer.should_clone")) 215 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 216 else 217 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 218 219 if (op->getAttr("legalizer.erase_old_blocks")) { 220 while (!opRegion.empty()) 221 rewriter.eraseBlock(&opRegion.front()); 222 } 223 224 // Drop this operation. 225 rewriter.eraseOp(op); 226 return success(); 227 } 228 }; 229 /// This pattern is a simple pattern that generates a region containing an 230 /// illegal operation. 231 struct TestRegionRewriteUndo : public RewritePattern { 232 TestRegionRewriteUndo(MLIRContext *ctx) 233 : RewritePattern("test.region_builder", 1, ctx) {} 234 235 LogicalResult matchAndRewrite(Operation *op, 236 PatternRewriter &rewriter) const final { 237 // Create the region operation with an entry block containing arguments. 238 OperationState newRegion(op->getLoc(), "test.region"); 239 newRegion.addRegion(); 240 auto *regionOp = rewriter.createOperation(newRegion); 241 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 242 entryBlock->addArgument(rewriter.getIntegerType(64)); 243 244 // Add an explicitly illegal operation to ensure the conversion fails. 245 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 246 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 247 248 // Drop this operation. 249 rewriter.eraseOp(op); 250 return success(); 251 } 252 }; 253 /// A simple pattern that creates a block at the end of the parent region of the 254 /// matched operation. 255 struct TestCreateBlock : public RewritePattern { 256 TestCreateBlock(MLIRContext *ctx) 257 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 258 259 LogicalResult matchAndRewrite(Operation *op, 260 PatternRewriter &rewriter) const final { 261 Region ®ion = *op->getParentRegion(); 262 Type i32Type = rewriter.getIntegerType(32); 263 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 264 rewriter.create<TerminatorOp>(op->getLoc()); 265 rewriter.replaceOp(op, {}); 266 return success(); 267 } 268 }; 269 270 /// A simple pattern that creates a block containing an invalid operation in 271 /// order to trigger the block creation undo mechanism. 272 struct TestCreateIllegalBlock : public RewritePattern { 273 TestCreateIllegalBlock(MLIRContext *ctx) 274 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 275 276 LogicalResult matchAndRewrite(Operation *op, 277 PatternRewriter &rewriter) const final { 278 Region ®ion = *op->getParentRegion(); 279 Type i32Type = rewriter.getIntegerType(32); 280 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 281 // Create an illegal op to ensure the conversion fails. 282 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 283 rewriter.create<TerminatorOp>(op->getLoc()); 284 rewriter.replaceOp(op, {}); 285 return success(); 286 } 287 }; 288 289 /// A simple pattern that tests the undo mechanism when replacing the uses of a 290 /// block argument. 291 struct TestUndoBlockArgReplace : public ConversionPattern { 292 TestUndoBlockArgReplace(MLIRContext *ctx) 293 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 294 295 LogicalResult 296 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 297 ConversionPatternRewriter &rewriter) const final { 298 auto illegalOp = 299 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 300 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 301 illegalOp); 302 rewriter.updateRootInPlace(op, [] {}); 303 return success(); 304 } 305 }; 306 307 /// A rewrite pattern that tests the undo mechanism when erasing a block. 308 struct TestUndoBlockErase : public ConversionPattern { 309 TestUndoBlockErase(MLIRContext *ctx) 310 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 311 312 LogicalResult 313 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 314 ConversionPatternRewriter &rewriter) const final { 315 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 316 rewriter.setInsertionPointToStart(secondBlock); 317 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 318 rewriter.eraseBlock(secondBlock); 319 rewriter.updateRootInPlace(op, [] {}); 320 return success(); 321 } 322 }; 323 324 //===----------------------------------------------------------------------===// 325 // Type-Conversion Rewrite Testing 326 327 /// This patterns erases a region operation that has had a type conversion. 328 struct TestDropOpSignatureConversion : public ConversionPattern { 329 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 330 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 331 LogicalResult 332 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 333 ConversionPatternRewriter &rewriter) const override { 334 Region ®ion = op->getRegion(0); 335 Block *entry = ®ion.front(); 336 337 // Convert the original entry arguments. 338 TypeConverter &converter = *getTypeConverter(); 339 TypeConverter::SignatureConversion result(entry->getNumArguments()); 340 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 341 result)) || 342 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 343 return failure(); 344 345 // Convert the region signature and just drop the operation. 346 rewriter.eraseOp(op); 347 return success(); 348 } 349 }; 350 /// This pattern simply updates the operands of the given operation. 351 struct TestPassthroughInvalidOp : public ConversionPattern { 352 TestPassthroughInvalidOp(MLIRContext *ctx) 353 : ConversionPattern("test.invalid", 1, ctx) {} 354 LogicalResult 355 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 356 ConversionPatternRewriter &rewriter) const final { 357 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 358 llvm::None); 359 return success(); 360 } 361 }; 362 /// This pattern handles the case of a split return value. 363 struct TestSplitReturnType : public ConversionPattern { 364 TestSplitReturnType(MLIRContext *ctx) 365 : ConversionPattern("test.return", 1, ctx) {} 366 LogicalResult 367 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 368 ConversionPatternRewriter &rewriter) const final { 369 // Check for a return of F32. 370 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 371 return failure(); 372 373 // Check if the first operation is a cast operation, if it is we use the 374 // results directly. 375 auto *defOp = operands[0].getDefiningOp(); 376 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 377 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 378 return success(); 379 } 380 381 // Otherwise, fail to match. 382 return failure(); 383 } 384 }; 385 386 //===----------------------------------------------------------------------===// 387 // Multi-Level Type-Conversion Rewrite Testing 388 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 389 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 390 : ConversionPattern("test.type_producer", 1, ctx) {} 391 LogicalResult 392 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 393 ConversionPatternRewriter &rewriter) const final { 394 // If the type is I32, change the type to F32. 395 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 396 return failure(); 397 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 398 return success(); 399 } 400 }; 401 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 402 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 403 : ConversionPattern("test.type_producer", 1, ctx) {} 404 LogicalResult 405 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 406 ConversionPatternRewriter &rewriter) const final { 407 // If the type is F32, change the type to F64. 408 if (!Type(*op->result_type_begin()).isF32()) 409 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 410 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 411 return success(); 412 } 413 }; 414 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 415 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 416 : ConversionPattern("test.type_producer", 10, ctx) {} 417 LogicalResult 418 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 419 ConversionPatternRewriter &rewriter) const final { 420 // Always convert to B16, even though it is not a legal type. This tests 421 // that values are unmapped correctly. 422 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 423 return success(); 424 } 425 }; 426 struct TestUpdateConsumerType : public ConversionPattern { 427 TestUpdateConsumerType(MLIRContext *ctx) 428 : ConversionPattern("test.type_consumer", 1, ctx) {} 429 LogicalResult 430 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 431 ConversionPatternRewriter &rewriter) const final { 432 // Verify that the incoming operand has been successfully remapped to F64. 433 if (!operands[0].getType().isF64()) 434 return failure(); 435 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 436 return success(); 437 } 438 }; 439 440 //===----------------------------------------------------------------------===// 441 // Non-Root Replacement Rewrite Testing 442 /// This pattern generates an invalid operation, but replaces it before the 443 /// pattern is finished. This checks that we don't need to legalize the 444 /// temporary op. 445 struct TestNonRootReplacement : public RewritePattern { 446 TestNonRootReplacement(MLIRContext *ctx) 447 : RewritePattern("test.replace_non_root", 1, ctx) {} 448 449 LogicalResult matchAndRewrite(Operation *op, 450 PatternRewriter &rewriter) const final { 451 auto resultType = *op->result_type_begin(); 452 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 453 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 454 455 rewriter.replaceOp(illegalOp, {legalOp}); 456 rewriter.replaceOp(op, {illegalOp}); 457 return success(); 458 } 459 }; 460 461 //===----------------------------------------------------------------------===// 462 // Recursive Rewrite Testing 463 /// This pattern is applied to the same operation multiple times, but has a 464 /// bounded recursion. 465 struct TestBoundedRecursiveRewrite 466 : public OpRewritePattern<TestRecursiveRewriteOp> { 467 TestBoundedRecursiveRewrite(MLIRContext *ctx) 468 : OpRewritePattern<TestRecursiveRewriteOp>(ctx) { 469 // The conversion target handles bounding the recursion of this pattern. 470 setHasBoundedRewriteRecursion(); 471 } 472 473 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 474 PatternRewriter &rewriter) const final { 475 // Decrement the depth of the op in-place. 476 rewriter.updateRootInPlace(op, [&] { 477 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 478 }); 479 return success(); 480 } 481 }; 482 483 struct TestNestedOpCreationUndoRewrite 484 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 485 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 486 487 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 488 PatternRewriter &rewriter) const final { 489 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 490 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 491 return success(); 492 }; 493 }; 494 } // namespace 495 496 namespace { 497 struct TestTypeConverter : public TypeConverter { 498 using TypeConverter::TypeConverter; 499 TestTypeConverter() { 500 addConversion(convertType); 501 addArgumentMaterialization(materializeCast); 502 addSourceMaterialization(materializeCast); 503 504 /// Materialize the cast for one-to-one conversion from i64 to f64. 505 const auto materializeOneToOneCast = 506 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 507 Location loc) -> Optional<Value> { 508 if (resultType.getWidth() == 42 && inputs.size() == 1) 509 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 510 return llvm::None; 511 }; 512 addArgumentMaterialization(materializeOneToOneCast); 513 } 514 515 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 516 // Drop I16 types. 517 if (t.isSignlessInteger(16)) 518 return success(); 519 520 // Convert I64 to F64. 521 if (t.isSignlessInteger(64)) { 522 results.push_back(FloatType::getF64(t.getContext())); 523 return success(); 524 } 525 526 // Convert I42 to I43. 527 if (t.isInteger(42)) { 528 results.push_back(IntegerType::get(t.getContext(), 43)); 529 return success(); 530 } 531 532 // Split F32 into F16,F16. 533 if (t.isF32()) { 534 results.assign(2, FloatType::getF16(t.getContext())); 535 return success(); 536 } 537 538 // Otherwise, convert the type directly. 539 results.push_back(t); 540 return success(); 541 } 542 543 /// Hook for materializing a conversion. This is necessary because we generate 544 /// 1->N type mappings. 545 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 546 ValueRange inputs, Location loc) { 547 if (inputs.size() == 1) 548 return inputs[0]; 549 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 550 } 551 }; 552 553 struct TestLegalizePatternDriver 554 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 555 /// The mode of conversion to use with the driver. 556 enum class ConversionMode { Analysis, Full, Partial }; 557 558 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 559 560 void runOnOperation() override { 561 TestTypeConverter converter; 562 mlir::RewritePatternSet patterns(&getContext()); 563 populateWithGenerated(patterns); 564 patterns 565 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 566 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 567 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 568 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 569 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 570 TestNonRootReplacement, TestBoundedRecursiveRewrite, 571 TestNestedOpCreationUndoRewrite>(&getContext()); 572 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 573 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 574 mlir::populateCallOpTypeConversionPattern(patterns, converter); 575 576 // Define the conversion target used for the test. 577 ConversionTarget target(getContext()); 578 target.addLegalOp<ModuleOp>(); 579 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 580 TerminatorOp>(); 581 target 582 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 583 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 584 // Don't allow F32 operands. 585 return llvm::none_of(op.getOperandTypes(), 586 [](Type type) { return type.isF32(); }); 587 }); 588 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 589 return converter.isSignatureLegal(op.getType()) && 590 converter.isLegal(&op.getBody()); 591 }); 592 593 // Expect the type_producer/type_consumer operations to only operate on f64. 594 target.addDynamicallyLegalOp<TestTypeProducerOp>( 595 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 596 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 597 return op.getOperand().getType().isF64(); 598 }); 599 600 // Check support for marking certain operations as recursively legal. 601 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 602 return static_cast<bool>( 603 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 604 }); 605 606 // Mark the bound recursion operation as dynamically legal. 607 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 608 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 609 610 // Handle a partial conversion. 611 if (mode == ConversionMode::Partial) { 612 DenseSet<Operation *> unlegalizedOps; 613 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 614 &unlegalizedOps); 615 // Emit remarks for each legalizable operation. 616 for (auto *op : unlegalizedOps) 617 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 618 return; 619 } 620 621 // Handle a full conversion. 622 if (mode == ConversionMode::Full) { 623 // Check support for marking unknown operations as dynamically legal. 624 target.markUnknownOpDynamicallyLegal([](Operation *op) { 625 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 626 }); 627 628 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 629 return; 630 } 631 632 // Otherwise, handle an analysis conversion. 633 assert(mode == ConversionMode::Analysis); 634 635 // Analyze the convertible operations. 636 DenseSet<Operation *> legalizedOps; 637 if (failed(applyAnalysisConversion(getOperation(), target, 638 std::move(patterns), legalizedOps))) 639 return signalPassFailure(); 640 641 // Emit remarks for each legalizable operation. 642 for (auto *op : legalizedOps) 643 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 644 } 645 646 /// The mode of conversion to use. 647 ConversionMode mode; 648 }; 649 } // end anonymous namespace 650 651 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 652 legalizerConversionMode( 653 "test-legalize-mode", 654 llvm::cl::desc("The legalization mode to use with the test driver"), 655 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 656 llvm::cl::values( 657 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 658 "analysis", "Perform an analysis conversion"), 659 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 660 "Perform a full conversion"), 661 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 662 "partial", "Perform a partial conversion"))); 663 664 //===----------------------------------------------------------------------===// 665 // ConversionPatternRewriter::getRemappedValue testing. This method is used 666 // to get the remapped value of an original value that was replaced using 667 // ConversionPatternRewriter. 668 namespace { 669 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 670 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 671 /// operand twice. 672 /// 673 /// Example: 674 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 675 /// is replaced with: 676 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 677 struct OneVResOneVOperandOp1Converter 678 : public OpConversionPattern<OneVResOneVOperandOp1> { 679 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 680 681 LogicalResult 682 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 683 ConversionPatternRewriter &rewriter) const override { 684 auto origOps = op.getOperands(); 685 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 686 "One operand expected"); 687 Value origOp = *origOps.begin(); 688 SmallVector<Value, 2> remappedOperands; 689 // Replicate the remapped original operand twice. Note that we don't used 690 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 691 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 692 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 693 694 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 695 remappedOperands); 696 return success(); 697 } 698 }; 699 700 struct TestRemappedValue 701 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 702 void runOnFunction() override { 703 mlir::RewritePatternSet patterns(&getContext()); 704 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 705 706 mlir::ConversionTarget target(getContext()); 707 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 708 // We make OneVResOneVOperandOp1 legal only when it has more that one 709 // operand. This will trigger the conversion that will replace one-operand 710 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 711 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 712 [](Operation *op) -> bool { 713 return std::distance(op->operand_begin(), op->operand_end()) > 1; 714 }); 715 716 if (failed(mlir::applyFullConversion(getFunction(), target, 717 std::move(patterns)))) { 718 signalPassFailure(); 719 } 720 } 721 }; 722 } // end anonymous namespace 723 724 //===----------------------------------------------------------------------===// 725 // Test patterns without a specific root operation kind 726 //===----------------------------------------------------------------------===// 727 728 namespace { 729 /// This pattern matches and removes any operation in the test dialect. 730 struct RemoveTestDialectOps : public RewritePattern { 731 RemoveTestDialectOps(MLIRContext *context) 732 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 733 734 LogicalResult matchAndRewrite(Operation *op, 735 PatternRewriter &rewriter) const override { 736 if (!isa<TestDialect>(op->getDialect())) 737 return failure(); 738 rewriter.eraseOp(op); 739 return success(); 740 } 741 }; 742 743 struct TestUnknownRootOpDriver 744 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 745 void runOnFunction() override { 746 mlir::RewritePatternSet patterns(&getContext()); 747 patterns.add<RemoveTestDialectOps>(&getContext()); 748 749 mlir::ConversionTarget target(getContext()); 750 target.addIllegalDialect<TestDialect>(); 751 if (failed( 752 applyPartialConversion(getFunction(), target, std::move(patterns)))) 753 signalPassFailure(); 754 } 755 }; 756 } // end anonymous namespace 757 758 //===----------------------------------------------------------------------===// 759 // Test type conversions 760 //===----------------------------------------------------------------------===// 761 762 namespace { 763 struct TestTypeConversionProducer 764 : public OpConversionPattern<TestTypeProducerOp> { 765 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 766 LogicalResult 767 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 768 ConversionPatternRewriter &rewriter) const final { 769 Type resultType = op.getType(); 770 if (resultType.isa<FloatType>()) 771 resultType = rewriter.getF64Type(); 772 else if (resultType.isInteger(16)) 773 resultType = rewriter.getIntegerType(64); 774 else 775 return failure(); 776 777 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 778 return success(); 779 } 780 }; 781 782 /// Call signature conversion and then fail the rewrite to trigger the undo 783 /// mechanism. 784 struct TestSignatureConversionUndo 785 : public OpConversionPattern<TestSignatureConversionUndoOp> { 786 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 787 788 LogicalResult 789 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 790 ConversionPatternRewriter &rewriter) const final { 791 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 792 return failure(); 793 } 794 }; 795 796 /// Just forward the operands to the root op. This is essentially a no-op 797 /// pattern that is used to trigger target materialization. 798 struct TestTypeConsumerForward 799 : public OpConversionPattern<TestTypeConsumerOp> { 800 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 801 802 LogicalResult 803 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 804 ConversionPatternRewriter &rewriter) const final { 805 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 806 return success(); 807 } 808 }; 809 810 struct TestTypeConversionAnotherProducer 811 : public OpRewritePattern<TestAnotherTypeProducerOp> { 812 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 813 814 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 815 PatternRewriter &rewriter) const final { 816 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 817 return success(); 818 } 819 }; 820 821 struct TestTypeConversionDriver 822 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 823 void getDependentDialects(DialectRegistry ®istry) const override { 824 registry.insert<TestDialect>(); 825 } 826 827 void runOnOperation() override { 828 // Initialize the type converter. 829 TypeConverter converter; 830 831 /// Add the legal set of type conversions. 832 converter.addConversion([](Type type) -> Type { 833 // Treat F64 as legal. 834 if (type.isF64()) 835 return type; 836 // Allow converting BF16/F16/F32 to F64. 837 if (type.isBF16() || type.isF16() || type.isF32()) 838 return FloatType::getF64(type.getContext()); 839 // Otherwise, the type is illegal. 840 return nullptr; 841 }); 842 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 843 // Drop all integer types. 844 return success(); 845 }); 846 847 /// Add the legal set of type materializations. 848 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 849 ValueRange inputs, 850 Location loc) -> Value { 851 // Allow casting from F64 back to F32. 852 if (!resultType.isF16() && inputs.size() == 1 && 853 inputs[0].getType().isF64()) 854 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 855 // Allow producing an i32 or i64 from nothing. 856 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 857 inputs.empty()) 858 return builder.create<TestTypeProducerOp>(loc, resultType); 859 // Allow producing an i64 from an integer. 860 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 861 inputs[0].getType().isa<IntegerType>()) 862 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 863 // Otherwise, fail. 864 return nullptr; 865 }); 866 867 // Initialize the conversion target. 868 mlir::ConversionTarget target(getContext()); 869 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 870 return op.getType().isF64() || op.getType().isInteger(64); 871 }); 872 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 873 return converter.isSignatureLegal(op.getType()) && 874 converter.isLegal(&op.getBody()); 875 }); 876 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 877 // Allow casts from F64 to F32. 878 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 879 }); 880 881 // Initialize the set of rewrite patterns. 882 RewritePatternSet patterns(&getContext()); 883 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 884 TestSignatureConversionUndo>(converter, &getContext()); 885 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 886 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 887 888 if (failed(applyPartialConversion(getOperation(), target, 889 std::move(patterns)))) 890 signalPassFailure(); 891 } 892 }; 893 } // end anonymous namespace 894 895 //===----------------------------------------------------------------------===// 896 // Test Block Merging 897 //===----------------------------------------------------------------------===// 898 899 namespace { 900 /// A rewriter pattern that tests that blocks can be merged. 901 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 902 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 903 904 LogicalResult 905 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 906 ConversionPatternRewriter &rewriter) const final { 907 Block &firstBlock = op.body().front(); 908 Operation *branchOp = firstBlock.getTerminator(); 909 Block *secondBlock = &*(std::next(op.body().begin())); 910 auto succOperands = branchOp->getOperands(); 911 SmallVector<Value, 2> replacements(succOperands); 912 rewriter.eraseOp(branchOp); 913 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 914 rewriter.updateRootInPlace(op, [] {}); 915 return success(); 916 } 917 }; 918 919 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 920 struct TestUndoBlocksMerge : public ConversionPattern { 921 TestUndoBlocksMerge(MLIRContext *ctx) 922 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 923 LogicalResult 924 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 925 ConversionPatternRewriter &rewriter) const final { 926 Block &firstBlock = op->getRegion(0).front(); 927 Operation *branchOp = firstBlock.getTerminator(); 928 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 929 rewriter.setInsertionPointToStart(secondBlock); 930 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 931 auto succOperands = branchOp->getOperands(); 932 SmallVector<Value, 2> replacements(succOperands); 933 rewriter.eraseOp(branchOp); 934 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 935 rewriter.updateRootInPlace(op, [] {}); 936 return success(); 937 } 938 }; 939 940 /// A rewrite mechanism to inline the body of the op into its parent, when both 941 /// ops can have a single block. 942 struct TestMergeSingleBlockOps 943 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 944 using OpConversionPattern< 945 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 946 947 LogicalResult 948 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 949 ConversionPatternRewriter &rewriter) const final { 950 SingleBlockImplicitTerminatorOp parentOp = 951 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 952 if (!parentOp) 953 return failure(); 954 Block &innerBlock = op.region().front(); 955 TerminatorOp innerTerminator = 956 cast<TerminatorOp>(innerBlock.getTerminator()); 957 rewriter.mergeBlockBefore(&innerBlock, op); 958 rewriter.eraseOp(innerTerminator); 959 rewriter.eraseOp(op); 960 rewriter.updateRootInPlace(op, [] {}); 961 return success(); 962 } 963 }; 964 965 struct TestMergeBlocksPatternDriver 966 : public PassWrapper<TestMergeBlocksPatternDriver, 967 OperationPass<ModuleOp>> { 968 void runOnOperation() override { 969 MLIRContext *context = &getContext(); 970 mlir::RewritePatternSet patterns(context); 971 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 972 context); 973 ConversionTarget target(*context); 974 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 975 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 976 target.addIllegalOp<ILLegalOpF>(); 977 978 /// Expect the op to have a single block after legalization. 979 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 980 [&](TestMergeBlocksOp op) -> bool { 981 return llvm::hasSingleElement(op.body()); 982 }); 983 984 /// Only allow `test.br` within test.merge_blocks op. 985 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 986 return op->getParentOfType<TestMergeBlocksOp>(); 987 }); 988 989 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 990 /// inlined. 991 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 992 [&](SingleBlockImplicitTerminatorOp op) -> bool { 993 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 994 }); 995 996 DenseSet<Operation *> unlegalizedOps; 997 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 998 &unlegalizedOps); 999 for (auto *op : unlegalizedOps) 1000 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1001 } 1002 }; 1003 } // namespace 1004 1005 //===----------------------------------------------------------------------===// 1006 // Test Selective Replacement 1007 //===----------------------------------------------------------------------===// 1008 1009 namespace { 1010 /// A rewrite mechanism to inline the body of the op into its parent, when both 1011 /// ops can have a single block. 1012 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1013 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1014 1015 LogicalResult matchAndRewrite(TestCastOp op, 1016 PatternRewriter &rewriter) const final { 1017 if (op.getNumOperands() != 2) 1018 return failure(); 1019 OperandRange operands = op.getOperands(); 1020 1021 // Replace non-terminator uses with the first operand. 1022 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1023 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1024 }); 1025 // Replace everything else with the second operand if the operation isn't 1026 // dead. 1027 rewriter.replaceOp(op, op.getOperand(1)); 1028 return success(); 1029 } 1030 }; 1031 1032 struct TestSelectiveReplacementPatternDriver 1033 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1034 OperationPass<>> { 1035 void runOnOperation() override { 1036 MLIRContext *context = &getContext(); 1037 mlir::RewritePatternSet patterns(context); 1038 patterns.add<TestSelectiveOpReplacementPattern>(context); 1039 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1040 std::move(patterns)); 1041 } 1042 }; 1043 } // namespace 1044 1045 //===----------------------------------------------------------------------===// 1046 // PassRegistration 1047 //===----------------------------------------------------------------------===// 1048 1049 namespace mlir { 1050 namespace test { 1051 void registerPatternsTestPass() { 1052 PassRegistration<TestReturnTypeDriver>("test-return-type", 1053 "Run return type functions"); 1054 1055 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 1056 "Run test derived attributes"); 1057 1058 PassRegistration<TestPatternDriver>("test-patterns", 1059 "Run test dialect patterns"); 1060 1061 PassRegistration<TestLegalizePatternDriver>( 1062 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 1063 return std::make_unique<TestLegalizePatternDriver>( 1064 legalizerConversionMode); 1065 }); 1066 1067 PassRegistration<TestRemappedValue>( 1068 "test-remapped-value", 1069 "Test public remapped value mechanism in ConversionPatternRewriter"); 1070 1071 PassRegistration<TestUnknownRootOpDriver>( 1072 "test-legalize-unknown-root-patterns", 1073 "Test public remapped value mechanism in ConversionPatternRewriter"); 1074 1075 PassRegistration<TestTypeConversionDriver>( 1076 "test-legalize-type-conversion", 1077 "Test various type conversion functionalities in DialectConversion"); 1078 1079 PassRegistration<TestMergeBlocksPatternDriver>{ 1080 "test-merge-blocks", 1081 "Test Merging operation in ConversionPatternRewriter"}; 1082 PassRegistration<TestSelectiveReplacementPatternDriver>{ 1083 "test-pattern-selective-replacement", 1084 "Test selective replacement in the PatternRewriter"}; 1085 } 1086 } // namespace test 1087 } // namespace mlir 1088