1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 20 // Native function for testing NativeCodeCall 21 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 22 return choice.getValue() ? input1 : input2; 23 } 24 25 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 26 rewriter.create<OpI>(loc, input); 27 } 28 29 static void handleNoResultOp(PatternRewriter &rewriter, 30 OpSymbolBindingNoResult op) { 31 // Turn the no result op to a one-result op. 32 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 33 op.operand()); 34 } 35 36 // Test that natives calls are only called once during rewrites. 37 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 38 // This let us check the number of times OpM_Test was called by inspecting 39 // the returned value in the MLIR output. 40 static int64_t opMIncreasingValue = 314159265; 41 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 42 int64_t i = opMIncreasingValue++; 43 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 44 } 45 46 namespace { 47 #include "TestPatterns.inc" 48 } // end anonymous namespace 49 50 //===----------------------------------------------------------------------===// 51 // Canonicalizer Driver. 52 //===----------------------------------------------------------------------===// 53 54 namespace { 55 struct FoldingPattern : public RewritePattern { 56 public: 57 FoldingPattern(MLIRContext *context) 58 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 59 /*benefit=*/1, context) {} 60 61 LogicalResult matchAndRewrite(Operation *op, 62 PatternRewriter &rewriter) const override { 63 // Exercice OperationFolder API for a single-result operation that is folded 64 // upon construction. The operation being created through the folder has an 65 // in-place folder, and it should be still present in the output. 66 // Furthermore, the folder should not crash when attempting to recover the 67 // (unchanged) operation result. 68 OperationFolder folder(op->getContext()); 69 Value result = folder.create<TestOpInPlaceFold>( 70 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 71 rewriter.getI32IntegerAttr(0)); 72 assert(result); 73 rewriter.replaceOp(op, result); 74 return success(); 75 } 76 }; 77 78 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 79 void runOnFunction() override { 80 mlir::OwningRewritePatternList patterns; 81 populateWithGenerated(&getContext(), patterns); 82 83 // Verify named pattern is generated with expected name. 84 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext()); 85 86 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 87 } 88 }; 89 } // end anonymous namespace 90 91 //===----------------------------------------------------------------------===// 92 // ReturnType Driver. 93 //===----------------------------------------------------------------------===// 94 95 namespace { 96 // Generate ops for each instance where the type can be successfully inferred. 97 template <typename OpTy> 98 static void invokeCreateWithInferredReturnType(Operation *op) { 99 auto *context = op->getContext(); 100 auto fop = op->getParentOfType<FuncOp>(); 101 auto location = UnknownLoc::get(context); 102 OpBuilder b(op); 103 b.setInsertionPointAfter(op); 104 105 // Use permutations of 2 args as operands. 106 assert(fop.getNumArguments() >= 2); 107 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 108 for (int j = 0; j < e; ++j) { 109 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 110 SmallVector<Type, 2> inferredReturnTypes; 111 if (succeeded(OpTy::inferReturnTypes( 112 context, llvm::None, values, op->getAttrDictionary(), 113 op->getRegions(), inferredReturnTypes))) { 114 OperationState state(location, OpTy::getOperationName()); 115 // TODO: Expand to regions. 116 OpTy::build(b, state, values, op->getAttrs()); 117 (void)b.createOperation(state); 118 } 119 } 120 } 121 } 122 123 static void reifyReturnShape(Operation *op) { 124 OpBuilder b(op); 125 126 // Use permutations of 2 args as operands. 127 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 128 SmallVector<Value, 2> shapes; 129 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 130 return; 131 for (auto it : llvm::enumerate(shapes)) 132 op->emitRemark() << "value " << it.index() << ": " 133 << it.value().getDefiningOp(); 134 } 135 136 struct TestReturnTypeDriver 137 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 138 void runOnFunction() override { 139 if (getFunction().getName() == "testCreateFunctions") { 140 std::vector<Operation *> ops; 141 // Collect ops to avoid triggering on inserted ops. 142 for (auto &op : getFunction().getBody().front()) 143 ops.push_back(&op); 144 // Generate test patterns for each, but skip terminator. 145 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 146 // Test create method of each of the Op classes below. The resultant 147 // output would be in reverse order underneath `op` from which 148 // the attributes and regions are used. 149 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 150 invokeCreateWithInferredReturnType< 151 OpWithShapedTypeInferTypeInterfaceOp>(op); 152 }; 153 return; 154 } 155 if (getFunction().getName() == "testReifyFunctions") { 156 std::vector<Operation *> ops; 157 // Collect ops to avoid triggering on inserted ops. 158 for (auto &op : getFunction().getBody().front()) 159 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 160 ops.push_back(&op); 161 // Generate test patterns for each, but skip terminator. 162 for (auto *op : ops) 163 reifyReturnShape(op); 164 } 165 } 166 }; 167 } // end anonymous namespace 168 169 namespace { 170 struct TestDerivedAttributeDriver 171 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 172 void runOnFunction() override; 173 }; 174 } // end anonymous namespace 175 176 void TestDerivedAttributeDriver::runOnFunction() { 177 getFunction().walk([](DerivedAttributeOpInterface dOp) { 178 auto dAttr = dOp.materializeDerivedAttributes(); 179 if (!dAttr) 180 return; 181 for (auto d : dAttr) 182 dOp.emitRemark() << d.first << " = " << d.second; 183 }); 184 } 185 186 //===----------------------------------------------------------------------===// 187 // Legalization Driver. 188 //===----------------------------------------------------------------------===// 189 190 namespace { 191 //===----------------------------------------------------------------------===// 192 // Region-Block Rewrite Testing 193 194 /// This pattern is a simple pattern that inlines the first region of a given 195 /// operation into the parent region. 196 struct TestRegionRewriteBlockMovement : public ConversionPattern { 197 TestRegionRewriteBlockMovement(MLIRContext *ctx) 198 : ConversionPattern("test.region", 1, ctx) {} 199 200 LogicalResult 201 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 202 ConversionPatternRewriter &rewriter) const final { 203 // Inline this region into the parent region. 204 auto &parentRegion = *op->getParentRegion(); 205 if (op->getAttr("legalizer.should_clone")) 206 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 207 parentRegion.end()); 208 else 209 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 210 parentRegion.end()); 211 212 // Drop this operation. 213 rewriter.eraseOp(op); 214 return success(); 215 } 216 }; 217 /// This pattern is a simple pattern that generates a region containing an 218 /// illegal operation. 219 struct TestRegionRewriteUndo : public RewritePattern { 220 TestRegionRewriteUndo(MLIRContext *ctx) 221 : RewritePattern("test.region_builder", 1, ctx) {} 222 223 LogicalResult matchAndRewrite(Operation *op, 224 PatternRewriter &rewriter) const final { 225 // Create the region operation with an entry block containing arguments. 226 OperationState newRegion(op->getLoc(), "test.region"); 227 newRegion.addRegion(); 228 auto *regionOp = rewriter.createOperation(newRegion); 229 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 230 entryBlock->addArgument(rewriter.getIntegerType(64)); 231 232 // Add an explicitly illegal operation to ensure the conversion fails. 233 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 234 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 235 236 // Drop this operation. 237 rewriter.eraseOp(op); 238 return success(); 239 } 240 }; 241 /// A simple pattern that creates a block at the end of the parent region of the 242 /// matched operation. 243 struct TestCreateBlock : public RewritePattern { 244 TestCreateBlock(MLIRContext *ctx) 245 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 246 247 LogicalResult matchAndRewrite(Operation *op, 248 PatternRewriter &rewriter) const final { 249 Region ®ion = *op->getParentRegion(); 250 Type i32Type = rewriter.getIntegerType(32); 251 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 252 rewriter.create<TerminatorOp>(op->getLoc()); 253 rewriter.replaceOp(op, {}); 254 return success(); 255 } 256 }; 257 258 /// A simple pattern that creates a block containing an invalid operation in 259 /// order to trigger the block creation undo mechanism. 260 struct TestCreateIllegalBlock : public RewritePattern { 261 TestCreateIllegalBlock(MLIRContext *ctx) 262 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 263 264 LogicalResult matchAndRewrite(Operation *op, 265 PatternRewriter &rewriter) const final { 266 Region ®ion = *op->getParentRegion(); 267 Type i32Type = rewriter.getIntegerType(32); 268 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 269 // Create an illegal op to ensure the conversion fails. 270 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 271 rewriter.create<TerminatorOp>(op->getLoc()); 272 rewriter.replaceOp(op, {}); 273 return success(); 274 } 275 }; 276 277 /// A simple pattern that tests the undo mechanism when replacing the uses of a 278 /// block argument. 279 struct TestUndoBlockArgReplace : public ConversionPattern { 280 TestUndoBlockArgReplace(MLIRContext *ctx) 281 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 282 283 LogicalResult 284 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 285 ConversionPatternRewriter &rewriter) const final { 286 auto illegalOp = 287 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 288 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 289 illegalOp); 290 rewriter.updateRootInPlace(op, [] {}); 291 return success(); 292 } 293 }; 294 295 /// A rewrite pattern that tests the undo mechanism when erasing a block. 296 struct TestUndoBlockErase : public ConversionPattern { 297 TestUndoBlockErase(MLIRContext *ctx) 298 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 299 300 LogicalResult 301 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 302 ConversionPatternRewriter &rewriter) const final { 303 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 304 rewriter.setInsertionPointToStart(secondBlock); 305 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 306 rewriter.eraseBlock(secondBlock); 307 rewriter.updateRootInPlace(op, [] {}); 308 return success(); 309 } 310 }; 311 312 //===----------------------------------------------------------------------===// 313 // Type-Conversion Rewrite Testing 314 315 /// This patterns erases a region operation that has had a type conversion. 316 struct TestDropOpSignatureConversion : public ConversionPattern { 317 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 318 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} 319 LogicalResult 320 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 321 ConversionPatternRewriter &rewriter) const override { 322 Region ®ion = op->getRegion(0); 323 Block *entry = ®ion.front(); 324 325 // Convert the original entry arguments. 326 TypeConverter &converter = *getTypeConverter(); 327 TypeConverter::SignatureConversion result(entry->getNumArguments()); 328 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 329 result)) || 330 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 331 return failure(); 332 333 // Convert the region signature and just drop the operation. 334 rewriter.eraseOp(op); 335 return success(); 336 } 337 }; 338 /// This pattern simply updates the operands of the given operation. 339 struct TestPassthroughInvalidOp : public ConversionPattern { 340 TestPassthroughInvalidOp(MLIRContext *ctx) 341 : ConversionPattern("test.invalid", 1, ctx) {} 342 LogicalResult 343 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 344 ConversionPatternRewriter &rewriter) const final { 345 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 346 llvm::None); 347 return success(); 348 } 349 }; 350 /// This pattern handles the case of a split return value. 351 struct TestSplitReturnType : public ConversionPattern { 352 TestSplitReturnType(MLIRContext *ctx) 353 : ConversionPattern("test.return", 1, ctx) {} 354 LogicalResult 355 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 356 ConversionPatternRewriter &rewriter) const final { 357 // Check for a return of F32. 358 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 359 return failure(); 360 361 // Check if the first operation is a cast operation, if it is we use the 362 // results directly. 363 auto *defOp = operands[0].getDefiningOp(); 364 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 365 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 366 return success(); 367 } 368 369 // Otherwise, fail to match. 370 return failure(); 371 } 372 }; 373 374 //===----------------------------------------------------------------------===// 375 // Multi-Level Type-Conversion Rewrite Testing 376 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 377 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 378 : ConversionPattern("test.type_producer", 1, ctx) {} 379 LogicalResult 380 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 381 ConversionPatternRewriter &rewriter) const final { 382 // If the type is I32, change the type to F32. 383 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 384 return failure(); 385 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 386 return success(); 387 } 388 }; 389 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 390 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 391 : ConversionPattern("test.type_producer", 1, ctx) {} 392 LogicalResult 393 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 394 ConversionPatternRewriter &rewriter) const final { 395 // If the type is F32, change the type to F64. 396 if (!Type(*op->result_type_begin()).isF32()) 397 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 398 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 399 return success(); 400 } 401 }; 402 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 403 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 404 : ConversionPattern("test.type_producer", 10, ctx) {} 405 LogicalResult 406 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 407 ConversionPatternRewriter &rewriter) const final { 408 // Always convert to B16, even though it is not a legal type. This tests 409 // that values are unmapped correctly. 410 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 411 return success(); 412 } 413 }; 414 struct TestUpdateConsumerType : public ConversionPattern { 415 TestUpdateConsumerType(MLIRContext *ctx) 416 : ConversionPattern("test.type_consumer", 1, ctx) {} 417 LogicalResult 418 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 419 ConversionPatternRewriter &rewriter) const final { 420 // Verify that the incoming operand has been successfully remapped to F64. 421 if (!operands[0].getType().isF64()) 422 return failure(); 423 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 424 return success(); 425 } 426 }; 427 428 //===----------------------------------------------------------------------===// 429 // Non-Root Replacement Rewrite Testing 430 /// This pattern generates an invalid operation, but replaces it before the 431 /// pattern is finished. This checks that we don't need to legalize the 432 /// temporary op. 433 struct TestNonRootReplacement : public RewritePattern { 434 TestNonRootReplacement(MLIRContext *ctx) 435 : RewritePattern("test.replace_non_root", 1, ctx) {} 436 437 LogicalResult matchAndRewrite(Operation *op, 438 PatternRewriter &rewriter) const final { 439 auto resultType = *op->result_type_begin(); 440 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 441 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 442 443 rewriter.replaceOp(illegalOp, {legalOp}); 444 rewriter.replaceOp(op, {illegalOp}); 445 return success(); 446 } 447 }; 448 449 //===----------------------------------------------------------------------===// 450 // Recursive Rewrite Testing 451 /// This pattern is applied to the same operation multiple times, but has a 452 /// bounded recursion. 453 struct TestBoundedRecursiveRewrite 454 : public OpRewritePattern<TestRecursiveRewriteOp> { 455 TestBoundedRecursiveRewrite(MLIRContext *ctx) 456 : OpRewritePattern<TestRecursiveRewriteOp>(ctx) { 457 // The conversion target handles bounding the recursion of this pattern. 458 setHasBoundedRewriteRecursion(); 459 } 460 461 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 462 PatternRewriter &rewriter) const final { 463 // Decrement the depth of the op in-place. 464 rewriter.updateRootInPlace(op, [&] { 465 op.setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 466 }); 467 return success(); 468 } 469 }; 470 471 struct TestNestedOpCreationUndoRewrite 472 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 473 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 474 475 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 476 PatternRewriter &rewriter) const final { 477 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 478 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 479 return success(); 480 }; 481 }; 482 } // namespace 483 484 namespace { 485 struct TestTypeConverter : public TypeConverter { 486 using TypeConverter::TypeConverter; 487 TestTypeConverter() { 488 addConversion(convertType); 489 addArgumentMaterialization(materializeCast); 490 addArgumentMaterialization(materializeOneToOneCast); 491 addSourceMaterialization(materializeCast); 492 } 493 494 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 495 // Drop I16 types. 496 if (t.isSignlessInteger(16)) 497 return success(); 498 499 // Convert I64 to F64. 500 if (t.isSignlessInteger(64)) { 501 results.push_back(FloatType::getF64(t.getContext())); 502 return success(); 503 } 504 505 // Convert I42 to I43. 506 if (t.isInteger(42)) { 507 results.push_back(IntegerType::get(43, t.getContext())); 508 return success(); 509 } 510 511 // Split F32 into F16,F16. 512 if (t.isF32()) { 513 results.assign(2, FloatType::getF16(t.getContext())); 514 return success(); 515 } 516 517 // Otherwise, convert the type directly. 518 results.push_back(t); 519 return success(); 520 } 521 522 /// Hook for materializing a conversion. This is necessary because we generate 523 /// 1->N type mappings. 524 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 525 ValueRange inputs, Location loc) { 526 if (inputs.size() == 1) 527 return inputs[0]; 528 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 529 } 530 531 /// Materialize the cast for one-to-one conversion from i64 to f64. 532 static Optional<Value> materializeOneToOneCast(OpBuilder &builder, 533 IntegerType resultType, 534 ValueRange inputs, 535 Location loc) { 536 if (resultType.getWidth() == 42 && inputs.size() == 1) 537 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 538 return llvm::None; 539 } 540 }; 541 542 struct TestLegalizePatternDriver 543 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 544 /// The mode of conversion to use with the driver. 545 enum class ConversionMode { Analysis, Full, Partial }; 546 547 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 548 549 void runOnOperation() override { 550 TestTypeConverter converter; 551 mlir::OwningRewritePatternList patterns; 552 populateWithGenerated(&getContext(), patterns); 553 patterns.insert< 554 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 555 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, 556 TestPassthroughInvalidOp, TestSplitReturnType, 557 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 558 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 559 TestNonRootReplacement, TestBoundedRecursiveRewrite, 560 TestNestedOpCreationUndoRewrite>(&getContext()); 561 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 562 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 563 converter); 564 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 565 converter); 566 567 // Define the conversion target used for the test. 568 ConversionTarget target(getContext()); 569 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 570 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 571 TerminatorOp>(); 572 target 573 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 574 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 575 // Don't allow F32 operands. 576 return llvm::none_of(op.getOperandTypes(), 577 [](Type type) { return type.isF32(); }); 578 }); 579 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 580 return converter.isSignatureLegal(op.getType()) && 581 converter.isLegal(&op.getBody()); 582 }); 583 584 // Expect the type_producer/type_consumer operations to only operate on f64. 585 target.addDynamicallyLegalOp<TestTypeProducerOp>( 586 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 587 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 588 return op.getOperand().getType().isF64(); 589 }); 590 591 // Check support for marking certain operations as recursively legal. 592 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 593 return static_cast<bool>( 594 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 595 }); 596 597 // Mark the bound recursion operation as dynamically legal. 598 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 599 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 600 601 // Handle a partial conversion. 602 if (mode == ConversionMode::Partial) { 603 DenseSet<Operation *> unlegalizedOps; 604 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 605 &unlegalizedOps); 606 // Emit remarks for each legalizable operation. 607 for (auto *op : unlegalizedOps) 608 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 609 return; 610 } 611 612 // Handle a full conversion. 613 if (mode == ConversionMode::Full) { 614 // Check support for marking unknown operations as dynamically legal. 615 target.markUnknownOpDynamicallyLegal([](Operation *op) { 616 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 617 }); 618 619 (void)applyFullConversion(getOperation(), target, std::move(patterns)); 620 return; 621 } 622 623 // Otherwise, handle an analysis conversion. 624 assert(mode == ConversionMode::Analysis); 625 626 // Analyze the convertible operations. 627 DenseSet<Operation *> legalizedOps; 628 if (failed(applyAnalysisConversion(getOperation(), target, 629 std::move(patterns), legalizedOps))) 630 return signalPassFailure(); 631 632 // Emit remarks for each legalizable operation. 633 for (auto *op : legalizedOps) 634 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 635 } 636 637 /// The mode of conversion to use. 638 ConversionMode mode; 639 }; 640 } // end anonymous namespace 641 642 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 643 legalizerConversionMode( 644 "test-legalize-mode", 645 llvm::cl::desc("The legalization mode to use with the test driver"), 646 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 647 llvm::cl::values( 648 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 649 "analysis", "Perform an analysis conversion"), 650 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 651 "Perform a full conversion"), 652 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 653 "partial", "Perform a partial conversion"))); 654 655 //===----------------------------------------------------------------------===// 656 // ConversionPatternRewriter::getRemappedValue testing. This method is used 657 // to get the remapped value of an original value that was replaced using 658 // ConversionPatternRewriter. 659 namespace { 660 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 661 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 662 /// operand twice. 663 /// 664 /// Example: 665 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 666 /// is replaced with: 667 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 668 struct OneVResOneVOperandOp1Converter 669 : public OpConversionPattern<OneVResOneVOperandOp1> { 670 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 671 672 LogicalResult 673 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 674 ConversionPatternRewriter &rewriter) const override { 675 auto origOps = op.getOperands(); 676 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 677 "One operand expected"); 678 Value origOp = *origOps.begin(); 679 SmallVector<Value, 2> remappedOperands; 680 // Replicate the remapped original operand twice. Note that we don't used 681 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 682 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 683 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 684 685 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 686 remappedOperands); 687 return success(); 688 } 689 }; 690 691 struct TestRemappedValue 692 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 693 void runOnFunction() override { 694 mlir::OwningRewritePatternList patterns; 695 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 696 697 mlir::ConversionTarget target(getContext()); 698 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 699 // We make OneVResOneVOperandOp1 legal only when it has more that one 700 // operand. This will trigger the conversion that will replace one-operand 701 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 702 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 703 [](Operation *op) -> bool { 704 return std::distance(op->operand_begin(), op->operand_end()) > 1; 705 }); 706 707 if (failed(mlir::applyFullConversion(getFunction(), target, 708 std::move(patterns)))) { 709 signalPassFailure(); 710 } 711 } 712 }; 713 } // end anonymous namespace 714 715 //===----------------------------------------------------------------------===// 716 // Test patterns without a specific root operation kind 717 //===----------------------------------------------------------------------===// 718 719 namespace { 720 /// This pattern matches and removes any operation in the test dialect. 721 struct RemoveTestDialectOps : public RewritePattern { 722 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 723 724 LogicalResult matchAndRewrite(Operation *op, 725 PatternRewriter &rewriter) const override { 726 if (!isa<TestDialect>(op->getDialect())) 727 return failure(); 728 rewriter.eraseOp(op); 729 return success(); 730 } 731 }; 732 733 struct TestUnknownRootOpDriver 734 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 735 void runOnFunction() override { 736 mlir::OwningRewritePatternList patterns; 737 patterns.insert<RemoveTestDialectOps>(); 738 739 mlir::ConversionTarget target(getContext()); 740 target.addIllegalDialect<TestDialect>(); 741 if (failed( 742 applyPartialConversion(getFunction(), target, std::move(patterns)))) 743 signalPassFailure(); 744 } 745 }; 746 } // end anonymous namespace 747 748 //===----------------------------------------------------------------------===// 749 // Test type conversions 750 //===----------------------------------------------------------------------===// 751 752 namespace { 753 struct TestTypeConversionProducer 754 : public OpConversionPattern<TestTypeProducerOp> { 755 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 756 LogicalResult 757 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 758 ConversionPatternRewriter &rewriter) const final { 759 Type resultType = op.getType(); 760 if (resultType.isa<FloatType>()) 761 resultType = rewriter.getF64Type(); 762 else if (resultType.isInteger(16)) 763 resultType = rewriter.getIntegerType(64); 764 else 765 return failure(); 766 767 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 768 return success(); 769 } 770 }; 771 772 struct TestTypeConversionDriver 773 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 774 void getDependentDialects(DialectRegistry ®istry) const override { 775 registry.insert<TestDialect>(); 776 } 777 778 void runOnOperation() override { 779 // Initialize the type converter. 780 TypeConverter converter; 781 782 /// Add the legal set of type conversions. 783 converter.addConversion([](Type type) -> Type { 784 // Treat F64 as legal. 785 if (type.isF64()) 786 return type; 787 // Allow converting BF16/F16/F32 to F64. 788 if (type.isBF16() || type.isF16() || type.isF32()) 789 return FloatType::getF64(type.getContext()); 790 // Otherwise, the type is illegal. 791 return nullptr; 792 }); 793 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 794 // Drop all integer types. 795 return success(); 796 }); 797 798 /// Add the legal set of type materializations. 799 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 800 ValueRange inputs, 801 Location loc) -> Value { 802 // Allow casting from F64 back to F32. 803 if (!resultType.isF16() && inputs.size() == 1 && 804 inputs[0].getType().isF64()) 805 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 806 // Allow producing an i32 or i64 from nothing. 807 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 808 inputs.empty()) 809 return builder.create<TestTypeProducerOp>(loc, resultType); 810 // Allow producing an i64 from an integer. 811 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 812 inputs[0].getType().isa<IntegerType>()) 813 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 814 // Otherwise, fail. 815 return nullptr; 816 }); 817 818 // Initialize the conversion target. 819 mlir::ConversionTarget target(getContext()); 820 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 821 return op.getType().isF64() || op.getType().isInteger(64); 822 }); 823 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 824 return converter.isSignatureLegal(op.getType()) && 825 converter.isLegal(&op.getBody()); 826 }); 827 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 828 // Allow casts from F64 to F32. 829 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 830 }); 831 832 // Initialize the set of rewrite patterns. 833 OwningRewritePatternList patterns; 834 patterns.insert<TestTypeConversionProducer>(converter, &getContext()); 835 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 836 converter); 837 838 if (failed(applyPartialConversion(getOperation(), target, 839 std::move(patterns)))) 840 signalPassFailure(); 841 } 842 }; 843 } // end anonymous namespace 844 845 namespace { 846 /// A rewriter pattern that tests that blocks can be merged. 847 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 848 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 849 850 LogicalResult 851 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 852 ConversionPatternRewriter &rewriter) const final { 853 Block &firstBlock = op.body().front(); 854 Operation *branchOp = firstBlock.getTerminator(); 855 Block *secondBlock = &*(std::next(op.body().begin())); 856 auto succOperands = branchOp->getOperands(); 857 SmallVector<Value, 2> replacements(succOperands); 858 rewriter.eraseOp(branchOp); 859 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 860 rewriter.updateRootInPlace(op, [] {}); 861 return success(); 862 } 863 }; 864 865 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 866 struct TestUndoBlocksMerge : public ConversionPattern { 867 TestUndoBlocksMerge(MLIRContext *ctx) 868 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 869 LogicalResult 870 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 871 ConversionPatternRewriter &rewriter) const final { 872 Block &firstBlock = op->getRegion(0).front(); 873 Operation *branchOp = firstBlock.getTerminator(); 874 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 875 rewriter.setInsertionPointToStart(secondBlock); 876 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 877 auto succOperands = branchOp->getOperands(); 878 SmallVector<Value, 2> replacements(succOperands); 879 rewriter.eraseOp(branchOp); 880 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 881 rewriter.updateRootInPlace(op, [] {}); 882 return success(); 883 } 884 }; 885 886 /// A rewrite mechanism to inline the body of the op into its parent, when both 887 /// ops can have a single block. 888 struct TestMergeSingleBlockOps 889 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 890 using OpConversionPattern< 891 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 892 893 LogicalResult 894 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 895 ConversionPatternRewriter &rewriter) const final { 896 SingleBlockImplicitTerminatorOp parentOp = 897 op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 898 if (!parentOp) 899 return failure(); 900 Block &innerBlock = op.region().front(); 901 TerminatorOp innerTerminator = 902 cast<TerminatorOp>(innerBlock.getTerminator()); 903 rewriter.mergeBlockBefore(&innerBlock, op); 904 rewriter.eraseOp(innerTerminator); 905 rewriter.eraseOp(op); 906 rewriter.updateRootInPlace(op, [] {}); 907 return success(); 908 } 909 }; 910 911 struct TestMergeBlocksPatternDriver 912 : public PassWrapper<TestMergeBlocksPatternDriver, 913 OperationPass<ModuleOp>> { 914 void runOnOperation() override { 915 mlir::OwningRewritePatternList patterns; 916 MLIRContext *context = &getContext(); 917 patterns 918 .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 919 context); 920 ConversionTarget target(*context); 921 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp, 922 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp, 923 TestReturnOp>(); 924 target.addIllegalOp<ILLegalOpF>(); 925 926 /// Expect the op to have a single block after legalization. 927 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 928 [&](TestMergeBlocksOp op) -> bool { 929 return llvm::hasSingleElement(op.body()); 930 }); 931 932 /// Only allow `test.br` within test.merge_blocks op. 933 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 934 return op.getParentOfType<TestMergeBlocksOp>(); 935 }); 936 937 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 938 /// inlined. 939 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 940 [&](SingleBlockImplicitTerminatorOp op) -> bool { 941 return !op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 942 }); 943 944 DenseSet<Operation *> unlegalizedOps; 945 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 946 &unlegalizedOps); 947 for (auto *op : unlegalizedOps) 948 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 949 } 950 }; 951 } // namespace 952 953 //===----------------------------------------------------------------------===// 954 // PassRegistration 955 //===----------------------------------------------------------------------===// 956 957 namespace mlir { 958 void registerPatternsTestPass() { 959 PassRegistration<TestReturnTypeDriver>("test-return-type", 960 "Run return type functions"); 961 962 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 963 "Run test derived attributes"); 964 965 PassRegistration<TestPatternDriver>("test-patterns", 966 "Run test dialect patterns"); 967 968 PassRegistration<TestLegalizePatternDriver>( 969 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 970 return std::make_unique<TestLegalizePatternDriver>( 971 legalizerConversionMode); 972 }); 973 974 PassRegistration<TestRemappedValue>( 975 "test-remapped-value", 976 "Test public remapped value mechanism in ConversionPatternRewriter"); 977 978 PassRegistration<TestUnknownRootOpDriver>( 979 "test-legalize-unknown-root-patterns", 980 "Test public remapped value mechanism in ConversionPatternRewriter"); 981 982 PassRegistration<TestTypeConversionDriver>( 983 "test-legalize-type-conversion", 984 "Test various type conversion functionalities in DialectConversion"); 985 986 PassRegistration<TestMergeBlocksPatternDriver>{ 987 "test-merge-blocks", 988 "Test Merging operation in ConversionPatternRewriter"}; 989 } 990 } // namespace mlir 991