1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Conversion/StandardToStandard/StandardToStandard.h" 11 #include "mlir/IR/PatternMatch.h" 12 #include "mlir/Pass/Pass.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 15 using namespace mlir; 16 17 // Native function for testing NativeCodeCall 18 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 19 return choice.getValue() ? input1 : input2; 20 } 21 22 static void createOpI(PatternRewriter &rewriter, Value input) { 23 rewriter.create<OpI>(rewriter.getUnknownLoc(), input); 24 } 25 26 static void handleNoResultOp(PatternRewriter &rewriter, 27 OpSymbolBindingNoResult op) { 28 // Turn the no result op to a one-result op. 29 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 30 op.operand()); 31 } 32 33 namespace { 34 #include "TestPatterns.inc" 35 } // end anonymous namespace 36 37 //===----------------------------------------------------------------------===// 38 // Canonicalizer Driver. 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 43 void runOnFunction() override { 44 mlir::OwningRewritePatternList patterns; 45 populateWithGenerated(&getContext(), &patterns); 46 47 // Verify named pattern is generated with expected name. 48 patterns.insert<TestNamedPatternRule>(&getContext()); 49 50 applyPatternsAndFoldGreedily(getFunction(), patterns); 51 } 52 }; 53 } // end anonymous namespace 54 55 //===----------------------------------------------------------------------===// 56 // ReturnType Driver. 57 //===----------------------------------------------------------------------===// 58 59 namespace { 60 // Generate ops for each instance where the type can be successfully inferred. 61 template <typename OpTy> 62 static void invokeCreateWithInferredReturnType(Operation *op) { 63 auto *context = op->getContext(); 64 auto fop = op->getParentOfType<FuncOp>(); 65 auto location = UnknownLoc::get(context); 66 OpBuilder b(op); 67 b.setInsertionPointAfter(op); 68 69 // Use permutations of 2 args as operands. 70 assert(fop.getNumArguments() >= 2); 71 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 72 for (int j = 0; j < e; ++j) { 73 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 74 SmallVector<Type, 2> inferredReturnTypes; 75 if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values, 76 op->getAttrs(), op->getRegions(), 77 inferredReturnTypes))) { 78 OperationState state(location, OpTy::getOperationName()); 79 // TODO(jpienaar): Expand to regions. 80 OpTy::build(&b, state, values, op->getAttrs()); 81 (void)b.createOperation(state); 82 } 83 } 84 } 85 } 86 87 static void reifyReturnShape(Operation *op) { 88 OpBuilder b(op); 89 90 // Use permutations of 2 args as operands. 91 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 92 SmallVector<Value, 2> shapes; 93 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 94 return; 95 for (auto it : llvm::enumerate(shapes)) 96 op->emitRemark() << "value " << it.index() << ": " 97 << it.value().getDefiningOp(); 98 } 99 100 struct TestReturnTypeDriver 101 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 102 void runOnFunction() override { 103 if (getFunction().getName() == "testCreateFunctions") { 104 std::vector<Operation *> ops; 105 // Collect ops to avoid triggering on inserted ops. 106 for (auto &op : getFunction().getBody().front()) 107 ops.push_back(&op); 108 // Generate test patterns for each, but skip terminator. 109 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 110 // Test create method of each of the Op classes below. The resultant 111 // output would be in reverse order underneath `op` from which 112 // the attributes and regions are used. 113 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 114 invokeCreateWithInferredReturnType< 115 OpWithShapedTypeInferTypeInterfaceOp>(op); 116 }; 117 return; 118 } 119 if (getFunction().getName() == "testReifyFunctions") { 120 std::vector<Operation *> ops; 121 // Collect ops to avoid triggering on inserted ops. 122 for (auto &op : getFunction().getBody().front()) 123 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 124 ops.push_back(&op); 125 // Generate test patterns for each, but skip terminator. 126 for (auto *op : ops) 127 reifyReturnShape(op); 128 } 129 } 130 }; 131 } // end anonymous namespace 132 133 namespace { 134 struct TestDerivedAttributeDriver 135 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 136 void runOnFunction() override; 137 }; 138 } // end anonymous namespace 139 140 void TestDerivedAttributeDriver::runOnFunction() { 141 getFunction().walk([](DerivedAttributeOpInterface dOp) { 142 auto dAttr = dOp.materializeDerivedAttributes(); 143 if (!dAttr) 144 return; 145 for (auto d : dAttr) 146 dOp.emitRemark() << d.first << " = " << d.second; 147 }); 148 } 149 150 //===----------------------------------------------------------------------===// 151 // Legalization Driver. 152 //===----------------------------------------------------------------------===// 153 154 namespace { 155 //===----------------------------------------------------------------------===// 156 // Region-Block Rewrite Testing 157 158 /// This pattern is a simple pattern that inlines the first region of a given 159 /// operation into the parent region. 160 struct TestRegionRewriteBlockMovement : public ConversionPattern { 161 TestRegionRewriteBlockMovement(MLIRContext *ctx) 162 : ConversionPattern("test.region", 1, ctx) {} 163 164 LogicalResult 165 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 166 ConversionPatternRewriter &rewriter) const final { 167 // Inline this region into the parent region. 168 auto &parentRegion = *op->getParentRegion(); 169 if (op->getAttr("legalizer.should_clone")) 170 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 171 parentRegion.end()); 172 else 173 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 174 parentRegion.end()); 175 176 // Drop this operation. 177 rewriter.eraseOp(op); 178 return success(); 179 } 180 }; 181 /// This pattern is a simple pattern that generates a region containing an 182 /// illegal operation. 183 struct TestRegionRewriteUndo : public RewritePattern { 184 TestRegionRewriteUndo(MLIRContext *ctx) 185 : RewritePattern("test.region_builder", 1, ctx) {} 186 187 LogicalResult matchAndRewrite(Operation *op, 188 PatternRewriter &rewriter) const final { 189 // Create the region operation with an entry block containing arguments. 190 OperationState newRegion(op->getLoc(), "test.region"); 191 newRegion.addRegion(); 192 auto *regionOp = rewriter.createOperation(newRegion); 193 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 194 entryBlock->addArgument(rewriter.getIntegerType(64)); 195 196 // Add an explicitly illegal operation to ensure the conversion fails. 197 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 198 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 199 200 // Drop this operation. 201 rewriter.eraseOp(op); 202 return success(); 203 } 204 }; 205 /// A simple pattern that creates a block at the end of the parent region of the 206 /// matched operation. 207 struct TestCreateBlock : public RewritePattern { 208 TestCreateBlock(MLIRContext *ctx) 209 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const final { 213 Region ®ion = *op->getParentRegion(); 214 Type i32Type = rewriter.getIntegerType(32); 215 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 216 rewriter.create<TerminatorOp>(op->getLoc()); 217 rewriter.replaceOp(op, {}); 218 return success(); 219 } 220 }; 221 222 /// A simple pattern that creates a block containing an invalid operaiton in 223 /// order to trigger the block creation undo mechanism. 224 struct TestCreateIllegalBlock : public RewritePattern { 225 TestCreateIllegalBlock(MLIRContext *ctx) 226 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 227 228 LogicalResult matchAndRewrite(Operation *op, 229 PatternRewriter &rewriter) const final { 230 Region ®ion = *op->getParentRegion(); 231 Type i32Type = rewriter.getIntegerType(32); 232 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 233 // Create an illegal op to ensure the conversion fails. 234 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 235 rewriter.create<TerminatorOp>(op->getLoc()); 236 rewriter.replaceOp(op, {}); 237 return success(); 238 } 239 }; 240 241 //===----------------------------------------------------------------------===// 242 // Type-Conversion Rewrite Testing 243 244 /// This patterns erases a region operation that has had a type conversion. 245 struct TestDropOpSignatureConversion : public ConversionPattern { 246 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 247 : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { 248 } 249 LogicalResult 250 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 251 ConversionPatternRewriter &rewriter) const override { 252 Region ®ion = op->getRegion(0); 253 Block *entry = ®ion.front(); 254 255 // Convert the original entry arguments. 256 TypeConverter::SignatureConversion result(entry->getNumArguments()); 257 for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i) 258 if (failed(converter.convertSignatureArg( 259 i, entry->getArgument(i).getType(), result))) 260 return failure(); 261 262 // Convert the region signature and just drop the operation. 263 rewriter.applySignatureConversion(®ion, result); 264 rewriter.eraseOp(op); 265 return success(); 266 } 267 268 /// The type converter to use when rewriting the signature. 269 TypeConverter &converter; 270 }; 271 /// This pattern simply updates the operands of the given operation. 272 struct TestPassthroughInvalidOp : public ConversionPattern { 273 TestPassthroughInvalidOp(MLIRContext *ctx) 274 : ConversionPattern("test.invalid", 1, ctx) {} 275 LogicalResult 276 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 277 ConversionPatternRewriter &rewriter) const final { 278 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 279 llvm::None); 280 return success(); 281 } 282 }; 283 /// This pattern handles the case of a split return value. 284 struct TestSplitReturnType : public ConversionPattern { 285 TestSplitReturnType(MLIRContext *ctx) 286 : ConversionPattern("test.return", 1, ctx) {} 287 LogicalResult 288 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 289 ConversionPatternRewriter &rewriter) const final { 290 // Check for a return of F32. 291 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 292 return failure(); 293 294 // Check if the first operation is a cast operation, if it is we use the 295 // results directly. 296 auto *defOp = operands[0].getDefiningOp(); 297 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 298 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 299 return success(); 300 } 301 302 // Otherwise, fail to match. 303 return failure(); 304 } 305 }; 306 307 //===----------------------------------------------------------------------===// 308 // Multi-Level Type-Conversion Rewrite Testing 309 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 310 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 311 : ConversionPattern("test.type_producer", 1, ctx) {} 312 LogicalResult 313 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 314 ConversionPatternRewriter &rewriter) const final { 315 // If the type is I32, change the type to F32. 316 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 317 return failure(); 318 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 319 return success(); 320 } 321 }; 322 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 323 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 324 : ConversionPattern("test.type_producer", 1, ctx) {} 325 LogicalResult 326 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 327 ConversionPatternRewriter &rewriter) const final { 328 // If the type is F32, change the type to F64. 329 if (!Type(*op->result_type_begin()).isF32()) 330 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 331 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 332 return success(); 333 } 334 }; 335 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 336 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 337 : ConversionPattern("test.type_producer", 10, ctx) {} 338 LogicalResult 339 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 340 ConversionPatternRewriter &rewriter) const final { 341 // Always convert to B16, even though it is not a legal type. This tests 342 // that values are unmapped correctly. 343 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 344 return success(); 345 } 346 }; 347 struct TestUpdateConsumerType : public ConversionPattern { 348 TestUpdateConsumerType(MLIRContext *ctx) 349 : ConversionPattern("test.type_consumer", 1, ctx) {} 350 LogicalResult 351 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 352 ConversionPatternRewriter &rewriter) const final { 353 // Verify that the incoming operand has been successfully remapped to F64. 354 if (!operands[0].getType().isF64()) 355 return failure(); 356 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 357 return success(); 358 } 359 }; 360 361 //===----------------------------------------------------------------------===// 362 // Non-Root Replacement Rewrite Testing 363 /// This pattern generates an invalid operation, but replaces it before the 364 /// pattern is finished. This checks that we don't need to legalize the 365 /// temporary op. 366 struct TestNonRootReplacement : public RewritePattern { 367 TestNonRootReplacement(MLIRContext *ctx) 368 : RewritePattern("test.replace_non_root", 1, ctx) {} 369 370 LogicalResult matchAndRewrite(Operation *op, 371 PatternRewriter &rewriter) const final { 372 auto resultType = *op->result_type_begin(); 373 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 374 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 375 376 rewriter.replaceOp(illegalOp, {legalOp}); 377 rewriter.replaceOp(op, {illegalOp}); 378 return success(); 379 } 380 }; 381 382 //===----------------------------------------------------------------------===// 383 // Recursive Rewrite Testing 384 /// This pattern is applied to the same operation multiple times, but has a 385 /// bounded recursion. 386 struct TestBoundedRecursiveRewrite 387 : public OpRewritePattern<TestRecursiveRewriteOp> { 388 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 389 390 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 391 PatternRewriter &rewriter) const final { 392 // Decrement the depth of the op in-place. 393 rewriter.updateRootInPlace(op, [&] { 394 op.setAttr("depth", 395 rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1)); 396 }); 397 return success(); 398 } 399 400 /// The conversion target handles bounding the recursion of this pattern. 401 bool hasBoundedRewriteRecursion() const final { return true; } 402 }; 403 } // namespace 404 405 namespace { 406 struct TestTypeConverter : public TypeConverter { 407 using TypeConverter::TypeConverter; 408 TestTypeConverter() { addConversion(convertType); } 409 410 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 411 // Drop I16 types. 412 if (t.isSignlessInteger(16)) 413 return success(); 414 415 // Convert I64 to F64. 416 if (t.isSignlessInteger(64)) { 417 results.push_back(FloatType::getF64(t.getContext())); 418 return success(); 419 } 420 421 // Split F32 into F16,F16. 422 if (t.isF32()) { 423 results.assign(2, FloatType::getF16(t.getContext())); 424 return success(); 425 } 426 427 // Otherwise, convert the type directly. 428 results.push_back(t); 429 return success(); 430 } 431 432 /// Override the hook to materialize a conversion. This is necessary because 433 /// we generate 1->N type mappings. 434 Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, 435 ArrayRef<Value> inputs, 436 Location loc) override { 437 return rewriter.create<TestCastOp>(loc, resultType, inputs); 438 } 439 }; 440 441 struct TestLegalizePatternDriver 442 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 443 /// The mode of conversion to use with the driver. 444 enum class ConversionMode { Analysis, Full, Partial }; 445 446 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 447 448 void runOnOperation() override { 449 TestTypeConverter converter; 450 mlir::OwningRewritePatternList patterns; 451 populateWithGenerated(&getContext(), &patterns); 452 patterns.insert< 453 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 454 TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType, 455 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 456 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 457 TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext()); 458 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 459 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 460 converter); 461 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 462 converter); 463 464 // Define the conversion target used for the test. 465 ConversionTarget target(getContext()); 466 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 467 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 468 TerminatorOp>(); 469 target 470 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 471 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 472 // Don't allow F32 operands. 473 return llvm::none_of(op.getOperandTypes(), 474 [](Type type) { return type.isF32(); }); 475 }); 476 target.addDynamicallyLegalOp<FuncOp>( 477 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 478 479 // Expect the type_producer/type_consumer operations to only operate on f64. 480 target.addDynamicallyLegalOp<TestTypeProducerOp>( 481 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 482 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 483 return op.getOperand().getType().isF64(); 484 }); 485 486 // Check support for marking certain operations as recursively legal. 487 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 488 return static_cast<bool>( 489 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 490 }); 491 492 // Mark the bound recursion operation as dynamically legal. 493 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 494 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 495 496 // Handle a partial conversion. 497 if (mode == ConversionMode::Partial) { 498 (void)applyPartialConversion(getOperation(), target, patterns, 499 &converter); 500 return; 501 } 502 503 // Handle a full conversion. 504 if (mode == ConversionMode::Full) { 505 // Check support for marking unknown operations as dynamically legal. 506 target.markUnknownOpDynamicallyLegal([](Operation *op) { 507 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 508 }); 509 510 (void)applyFullConversion(getOperation(), target, patterns, &converter); 511 return; 512 } 513 514 // Otherwise, handle an analysis conversion. 515 assert(mode == ConversionMode::Analysis); 516 517 // Analyze the convertible operations. 518 DenseSet<Operation *> legalizedOps; 519 if (failed(applyAnalysisConversion(getOperation(), target, patterns, 520 legalizedOps, &converter))) 521 return signalPassFailure(); 522 523 // Emit remarks for each legalizable operation. 524 for (auto *op : legalizedOps) 525 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 526 } 527 528 /// The mode of conversion to use. 529 ConversionMode mode; 530 }; 531 } // end anonymous namespace 532 533 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 534 legalizerConversionMode( 535 "test-legalize-mode", 536 llvm::cl::desc("The legalization mode to use with the test driver"), 537 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 538 llvm::cl::values( 539 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 540 "analysis", "Perform an analysis conversion"), 541 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 542 "Perform a full conversion"), 543 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 544 "partial", "Perform a partial conversion"))); 545 546 //===----------------------------------------------------------------------===// 547 // ConversionPatternRewriter::getRemappedValue testing. This method is used 548 // to get the remapped value of an original value that was replaced using 549 // ConversionPatternRewriter. 550 namespace { 551 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 552 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 553 /// operand twice. 554 /// 555 /// Example: 556 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 557 /// is replaced with: 558 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 559 struct OneVResOneVOperandOp1Converter 560 : public OpConversionPattern<OneVResOneVOperandOp1> { 561 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 562 563 LogicalResult 564 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 565 ConversionPatternRewriter &rewriter) const override { 566 auto origOps = op.getOperands(); 567 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 568 "One operand expected"); 569 Value origOp = *origOps.begin(); 570 SmallVector<Value, 2> remappedOperands; 571 // Replicate the remapped original operand twice. Note that we don't used 572 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 573 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 574 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 575 576 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 577 remappedOperands); 578 return success(); 579 } 580 }; 581 582 struct TestRemappedValue 583 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 584 void runOnFunction() override { 585 mlir::OwningRewritePatternList patterns; 586 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 587 588 mlir::ConversionTarget target(getContext()); 589 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 590 // We make OneVResOneVOperandOp1 legal only when it has more that one 591 // operand. This will trigger the conversion that will replace one-operand 592 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 593 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 594 [](Operation *op) -> bool { 595 return std::distance(op->operand_begin(), op->operand_end()) > 1; 596 }); 597 598 if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { 599 signalPassFailure(); 600 } 601 } 602 }; 603 } // end anonymous namespace 604 605 namespace mlir { 606 void registerPatternsTestPass() { 607 mlir::PassRegistration<TestReturnTypeDriver>("test-return-type", 608 "Run return type functions"); 609 610 mlir::PassRegistration<TestDerivedAttributeDriver>( 611 "test-derived-attr", "Run test derived attributes"); 612 613 mlir::PassRegistration<TestPatternDriver>("test-patterns", 614 "Run test dialect patterns"); 615 616 mlir::PassRegistration<TestLegalizePatternDriver>( 617 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 618 return std::make_unique<TestLegalizePatternDriver>( 619 legalizerConversionMode); 620 }); 621 622 PassRegistration<TestRemappedValue>( 623 "test-remapped-value", 624 "Test public remapped value mechanism in ConversionPatternRewriter"); 625 } 626 } // namespace mlir 627