1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::test; 20 21 // Native function for testing NativeCodeCall 22 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 23 return choice.getValue() ? input1 : input2; 24 } 25 26 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 27 rewriter.create<OpI>(loc, input); 28 } 29 30 static void handleNoResultOp(PatternRewriter &rewriter, 31 OpSymbolBindingNoResult op) { 32 // Turn the no result op to a one-result op. 33 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 34 op.operand()); 35 } 36 37 // Test that natives calls are only called once during rewrites. 38 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 39 // This let us check the number of times OpM_Test was called by inspecting 40 // the returned value in the MLIR output. 41 static int64_t opMIncreasingValue = 314159265; 42 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 43 int64_t i = opMIncreasingValue++; 44 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 45 } 46 47 namespace { 48 #include "TestPatterns.inc" 49 } // end anonymous namespace 50 51 //===----------------------------------------------------------------------===// 52 // Canonicalizer Driver. 53 //===----------------------------------------------------------------------===// 54 55 namespace { 56 struct FoldingPattern : public RewritePattern { 57 public: 58 FoldingPattern(MLIRContext *context) 59 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 60 /*benefit=*/1, context) {} 61 62 LogicalResult matchAndRewrite(Operation *op, 63 PatternRewriter &rewriter) const override { 64 // Exercise OperationFolder API for a single-result operation that is folded 65 // upon construction. The operation being created through the folder has an 66 // in-place folder, and it should be still present in the output. 67 // Furthermore, the folder should not crash when attempting to recover the 68 // (unchanged) operation result. 69 OperationFolder folder(op->getContext()); 70 Value result = folder.create<TestOpInPlaceFold>( 71 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 72 rewriter.getI32IntegerAttr(0)); 73 assert(result); 74 rewriter.replaceOp(op, result); 75 return success(); 76 } 77 }; 78 79 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 80 void runOnFunction() override { 81 mlir::OwningRewritePatternList patterns; 82 populateWithGenerated(&getContext(), patterns); 83 84 // Verify named pattern is generated with expected name. 85 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext()); 86 87 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 88 } 89 }; 90 } // end anonymous namespace 91 92 //===----------------------------------------------------------------------===// 93 // ReturnType Driver. 94 //===----------------------------------------------------------------------===// 95 96 namespace { 97 // Generate ops for each instance where the type can be successfully inferred. 98 template <typename OpTy> 99 static void invokeCreateWithInferredReturnType(Operation *op) { 100 auto *context = op->getContext(); 101 auto fop = op->getParentOfType<FuncOp>(); 102 auto location = UnknownLoc::get(context); 103 OpBuilder b(op); 104 b.setInsertionPointAfter(op); 105 106 // Use permutations of 2 args as operands. 107 assert(fop.getNumArguments() >= 2); 108 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 109 for (int j = 0; j < e; ++j) { 110 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 111 SmallVector<Type, 2> inferredReturnTypes; 112 if (succeeded(OpTy::inferReturnTypes( 113 context, llvm::None, values, op->getAttrDictionary(), 114 op->getRegions(), inferredReturnTypes))) { 115 OperationState state(location, OpTy::getOperationName()); 116 // TODO: Expand to regions. 117 OpTy::build(b, state, values, op->getAttrs()); 118 (void)b.createOperation(state); 119 } 120 } 121 } 122 } 123 124 static void reifyReturnShape(Operation *op) { 125 OpBuilder b(op); 126 127 // Use permutations of 2 args as operands. 128 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 129 SmallVector<Value, 2> shapes; 130 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 131 return; 132 for (auto it : llvm::enumerate(shapes)) 133 op->emitRemark() << "value " << it.index() << ": " 134 << it.value().getDefiningOp(); 135 } 136 137 struct TestReturnTypeDriver 138 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 139 void runOnFunction() override { 140 if (getFunction().getName() == "testCreateFunctions") { 141 std::vector<Operation *> ops; 142 // Collect ops to avoid triggering on inserted ops. 143 for (auto &op : getFunction().getBody().front()) 144 ops.push_back(&op); 145 // Generate test patterns for each, but skip terminator. 146 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 147 // Test create method of each of the Op classes below. The resultant 148 // output would be in reverse order underneath `op` from which 149 // the attributes and regions are used. 150 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 151 invokeCreateWithInferredReturnType< 152 OpWithShapedTypeInferTypeInterfaceOp>(op); 153 }; 154 return; 155 } 156 if (getFunction().getName() == "testReifyFunctions") { 157 std::vector<Operation *> ops; 158 // Collect ops to avoid triggering on inserted ops. 159 for (auto &op : getFunction().getBody().front()) 160 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 161 ops.push_back(&op); 162 // Generate test patterns for each, but skip terminator. 163 for (auto *op : ops) 164 reifyReturnShape(op); 165 } 166 } 167 }; 168 } // end anonymous namespace 169 170 namespace { 171 struct TestDerivedAttributeDriver 172 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 173 void runOnFunction() override; 174 }; 175 } // end anonymous namespace 176 177 void TestDerivedAttributeDriver::runOnFunction() { 178 getFunction().walk([](DerivedAttributeOpInterface dOp) { 179 auto dAttr = dOp.materializeDerivedAttributes(); 180 if (!dAttr) 181 return; 182 for (auto d : dAttr) 183 dOp.emitRemark() << d.first << " = " << d.second; 184 }); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // Legalization Driver. 189 //===----------------------------------------------------------------------===// 190 191 namespace { 192 //===----------------------------------------------------------------------===// 193 // Region-Block Rewrite Testing 194 195 /// This pattern is a simple pattern that inlines the first region of a given 196 /// operation into the parent region. 197 struct TestRegionRewriteBlockMovement : public ConversionPattern { 198 TestRegionRewriteBlockMovement(MLIRContext *ctx) 199 : ConversionPattern("test.region", 1, ctx) {} 200 201 LogicalResult 202 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 203 ConversionPatternRewriter &rewriter) const final { 204 // Inline this region into the parent region. 205 auto &parentRegion = *op->getParentRegion(); 206 auto &opRegion = op->getRegion(0); 207 if (op->getAttr("legalizer.should_clone")) 208 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 209 else 210 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 211 212 if (op->getAttr("legalizer.erase_old_blocks")) { 213 while (!opRegion.empty()) 214 rewriter.eraseBlock(&opRegion.front()); 215 } 216 217 // Drop this operation. 218 rewriter.eraseOp(op); 219 return success(); 220 } 221 }; 222 /// This pattern is a simple pattern that generates a region containing an 223 /// illegal operation. 224 struct TestRegionRewriteUndo : public RewritePattern { 225 TestRegionRewriteUndo(MLIRContext *ctx) 226 : RewritePattern("test.region_builder", 1, ctx) {} 227 228 LogicalResult matchAndRewrite(Operation *op, 229 PatternRewriter &rewriter) const final { 230 // Create the region operation with an entry block containing arguments. 231 OperationState newRegion(op->getLoc(), "test.region"); 232 newRegion.addRegion(); 233 auto *regionOp = rewriter.createOperation(newRegion); 234 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 235 entryBlock->addArgument(rewriter.getIntegerType(64)); 236 237 // Add an explicitly illegal operation to ensure the conversion fails. 238 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 239 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 240 241 // Drop this operation. 242 rewriter.eraseOp(op); 243 return success(); 244 } 245 }; 246 /// A simple pattern that creates a block at the end of the parent region of the 247 /// matched operation. 248 struct TestCreateBlock : public RewritePattern { 249 TestCreateBlock(MLIRContext *ctx) 250 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 251 252 LogicalResult matchAndRewrite(Operation *op, 253 PatternRewriter &rewriter) const final { 254 Region ®ion = *op->getParentRegion(); 255 Type i32Type = rewriter.getIntegerType(32); 256 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 257 rewriter.create<TerminatorOp>(op->getLoc()); 258 rewriter.replaceOp(op, {}); 259 return success(); 260 } 261 }; 262 263 /// A simple pattern that creates a block containing an invalid operation in 264 /// order to trigger the block creation undo mechanism. 265 struct TestCreateIllegalBlock : public RewritePattern { 266 TestCreateIllegalBlock(MLIRContext *ctx) 267 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 268 269 LogicalResult matchAndRewrite(Operation *op, 270 PatternRewriter &rewriter) const final { 271 Region ®ion = *op->getParentRegion(); 272 Type i32Type = rewriter.getIntegerType(32); 273 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 274 // Create an illegal op to ensure the conversion fails. 275 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 276 rewriter.create<TerminatorOp>(op->getLoc()); 277 rewriter.replaceOp(op, {}); 278 return success(); 279 } 280 }; 281 282 /// A simple pattern that tests the undo mechanism when replacing the uses of a 283 /// block argument. 284 struct TestUndoBlockArgReplace : public ConversionPattern { 285 TestUndoBlockArgReplace(MLIRContext *ctx) 286 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 287 288 LogicalResult 289 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 290 ConversionPatternRewriter &rewriter) const final { 291 auto illegalOp = 292 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 293 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 294 illegalOp); 295 rewriter.updateRootInPlace(op, [] {}); 296 return success(); 297 } 298 }; 299 300 /// A rewrite pattern that tests the undo mechanism when erasing a block. 301 struct TestUndoBlockErase : public ConversionPattern { 302 TestUndoBlockErase(MLIRContext *ctx) 303 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 304 305 LogicalResult 306 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 307 ConversionPatternRewriter &rewriter) const final { 308 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 309 rewriter.setInsertionPointToStart(secondBlock); 310 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 311 rewriter.eraseBlock(secondBlock); 312 rewriter.updateRootInPlace(op, [] {}); 313 return success(); 314 } 315 }; 316 317 //===----------------------------------------------------------------------===// 318 // Type-Conversion Rewrite Testing 319 320 /// This patterns erases a region operation that has had a type conversion. 321 struct TestDropOpSignatureConversion : public ConversionPattern { 322 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 323 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} 324 LogicalResult 325 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 326 ConversionPatternRewriter &rewriter) const override { 327 Region ®ion = op->getRegion(0); 328 Block *entry = ®ion.front(); 329 330 // Convert the original entry arguments. 331 TypeConverter &converter = *getTypeConverter(); 332 TypeConverter::SignatureConversion result(entry->getNumArguments()); 333 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 334 result)) || 335 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 336 return failure(); 337 338 // Convert the region signature and just drop the operation. 339 rewriter.eraseOp(op); 340 return success(); 341 } 342 }; 343 /// This pattern simply updates the operands of the given operation. 344 struct TestPassthroughInvalidOp : public ConversionPattern { 345 TestPassthroughInvalidOp(MLIRContext *ctx) 346 : ConversionPattern("test.invalid", 1, ctx) {} 347 LogicalResult 348 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 349 ConversionPatternRewriter &rewriter) const final { 350 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 351 llvm::None); 352 return success(); 353 } 354 }; 355 /// This pattern handles the case of a split return value. 356 struct TestSplitReturnType : public ConversionPattern { 357 TestSplitReturnType(MLIRContext *ctx) 358 : ConversionPattern("test.return", 1, ctx) {} 359 LogicalResult 360 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 361 ConversionPatternRewriter &rewriter) const final { 362 // Check for a return of F32. 363 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 364 return failure(); 365 366 // Check if the first operation is a cast operation, if it is we use the 367 // results directly. 368 auto *defOp = operands[0].getDefiningOp(); 369 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 370 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 371 return success(); 372 } 373 374 // Otherwise, fail to match. 375 return failure(); 376 } 377 }; 378 379 //===----------------------------------------------------------------------===// 380 // Multi-Level Type-Conversion Rewrite Testing 381 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 382 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 383 : ConversionPattern("test.type_producer", 1, ctx) {} 384 LogicalResult 385 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 386 ConversionPatternRewriter &rewriter) const final { 387 // If the type is I32, change the type to F32. 388 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 389 return failure(); 390 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 391 return success(); 392 } 393 }; 394 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 395 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 396 : ConversionPattern("test.type_producer", 1, ctx) {} 397 LogicalResult 398 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 399 ConversionPatternRewriter &rewriter) const final { 400 // If the type is F32, change the type to F64. 401 if (!Type(*op->result_type_begin()).isF32()) 402 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 403 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 404 return success(); 405 } 406 }; 407 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 408 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 409 : ConversionPattern("test.type_producer", 10, ctx) {} 410 LogicalResult 411 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 412 ConversionPatternRewriter &rewriter) const final { 413 // Always convert to B16, even though it is not a legal type. This tests 414 // that values are unmapped correctly. 415 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 416 return success(); 417 } 418 }; 419 struct TestUpdateConsumerType : public ConversionPattern { 420 TestUpdateConsumerType(MLIRContext *ctx) 421 : ConversionPattern("test.type_consumer", 1, ctx) {} 422 LogicalResult 423 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 424 ConversionPatternRewriter &rewriter) const final { 425 // Verify that the incoming operand has been successfully remapped to F64. 426 if (!operands[0].getType().isF64()) 427 return failure(); 428 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 429 return success(); 430 } 431 }; 432 433 //===----------------------------------------------------------------------===// 434 // Non-Root Replacement Rewrite Testing 435 /// This pattern generates an invalid operation, but replaces it before the 436 /// pattern is finished. This checks that we don't need to legalize the 437 /// temporary op. 438 struct TestNonRootReplacement : public RewritePattern { 439 TestNonRootReplacement(MLIRContext *ctx) 440 : RewritePattern("test.replace_non_root", 1, ctx) {} 441 442 LogicalResult matchAndRewrite(Operation *op, 443 PatternRewriter &rewriter) const final { 444 auto resultType = *op->result_type_begin(); 445 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 446 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 447 448 rewriter.replaceOp(illegalOp, {legalOp}); 449 rewriter.replaceOp(op, {illegalOp}); 450 return success(); 451 } 452 }; 453 454 //===----------------------------------------------------------------------===// 455 // Recursive Rewrite Testing 456 /// This pattern is applied to the same operation multiple times, but has a 457 /// bounded recursion. 458 struct TestBoundedRecursiveRewrite 459 : public OpRewritePattern<TestRecursiveRewriteOp> { 460 TestBoundedRecursiveRewrite(MLIRContext *ctx) 461 : OpRewritePattern<TestRecursiveRewriteOp>(ctx) { 462 // The conversion target handles bounding the recursion of this pattern. 463 setHasBoundedRewriteRecursion(); 464 } 465 466 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 467 PatternRewriter &rewriter) const final { 468 // Decrement the depth of the op in-place. 469 rewriter.updateRootInPlace(op, [&] { 470 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 471 }); 472 return success(); 473 } 474 }; 475 476 struct TestNestedOpCreationUndoRewrite 477 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 478 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 479 480 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 481 PatternRewriter &rewriter) const final { 482 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 483 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 484 return success(); 485 }; 486 }; 487 } // namespace 488 489 namespace { 490 struct TestTypeConverter : public TypeConverter { 491 using TypeConverter::TypeConverter; 492 TestTypeConverter() { 493 addConversion(convertType); 494 addArgumentMaterialization(materializeCast); 495 addSourceMaterialization(materializeCast); 496 497 /// Materialize the cast for one-to-one conversion from i64 to f64. 498 const auto materializeOneToOneCast = 499 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 500 Location loc) -> Optional<Value> { 501 if (resultType.getWidth() == 42 && inputs.size() == 1) 502 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 503 return llvm::None; 504 }; 505 addArgumentMaterialization(materializeOneToOneCast); 506 } 507 508 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 509 // Drop I16 types. 510 if (t.isSignlessInteger(16)) 511 return success(); 512 513 // Convert I64 to F64. 514 if (t.isSignlessInteger(64)) { 515 results.push_back(FloatType::getF64(t.getContext())); 516 return success(); 517 } 518 519 // Convert I42 to I43. 520 if (t.isInteger(42)) { 521 results.push_back(IntegerType::get(t.getContext(), 43)); 522 return success(); 523 } 524 525 // Split F32 into F16,F16. 526 if (t.isF32()) { 527 results.assign(2, FloatType::getF16(t.getContext())); 528 return success(); 529 } 530 531 // Otherwise, convert the type directly. 532 results.push_back(t); 533 return success(); 534 } 535 536 /// Hook for materializing a conversion. This is necessary because we generate 537 /// 1->N type mappings. 538 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 539 ValueRange inputs, Location loc) { 540 if (inputs.size() == 1) 541 return inputs[0]; 542 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 543 } 544 }; 545 546 struct TestLegalizePatternDriver 547 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 548 /// The mode of conversion to use with the driver. 549 enum class ConversionMode { Analysis, Full, Partial }; 550 551 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 552 553 void runOnOperation() override { 554 TestTypeConverter converter; 555 mlir::OwningRewritePatternList patterns; 556 populateWithGenerated(&getContext(), patterns); 557 patterns.insert< 558 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 559 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, 560 TestPassthroughInvalidOp, TestSplitReturnType, 561 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 562 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 563 TestNonRootReplacement, TestBoundedRecursiveRewrite, 564 TestNestedOpCreationUndoRewrite>(&getContext()); 565 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 566 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 567 converter); 568 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 569 converter); 570 571 // Define the conversion target used for the test. 572 ConversionTarget target(getContext()); 573 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 574 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 575 TerminatorOp>(); 576 target 577 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 578 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 579 // Don't allow F32 operands. 580 return llvm::none_of(op.getOperandTypes(), 581 [](Type type) { return type.isF32(); }); 582 }); 583 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 584 return converter.isSignatureLegal(op.getType()) && 585 converter.isLegal(&op.getBody()); 586 }); 587 588 // Expect the type_producer/type_consumer operations to only operate on f64. 589 target.addDynamicallyLegalOp<TestTypeProducerOp>( 590 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 591 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 592 return op.getOperand().getType().isF64(); 593 }); 594 595 // Check support for marking certain operations as recursively legal. 596 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 597 return static_cast<bool>( 598 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 599 }); 600 601 // Mark the bound recursion operation as dynamically legal. 602 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 603 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 604 605 // Handle a partial conversion. 606 if (mode == ConversionMode::Partial) { 607 DenseSet<Operation *> unlegalizedOps; 608 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 609 &unlegalizedOps); 610 // Emit remarks for each legalizable operation. 611 for (auto *op : unlegalizedOps) 612 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 613 return; 614 } 615 616 // Handle a full conversion. 617 if (mode == ConversionMode::Full) { 618 // Check support for marking unknown operations as dynamically legal. 619 target.markUnknownOpDynamicallyLegal([](Operation *op) { 620 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 621 }); 622 623 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 624 return; 625 } 626 627 // Otherwise, handle an analysis conversion. 628 assert(mode == ConversionMode::Analysis); 629 630 // Analyze the convertible operations. 631 DenseSet<Operation *> legalizedOps; 632 if (failed(applyAnalysisConversion(getOperation(), target, 633 std::move(patterns), legalizedOps))) 634 return signalPassFailure(); 635 636 // Emit remarks for each legalizable operation. 637 for (auto *op : legalizedOps) 638 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 639 } 640 641 /// The mode of conversion to use. 642 ConversionMode mode; 643 }; 644 } // end anonymous namespace 645 646 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 647 legalizerConversionMode( 648 "test-legalize-mode", 649 llvm::cl::desc("The legalization mode to use with the test driver"), 650 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 651 llvm::cl::values( 652 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 653 "analysis", "Perform an analysis conversion"), 654 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 655 "Perform a full conversion"), 656 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 657 "partial", "Perform a partial conversion"))); 658 659 //===----------------------------------------------------------------------===// 660 // ConversionPatternRewriter::getRemappedValue testing. This method is used 661 // to get the remapped value of an original value that was replaced using 662 // ConversionPatternRewriter. 663 namespace { 664 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 665 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 666 /// operand twice. 667 /// 668 /// Example: 669 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 670 /// is replaced with: 671 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 672 struct OneVResOneVOperandOp1Converter 673 : public OpConversionPattern<OneVResOneVOperandOp1> { 674 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 675 676 LogicalResult 677 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 678 ConversionPatternRewriter &rewriter) const override { 679 auto origOps = op.getOperands(); 680 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 681 "One operand expected"); 682 Value origOp = *origOps.begin(); 683 SmallVector<Value, 2> remappedOperands; 684 // Replicate the remapped original operand twice. Note that we don't used 685 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 686 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 687 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 688 689 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 690 remappedOperands); 691 return success(); 692 } 693 }; 694 695 struct TestRemappedValue 696 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 697 void runOnFunction() override { 698 mlir::OwningRewritePatternList patterns; 699 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 700 701 mlir::ConversionTarget target(getContext()); 702 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 703 // We make OneVResOneVOperandOp1 legal only when it has more that one 704 // operand. This will trigger the conversion that will replace one-operand 705 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 706 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 707 [](Operation *op) -> bool { 708 return std::distance(op->operand_begin(), op->operand_end()) > 1; 709 }); 710 711 if (failed(mlir::applyFullConversion(getFunction(), target, 712 std::move(patterns)))) { 713 signalPassFailure(); 714 } 715 } 716 }; 717 } // end anonymous namespace 718 719 //===----------------------------------------------------------------------===// 720 // Test patterns without a specific root operation kind 721 //===----------------------------------------------------------------------===// 722 723 namespace { 724 /// This pattern matches and removes any operation in the test dialect. 725 struct RemoveTestDialectOps : public RewritePattern { 726 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 727 728 LogicalResult matchAndRewrite(Operation *op, 729 PatternRewriter &rewriter) const override { 730 if (!isa<TestDialect>(op->getDialect())) 731 return failure(); 732 rewriter.eraseOp(op); 733 return success(); 734 } 735 }; 736 737 struct TestUnknownRootOpDriver 738 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 739 void runOnFunction() override { 740 mlir::OwningRewritePatternList patterns; 741 patterns.insert<RemoveTestDialectOps>(); 742 743 mlir::ConversionTarget target(getContext()); 744 target.addIllegalDialect<TestDialect>(); 745 if (failed( 746 applyPartialConversion(getFunction(), target, std::move(patterns)))) 747 signalPassFailure(); 748 } 749 }; 750 } // end anonymous namespace 751 752 //===----------------------------------------------------------------------===// 753 // Test type conversions 754 //===----------------------------------------------------------------------===// 755 756 namespace { 757 struct TestTypeConversionProducer 758 : public OpConversionPattern<TestTypeProducerOp> { 759 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 760 LogicalResult 761 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 762 ConversionPatternRewriter &rewriter) const final { 763 Type resultType = op.getType(); 764 if (resultType.isa<FloatType>()) 765 resultType = rewriter.getF64Type(); 766 else if (resultType.isInteger(16)) 767 resultType = rewriter.getIntegerType(64); 768 else 769 return failure(); 770 771 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 772 return success(); 773 } 774 }; 775 776 /// Call signature conversion and then fail the rewrite to trigger the undo 777 /// mechanism. 778 struct TestSignatureConversionUndo 779 : public OpConversionPattern<TestSignatureConversionUndoOp> { 780 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 781 782 LogicalResult 783 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 784 ConversionPatternRewriter &rewriter) const final { 785 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 786 return failure(); 787 } 788 }; 789 790 /// Just forward the operands to the root op. This is essentially a no-op 791 /// pattern that is used to trigger target materialization. 792 struct TestTypeConsumerForward 793 : public OpConversionPattern<TestTypeConsumerOp> { 794 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 795 796 LogicalResult 797 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 798 ConversionPatternRewriter &rewriter) const final { 799 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 800 return success(); 801 } 802 }; 803 804 struct TestTypeConversionAnotherProducer 805 : public OpRewritePattern<TestAnotherTypeProducerOp> { 806 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 807 808 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 809 PatternRewriter &rewriter) const final { 810 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 811 return success(); 812 } 813 }; 814 815 struct TestTypeConversionDriver 816 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 817 void getDependentDialects(DialectRegistry ®istry) const override { 818 registry.insert<TestDialect>(); 819 } 820 821 void runOnOperation() override { 822 // Initialize the type converter. 823 TypeConverter converter; 824 825 /// Add the legal set of type conversions. 826 converter.addConversion([](Type type) -> Type { 827 // Treat F64 as legal. 828 if (type.isF64()) 829 return type; 830 // Allow converting BF16/F16/F32 to F64. 831 if (type.isBF16() || type.isF16() || type.isF32()) 832 return FloatType::getF64(type.getContext()); 833 // Otherwise, the type is illegal. 834 return nullptr; 835 }); 836 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 837 // Drop all integer types. 838 return success(); 839 }); 840 841 /// Add the legal set of type materializations. 842 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 843 ValueRange inputs, 844 Location loc) -> Value { 845 // Allow casting from F64 back to F32. 846 if (!resultType.isF16() && inputs.size() == 1 && 847 inputs[0].getType().isF64()) 848 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 849 // Allow producing an i32 or i64 from nothing. 850 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 851 inputs.empty()) 852 return builder.create<TestTypeProducerOp>(loc, resultType); 853 // Allow producing an i64 from an integer. 854 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 855 inputs[0].getType().isa<IntegerType>()) 856 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 857 // Otherwise, fail. 858 return nullptr; 859 }); 860 861 // Initialize the conversion target. 862 mlir::ConversionTarget target(getContext()); 863 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 864 return op.getType().isF64() || op.getType().isInteger(64); 865 }); 866 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 867 return converter.isSignatureLegal(op.getType()) && 868 converter.isLegal(&op.getBody()); 869 }); 870 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 871 // Allow casts from F64 to F32. 872 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 873 }); 874 875 // Initialize the set of rewrite patterns. 876 OwningRewritePatternList patterns; 877 patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer, 878 TestSignatureConversionUndo>(converter, &getContext()); 879 patterns.insert<TestTypeConversionAnotherProducer>(&getContext()); 880 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 881 converter); 882 883 if (failed(applyPartialConversion(getOperation(), target, 884 std::move(patterns)))) 885 signalPassFailure(); 886 } 887 }; 888 } // end anonymous namespace 889 890 //===----------------------------------------------------------------------===// 891 // Test Block Merging 892 //===----------------------------------------------------------------------===// 893 894 namespace { 895 /// A rewriter pattern that tests that blocks can be merged. 896 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 897 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 898 899 LogicalResult 900 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 901 ConversionPatternRewriter &rewriter) const final { 902 Block &firstBlock = op.body().front(); 903 Operation *branchOp = firstBlock.getTerminator(); 904 Block *secondBlock = &*(std::next(op.body().begin())); 905 auto succOperands = branchOp->getOperands(); 906 SmallVector<Value, 2> replacements(succOperands); 907 rewriter.eraseOp(branchOp); 908 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 909 rewriter.updateRootInPlace(op, [] {}); 910 return success(); 911 } 912 }; 913 914 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 915 struct TestUndoBlocksMerge : public ConversionPattern { 916 TestUndoBlocksMerge(MLIRContext *ctx) 917 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 918 LogicalResult 919 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 920 ConversionPatternRewriter &rewriter) const final { 921 Block &firstBlock = op->getRegion(0).front(); 922 Operation *branchOp = firstBlock.getTerminator(); 923 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 924 rewriter.setInsertionPointToStart(secondBlock); 925 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 926 auto succOperands = branchOp->getOperands(); 927 SmallVector<Value, 2> replacements(succOperands); 928 rewriter.eraseOp(branchOp); 929 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 930 rewriter.updateRootInPlace(op, [] {}); 931 return success(); 932 } 933 }; 934 935 /// A rewrite mechanism to inline the body of the op into its parent, when both 936 /// ops can have a single block. 937 struct TestMergeSingleBlockOps 938 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 939 using OpConversionPattern< 940 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 941 942 LogicalResult 943 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 944 ConversionPatternRewriter &rewriter) const final { 945 SingleBlockImplicitTerminatorOp parentOp = 946 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 947 if (!parentOp) 948 return failure(); 949 Block &innerBlock = op.region().front(); 950 TerminatorOp innerTerminator = 951 cast<TerminatorOp>(innerBlock.getTerminator()); 952 rewriter.mergeBlockBefore(&innerBlock, op); 953 rewriter.eraseOp(innerTerminator); 954 rewriter.eraseOp(op); 955 rewriter.updateRootInPlace(op, [] {}); 956 return success(); 957 } 958 }; 959 960 struct TestMergeBlocksPatternDriver 961 : public PassWrapper<TestMergeBlocksPatternDriver, 962 OperationPass<ModuleOp>> { 963 void runOnOperation() override { 964 mlir::OwningRewritePatternList patterns; 965 MLIRContext *context = &getContext(); 966 patterns 967 .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 968 context); 969 ConversionTarget target(*context); 970 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp, 971 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp, 972 TestReturnOp>(); 973 target.addIllegalOp<ILLegalOpF>(); 974 975 /// Expect the op to have a single block after legalization. 976 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 977 [&](TestMergeBlocksOp op) -> bool { 978 return llvm::hasSingleElement(op.body()); 979 }); 980 981 /// Only allow `test.br` within test.merge_blocks op. 982 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 983 return op->getParentOfType<TestMergeBlocksOp>(); 984 }); 985 986 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 987 /// inlined. 988 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 989 [&](SingleBlockImplicitTerminatorOp op) -> bool { 990 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 991 }); 992 993 DenseSet<Operation *> unlegalizedOps; 994 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 995 &unlegalizedOps); 996 for (auto *op : unlegalizedOps) 997 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 998 } 999 }; 1000 } // namespace 1001 1002 //===----------------------------------------------------------------------===// 1003 // Test Selective Replacement 1004 //===----------------------------------------------------------------------===// 1005 1006 namespace { 1007 /// A rewrite mechanism to inline the body of the op into its parent, when both 1008 /// ops can have a single block. 1009 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1010 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1011 1012 LogicalResult matchAndRewrite(TestCastOp op, 1013 PatternRewriter &rewriter) const final { 1014 if (op.getNumOperands() != 2) 1015 return failure(); 1016 OperandRange operands = op.getOperands(); 1017 1018 // Replace non-terminator uses with the first operand. 1019 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1020 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1021 }); 1022 // Replace everything else with the second operand if the operation isn't 1023 // dead. 1024 rewriter.replaceOp(op, op.getOperand(1)); 1025 return success(); 1026 } 1027 }; 1028 1029 struct TestSelectiveReplacementPatternDriver 1030 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1031 OperationPass<>> { 1032 void runOnOperation() override { 1033 mlir::OwningRewritePatternList patterns; 1034 MLIRContext *context = &getContext(); 1035 patterns.insert<TestSelectiveOpReplacementPattern>(context); 1036 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1037 std::move(patterns)); 1038 } 1039 }; 1040 } // namespace 1041 1042 //===----------------------------------------------------------------------===// 1043 // PassRegistration 1044 //===----------------------------------------------------------------------===// 1045 1046 namespace mlir { 1047 namespace test { 1048 void registerPatternsTestPass() { 1049 PassRegistration<TestReturnTypeDriver>("test-return-type", 1050 "Run return type functions"); 1051 1052 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 1053 "Run test derived attributes"); 1054 1055 PassRegistration<TestPatternDriver>("test-patterns", 1056 "Run test dialect patterns"); 1057 1058 PassRegistration<TestLegalizePatternDriver>( 1059 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 1060 return std::make_unique<TestLegalizePatternDriver>( 1061 legalizerConversionMode); 1062 }); 1063 1064 PassRegistration<TestRemappedValue>( 1065 "test-remapped-value", 1066 "Test public remapped value mechanism in ConversionPatternRewriter"); 1067 1068 PassRegistration<TestUnknownRootOpDriver>( 1069 "test-legalize-unknown-root-patterns", 1070 "Test public remapped value mechanism in ConversionPatternRewriter"); 1071 1072 PassRegistration<TestTypeConversionDriver>( 1073 "test-legalize-type-conversion", 1074 "Test various type conversion functionalities in DialectConversion"); 1075 1076 PassRegistration<TestMergeBlocksPatternDriver>{ 1077 "test-merge-blocks", 1078 "Test Merging operation in ConversionPatternRewriter"}; 1079 PassRegistration<TestSelectiveReplacementPatternDriver>{ 1080 "test-pattern-selective-replacement", 1081 "Test selective replacement in the PatternRewriter"}; 1082 } 1083 } // namespace test 1084 } // namespace mlir 1085