1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 // Test that natives calls are only called once during rewrites. 39 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 40 // This let us check the number of times OpM_Test was called by inspecting 41 // the returned value in the MLIR output. 42 static int64_t opMIncreasingValue = 314159265; 43 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 44 int64_t i = opMIncreasingValue++; 45 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 46 } 47 48 namespace { 49 #include "TestPatterns.inc" 50 } // end anonymous namespace 51 52 //===----------------------------------------------------------------------===// 53 // Canonicalizer Driver. 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 struct FoldingPattern : public RewritePattern { 58 public: 59 FoldingPattern(MLIRContext *context) 60 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 61 /*benefit=*/1, context) {} 62 63 LogicalResult matchAndRewrite(Operation *op, 64 PatternRewriter &rewriter) const override { 65 // Exercise OperationFolder API for a single-result operation that is folded 66 // upon construction. The operation being created through the folder has an 67 // in-place folder, and it should be still present in the output. 68 // Furthermore, the folder should not crash when attempting to recover the 69 // (unchanged) operation result. 70 OperationFolder folder(op->getContext()); 71 Value result = folder.create<TestOpInPlaceFold>( 72 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 73 rewriter.getI32IntegerAttr(0)); 74 assert(result); 75 rewriter.replaceOp(op, result); 76 return success(); 77 } 78 }; 79 80 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 81 void runOnFunction() override { 82 mlir::RewritePatternSet patterns(&getContext()); 83 populateWithGenerated(patterns); 84 85 // Verify named pattern is generated with expected name. 86 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext()); 87 88 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 89 } 90 }; 91 } // end anonymous namespace 92 93 //===----------------------------------------------------------------------===// 94 // ReturnType Driver. 95 //===----------------------------------------------------------------------===// 96 97 namespace { 98 // Generate ops for each instance where the type can be successfully inferred. 99 template <typename OpTy> 100 static void invokeCreateWithInferredReturnType(Operation *op) { 101 auto *context = op->getContext(); 102 auto fop = op->getParentOfType<FuncOp>(); 103 auto location = UnknownLoc::get(context); 104 OpBuilder b(op); 105 b.setInsertionPointAfter(op); 106 107 // Use permutations of 2 args as operands. 108 assert(fop.getNumArguments() >= 2); 109 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 110 for (int j = 0; j < e; ++j) { 111 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 112 SmallVector<Type, 2> inferredReturnTypes; 113 if (succeeded(OpTy::inferReturnTypes( 114 context, llvm::None, values, op->getAttrDictionary(), 115 op->getRegions(), inferredReturnTypes))) { 116 OperationState state(location, OpTy::getOperationName()); 117 // TODO: Expand to regions. 118 OpTy::build(b, state, values, op->getAttrs()); 119 (void)b.createOperation(state); 120 } 121 } 122 } 123 } 124 125 static void reifyReturnShape(Operation *op) { 126 OpBuilder b(op); 127 128 // Use permutations of 2 args as operands. 129 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 130 SmallVector<Value, 2> shapes; 131 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 132 return; 133 for (auto it : llvm::enumerate(shapes)) 134 op->emitRemark() << "value " << it.index() << ": " 135 << it.value().getDefiningOp(); 136 } 137 138 struct TestReturnTypeDriver 139 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 140 void getDependentDialects(DialectRegistry ®istry) const override { 141 registry.insert<memref::MemRefDialect>(); 142 } 143 144 void runOnFunction() override { 145 if (getFunction().getName() == "testCreateFunctions") { 146 std::vector<Operation *> ops; 147 // Collect ops to avoid triggering on inserted ops. 148 for (auto &op : getFunction().getBody().front()) 149 ops.push_back(&op); 150 // Generate test patterns for each, but skip terminator. 151 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 152 // Test create method of each of the Op classes below. The resultant 153 // output would be in reverse order underneath `op` from which 154 // the attributes and regions are used. 155 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 156 invokeCreateWithInferredReturnType< 157 OpWithShapedTypeInferTypeInterfaceOp>(op); 158 }; 159 return; 160 } 161 if (getFunction().getName() == "testReifyFunctions") { 162 std::vector<Operation *> ops; 163 // Collect ops to avoid triggering on inserted ops. 164 for (auto &op : getFunction().getBody().front()) 165 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 166 ops.push_back(&op); 167 // Generate test patterns for each, but skip terminator. 168 for (auto *op : ops) 169 reifyReturnShape(op); 170 } 171 } 172 }; 173 } // end anonymous namespace 174 175 namespace { 176 struct TestDerivedAttributeDriver 177 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 178 void runOnFunction() override; 179 }; 180 } // end anonymous namespace 181 182 void TestDerivedAttributeDriver::runOnFunction() { 183 getFunction().walk([](DerivedAttributeOpInterface dOp) { 184 auto dAttr = dOp.materializeDerivedAttributes(); 185 if (!dAttr) 186 return; 187 for (auto d : dAttr) 188 dOp.emitRemark() << d.first << " = " << d.second; 189 }); 190 } 191 192 //===----------------------------------------------------------------------===// 193 // Legalization Driver. 194 //===----------------------------------------------------------------------===// 195 196 namespace { 197 //===----------------------------------------------------------------------===// 198 // Region-Block Rewrite Testing 199 200 /// This pattern is a simple pattern that inlines the first region of a given 201 /// operation into the parent region. 202 struct TestRegionRewriteBlockMovement : public ConversionPattern { 203 TestRegionRewriteBlockMovement(MLIRContext *ctx) 204 : ConversionPattern("test.region", 1, ctx) {} 205 206 LogicalResult 207 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 208 ConversionPatternRewriter &rewriter) const final { 209 // Inline this region into the parent region. 210 auto &parentRegion = *op->getParentRegion(); 211 auto &opRegion = op->getRegion(0); 212 if (op->getAttr("legalizer.should_clone")) 213 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 214 else 215 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 216 217 if (op->getAttr("legalizer.erase_old_blocks")) { 218 while (!opRegion.empty()) 219 rewriter.eraseBlock(&opRegion.front()); 220 } 221 222 // Drop this operation. 223 rewriter.eraseOp(op); 224 return success(); 225 } 226 }; 227 /// This pattern is a simple pattern that generates a region containing an 228 /// illegal operation. 229 struct TestRegionRewriteUndo : public RewritePattern { 230 TestRegionRewriteUndo(MLIRContext *ctx) 231 : RewritePattern("test.region_builder", 1, ctx) {} 232 233 LogicalResult matchAndRewrite(Operation *op, 234 PatternRewriter &rewriter) const final { 235 // Create the region operation with an entry block containing arguments. 236 OperationState newRegion(op->getLoc(), "test.region"); 237 newRegion.addRegion(); 238 auto *regionOp = rewriter.createOperation(newRegion); 239 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 240 entryBlock->addArgument(rewriter.getIntegerType(64)); 241 242 // Add an explicitly illegal operation to ensure the conversion fails. 243 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 244 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 245 246 // Drop this operation. 247 rewriter.eraseOp(op); 248 return success(); 249 } 250 }; 251 /// A simple pattern that creates a block at the end of the parent region of the 252 /// matched operation. 253 struct TestCreateBlock : public RewritePattern { 254 TestCreateBlock(MLIRContext *ctx) 255 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 256 257 LogicalResult matchAndRewrite(Operation *op, 258 PatternRewriter &rewriter) const final { 259 Region ®ion = *op->getParentRegion(); 260 Type i32Type = rewriter.getIntegerType(32); 261 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 262 rewriter.create<TerminatorOp>(op->getLoc()); 263 rewriter.replaceOp(op, {}); 264 return success(); 265 } 266 }; 267 268 /// A simple pattern that creates a block containing an invalid operation in 269 /// order to trigger the block creation undo mechanism. 270 struct TestCreateIllegalBlock : public RewritePattern { 271 TestCreateIllegalBlock(MLIRContext *ctx) 272 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 273 274 LogicalResult matchAndRewrite(Operation *op, 275 PatternRewriter &rewriter) const final { 276 Region ®ion = *op->getParentRegion(); 277 Type i32Type = rewriter.getIntegerType(32); 278 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 279 // Create an illegal op to ensure the conversion fails. 280 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 281 rewriter.create<TerminatorOp>(op->getLoc()); 282 rewriter.replaceOp(op, {}); 283 return success(); 284 } 285 }; 286 287 /// A simple pattern that tests the undo mechanism when replacing the uses of a 288 /// block argument. 289 struct TestUndoBlockArgReplace : public ConversionPattern { 290 TestUndoBlockArgReplace(MLIRContext *ctx) 291 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 292 293 LogicalResult 294 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 295 ConversionPatternRewriter &rewriter) const final { 296 auto illegalOp = 297 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 298 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 299 illegalOp); 300 rewriter.updateRootInPlace(op, [] {}); 301 return success(); 302 } 303 }; 304 305 /// A rewrite pattern that tests the undo mechanism when erasing a block. 306 struct TestUndoBlockErase : public ConversionPattern { 307 TestUndoBlockErase(MLIRContext *ctx) 308 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 309 310 LogicalResult 311 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 312 ConversionPatternRewriter &rewriter) const final { 313 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 314 rewriter.setInsertionPointToStart(secondBlock); 315 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 316 rewriter.eraseBlock(secondBlock); 317 rewriter.updateRootInPlace(op, [] {}); 318 return success(); 319 } 320 }; 321 322 //===----------------------------------------------------------------------===// 323 // Type-Conversion Rewrite Testing 324 325 /// This patterns erases a region operation that has had a type conversion. 326 struct TestDropOpSignatureConversion : public ConversionPattern { 327 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 328 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 329 LogicalResult 330 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 331 ConversionPatternRewriter &rewriter) const override { 332 Region ®ion = op->getRegion(0); 333 Block *entry = ®ion.front(); 334 335 // Convert the original entry arguments. 336 TypeConverter &converter = *getTypeConverter(); 337 TypeConverter::SignatureConversion result(entry->getNumArguments()); 338 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 339 result)) || 340 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 341 return failure(); 342 343 // Convert the region signature and just drop the operation. 344 rewriter.eraseOp(op); 345 return success(); 346 } 347 }; 348 /// This pattern simply updates the operands of the given operation. 349 struct TestPassthroughInvalidOp : public ConversionPattern { 350 TestPassthroughInvalidOp(MLIRContext *ctx) 351 : ConversionPattern("test.invalid", 1, ctx) {} 352 LogicalResult 353 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 354 ConversionPatternRewriter &rewriter) const final { 355 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 356 llvm::None); 357 return success(); 358 } 359 }; 360 /// This pattern handles the case of a split return value. 361 struct TestSplitReturnType : public ConversionPattern { 362 TestSplitReturnType(MLIRContext *ctx) 363 : ConversionPattern("test.return", 1, ctx) {} 364 LogicalResult 365 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 366 ConversionPatternRewriter &rewriter) const final { 367 // Check for a return of F32. 368 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 369 return failure(); 370 371 // Check if the first operation is a cast operation, if it is we use the 372 // results directly. 373 auto *defOp = operands[0].getDefiningOp(); 374 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 375 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 376 return success(); 377 } 378 379 // Otherwise, fail to match. 380 return failure(); 381 } 382 }; 383 384 //===----------------------------------------------------------------------===// 385 // Multi-Level Type-Conversion Rewrite Testing 386 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 387 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 388 : ConversionPattern("test.type_producer", 1, ctx) {} 389 LogicalResult 390 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 391 ConversionPatternRewriter &rewriter) const final { 392 // If the type is I32, change the type to F32. 393 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 394 return failure(); 395 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 396 return success(); 397 } 398 }; 399 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 400 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 401 : ConversionPattern("test.type_producer", 1, ctx) {} 402 LogicalResult 403 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 404 ConversionPatternRewriter &rewriter) const final { 405 // If the type is F32, change the type to F64. 406 if (!Type(*op->result_type_begin()).isF32()) 407 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 408 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 409 return success(); 410 } 411 }; 412 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 413 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 414 : ConversionPattern("test.type_producer", 10, ctx) {} 415 LogicalResult 416 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 417 ConversionPatternRewriter &rewriter) const final { 418 // Always convert to B16, even though it is not a legal type. This tests 419 // that values are unmapped correctly. 420 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 421 return success(); 422 } 423 }; 424 struct TestUpdateConsumerType : public ConversionPattern { 425 TestUpdateConsumerType(MLIRContext *ctx) 426 : ConversionPattern("test.type_consumer", 1, ctx) {} 427 LogicalResult 428 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 429 ConversionPatternRewriter &rewriter) const final { 430 // Verify that the incoming operand has been successfully remapped to F64. 431 if (!operands[0].getType().isF64()) 432 return failure(); 433 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 434 return success(); 435 } 436 }; 437 438 //===----------------------------------------------------------------------===// 439 // Non-Root Replacement Rewrite Testing 440 /// This pattern generates an invalid operation, but replaces it before the 441 /// pattern is finished. This checks that we don't need to legalize the 442 /// temporary op. 443 struct TestNonRootReplacement : public RewritePattern { 444 TestNonRootReplacement(MLIRContext *ctx) 445 : RewritePattern("test.replace_non_root", 1, ctx) {} 446 447 LogicalResult matchAndRewrite(Operation *op, 448 PatternRewriter &rewriter) const final { 449 auto resultType = *op->result_type_begin(); 450 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 451 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 452 453 rewriter.replaceOp(illegalOp, {legalOp}); 454 rewriter.replaceOp(op, {illegalOp}); 455 return success(); 456 } 457 }; 458 459 //===----------------------------------------------------------------------===// 460 // Recursive Rewrite Testing 461 /// This pattern is applied to the same operation multiple times, but has a 462 /// bounded recursion. 463 struct TestBoundedRecursiveRewrite 464 : public OpRewritePattern<TestRecursiveRewriteOp> { 465 TestBoundedRecursiveRewrite(MLIRContext *ctx) 466 : OpRewritePattern<TestRecursiveRewriteOp>(ctx) { 467 // The conversion target handles bounding the recursion of this pattern. 468 setHasBoundedRewriteRecursion(); 469 } 470 471 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 472 PatternRewriter &rewriter) const final { 473 // Decrement the depth of the op in-place. 474 rewriter.updateRootInPlace(op, [&] { 475 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 476 }); 477 return success(); 478 } 479 }; 480 481 struct TestNestedOpCreationUndoRewrite 482 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 483 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 484 485 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 486 PatternRewriter &rewriter) const final { 487 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 488 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 489 return success(); 490 }; 491 }; 492 } // namespace 493 494 namespace { 495 struct TestTypeConverter : public TypeConverter { 496 using TypeConverter::TypeConverter; 497 TestTypeConverter() { 498 addConversion(convertType); 499 addArgumentMaterialization(materializeCast); 500 addSourceMaterialization(materializeCast); 501 502 /// Materialize the cast for one-to-one conversion from i64 to f64. 503 const auto materializeOneToOneCast = 504 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 505 Location loc) -> Optional<Value> { 506 if (resultType.getWidth() == 42 && inputs.size() == 1) 507 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 508 return llvm::None; 509 }; 510 addArgumentMaterialization(materializeOneToOneCast); 511 } 512 513 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 514 // Drop I16 types. 515 if (t.isSignlessInteger(16)) 516 return success(); 517 518 // Convert I64 to F64. 519 if (t.isSignlessInteger(64)) { 520 results.push_back(FloatType::getF64(t.getContext())); 521 return success(); 522 } 523 524 // Convert I42 to I43. 525 if (t.isInteger(42)) { 526 results.push_back(IntegerType::get(t.getContext(), 43)); 527 return success(); 528 } 529 530 // Split F32 into F16,F16. 531 if (t.isF32()) { 532 results.assign(2, FloatType::getF16(t.getContext())); 533 return success(); 534 } 535 536 // Otherwise, convert the type directly. 537 results.push_back(t); 538 return success(); 539 } 540 541 /// Hook for materializing a conversion. This is necessary because we generate 542 /// 1->N type mappings. 543 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 544 ValueRange inputs, Location loc) { 545 if (inputs.size() == 1) 546 return inputs[0]; 547 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 548 } 549 }; 550 551 struct TestLegalizePatternDriver 552 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 553 /// The mode of conversion to use with the driver. 554 enum class ConversionMode { Analysis, Full, Partial }; 555 556 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 557 558 void runOnOperation() override { 559 TestTypeConverter converter; 560 mlir::RewritePatternSet patterns(&getContext()); 561 populateWithGenerated(patterns); 562 patterns 563 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 564 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 565 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 566 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 567 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 568 TestNonRootReplacement, TestBoundedRecursiveRewrite, 569 TestNestedOpCreationUndoRewrite>(&getContext()); 570 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 571 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 572 mlir::populateCallOpTypeConversionPattern(patterns, converter); 573 574 // Define the conversion target used for the test. 575 ConversionTarget target(getContext()); 576 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 577 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 578 TerminatorOp>(); 579 target 580 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 581 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 582 // Don't allow F32 operands. 583 return llvm::none_of(op.getOperandTypes(), 584 [](Type type) { return type.isF32(); }); 585 }); 586 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 587 return converter.isSignatureLegal(op.getType()) && 588 converter.isLegal(&op.getBody()); 589 }); 590 591 // Expect the type_producer/type_consumer operations to only operate on f64. 592 target.addDynamicallyLegalOp<TestTypeProducerOp>( 593 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 594 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 595 return op.getOperand().getType().isF64(); 596 }); 597 598 // Check support for marking certain operations as recursively legal. 599 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 600 return static_cast<bool>( 601 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 602 }); 603 604 // Mark the bound recursion operation as dynamically legal. 605 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 606 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 607 608 // Handle a partial conversion. 609 if (mode == ConversionMode::Partial) { 610 DenseSet<Operation *> unlegalizedOps; 611 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 612 &unlegalizedOps); 613 // Emit remarks for each legalizable operation. 614 for (auto *op : unlegalizedOps) 615 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 616 return; 617 } 618 619 // Handle a full conversion. 620 if (mode == ConversionMode::Full) { 621 // Check support for marking unknown operations as dynamically legal. 622 target.markUnknownOpDynamicallyLegal([](Operation *op) { 623 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 624 }); 625 626 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 627 return; 628 } 629 630 // Otherwise, handle an analysis conversion. 631 assert(mode == ConversionMode::Analysis); 632 633 // Analyze the convertible operations. 634 DenseSet<Operation *> legalizedOps; 635 if (failed(applyAnalysisConversion(getOperation(), target, 636 std::move(patterns), legalizedOps))) 637 return signalPassFailure(); 638 639 // Emit remarks for each legalizable operation. 640 for (auto *op : legalizedOps) 641 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 642 } 643 644 /// The mode of conversion to use. 645 ConversionMode mode; 646 }; 647 } // end anonymous namespace 648 649 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 650 legalizerConversionMode( 651 "test-legalize-mode", 652 llvm::cl::desc("The legalization mode to use with the test driver"), 653 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 654 llvm::cl::values( 655 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 656 "analysis", "Perform an analysis conversion"), 657 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 658 "Perform a full conversion"), 659 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 660 "partial", "Perform a partial conversion"))); 661 662 //===----------------------------------------------------------------------===// 663 // ConversionPatternRewriter::getRemappedValue testing. This method is used 664 // to get the remapped value of an original value that was replaced using 665 // ConversionPatternRewriter. 666 namespace { 667 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 668 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 669 /// operand twice. 670 /// 671 /// Example: 672 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 673 /// is replaced with: 674 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 675 struct OneVResOneVOperandOp1Converter 676 : public OpConversionPattern<OneVResOneVOperandOp1> { 677 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 678 679 LogicalResult 680 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 681 ConversionPatternRewriter &rewriter) const override { 682 auto origOps = op.getOperands(); 683 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 684 "One operand expected"); 685 Value origOp = *origOps.begin(); 686 SmallVector<Value, 2> remappedOperands; 687 // Replicate the remapped original operand twice. Note that we don't used 688 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 689 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 690 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 691 692 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 693 remappedOperands); 694 return success(); 695 } 696 }; 697 698 struct TestRemappedValue 699 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 700 void runOnFunction() override { 701 mlir::RewritePatternSet patterns(&getContext()); 702 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 703 704 mlir::ConversionTarget target(getContext()); 705 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 706 // We make OneVResOneVOperandOp1 legal only when it has more that one 707 // operand. This will trigger the conversion that will replace one-operand 708 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 709 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 710 [](Operation *op) -> bool { 711 return std::distance(op->operand_begin(), op->operand_end()) > 1; 712 }); 713 714 if (failed(mlir::applyFullConversion(getFunction(), target, 715 std::move(patterns)))) { 716 signalPassFailure(); 717 } 718 } 719 }; 720 } // end anonymous namespace 721 722 //===----------------------------------------------------------------------===// 723 // Test patterns without a specific root operation kind 724 //===----------------------------------------------------------------------===// 725 726 namespace { 727 /// This pattern matches and removes any operation in the test dialect. 728 struct RemoveTestDialectOps : public RewritePattern { 729 RemoveTestDialectOps(MLIRContext *context) 730 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 731 732 LogicalResult matchAndRewrite(Operation *op, 733 PatternRewriter &rewriter) const override { 734 if (!isa<TestDialect>(op->getDialect())) 735 return failure(); 736 rewriter.eraseOp(op); 737 return success(); 738 } 739 }; 740 741 struct TestUnknownRootOpDriver 742 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 743 void runOnFunction() override { 744 mlir::RewritePatternSet patterns(&getContext()); 745 patterns.add<RemoveTestDialectOps>(&getContext()); 746 747 mlir::ConversionTarget target(getContext()); 748 target.addIllegalDialect<TestDialect>(); 749 if (failed( 750 applyPartialConversion(getFunction(), target, std::move(patterns)))) 751 signalPassFailure(); 752 } 753 }; 754 } // end anonymous namespace 755 756 //===----------------------------------------------------------------------===// 757 // Test type conversions 758 //===----------------------------------------------------------------------===// 759 760 namespace { 761 struct TestTypeConversionProducer 762 : public OpConversionPattern<TestTypeProducerOp> { 763 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 764 LogicalResult 765 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 766 ConversionPatternRewriter &rewriter) const final { 767 Type resultType = op.getType(); 768 if (resultType.isa<FloatType>()) 769 resultType = rewriter.getF64Type(); 770 else if (resultType.isInteger(16)) 771 resultType = rewriter.getIntegerType(64); 772 else 773 return failure(); 774 775 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 776 return success(); 777 } 778 }; 779 780 /// Call signature conversion and then fail the rewrite to trigger the undo 781 /// mechanism. 782 struct TestSignatureConversionUndo 783 : public OpConversionPattern<TestSignatureConversionUndoOp> { 784 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 785 786 LogicalResult 787 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 788 ConversionPatternRewriter &rewriter) const final { 789 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 790 return failure(); 791 } 792 }; 793 794 /// Just forward the operands to the root op. This is essentially a no-op 795 /// pattern that is used to trigger target materialization. 796 struct TestTypeConsumerForward 797 : public OpConversionPattern<TestTypeConsumerOp> { 798 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 799 800 LogicalResult 801 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 802 ConversionPatternRewriter &rewriter) const final { 803 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 804 return success(); 805 } 806 }; 807 808 struct TestTypeConversionAnotherProducer 809 : public OpRewritePattern<TestAnotherTypeProducerOp> { 810 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 811 812 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 813 PatternRewriter &rewriter) const final { 814 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 815 return success(); 816 } 817 }; 818 819 struct TestTypeConversionDriver 820 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 821 void getDependentDialects(DialectRegistry ®istry) const override { 822 registry.insert<TestDialect>(); 823 } 824 825 void runOnOperation() override { 826 // Initialize the type converter. 827 TypeConverter converter; 828 829 /// Add the legal set of type conversions. 830 converter.addConversion([](Type type) -> Type { 831 // Treat F64 as legal. 832 if (type.isF64()) 833 return type; 834 // Allow converting BF16/F16/F32 to F64. 835 if (type.isBF16() || type.isF16() || type.isF32()) 836 return FloatType::getF64(type.getContext()); 837 // Otherwise, the type is illegal. 838 return nullptr; 839 }); 840 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 841 // Drop all integer types. 842 return success(); 843 }); 844 845 /// Add the legal set of type materializations. 846 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 847 ValueRange inputs, 848 Location loc) -> Value { 849 // Allow casting from F64 back to F32. 850 if (!resultType.isF16() && inputs.size() == 1 && 851 inputs[0].getType().isF64()) 852 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 853 // Allow producing an i32 or i64 from nothing. 854 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 855 inputs.empty()) 856 return builder.create<TestTypeProducerOp>(loc, resultType); 857 // Allow producing an i64 from an integer. 858 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 859 inputs[0].getType().isa<IntegerType>()) 860 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 861 // Otherwise, fail. 862 return nullptr; 863 }); 864 865 // Initialize the conversion target. 866 mlir::ConversionTarget target(getContext()); 867 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 868 return op.getType().isF64() || op.getType().isInteger(64); 869 }); 870 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 871 return converter.isSignatureLegal(op.getType()) && 872 converter.isLegal(&op.getBody()); 873 }); 874 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 875 // Allow casts from F64 to F32. 876 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 877 }); 878 879 // Initialize the set of rewrite patterns. 880 RewritePatternSet patterns(&getContext()); 881 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 882 TestSignatureConversionUndo>(converter, &getContext()); 883 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 884 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 885 886 if (failed(applyPartialConversion(getOperation(), target, 887 std::move(patterns)))) 888 signalPassFailure(); 889 } 890 }; 891 } // end anonymous namespace 892 893 //===----------------------------------------------------------------------===// 894 // Test Block Merging 895 //===----------------------------------------------------------------------===// 896 897 namespace { 898 /// A rewriter pattern that tests that blocks can be merged. 899 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 900 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 901 902 LogicalResult 903 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 904 ConversionPatternRewriter &rewriter) const final { 905 Block &firstBlock = op.body().front(); 906 Operation *branchOp = firstBlock.getTerminator(); 907 Block *secondBlock = &*(std::next(op.body().begin())); 908 auto succOperands = branchOp->getOperands(); 909 SmallVector<Value, 2> replacements(succOperands); 910 rewriter.eraseOp(branchOp); 911 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 912 rewriter.updateRootInPlace(op, [] {}); 913 return success(); 914 } 915 }; 916 917 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 918 struct TestUndoBlocksMerge : public ConversionPattern { 919 TestUndoBlocksMerge(MLIRContext *ctx) 920 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 921 LogicalResult 922 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 923 ConversionPatternRewriter &rewriter) const final { 924 Block &firstBlock = op->getRegion(0).front(); 925 Operation *branchOp = firstBlock.getTerminator(); 926 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 927 rewriter.setInsertionPointToStart(secondBlock); 928 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 929 auto succOperands = branchOp->getOperands(); 930 SmallVector<Value, 2> replacements(succOperands); 931 rewriter.eraseOp(branchOp); 932 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 933 rewriter.updateRootInPlace(op, [] {}); 934 return success(); 935 } 936 }; 937 938 /// A rewrite mechanism to inline the body of the op into its parent, when both 939 /// ops can have a single block. 940 struct TestMergeSingleBlockOps 941 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 942 using OpConversionPattern< 943 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 944 945 LogicalResult 946 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 947 ConversionPatternRewriter &rewriter) const final { 948 SingleBlockImplicitTerminatorOp parentOp = 949 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 950 if (!parentOp) 951 return failure(); 952 Block &innerBlock = op.region().front(); 953 TerminatorOp innerTerminator = 954 cast<TerminatorOp>(innerBlock.getTerminator()); 955 rewriter.mergeBlockBefore(&innerBlock, op); 956 rewriter.eraseOp(innerTerminator); 957 rewriter.eraseOp(op); 958 rewriter.updateRootInPlace(op, [] {}); 959 return success(); 960 } 961 }; 962 963 struct TestMergeBlocksPatternDriver 964 : public PassWrapper<TestMergeBlocksPatternDriver, 965 OperationPass<ModuleOp>> { 966 void runOnOperation() override { 967 MLIRContext *context = &getContext(); 968 mlir::RewritePatternSet patterns(context); 969 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 970 context); 971 ConversionTarget target(*context); 972 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp, 973 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp, 974 TestReturnOp>(); 975 target.addIllegalOp<ILLegalOpF>(); 976 977 /// Expect the op to have a single block after legalization. 978 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 979 [&](TestMergeBlocksOp op) -> bool { 980 return llvm::hasSingleElement(op.body()); 981 }); 982 983 /// Only allow `test.br` within test.merge_blocks op. 984 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 985 return op->getParentOfType<TestMergeBlocksOp>(); 986 }); 987 988 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 989 /// inlined. 990 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 991 [&](SingleBlockImplicitTerminatorOp op) -> bool { 992 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 993 }); 994 995 DenseSet<Operation *> unlegalizedOps; 996 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 997 &unlegalizedOps); 998 for (auto *op : unlegalizedOps) 999 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1000 } 1001 }; 1002 } // namespace 1003 1004 //===----------------------------------------------------------------------===// 1005 // Test Selective Replacement 1006 //===----------------------------------------------------------------------===// 1007 1008 namespace { 1009 /// A rewrite mechanism to inline the body of the op into its parent, when both 1010 /// ops can have a single block. 1011 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1012 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1013 1014 LogicalResult matchAndRewrite(TestCastOp op, 1015 PatternRewriter &rewriter) const final { 1016 if (op.getNumOperands() != 2) 1017 return failure(); 1018 OperandRange operands = op.getOperands(); 1019 1020 // Replace non-terminator uses with the first operand. 1021 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1022 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1023 }); 1024 // Replace everything else with the second operand if the operation isn't 1025 // dead. 1026 rewriter.replaceOp(op, op.getOperand(1)); 1027 return success(); 1028 } 1029 }; 1030 1031 struct TestSelectiveReplacementPatternDriver 1032 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1033 OperationPass<>> { 1034 void runOnOperation() override { 1035 MLIRContext *context = &getContext(); 1036 mlir::RewritePatternSet patterns(context); 1037 patterns.add<TestSelectiveOpReplacementPattern>(context); 1038 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1039 std::move(patterns)); 1040 } 1041 }; 1042 } // namespace 1043 1044 //===----------------------------------------------------------------------===// 1045 // PassRegistration 1046 //===----------------------------------------------------------------------===// 1047 1048 namespace mlir { 1049 namespace test { 1050 void registerPatternsTestPass() { 1051 PassRegistration<TestReturnTypeDriver>("test-return-type", 1052 "Run return type functions"); 1053 1054 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 1055 "Run test derived attributes"); 1056 1057 PassRegistration<TestPatternDriver>("test-patterns", 1058 "Run test dialect patterns"); 1059 1060 PassRegistration<TestLegalizePatternDriver>( 1061 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 1062 return std::make_unique<TestLegalizePatternDriver>( 1063 legalizerConversionMode); 1064 }); 1065 1066 PassRegistration<TestRemappedValue>( 1067 "test-remapped-value", 1068 "Test public remapped value mechanism in ConversionPatternRewriter"); 1069 1070 PassRegistration<TestUnknownRootOpDriver>( 1071 "test-legalize-unknown-root-patterns", 1072 "Test public remapped value mechanism in ConversionPatternRewriter"); 1073 1074 PassRegistration<TestTypeConversionDriver>( 1075 "test-legalize-type-conversion", 1076 "Test various type conversion functionalities in DialectConversion"); 1077 1078 PassRegistration<TestMergeBlocksPatternDriver>{ 1079 "test-merge-blocks", 1080 "Test Merging operation in ConversionPatternRewriter"}; 1081 PassRegistration<TestSelectiveReplacementPatternDriver>{ 1082 "test-pattern-selective-replacement", 1083 "Test selective replacement in the PatternRewriter"}; 1084 } 1085 } // namespace test 1086 } // namespace mlir 1087