1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::test; 20 21 // Native function for testing NativeCodeCall 22 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 23 return choice.getValue() ? input1 : input2; 24 } 25 26 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 27 rewriter.create<OpI>(loc, input); 28 } 29 30 static void handleNoResultOp(PatternRewriter &rewriter, 31 OpSymbolBindingNoResult op) { 32 // Turn the no result op to a one-result op. 33 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 34 op.operand()); 35 } 36 37 // Test that natives calls are only called once during rewrites. 38 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 39 // This let us check the number of times OpM_Test was called by inspecting 40 // the returned value in the MLIR output. 41 static int64_t opMIncreasingValue = 314159265; 42 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 43 int64_t i = opMIncreasingValue++; 44 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 45 } 46 47 namespace { 48 #include "TestPatterns.inc" 49 } // end anonymous namespace 50 51 //===----------------------------------------------------------------------===// 52 // Canonicalizer Driver. 53 //===----------------------------------------------------------------------===// 54 55 namespace { 56 struct FoldingPattern : public RewritePattern { 57 public: 58 FoldingPattern(MLIRContext *context) 59 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 60 /*benefit=*/1, context) {} 61 62 LogicalResult matchAndRewrite(Operation *op, 63 PatternRewriter &rewriter) const override { 64 // Exercice OperationFolder API for a single-result operation that is folded 65 // upon construction. The operation being created through the folder has an 66 // in-place folder, and it should be still present in the output. 67 // Furthermore, the folder should not crash when attempting to recover the 68 // (unchanged) operation result. 69 OperationFolder folder(op->getContext()); 70 Value result = folder.create<TestOpInPlaceFold>( 71 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 72 rewriter.getI32IntegerAttr(0)); 73 assert(result); 74 rewriter.replaceOp(op, result); 75 return success(); 76 } 77 }; 78 79 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 80 void runOnFunction() override { 81 mlir::OwningRewritePatternList patterns; 82 populateWithGenerated(&getContext(), patterns); 83 84 // Verify named pattern is generated with expected name. 85 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext()); 86 87 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 88 } 89 }; 90 } // end anonymous namespace 91 92 //===----------------------------------------------------------------------===// 93 // ReturnType Driver. 94 //===----------------------------------------------------------------------===// 95 96 namespace { 97 // Generate ops for each instance where the type can be successfully inferred. 98 template <typename OpTy> 99 static void invokeCreateWithInferredReturnType(Operation *op) { 100 auto *context = op->getContext(); 101 auto fop = op->getParentOfType<FuncOp>(); 102 auto location = UnknownLoc::get(context); 103 OpBuilder b(op); 104 b.setInsertionPointAfter(op); 105 106 // Use permutations of 2 args as operands. 107 assert(fop.getNumArguments() >= 2); 108 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 109 for (int j = 0; j < e; ++j) { 110 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 111 SmallVector<Type, 2> inferredReturnTypes; 112 if (succeeded(OpTy::inferReturnTypes( 113 context, llvm::None, values, op->getAttrDictionary(), 114 op->getRegions(), inferredReturnTypes))) { 115 OperationState state(location, OpTy::getOperationName()); 116 // TODO: Expand to regions. 117 OpTy::build(b, state, values, op->getAttrs()); 118 (void)b.createOperation(state); 119 } 120 } 121 } 122 } 123 124 static void reifyReturnShape(Operation *op) { 125 OpBuilder b(op); 126 127 // Use permutations of 2 args as operands. 128 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 129 SmallVector<Value, 2> shapes; 130 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 131 return; 132 for (auto it : llvm::enumerate(shapes)) 133 op->emitRemark() << "value " << it.index() << ": " 134 << it.value().getDefiningOp(); 135 } 136 137 struct TestReturnTypeDriver 138 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 139 void runOnFunction() override { 140 if (getFunction().getName() == "testCreateFunctions") { 141 std::vector<Operation *> ops; 142 // Collect ops to avoid triggering on inserted ops. 143 for (auto &op : getFunction().getBody().front()) 144 ops.push_back(&op); 145 // Generate test patterns for each, but skip terminator. 146 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 147 // Test create method of each of the Op classes below. The resultant 148 // output would be in reverse order underneath `op` from which 149 // the attributes and regions are used. 150 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 151 invokeCreateWithInferredReturnType< 152 OpWithShapedTypeInferTypeInterfaceOp>(op); 153 }; 154 return; 155 } 156 if (getFunction().getName() == "testReifyFunctions") { 157 std::vector<Operation *> ops; 158 // Collect ops to avoid triggering on inserted ops. 159 for (auto &op : getFunction().getBody().front()) 160 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 161 ops.push_back(&op); 162 // Generate test patterns for each, but skip terminator. 163 for (auto *op : ops) 164 reifyReturnShape(op); 165 } 166 } 167 }; 168 } // end anonymous namespace 169 170 namespace { 171 struct TestDerivedAttributeDriver 172 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 173 void runOnFunction() override; 174 }; 175 } // end anonymous namespace 176 177 void TestDerivedAttributeDriver::runOnFunction() { 178 getFunction().walk([](DerivedAttributeOpInterface dOp) { 179 auto dAttr = dOp.materializeDerivedAttributes(); 180 if (!dAttr) 181 return; 182 for (auto d : dAttr) 183 dOp.emitRemark() << d.first << " = " << d.second; 184 }); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // Legalization Driver. 189 //===----------------------------------------------------------------------===// 190 191 namespace { 192 //===----------------------------------------------------------------------===// 193 // Region-Block Rewrite Testing 194 195 /// This pattern is a simple pattern that inlines the first region of a given 196 /// operation into the parent region. 197 struct TestRegionRewriteBlockMovement : public ConversionPattern { 198 TestRegionRewriteBlockMovement(MLIRContext *ctx) 199 : ConversionPattern("test.region", 1, ctx) {} 200 201 LogicalResult 202 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 203 ConversionPatternRewriter &rewriter) const final { 204 // Inline this region into the parent region. 205 auto &parentRegion = *op->getParentRegion(); 206 if (op->getAttr("legalizer.should_clone")) 207 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 208 parentRegion.end()); 209 else 210 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 211 parentRegion.end()); 212 213 // Drop this operation. 214 rewriter.eraseOp(op); 215 return success(); 216 } 217 }; 218 /// This pattern is a simple pattern that generates a region containing an 219 /// illegal operation. 220 struct TestRegionRewriteUndo : public RewritePattern { 221 TestRegionRewriteUndo(MLIRContext *ctx) 222 : RewritePattern("test.region_builder", 1, ctx) {} 223 224 LogicalResult matchAndRewrite(Operation *op, 225 PatternRewriter &rewriter) const final { 226 // Create the region operation with an entry block containing arguments. 227 OperationState newRegion(op->getLoc(), "test.region"); 228 newRegion.addRegion(); 229 auto *regionOp = rewriter.createOperation(newRegion); 230 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 231 entryBlock->addArgument(rewriter.getIntegerType(64)); 232 233 // Add an explicitly illegal operation to ensure the conversion fails. 234 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 235 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 236 237 // Drop this operation. 238 rewriter.eraseOp(op); 239 return success(); 240 } 241 }; 242 /// A simple pattern that creates a block at the end of the parent region of the 243 /// matched operation. 244 struct TestCreateBlock : public RewritePattern { 245 TestCreateBlock(MLIRContext *ctx) 246 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 247 248 LogicalResult matchAndRewrite(Operation *op, 249 PatternRewriter &rewriter) const final { 250 Region ®ion = *op->getParentRegion(); 251 Type i32Type = rewriter.getIntegerType(32); 252 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 253 rewriter.create<TerminatorOp>(op->getLoc()); 254 rewriter.replaceOp(op, {}); 255 return success(); 256 } 257 }; 258 259 /// A simple pattern that creates a block containing an invalid operation in 260 /// order to trigger the block creation undo mechanism. 261 struct TestCreateIllegalBlock : public RewritePattern { 262 TestCreateIllegalBlock(MLIRContext *ctx) 263 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 264 265 LogicalResult matchAndRewrite(Operation *op, 266 PatternRewriter &rewriter) const final { 267 Region ®ion = *op->getParentRegion(); 268 Type i32Type = rewriter.getIntegerType(32); 269 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 270 // Create an illegal op to ensure the conversion fails. 271 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 272 rewriter.create<TerminatorOp>(op->getLoc()); 273 rewriter.replaceOp(op, {}); 274 return success(); 275 } 276 }; 277 278 /// A simple pattern that tests the undo mechanism when replacing the uses of a 279 /// block argument. 280 struct TestUndoBlockArgReplace : public ConversionPattern { 281 TestUndoBlockArgReplace(MLIRContext *ctx) 282 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 283 284 LogicalResult 285 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 286 ConversionPatternRewriter &rewriter) const final { 287 auto illegalOp = 288 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 289 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 290 illegalOp); 291 rewriter.updateRootInPlace(op, [] {}); 292 return success(); 293 } 294 }; 295 296 /// A rewrite pattern that tests the undo mechanism when erasing a block. 297 struct TestUndoBlockErase : public ConversionPattern { 298 TestUndoBlockErase(MLIRContext *ctx) 299 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 300 301 LogicalResult 302 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 303 ConversionPatternRewriter &rewriter) const final { 304 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 305 rewriter.setInsertionPointToStart(secondBlock); 306 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 307 rewriter.eraseBlock(secondBlock); 308 rewriter.updateRootInPlace(op, [] {}); 309 return success(); 310 } 311 }; 312 313 //===----------------------------------------------------------------------===// 314 // Type-Conversion Rewrite Testing 315 316 /// This patterns erases a region operation that has had a type conversion. 317 struct TestDropOpSignatureConversion : public ConversionPattern { 318 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 319 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} 320 LogicalResult 321 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 322 ConversionPatternRewriter &rewriter) const override { 323 Region ®ion = op->getRegion(0); 324 Block *entry = ®ion.front(); 325 326 // Convert the original entry arguments. 327 TypeConverter &converter = *getTypeConverter(); 328 TypeConverter::SignatureConversion result(entry->getNumArguments()); 329 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 330 result)) || 331 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 332 return failure(); 333 334 // Convert the region signature and just drop the operation. 335 rewriter.eraseOp(op); 336 return success(); 337 } 338 }; 339 /// This pattern simply updates the operands of the given operation. 340 struct TestPassthroughInvalidOp : public ConversionPattern { 341 TestPassthroughInvalidOp(MLIRContext *ctx) 342 : ConversionPattern("test.invalid", 1, ctx) {} 343 LogicalResult 344 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const final { 346 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 347 llvm::None); 348 return success(); 349 } 350 }; 351 /// This pattern handles the case of a split return value. 352 struct TestSplitReturnType : public ConversionPattern { 353 TestSplitReturnType(MLIRContext *ctx) 354 : ConversionPattern("test.return", 1, ctx) {} 355 LogicalResult 356 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 357 ConversionPatternRewriter &rewriter) const final { 358 // Check for a return of F32. 359 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 360 return failure(); 361 362 // Check if the first operation is a cast operation, if it is we use the 363 // results directly. 364 auto *defOp = operands[0].getDefiningOp(); 365 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 366 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 367 return success(); 368 } 369 370 // Otherwise, fail to match. 371 return failure(); 372 } 373 }; 374 375 //===----------------------------------------------------------------------===// 376 // Multi-Level Type-Conversion Rewrite Testing 377 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 378 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 379 : ConversionPattern("test.type_producer", 1, ctx) {} 380 LogicalResult 381 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 382 ConversionPatternRewriter &rewriter) const final { 383 // If the type is I32, change the type to F32. 384 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 385 return failure(); 386 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 387 return success(); 388 } 389 }; 390 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 391 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 392 : ConversionPattern("test.type_producer", 1, ctx) {} 393 LogicalResult 394 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 395 ConversionPatternRewriter &rewriter) const final { 396 // If the type is F32, change the type to F64. 397 if (!Type(*op->result_type_begin()).isF32()) 398 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 399 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 400 return success(); 401 } 402 }; 403 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 404 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 405 : ConversionPattern("test.type_producer", 10, ctx) {} 406 LogicalResult 407 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 408 ConversionPatternRewriter &rewriter) const final { 409 // Always convert to B16, even though it is not a legal type. This tests 410 // that values are unmapped correctly. 411 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 412 return success(); 413 } 414 }; 415 struct TestUpdateConsumerType : public ConversionPattern { 416 TestUpdateConsumerType(MLIRContext *ctx) 417 : ConversionPattern("test.type_consumer", 1, ctx) {} 418 LogicalResult 419 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 420 ConversionPatternRewriter &rewriter) const final { 421 // Verify that the incoming operand has been successfully remapped to F64. 422 if (!operands[0].getType().isF64()) 423 return failure(); 424 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 425 return success(); 426 } 427 }; 428 429 //===----------------------------------------------------------------------===// 430 // Non-Root Replacement Rewrite Testing 431 /// This pattern generates an invalid operation, but replaces it before the 432 /// pattern is finished. This checks that we don't need to legalize the 433 /// temporary op. 434 struct TestNonRootReplacement : public RewritePattern { 435 TestNonRootReplacement(MLIRContext *ctx) 436 : RewritePattern("test.replace_non_root", 1, ctx) {} 437 438 LogicalResult matchAndRewrite(Operation *op, 439 PatternRewriter &rewriter) const final { 440 auto resultType = *op->result_type_begin(); 441 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 442 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 443 444 rewriter.replaceOp(illegalOp, {legalOp}); 445 rewriter.replaceOp(op, {illegalOp}); 446 return success(); 447 } 448 }; 449 450 //===----------------------------------------------------------------------===// 451 // Recursive Rewrite Testing 452 /// This pattern is applied to the same operation multiple times, but has a 453 /// bounded recursion. 454 struct TestBoundedRecursiveRewrite 455 : public OpRewritePattern<TestRecursiveRewriteOp> { 456 TestBoundedRecursiveRewrite(MLIRContext *ctx) 457 : OpRewritePattern<TestRecursiveRewriteOp>(ctx) { 458 // The conversion target handles bounding the recursion of this pattern. 459 setHasBoundedRewriteRecursion(); 460 } 461 462 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 463 PatternRewriter &rewriter) const final { 464 // Decrement the depth of the op in-place. 465 rewriter.updateRootInPlace(op, [&] { 466 op.setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 467 }); 468 return success(); 469 } 470 }; 471 472 struct TestNestedOpCreationUndoRewrite 473 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 474 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 475 476 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 477 PatternRewriter &rewriter) const final { 478 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 479 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 480 return success(); 481 }; 482 }; 483 } // namespace 484 485 namespace { 486 struct TestTypeConverter : public TypeConverter { 487 using TypeConverter::TypeConverter; 488 TestTypeConverter() { 489 addConversion(convertType); 490 addArgumentMaterialization(materializeCast); 491 addArgumentMaterialization(materializeOneToOneCast); 492 addSourceMaterialization(materializeCast); 493 } 494 495 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 496 // Drop I16 types. 497 if (t.isSignlessInteger(16)) 498 return success(); 499 500 // Convert I64 to F64. 501 if (t.isSignlessInteger(64)) { 502 results.push_back(FloatType::getF64(t.getContext())); 503 return success(); 504 } 505 506 // Convert I42 to I43. 507 if (t.isInteger(42)) { 508 results.push_back(IntegerType::get(43, t.getContext())); 509 return success(); 510 } 511 512 // Split F32 into F16,F16. 513 if (t.isF32()) { 514 results.assign(2, FloatType::getF16(t.getContext())); 515 return success(); 516 } 517 518 // Otherwise, convert the type directly. 519 results.push_back(t); 520 return success(); 521 } 522 523 /// Hook for materializing a conversion. This is necessary because we generate 524 /// 1->N type mappings. 525 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 526 ValueRange inputs, Location loc) { 527 if (inputs.size() == 1) 528 return inputs[0]; 529 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 530 } 531 532 /// Materialize the cast for one-to-one conversion from i64 to f64. 533 static Optional<Value> materializeOneToOneCast(OpBuilder &builder, 534 IntegerType resultType, 535 ValueRange inputs, 536 Location loc) { 537 if (resultType.getWidth() == 42 && inputs.size() == 1) 538 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 539 return llvm::None; 540 } 541 }; 542 543 struct TestLegalizePatternDriver 544 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 545 /// The mode of conversion to use with the driver. 546 enum class ConversionMode { Analysis, Full, Partial }; 547 548 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 549 550 void runOnOperation() override { 551 TestTypeConverter converter; 552 mlir::OwningRewritePatternList patterns; 553 populateWithGenerated(&getContext(), patterns); 554 patterns.insert< 555 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 556 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, 557 TestPassthroughInvalidOp, TestSplitReturnType, 558 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 559 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 560 TestNonRootReplacement, TestBoundedRecursiveRewrite, 561 TestNestedOpCreationUndoRewrite>(&getContext()); 562 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 563 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 564 converter); 565 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 566 converter); 567 568 // Define the conversion target used for the test. 569 ConversionTarget target(getContext()); 570 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 571 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 572 TerminatorOp>(); 573 target 574 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 575 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 576 // Don't allow F32 operands. 577 return llvm::none_of(op.getOperandTypes(), 578 [](Type type) { return type.isF32(); }); 579 }); 580 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 581 return converter.isSignatureLegal(op.getType()) && 582 converter.isLegal(&op.getBody()); 583 }); 584 585 // Expect the type_producer/type_consumer operations to only operate on f64. 586 target.addDynamicallyLegalOp<TestTypeProducerOp>( 587 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 588 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 589 return op.getOperand().getType().isF64(); 590 }); 591 592 // Check support for marking certain operations as recursively legal. 593 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 594 return static_cast<bool>( 595 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 596 }); 597 598 // Mark the bound recursion operation as dynamically legal. 599 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 600 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 601 602 // Handle a partial conversion. 603 if (mode == ConversionMode::Partial) { 604 DenseSet<Operation *> unlegalizedOps; 605 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 606 &unlegalizedOps); 607 // Emit remarks for each legalizable operation. 608 for (auto *op : unlegalizedOps) 609 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 610 return; 611 } 612 613 // Handle a full conversion. 614 if (mode == ConversionMode::Full) { 615 // Check support for marking unknown operations as dynamically legal. 616 target.markUnknownOpDynamicallyLegal([](Operation *op) { 617 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 618 }); 619 620 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 621 return; 622 } 623 624 // Otherwise, handle an analysis conversion. 625 assert(mode == ConversionMode::Analysis); 626 627 // Analyze the convertible operations. 628 DenseSet<Operation *> legalizedOps; 629 if (failed(applyAnalysisConversion(getOperation(), target, 630 std::move(patterns), legalizedOps))) 631 return signalPassFailure(); 632 633 // Emit remarks for each legalizable operation. 634 for (auto *op : legalizedOps) 635 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 636 } 637 638 /// The mode of conversion to use. 639 ConversionMode mode; 640 }; 641 } // end anonymous namespace 642 643 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 644 legalizerConversionMode( 645 "test-legalize-mode", 646 llvm::cl::desc("The legalization mode to use with the test driver"), 647 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 648 llvm::cl::values( 649 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 650 "analysis", "Perform an analysis conversion"), 651 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 652 "Perform a full conversion"), 653 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 654 "partial", "Perform a partial conversion"))); 655 656 //===----------------------------------------------------------------------===// 657 // ConversionPatternRewriter::getRemappedValue testing. This method is used 658 // to get the remapped value of an original value that was replaced using 659 // ConversionPatternRewriter. 660 namespace { 661 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 662 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 663 /// operand twice. 664 /// 665 /// Example: 666 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 667 /// is replaced with: 668 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 669 struct OneVResOneVOperandOp1Converter 670 : public OpConversionPattern<OneVResOneVOperandOp1> { 671 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 672 673 LogicalResult 674 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 675 ConversionPatternRewriter &rewriter) const override { 676 auto origOps = op.getOperands(); 677 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 678 "One operand expected"); 679 Value origOp = *origOps.begin(); 680 SmallVector<Value, 2> remappedOperands; 681 // Replicate the remapped original operand twice. Note that we don't used 682 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 683 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 684 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 685 686 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 687 remappedOperands); 688 return success(); 689 } 690 }; 691 692 struct TestRemappedValue 693 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 694 void runOnFunction() override { 695 mlir::OwningRewritePatternList patterns; 696 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 697 698 mlir::ConversionTarget target(getContext()); 699 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 700 // We make OneVResOneVOperandOp1 legal only when it has more that one 701 // operand. This will trigger the conversion that will replace one-operand 702 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 703 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 704 [](Operation *op) -> bool { 705 return std::distance(op->operand_begin(), op->operand_end()) > 1; 706 }); 707 708 if (failed(mlir::applyFullConversion(getFunction(), target, 709 std::move(patterns)))) { 710 signalPassFailure(); 711 } 712 } 713 }; 714 } // end anonymous namespace 715 716 //===----------------------------------------------------------------------===// 717 // Test patterns without a specific root operation kind 718 //===----------------------------------------------------------------------===// 719 720 namespace { 721 /// This pattern matches and removes any operation in the test dialect. 722 struct RemoveTestDialectOps : public RewritePattern { 723 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 724 725 LogicalResult matchAndRewrite(Operation *op, 726 PatternRewriter &rewriter) const override { 727 if (!isa<TestDialect>(op->getDialect())) 728 return failure(); 729 rewriter.eraseOp(op); 730 return success(); 731 } 732 }; 733 734 struct TestUnknownRootOpDriver 735 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 736 void runOnFunction() override { 737 mlir::OwningRewritePatternList patterns; 738 patterns.insert<RemoveTestDialectOps>(); 739 740 mlir::ConversionTarget target(getContext()); 741 target.addIllegalDialect<TestDialect>(); 742 if (failed( 743 applyPartialConversion(getFunction(), target, std::move(patterns)))) 744 signalPassFailure(); 745 } 746 }; 747 } // end anonymous namespace 748 749 //===----------------------------------------------------------------------===// 750 // Test type conversions 751 //===----------------------------------------------------------------------===// 752 753 namespace { 754 struct TestTypeConversionProducer 755 : public OpConversionPattern<TestTypeProducerOp> { 756 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 757 LogicalResult 758 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 759 ConversionPatternRewriter &rewriter) const final { 760 Type resultType = op.getType(); 761 if (resultType.isa<FloatType>()) 762 resultType = rewriter.getF64Type(); 763 else if (resultType.isInteger(16)) 764 resultType = rewriter.getIntegerType(64); 765 else 766 return failure(); 767 768 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 769 return success(); 770 } 771 }; 772 773 struct TestTypeConversionDriver 774 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 775 void getDependentDialects(DialectRegistry ®istry) const override { 776 registry.insert<TestDialect>(); 777 } 778 779 void runOnOperation() override { 780 // Initialize the type converter. 781 TypeConverter converter; 782 783 /// Add the legal set of type conversions. 784 converter.addConversion([](Type type) -> Type { 785 // Treat F64 as legal. 786 if (type.isF64()) 787 return type; 788 // Allow converting BF16/F16/F32 to F64. 789 if (type.isBF16() || type.isF16() || type.isF32()) 790 return FloatType::getF64(type.getContext()); 791 // Otherwise, the type is illegal. 792 return nullptr; 793 }); 794 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 795 // Drop all integer types. 796 return success(); 797 }); 798 799 /// Add the legal set of type materializations. 800 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 801 ValueRange inputs, 802 Location loc) -> Value { 803 // Allow casting from F64 back to F32. 804 if (!resultType.isF16() && inputs.size() == 1 && 805 inputs[0].getType().isF64()) 806 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 807 // Allow producing an i32 or i64 from nothing. 808 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 809 inputs.empty()) 810 return builder.create<TestTypeProducerOp>(loc, resultType); 811 // Allow producing an i64 from an integer. 812 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 813 inputs[0].getType().isa<IntegerType>()) 814 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 815 // Otherwise, fail. 816 return nullptr; 817 }); 818 819 // Initialize the conversion target. 820 mlir::ConversionTarget target(getContext()); 821 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 822 return op.getType().isF64() || op.getType().isInteger(64); 823 }); 824 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 825 return converter.isSignatureLegal(op.getType()) && 826 converter.isLegal(&op.getBody()); 827 }); 828 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 829 // Allow casts from F64 to F32. 830 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 831 }); 832 833 // Initialize the set of rewrite patterns. 834 OwningRewritePatternList patterns; 835 patterns.insert<TestTypeConversionProducer>(converter, &getContext()); 836 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 837 converter); 838 839 if (failed(applyPartialConversion(getOperation(), target, 840 std::move(patterns)))) 841 signalPassFailure(); 842 } 843 }; 844 } // end anonymous namespace 845 846 namespace { 847 /// A rewriter pattern that tests that blocks can be merged. 848 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 849 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 850 851 LogicalResult 852 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 853 ConversionPatternRewriter &rewriter) const final { 854 Block &firstBlock = op.body().front(); 855 Operation *branchOp = firstBlock.getTerminator(); 856 Block *secondBlock = &*(std::next(op.body().begin())); 857 auto succOperands = branchOp->getOperands(); 858 SmallVector<Value, 2> replacements(succOperands); 859 rewriter.eraseOp(branchOp); 860 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 861 rewriter.updateRootInPlace(op, [] {}); 862 return success(); 863 } 864 }; 865 866 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 867 struct TestUndoBlocksMerge : public ConversionPattern { 868 TestUndoBlocksMerge(MLIRContext *ctx) 869 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 870 LogicalResult 871 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 872 ConversionPatternRewriter &rewriter) const final { 873 Block &firstBlock = op->getRegion(0).front(); 874 Operation *branchOp = firstBlock.getTerminator(); 875 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 876 rewriter.setInsertionPointToStart(secondBlock); 877 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 878 auto succOperands = branchOp->getOperands(); 879 SmallVector<Value, 2> replacements(succOperands); 880 rewriter.eraseOp(branchOp); 881 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 882 rewriter.updateRootInPlace(op, [] {}); 883 return success(); 884 } 885 }; 886 887 /// A rewrite mechanism to inline the body of the op into its parent, when both 888 /// ops can have a single block. 889 struct TestMergeSingleBlockOps 890 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 891 using OpConversionPattern< 892 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 893 894 LogicalResult 895 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 896 ConversionPatternRewriter &rewriter) const final { 897 SingleBlockImplicitTerminatorOp parentOp = 898 op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 899 if (!parentOp) 900 return failure(); 901 Block &innerBlock = op.region().front(); 902 TerminatorOp innerTerminator = 903 cast<TerminatorOp>(innerBlock.getTerminator()); 904 rewriter.mergeBlockBefore(&innerBlock, op); 905 rewriter.eraseOp(innerTerminator); 906 rewriter.eraseOp(op); 907 rewriter.updateRootInPlace(op, [] {}); 908 return success(); 909 } 910 }; 911 912 struct TestMergeBlocksPatternDriver 913 : public PassWrapper<TestMergeBlocksPatternDriver, 914 OperationPass<ModuleOp>> { 915 void runOnOperation() override { 916 mlir::OwningRewritePatternList patterns; 917 MLIRContext *context = &getContext(); 918 patterns 919 .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 920 context); 921 ConversionTarget target(*context); 922 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp, 923 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp, 924 TestReturnOp>(); 925 target.addIllegalOp<ILLegalOpF>(); 926 927 /// Expect the op to have a single block after legalization. 928 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 929 [&](TestMergeBlocksOp op) -> bool { 930 return llvm::hasSingleElement(op.body()); 931 }); 932 933 /// Only allow `test.br` within test.merge_blocks op. 934 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 935 return op.getParentOfType<TestMergeBlocksOp>(); 936 }); 937 938 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 939 /// inlined. 940 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 941 [&](SingleBlockImplicitTerminatorOp op) -> bool { 942 return !op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 943 }); 944 945 DenseSet<Operation *> unlegalizedOps; 946 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 947 &unlegalizedOps); 948 for (auto *op : unlegalizedOps) 949 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 950 } 951 }; 952 } // namespace 953 954 //===----------------------------------------------------------------------===// 955 // PassRegistration 956 //===----------------------------------------------------------------------===// 957 958 namespace mlir { 959 namespace test { 960 void registerPatternsTestPass() { 961 PassRegistration<TestReturnTypeDriver>("test-return-type", 962 "Run return type functions"); 963 964 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 965 "Run test derived attributes"); 966 967 PassRegistration<TestPatternDriver>("test-patterns", 968 "Run test dialect patterns"); 969 970 PassRegistration<TestLegalizePatternDriver>( 971 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 972 return std::make_unique<TestLegalizePatternDriver>( 973 legalizerConversionMode); 974 }); 975 976 PassRegistration<TestRemappedValue>( 977 "test-remapped-value", 978 "Test public remapped value mechanism in ConversionPatternRewriter"); 979 980 PassRegistration<TestUnknownRootOpDriver>( 981 "test-legalize-unknown-root-patterns", 982 "Test public remapped value mechanism in ConversionPatternRewriter"); 983 984 PassRegistration<TestTypeConversionDriver>( 985 "test-legalize-type-conversion", 986 "Test various type conversion functionalities in DialectConversion"); 987 988 PassRegistration<TestMergeBlocksPatternDriver>{ 989 "test-merge-blocks", 990 "Test Merging operation in ConversionPatternRewriter"}; 991 } 992 } // namespace test 993 } // namespace mlir 994