1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Conversion/StandardToStandard/StandardToStandard.h" 11 #include "mlir/IR/PatternMatch.h" 12 #include "mlir/Pass/Pass.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 15 using namespace mlir; 16 17 // Native function for testing NativeCodeCall 18 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 19 return choice.getValue() ? input1 : input2; 20 } 21 22 static void createOpI(PatternRewriter &rewriter, Value input) { 23 rewriter.create<OpI>(rewriter.getUnknownLoc(), input); 24 } 25 26 static void handleNoResultOp(PatternRewriter &rewriter, 27 OpSymbolBindingNoResult op) { 28 // Turn the no result op to a one-result op. 29 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 30 op.operand()); 31 } 32 33 namespace { 34 #include "TestPatterns.inc" 35 } // end anonymous namespace 36 37 //===----------------------------------------------------------------------===// 38 // Canonicalizer Driver. 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 43 void runOnFunction() override { 44 mlir::OwningRewritePatternList patterns; 45 populateWithGenerated(&getContext(), &patterns); 46 47 // Verify named pattern is generated with expected name. 48 patterns.insert<TestNamedPatternRule>(&getContext()); 49 50 applyPatternsAndFoldGreedily(getFunction(), patterns); 51 } 52 }; 53 } // end anonymous namespace 54 55 //===----------------------------------------------------------------------===// 56 // ReturnType Driver. 57 //===----------------------------------------------------------------------===// 58 59 namespace { 60 // Generate ops for each instance where the type can be successfully inferred. 61 template <typename OpTy> 62 static void invokeCreateWithInferredReturnType(Operation *op) { 63 auto *context = op->getContext(); 64 auto fop = op->getParentOfType<FuncOp>(); 65 auto location = UnknownLoc::get(context); 66 OpBuilder b(op); 67 b.setInsertionPointAfter(op); 68 69 // Use permutations of 2 args as operands. 70 assert(fop.getNumArguments() >= 2); 71 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 72 for (int j = 0; j < e; ++j) { 73 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 74 SmallVector<Type, 2> inferredReturnTypes; 75 if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values, 76 op->getAttrs(), op->getRegions(), 77 inferredReturnTypes))) { 78 OperationState state(location, OpTy::getOperationName()); 79 // TODO(jpienaar): Expand to regions. 80 OpTy::build(b, state, values, op->getAttrs()); 81 (void)b.createOperation(state); 82 } 83 } 84 } 85 } 86 87 static void reifyReturnShape(Operation *op) { 88 OpBuilder b(op); 89 90 // Use permutations of 2 args as operands. 91 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 92 SmallVector<Value, 2> shapes; 93 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 94 return; 95 for (auto it : llvm::enumerate(shapes)) 96 op->emitRemark() << "value " << it.index() << ": " 97 << it.value().getDefiningOp(); 98 } 99 100 struct TestReturnTypeDriver 101 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 102 void runOnFunction() override { 103 if (getFunction().getName() == "testCreateFunctions") { 104 std::vector<Operation *> ops; 105 // Collect ops to avoid triggering on inserted ops. 106 for (auto &op : getFunction().getBody().front()) 107 ops.push_back(&op); 108 // Generate test patterns for each, but skip terminator. 109 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 110 // Test create method of each of the Op classes below. The resultant 111 // output would be in reverse order underneath `op` from which 112 // the attributes and regions are used. 113 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 114 invokeCreateWithInferredReturnType< 115 OpWithShapedTypeInferTypeInterfaceOp>(op); 116 }; 117 return; 118 } 119 if (getFunction().getName() == "testReifyFunctions") { 120 std::vector<Operation *> ops; 121 // Collect ops to avoid triggering on inserted ops. 122 for (auto &op : getFunction().getBody().front()) 123 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 124 ops.push_back(&op); 125 // Generate test patterns for each, but skip terminator. 126 for (auto *op : ops) 127 reifyReturnShape(op); 128 } 129 } 130 }; 131 } // end anonymous namespace 132 133 namespace { 134 struct TestDerivedAttributeDriver 135 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 136 void runOnFunction() override; 137 }; 138 } // end anonymous namespace 139 140 void TestDerivedAttributeDriver::runOnFunction() { 141 getFunction().walk([](DerivedAttributeOpInterface dOp) { 142 auto dAttr = dOp.materializeDerivedAttributes(); 143 if (!dAttr) 144 return; 145 for (auto d : dAttr) 146 dOp.emitRemark() << d.first << " = " << d.second; 147 }); 148 } 149 150 //===----------------------------------------------------------------------===// 151 // Legalization Driver. 152 //===----------------------------------------------------------------------===// 153 154 namespace { 155 //===----------------------------------------------------------------------===// 156 // Region-Block Rewrite Testing 157 158 /// This pattern is a simple pattern that inlines the first region of a given 159 /// operation into the parent region. 160 struct TestRegionRewriteBlockMovement : public ConversionPattern { 161 TestRegionRewriteBlockMovement(MLIRContext *ctx) 162 : ConversionPattern("test.region", 1, ctx) {} 163 164 LogicalResult 165 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 166 ConversionPatternRewriter &rewriter) const final { 167 // Inline this region into the parent region. 168 auto &parentRegion = *op->getParentRegion(); 169 if (op->getAttr("legalizer.should_clone")) 170 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 171 parentRegion.end()); 172 else 173 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 174 parentRegion.end()); 175 176 // Drop this operation. 177 rewriter.eraseOp(op); 178 return success(); 179 } 180 }; 181 /// This pattern is a simple pattern that generates a region containing an 182 /// illegal operation. 183 struct TestRegionRewriteUndo : public RewritePattern { 184 TestRegionRewriteUndo(MLIRContext *ctx) 185 : RewritePattern("test.region_builder", 1, ctx) {} 186 187 LogicalResult matchAndRewrite(Operation *op, 188 PatternRewriter &rewriter) const final { 189 // Create the region operation with an entry block containing arguments. 190 OperationState newRegion(op->getLoc(), "test.region"); 191 newRegion.addRegion(); 192 auto *regionOp = rewriter.createOperation(newRegion); 193 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 194 entryBlock->addArgument(rewriter.getIntegerType(64)); 195 196 // Add an explicitly illegal operation to ensure the conversion fails. 197 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 198 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 199 200 // Drop this operation. 201 rewriter.eraseOp(op); 202 return success(); 203 } 204 }; 205 /// A simple pattern that creates a block at the end of the parent region of the 206 /// matched operation. 207 struct TestCreateBlock : public RewritePattern { 208 TestCreateBlock(MLIRContext *ctx) 209 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 210 211 LogicalResult matchAndRewrite(Operation *op, 212 PatternRewriter &rewriter) const final { 213 Region ®ion = *op->getParentRegion(); 214 Type i32Type = rewriter.getIntegerType(32); 215 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 216 rewriter.create<TerminatorOp>(op->getLoc()); 217 rewriter.replaceOp(op, {}); 218 return success(); 219 } 220 }; 221 222 /// A simple pattern that creates a block containing an invalid operaiton in 223 /// order to trigger the block creation undo mechanism. 224 struct TestCreateIllegalBlock : public RewritePattern { 225 TestCreateIllegalBlock(MLIRContext *ctx) 226 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 227 228 LogicalResult matchAndRewrite(Operation *op, 229 PatternRewriter &rewriter) const final { 230 Region ®ion = *op->getParentRegion(); 231 Type i32Type = rewriter.getIntegerType(32); 232 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 233 // Create an illegal op to ensure the conversion fails. 234 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 235 rewriter.create<TerminatorOp>(op->getLoc()); 236 rewriter.replaceOp(op, {}); 237 return success(); 238 } 239 }; 240 241 /// A simple pattern that tests the undo mechanism when replacing the uses of a 242 /// block argument. 243 struct TestUndoBlockArgReplace : public ConversionPattern { 244 TestUndoBlockArgReplace(MLIRContext *ctx) 245 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 246 247 LogicalResult 248 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 249 ConversionPatternRewriter &rewriter) const final { 250 auto illegalOp = 251 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 252 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0), 253 illegalOp); 254 rewriter.updateRootInPlace(op, [] {}); 255 return success(); 256 } 257 }; 258 259 //===----------------------------------------------------------------------===// 260 // Type-Conversion Rewrite Testing 261 262 /// This patterns erases a region operation that has had a type conversion. 263 struct TestDropOpSignatureConversion : public ConversionPattern { 264 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 265 : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { 266 } 267 LogicalResult 268 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 269 ConversionPatternRewriter &rewriter) const override { 270 Region ®ion = op->getRegion(0); 271 Block *entry = ®ion.front(); 272 273 // Convert the original entry arguments. 274 TypeConverter::SignatureConversion result(entry->getNumArguments()); 275 for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i) 276 if (failed(converter.convertSignatureArg( 277 i, entry->getArgument(i).getType(), result))) 278 return failure(); 279 280 // Convert the region signature and just drop the operation. 281 rewriter.applySignatureConversion(®ion, result); 282 rewriter.eraseOp(op); 283 return success(); 284 } 285 286 /// The type converter to use when rewriting the signature. 287 TypeConverter &converter; 288 }; 289 /// This pattern simply updates the operands of the given operation. 290 struct TestPassthroughInvalidOp : public ConversionPattern { 291 TestPassthroughInvalidOp(MLIRContext *ctx) 292 : ConversionPattern("test.invalid", 1, ctx) {} 293 LogicalResult 294 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 295 ConversionPatternRewriter &rewriter) const final { 296 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 297 llvm::None); 298 return success(); 299 } 300 }; 301 /// This pattern handles the case of a split return value. 302 struct TestSplitReturnType : public ConversionPattern { 303 TestSplitReturnType(MLIRContext *ctx) 304 : ConversionPattern("test.return", 1, ctx) {} 305 LogicalResult 306 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 307 ConversionPatternRewriter &rewriter) const final { 308 // Check for a return of F32. 309 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 310 return failure(); 311 312 // Check if the first operation is a cast operation, if it is we use the 313 // results directly. 314 auto *defOp = operands[0].getDefiningOp(); 315 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 316 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 317 return success(); 318 } 319 320 // Otherwise, fail to match. 321 return failure(); 322 } 323 }; 324 325 //===----------------------------------------------------------------------===// 326 // Multi-Level Type-Conversion Rewrite Testing 327 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 328 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 329 : ConversionPattern("test.type_producer", 1, ctx) {} 330 LogicalResult 331 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 332 ConversionPatternRewriter &rewriter) const final { 333 // If the type is I32, change the type to F32. 334 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 335 return failure(); 336 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 337 return success(); 338 } 339 }; 340 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 341 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 342 : ConversionPattern("test.type_producer", 1, ctx) {} 343 LogicalResult 344 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const final { 346 // If the type is F32, change the type to F64. 347 if (!Type(*op->result_type_begin()).isF32()) 348 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 349 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 350 return success(); 351 } 352 }; 353 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 354 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 355 : ConversionPattern("test.type_producer", 10, ctx) {} 356 LogicalResult 357 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 358 ConversionPatternRewriter &rewriter) const final { 359 // Always convert to B16, even though it is not a legal type. This tests 360 // that values are unmapped correctly. 361 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 362 return success(); 363 } 364 }; 365 struct TestUpdateConsumerType : public ConversionPattern { 366 TestUpdateConsumerType(MLIRContext *ctx) 367 : ConversionPattern("test.type_consumer", 1, ctx) {} 368 LogicalResult 369 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 370 ConversionPatternRewriter &rewriter) const final { 371 // Verify that the incoming operand has been successfully remapped to F64. 372 if (!operands[0].getType().isF64()) 373 return failure(); 374 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 375 return success(); 376 } 377 }; 378 379 //===----------------------------------------------------------------------===// 380 // Non-Root Replacement Rewrite Testing 381 /// This pattern generates an invalid operation, but replaces it before the 382 /// pattern is finished. This checks that we don't need to legalize the 383 /// temporary op. 384 struct TestNonRootReplacement : public RewritePattern { 385 TestNonRootReplacement(MLIRContext *ctx) 386 : RewritePattern("test.replace_non_root", 1, ctx) {} 387 388 LogicalResult matchAndRewrite(Operation *op, 389 PatternRewriter &rewriter) const final { 390 auto resultType = *op->result_type_begin(); 391 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 392 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 393 394 rewriter.replaceOp(illegalOp, {legalOp}); 395 rewriter.replaceOp(op, {illegalOp}); 396 return success(); 397 } 398 }; 399 400 //===----------------------------------------------------------------------===// 401 // Recursive Rewrite Testing 402 /// This pattern is applied to the same operation multiple times, but has a 403 /// bounded recursion. 404 struct TestBoundedRecursiveRewrite 405 : public OpRewritePattern<TestRecursiveRewriteOp> { 406 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 407 408 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 409 PatternRewriter &rewriter) const final { 410 // Decrement the depth of the op in-place. 411 rewriter.updateRootInPlace(op, [&] { 412 op.setAttr("depth", 413 rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1)); 414 }); 415 return success(); 416 } 417 418 /// The conversion target handles bounding the recursion of this pattern. 419 bool hasBoundedRewriteRecursion() const final { return true; } 420 }; 421 } // namespace 422 423 namespace { 424 struct TestTypeConverter : public TypeConverter { 425 using TypeConverter::TypeConverter; 426 TestTypeConverter() { addConversion(convertType); } 427 428 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 429 // Drop I16 types. 430 if (t.isSignlessInteger(16)) 431 return success(); 432 433 // Convert I64 to F64. 434 if (t.isSignlessInteger(64)) { 435 results.push_back(FloatType::getF64(t.getContext())); 436 return success(); 437 } 438 439 // Split F32 into F16,F16. 440 if (t.isF32()) { 441 results.assign(2, FloatType::getF16(t.getContext())); 442 return success(); 443 } 444 445 // Otherwise, convert the type directly. 446 results.push_back(t); 447 return success(); 448 } 449 450 /// Override the hook to materialize a conversion. This is necessary because 451 /// we generate 1->N type mappings. 452 Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, 453 ArrayRef<Value> inputs, 454 Location loc) override { 455 return rewriter.create<TestCastOp>(loc, resultType, inputs); 456 } 457 }; 458 459 struct TestLegalizePatternDriver 460 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 461 /// The mode of conversion to use with the driver. 462 enum class ConversionMode { Analysis, Full, Partial }; 463 464 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 465 466 void runOnOperation() override { 467 TestTypeConverter converter; 468 mlir::OwningRewritePatternList patterns; 469 populateWithGenerated(&getContext(), &patterns); 470 patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 471 TestCreateBlock, TestCreateIllegalBlock, 472 TestUndoBlockArgReplace, TestPassthroughInvalidOp, 473 TestSplitReturnType, TestChangeProducerTypeI32ToF32, 474 TestChangeProducerTypeF32ToF64, 475 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 476 TestNonRootReplacement, TestBoundedRecursiveRewrite>( 477 &getContext()); 478 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 479 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 480 converter); 481 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 482 converter); 483 484 // Define the conversion target used for the test. 485 ConversionTarget target(getContext()); 486 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 487 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 488 TerminatorOp>(); 489 target 490 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 491 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 492 // Don't allow F32 operands. 493 return llvm::none_of(op.getOperandTypes(), 494 [](Type type) { return type.isF32(); }); 495 }); 496 target.addDynamicallyLegalOp<FuncOp>( 497 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 498 499 // Expect the type_producer/type_consumer operations to only operate on f64. 500 target.addDynamicallyLegalOp<TestTypeProducerOp>( 501 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 502 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 503 return op.getOperand().getType().isF64(); 504 }); 505 506 // Check support for marking certain operations as recursively legal. 507 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 508 return static_cast<bool>( 509 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 510 }); 511 512 // Mark the bound recursion operation as dynamically legal. 513 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 514 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 515 516 // Handle a partial conversion. 517 if (mode == ConversionMode::Partial) { 518 (void)applyPartialConversion(getOperation(), target, patterns, 519 &converter); 520 return; 521 } 522 523 // Handle a full conversion. 524 if (mode == ConversionMode::Full) { 525 // Check support for marking unknown operations as dynamically legal. 526 target.markUnknownOpDynamicallyLegal([](Operation *op) { 527 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 528 }); 529 530 (void)applyFullConversion(getOperation(), target, patterns, &converter); 531 return; 532 } 533 534 // Otherwise, handle an analysis conversion. 535 assert(mode == ConversionMode::Analysis); 536 537 // Analyze the convertible operations. 538 DenseSet<Operation *> legalizedOps; 539 if (failed(applyAnalysisConversion(getOperation(), target, patterns, 540 legalizedOps, &converter))) 541 return signalPassFailure(); 542 543 // Emit remarks for each legalizable operation. 544 for (auto *op : legalizedOps) 545 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 546 } 547 548 /// The mode of conversion to use. 549 ConversionMode mode; 550 }; 551 } // end anonymous namespace 552 553 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 554 legalizerConversionMode( 555 "test-legalize-mode", 556 llvm::cl::desc("The legalization mode to use with the test driver"), 557 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 558 llvm::cl::values( 559 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 560 "analysis", "Perform an analysis conversion"), 561 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 562 "Perform a full conversion"), 563 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 564 "partial", "Perform a partial conversion"))); 565 566 //===----------------------------------------------------------------------===// 567 // ConversionPatternRewriter::getRemappedValue testing. This method is used 568 // to get the remapped value of an original value that was replaced using 569 // ConversionPatternRewriter. 570 namespace { 571 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 572 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 573 /// operand twice. 574 /// 575 /// Example: 576 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 577 /// is replaced with: 578 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 579 struct OneVResOneVOperandOp1Converter 580 : public OpConversionPattern<OneVResOneVOperandOp1> { 581 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 582 583 LogicalResult 584 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 585 ConversionPatternRewriter &rewriter) const override { 586 auto origOps = op.getOperands(); 587 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 588 "One operand expected"); 589 Value origOp = *origOps.begin(); 590 SmallVector<Value, 2> remappedOperands; 591 // Replicate the remapped original operand twice. Note that we don't used 592 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 593 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 594 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 595 596 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 597 remappedOperands); 598 return success(); 599 } 600 }; 601 602 struct TestRemappedValue 603 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 604 void runOnFunction() override { 605 mlir::OwningRewritePatternList patterns; 606 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 607 608 mlir::ConversionTarget target(getContext()); 609 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 610 // We make OneVResOneVOperandOp1 legal only when it has more that one 611 // operand. This will trigger the conversion that will replace one-operand 612 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 613 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 614 [](Operation *op) -> bool { 615 return std::distance(op->operand_begin(), op->operand_end()) > 1; 616 }); 617 618 if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { 619 signalPassFailure(); 620 } 621 } 622 }; 623 } // end anonymous namespace 624 625 namespace mlir { 626 void registerPatternsTestPass() { 627 mlir::PassRegistration<TestReturnTypeDriver>("test-return-type", 628 "Run return type functions"); 629 630 mlir::PassRegistration<TestDerivedAttributeDriver>( 631 "test-derived-attr", "Run test derived attributes"); 632 633 mlir::PassRegistration<TestPatternDriver>("test-patterns", 634 "Run test dialect patterns"); 635 636 mlir::PassRegistration<TestLegalizePatternDriver>( 637 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 638 return std::make_unique<TestLegalizePatternDriver>( 639 legalizerConversionMode); 640 }); 641 642 PassRegistration<TestRemappedValue>( 643 "test-remapped-value", 644 "Test public remapped value mechanism in ConversionPatternRewriter"); 645 } 646 } // namespace mlir 647