1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/IR/PatternMatch.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 17 using namespace mlir; 18 19 // Native function for testing NativeCodeCall 20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 21 return choice.getValue() ? input1 : input2; 22 } 23 24 static void createOpI(PatternRewriter &rewriter, Value input) { 25 rewriter.create<OpI>(rewriter.getUnknownLoc(), input); 26 } 27 28 static void handleNoResultOp(PatternRewriter &rewriter, 29 OpSymbolBindingNoResult op) { 30 // Turn the no result op to a one-result op. 31 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 32 op.operand()); 33 } 34 35 namespace { 36 #include "TestPatterns.inc" 37 } // end anonymous namespace 38 39 //===----------------------------------------------------------------------===// 40 // Canonicalizer Driver. 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 struct FoldingPattern : public RewritePattern { 45 public: 46 FoldingPattern(MLIRContext *context) 47 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 48 /*benefit=*/1, context) {} 49 50 LogicalResult matchAndRewrite(Operation *op, 51 PatternRewriter &rewriter) const override { 52 // Exercice OperationFolder API for a single-result operation that is folded 53 // upon construction. The operation being created through the folder has an 54 // in-place folder, and it should be still present in the output. 55 // Furthermore, the folder should not crash when attempting to recover the 56 // (unchanged) opeation result. 57 OperationFolder folder(op->getContext()); 58 Value result = folder.create<TestOpInPlaceFold>( 59 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 60 rewriter.getI32IntegerAttr(0)); 61 assert(result); 62 rewriter.replaceOp(op, result); 63 return success(); 64 } 65 }; 66 67 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 68 void runOnFunction() override { 69 mlir::OwningRewritePatternList patterns; 70 populateWithGenerated(&getContext(), &patterns); 71 72 // Verify named pattern is generated with expected name. 73 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext()); 74 75 applyPatternsAndFoldGreedily(getFunction(), patterns); 76 } 77 }; 78 } // end anonymous namespace 79 80 //===----------------------------------------------------------------------===// 81 // ReturnType Driver. 82 //===----------------------------------------------------------------------===// 83 84 namespace { 85 // Generate ops for each instance where the type can be successfully inferred. 86 template <typename OpTy> 87 static void invokeCreateWithInferredReturnType(Operation *op) { 88 auto *context = op->getContext(); 89 auto fop = op->getParentOfType<FuncOp>(); 90 auto location = UnknownLoc::get(context); 91 OpBuilder b(op); 92 b.setInsertionPointAfter(op); 93 94 // Use permutations of 2 args as operands. 95 assert(fop.getNumArguments() >= 2); 96 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 97 for (int j = 0; j < e; ++j) { 98 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 99 SmallVector<Type, 2> inferredReturnTypes; 100 if (succeeded(OpTy::inferReturnTypes( 101 context, llvm::None, values, op->getAttrDictionary(), 102 op->getRegions(), inferredReturnTypes))) { 103 OperationState state(location, OpTy::getOperationName()); 104 // TODO(jpienaar): Expand to regions. 105 OpTy::build(b, state, values, op->getAttrs()); 106 (void)b.createOperation(state); 107 } 108 } 109 } 110 } 111 112 static void reifyReturnShape(Operation *op) { 113 OpBuilder b(op); 114 115 // Use permutations of 2 args as operands. 116 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 117 SmallVector<Value, 2> shapes; 118 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 119 return; 120 for (auto it : llvm::enumerate(shapes)) 121 op->emitRemark() << "value " << it.index() << ": " 122 << it.value().getDefiningOp(); 123 } 124 125 struct TestReturnTypeDriver 126 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 127 void runOnFunction() override { 128 if (getFunction().getName() == "testCreateFunctions") { 129 std::vector<Operation *> ops; 130 // Collect ops to avoid triggering on inserted ops. 131 for (auto &op : getFunction().getBody().front()) 132 ops.push_back(&op); 133 // Generate test patterns for each, but skip terminator. 134 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 135 // Test create method of each of the Op classes below. The resultant 136 // output would be in reverse order underneath `op` from which 137 // the attributes and regions are used. 138 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 139 invokeCreateWithInferredReturnType< 140 OpWithShapedTypeInferTypeInterfaceOp>(op); 141 }; 142 return; 143 } 144 if (getFunction().getName() == "testReifyFunctions") { 145 std::vector<Operation *> ops; 146 // Collect ops to avoid triggering on inserted ops. 147 for (auto &op : getFunction().getBody().front()) 148 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 149 ops.push_back(&op); 150 // Generate test patterns for each, but skip terminator. 151 for (auto *op : ops) 152 reifyReturnShape(op); 153 } 154 } 155 }; 156 } // end anonymous namespace 157 158 namespace { 159 struct TestDerivedAttributeDriver 160 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 161 void runOnFunction() override; 162 }; 163 } // end anonymous namespace 164 165 void TestDerivedAttributeDriver::runOnFunction() { 166 getFunction().walk([](DerivedAttributeOpInterface dOp) { 167 auto dAttr = dOp.materializeDerivedAttributes(); 168 if (!dAttr) 169 return; 170 for (auto d : dAttr) 171 dOp.emitRemark() << d.first << " = " << d.second; 172 }); 173 } 174 175 //===----------------------------------------------------------------------===// 176 // Legalization Driver. 177 //===----------------------------------------------------------------------===// 178 179 namespace { 180 //===----------------------------------------------------------------------===// 181 // Region-Block Rewrite Testing 182 183 /// This pattern is a simple pattern that inlines the first region of a given 184 /// operation into the parent region. 185 struct TestRegionRewriteBlockMovement : public ConversionPattern { 186 TestRegionRewriteBlockMovement(MLIRContext *ctx) 187 : ConversionPattern("test.region", 1, ctx) {} 188 189 LogicalResult 190 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 191 ConversionPatternRewriter &rewriter) const final { 192 // Inline this region into the parent region. 193 auto &parentRegion = *op->getParentRegion(); 194 if (op->getAttr("legalizer.should_clone")) 195 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 196 parentRegion.end()); 197 else 198 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 199 parentRegion.end()); 200 201 // Drop this operation. 202 rewriter.eraseOp(op); 203 return success(); 204 } 205 }; 206 /// This pattern is a simple pattern that generates a region containing an 207 /// illegal operation. 208 struct TestRegionRewriteUndo : public RewritePattern { 209 TestRegionRewriteUndo(MLIRContext *ctx) 210 : RewritePattern("test.region_builder", 1, ctx) {} 211 212 LogicalResult matchAndRewrite(Operation *op, 213 PatternRewriter &rewriter) const final { 214 // Create the region operation with an entry block containing arguments. 215 OperationState newRegion(op->getLoc(), "test.region"); 216 newRegion.addRegion(); 217 auto *regionOp = rewriter.createOperation(newRegion); 218 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 219 entryBlock->addArgument(rewriter.getIntegerType(64)); 220 221 // Add an explicitly illegal operation to ensure the conversion fails. 222 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 223 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 224 225 // Drop this operation. 226 rewriter.eraseOp(op); 227 return success(); 228 } 229 }; 230 /// A simple pattern that creates a block at the end of the parent region of the 231 /// matched operation. 232 struct TestCreateBlock : public RewritePattern { 233 TestCreateBlock(MLIRContext *ctx) 234 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 235 236 LogicalResult matchAndRewrite(Operation *op, 237 PatternRewriter &rewriter) const final { 238 Region ®ion = *op->getParentRegion(); 239 Type i32Type = rewriter.getIntegerType(32); 240 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 241 rewriter.create<TerminatorOp>(op->getLoc()); 242 rewriter.replaceOp(op, {}); 243 return success(); 244 } 245 }; 246 247 /// A simple pattern that creates a block containing an invalid operaiton in 248 /// order to trigger the block creation undo mechanism. 249 struct TestCreateIllegalBlock : public RewritePattern { 250 TestCreateIllegalBlock(MLIRContext *ctx) 251 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 252 253 LogicalResult matchAndRewrite(Operation *op, 254 PatternRewriter &rewriter) const final { 255 Region ®ion = *op->getParentRegion(); 256 Type i32Type = rewriter.getIntegerType(32); 257 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 258 // Create an illegal op to ensure the conversion fails. 259 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 260 rewriter.create<TerminatorOp>(op->getLoc()); 261 rewriter.replaceOp(op, {}); 262 return success(); 263 } 264 }; 265 266 /// A simple pattern that tests the undo mechanism when replacing the uses of a 267 /// block argument. 268 struct TestUndoBlockArgReplace : public ConversionPattern { 269 TestUndoBlockArgReplace(MLIRContext *ctx) 270 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 271 272 LogicalResult 273 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 274 ConversionPatternRewriter &rewriter) const final { 275 auto illegalOp = 276 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 277 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0), 278 illegalOp); 279 rewriter.updateRootInPlace(op, [] {}); 280 return success(); 281 } 282 }; 283 284 /// A rewrite pattern that tests the undo mechanism when erasing a block. 285 struct TestUndoBlockErase : public ConversionPattern { 286 TestUndoBlockErase(MLIRContext *ctx) 287 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 288 289 LogicalResult 290 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 291 ConversionPatternRewriter &rewriter) const final { 292 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 293 rewriter.setInsertionPointToStart(secondBlock); 294 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 295 rewriter.eraseBlock(secondBlock); 296 rewriter.updateRootInPlace(op, [] {}); 297 return success(); 298 } 299 }; 300 301 //===----------------------------------------------------------------------===// 302 // Type-Conversion Rewrite Testing 303 304 /// This patterns erases a region operation that has had a type conversion. 305 struct TestDropOpSignatureConversion : public ConversionPattern { 306 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 307 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} 308 LogicalResult 309 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 310 ConversionPatternRewriter &rewriter) const override { 311 Region ®ion = op->getRegion(0); 312 Block *entry = ®ion.front(); 313 314 // Convert the original entry arguments. 315 TypeConverter &converter = *getTypeConverter(); 316 TypeConverter::SignatureConversion result(entry->getNumArguments()); 317 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 318 result)) || 319 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 320 return failure(); 321 322 // Convert the region signature and just drop the operation. 323 rewriter.eraseOp(op); 324 return success(); 325 } 326 }; 327 /// This pattern simply updates the operands of the given operation. 328 struct TestPassthroughInvalidOp : public ConversionPattern { 329 TestPassthroughInvalidOp(MLIRContext *ctx) 330 : ConversionPattern("test.invalid", 1, ctx) {} 331 LogicalResult 332 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 333 ConversionPatternRewriter &rewriter) const final { 334 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 335 llvm::None); 336 return success(); 337 } 338 }; 339 /// This pattern handles the case of a split return value. 340 struct TestSplitReturnType : public ConversionPattern { 341 TestSplitReturnType(MLIRContext *ctx) 342 : ConversionPattern("test.return", 1, ctx) {} 343 LogicalResult 344 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const final { 346 // Check for a return of F32. 347 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 348 return failure(); 349 350 // Check if the first operation is a cast operation, if it is we use the 351 // results directly. 352 auto *defOp = operands[0].getDefiningOp(); 353 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 354 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 355 return success(); 356 } 357 358 // Otherwise, fail to match. 359 return failure(); 360 } 361 }; 362 363 //===----------------------------------------------------------------------===// 364 // Multi-Level Type-Conversion Rewrite Testing 365 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 366 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 367 : ConversionPattern("test.type_producer", 1, ctx) {} 368 LogicalResult 369 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 370 ConversionPatternRewriter &rewriter) const final { 371 // If the type is I32, change the type to F32. 372 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 373 return failure(); 374 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 375 return success(); 376 } 377 }; 378 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 379 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 380 : ConversionPattern("test.type_producer", 1, ctx) {} 381 LogicalResult 382 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 383 ConversionPatternRewriter &rewriter) const final { 384 // If the type is F32, change the type to F64. 385 if (!Type(*op->result_type_begin()).isF32()) 386 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 387 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 388 return success(); 389 } 390 }; 391 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 392 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 393 : ConversionPattern("test.type_producer", 10, ctx) {} 394 LogicalResult 395 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 396 ConversionPatternRewriter &rewriter) const final { 397 // Always convert to B16, even though it is not a legal type. This tests 398 // that values are unmapped correctly. 399 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 400 return success(); 401 } 402 }; 403 struct TestUpdateConsumerType : public ConversionPattern { 404 TestUpdateConsumerType(MLIRContext *ctx) 405 : ConversionPattern("test.type_consumer", 1, ctx) {} 406 LogicalResult 407 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 408 ConversionPatternRewriter &rewriter) const final { 409 // Verify that the incoming operand has been successfully remapped to F64. 410 if (!operands[0].getType().isF64()) 411 return failure(); 412 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 413 return success(); 414 } 415 }; 416 417 //===----------------------------------------------------------------------===// 418 // Non-Root Replacement Rewrite Testing 419 /// This pattern generates an invalid operation, but replaces it before the 420 /// pattern is finished. This checks that we don't need to legalize the 421 /// temporary op. 422 struct TestNonRootReplacement : public RewritePattern { 423 TestNonRootReplacement(MLIRContext *ctx) 424 : RewritePattern("test.replace_non_root", 1, ctx) {} 425 426 LogicalResult matchAndRewrite(Operation *op, 427 PatternRewriter &rewriter) const final { 428 auto resultType = *op->result_type_begin(); 429 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 430 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 431 432 rewriter.replaceOp(illegalOp, {legalOp}); 433 rewriter.replaceOp(op, {illegalOp}); 434 return success(); 435 } 436 }; 437 438 //===----------------------------------------------------------------------===// 439 // Recursive Rewrite Testing 440 /// This pattern is applied to the same operation multiple times, but has a 441 /// bounded recursion. 442 struct TestBoundedRecursiveRewrite 443 : public OpRewritePattern<TestRecursiveRewriteOp> { 444 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 445 446 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 447 PatternRewriter &rewriter) const final { 448 // Decrement the depth of the op in-place. 449 rewriter.updateRootInPlace(op, [&] { 450 op.setAttr("depth", 451 rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1)); 452 }); 453 return success(); 454 } 455 456 /// The conversion target handles bounding the recursion of this pattern. 457 bool hasBoundedRewriteRecursion() const final { return true; } 458 }; 459 460 struct TestNestedOpCreationUndoRewrite 461 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 462 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 463 464 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 465 PatternRewriter &rewriter) const final { 466 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 467 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 468 return success(); 469 }; 470 }; 471 } // namespace 472 473 namespace { 474 struct TestTypeConverter : public TypeConverter { 475 using TypeConverter::TypeConverter; 476 TestTypeConverter() { 477 addConversion(convertType); 478 addMaterialization(materializeCast); 479 addMaterialization(materializeOneToOneCast); 480 } 481 482 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 483 // Drop I16 types. 484 if (t.isSignlessInteger(16)) 485 return success(); 486 487 // Convert I64 to F64. 488 if (t.isSignlessInteger(64)) { 489 results.push_back(FloatType::getF64(t.getContext())); 490 return success(); 491 } 492 493 // Convert I42 to I43. 494 if (t.isInteger(42)) { 495 results.push_back(IntegerType::get(43, t.getContext())); 496 return success(); 497 } 498 499 // Split F32 into F16,F16. 500 if (t.isF32()) { 501 results.assign(2, FloatType::getF16(t.getContext())); 502 return success(); 503 } 504 505 // Otherwise, convert the type directly. 506 results.push_back(t); 507 return success(); 508 } 509 510 /// Hook for materializing a conversion. This is necessary because we generate 511 /// 1->N type mappings. 512 static Optional<Value> materializeCast(PatternRewriter &rewriter, 513 Type resultType, ValueRange inputs, 514 Location loc) { 515 if (inputs.size() == 1) 516 return inputs[0]; 517 return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult(); 518 } 519 520 /// Materialize the cast for one-to-one conversion from i64 to f64. 521 static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter, 522 IntegerType resultType, 523 ValueRange inputs, 524 Location loc) { 525 if (resultType.getWidth() == 42 && inputs.size() == 1) 526 return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult(); 527 return llvm::None; 528 } 529 }; 530 531 struct TestLegalizePatternDriver 532 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 533 /// The mode of conversion to use with the driver. 534 enum class ConversionMode { Analysis, Full, Partial }; 535 536 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 537 538 void runOnOperation() override { 539 TestTypeConverter converter; 540 mlir::OwningRewritePatternList patterns; 541 populateWithGenerated(&getContext(), &patterns); 542 patterns.insert< 543 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 544 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, 545 TestPassthroughInvalidOp, TestSplitReturnType, 546 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 547 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 548 TestNonRootReplacement, TestBoundedRecursiveRewrite, 549 TestNestedOpCreationUndoRewrite>(&getContext()); 550 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 551 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 552 converter); 553 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 554 converter); 555 556 // Define the conversion target used for the test. 557 ConversionTarget target(getContext()); 558 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 559 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 560 TerminatorOp>(); 561 target 562 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 563 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 564 // Don't allow F32 operands. 565 return llvm::none_of(op.getOperandTypes(), 566 [](Type type) { return type.isF32(); }); 567 }); 568 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 569 return converter.isSignatureLegal(op.getType()) && 570 converter.isLegal(&op.getBody()); 571 }); 572 573 // Expect the type_producer/type_consumer operations to only operate on f64. 574 target.addDynamicallyLegalOp<TestTypeProducerOp>( 575 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 576 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 577 return op.getOperand().getType().isF64(); 578 }); 579 580 // Check support for marking certain operations as recursively legal. 581 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 582 return static_cast<bool>( 583 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 584 }); 585 586 // Mark the bound recursion operation as dynamically legal. 587 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 588 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 589 590 // Handle a partial conversion. 591 if (mode == ConversionMode::Partial) { 592 DenseSet<Operation *> unlegalizedOps; 593 (void)applyPartialConversion(getOperation(), target, patterns, 594 &unlegalizedOps); 595 // Emit remarks for each legalizable operation. 596 for (auto *op : unlegalizedOps) 597 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 598 return; 599 } 600 601 // Handle a full conversion. 602 if (mode == ConversionMode::Full) { 603 // Check support for marking unknown operations as dynamically legal. 604 target.markUnknownOpDynamicallyLegal([](Operation *op) { 605 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 606 }); 607 608 (void)applyFullConversion(getOperation(), target, patterns); 609 return; 610 } 611 612 // Otherwise, handle an analysis conversion. 613 assert(mode == ConversionMode::Analysis); 614 615 // Analyze the convertible operations. 616 DenseSet<Operation *> legalizedOps; 617 if (failed(applyAnalysisConversion(getOperation(), target, patterns, 618 legalizedOps))) 619 return signalPassFailure(); 620 621 // Emit remarks for each legalizable operation. 622 for (auto *op : legalizedOps) 623 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 624 } 625 626 /// The mode of conversion to use. 627 ConversionMode mode; 628 }; 629 } // end anonymous namespace 630 631 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 632 legalizerConversionMode( 633 "test-legalize-mode", 634 llvm::cl::desc("The legalization mode to use with the test driver"), 635 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 636 llvm::cl::values( 637 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 638 "analysis", "Perform an analysis conversion"), 639 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 640 "Perform a full conversion"), 641 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 642 "partial", "Perform a partial conversion"))); 643 644 //===----------------------------------------------------------------------===// 645 // ConversionPatternRewriter::getRemappedValue testing. This method is used 646 // to get the remapped value of an original value that was replaced using 647 // ConversionPatternRewriter. 648 namespace { 649 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 650 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 651 /// operand twice. 652 /// 653 /// Example: 654 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 655 /// is replaced with: 656 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 657 struct OneVResOneVOperandOp1Converter 658 : public OpConversionPattern<OneVResOneVOperandOp1> { 659 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 660 661 LogicalResult 662 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 663 ConversionPatternRewriter &rewriter) const override { 664 auto origOps = op.getOperands(); 665 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 666 "One operand expected"); 667 Value origOp = *origOps.begin(); 668 SmallVector<Value, 2> remappedOperands; 669 // Replicate the remapped original operand twice. Note that we don't used 670 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 671 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 672 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 673 674 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 675 remappedOperands); 676 return success(); 677 } 678 }; 679 680 struct TestRemappedValue 681 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 682 void runOnFunction() override { 683 mlir::OwningRewritePatternList patterns; 684 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 685 686 mlir::ConversionTarget target(getContext()); 687 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 688 // We make OneVResOneVOperandOp1 legal only when it has more that one 689 // operand. This will trigger the conversion that will replace one-operand 690 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 691 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 692 [](Operation *op) -> bool { 693 return std::distance(op->operand_begin(), op->operand_end()) > 1; 694 }); 695 696 if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { 697 signalPassFailure(); 698 } 699 } 700 }; 701 } // end anonymous namespace 702 703 //===----------------------------------------------------------------------===// 704 // Test patterns without a specific root operation kind 705 //===----------------------------------------------------------------------===// 706 707 namespace { 708 /// This pattern matches and removes any operation in the test dialect. 709 struct RemoveTestDialectOps : public RewritePattern { 710 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 711 712 LogicalResult matchAndRewrite(Operation *op, 713 PatternRewriter &rewriter) const override { 714 if (!isa<TestDialect>(op->getDialect())) 715 return failure(); 716 rewriter.eraseOp(op); 717 return success(); 718 } 719 }; 720 721 struct TestUnknownRootOpDriver 722 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 723 void runOnFunction() override { 724 mlir::OwningRewritePatternList patterns; 725 patterns.insert<RemoveTestDialectOps>(); 726 727 mlir::ConversionTarget target(getContext()); 728 target.addIllegalDialect<TestDialect>(); 729 if (failed(applyPartialConversion(getFunction(), target, patterns))) 730 signalPassFailure(); 731 } 732 }; 733 } // end anonymous namespace 734 735 namespace mlir { 736 void registerPatternsTestPass() { 737 PassRegistration<TestReturnTypeDriver>("test-return-type", 738 "Run return type functions"); 739 740 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 741 "Run test derived attributes"); 742 743 PassRegistration<TestPatternDriver>("test-patterns", 744 "Run test dialect patterns"); 745 746 PassRegistration<TestLegalizePatternDriver>( 747 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 748 return std::make_unique<TestLegalizePatternDriver>( 749 legalizerConversionMode); 750 }); 751 752 PassRegistration<TestRemappedValue>( 753 "test-remapped-value", 754 "Test public remapped value mechanism in ConversionPatternRewriter"); 755 756 PassRegistration<TestUnknownRootOpDriver>( 757 "test-legalize-unknown-root-patterns", 758 "Test public remapped value mechanism in ConversionPatternRewriter"); 759 } 760 } // namespace mlir 761