1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/IR/PatternMatch.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 17 using namespace mlir; 18 19 // Native function for testing NativeCodeCall 20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 21 return choice.getValue() ? input1 : input2; 22 } 23 24 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 25 rewriter.create<OpI>(loc, input); 26 } 27 28 static void handleNoResultOp(PatternRewriter &rewriter, 29 OpSymbolBindingNoResult op) { 30 // Turn the no result op to a one-result op. 31 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 32 op.operand()); 33 } 34 35 // Test that natives calls are only called once during rewrites. 36 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 37 // This let us check the number of times OpM_Test was called by inspecting 38 // the returned value in the MLIR output. 39 static int64_t opMIncreasingValue = 314159265; 40 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 41 int64_t i = opMIncreasingValue++; 42 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 43 } 44 45 namespace { 46 #include "TestPatterns.inc" 47 } // end anonymous namespace 48 49 //===----------------------------------------------------------------------===// 50 // Canonicalizer Driver. 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 struct FoldingPattern : public RewritePattern { 55 public: 56 FoldingPattern(MLIRContext *context) 57 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 58 /*benefit=*/1, context) {} 59 60 LogicalResult matchAndRewrite(Operation *op, 61 PatternRewriter &rewriter) const override { 62 // Exercice OperationFolder API for a single-result operation that is folded 63 // upon construction. The operation being created through the folder has an 64 // in-place folder, and it should be still present in the output. 65 // Furthermore, the folder should not crash when attempting to recover the 66 // (unchanged) opeation result. 67 OperationFolder folder(op->getContext()); 68 Value result = folder.create<TestOpInPlaceFold>( 69 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 70 rewriter.getI32IntegerAttr(0)); 71 assert(result); 72 rewriter.replaceOp(op, result); 73 return success(); 74 } 75 }; 76 77 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 78 void runOnFunction() override { 79 mlir::OwningRewritePatternList patterns; 80 populateWithGenerated(&getContext(), &patterns); 81 82 // Verify named pattern is generated with expected name. 83 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext()); 84 85 applyPatternsAndFoldGreedily(getFunction(), patterns); 86 } 87 }; 88 } // end anonymous namespace 89 90 //===----------------------------------------------------------------------===// 91 // ReturnType Driver. 92 //===----------------------------------------------------------------------===// 93 94 namespace { 95 // Generate ops for each instance where the type can be successfully inferred. 96 template <typename OpTy> 97 static void invokeCreateWithInferredReturnType(Operation *op) { 98 auto *context = op->getContext(); 99 auto fop = op->getParentOfType<FuncOp>(); 100 auto location = UnknownLoc::get(context); 101 OpBuilder b(op); 102 b.setInsertionPointAfter(op); 103 104 // Use permutations of 2 args as operands. 105 assert(fop.getNumArguments() >= 2); 106 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 107 for (int j = 0; j < e; ++j) { 108 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 109 SmallVector<Type, 2> inferredReturnTypes; 110 if (succeeded(OpTy::inferReturnTypes( 111 context, llvm::None, values, op->getAttrDictionary(), 112 op->getRegions(), inferredReturnTypes))) { 113 OperationState state(location, OpTy::getOperationName()); 114 // TODO: Expand to regions. 115 OpTy::build(b, state, values, op->getAttrs()); 116 (void)b.createOperation(state); 117 } 118 } 119 } 120 } 121 122 static void reifyReturnShape(Operation *op) { 123 OpBuilder b(op); 124 125 // Use permutations of 2 args as operands. 126 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 127 SmallVector<Value, 2> shapes; 128 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 129 return; 130 for (auto it : llvm::enumerate(shapes)) 131 op->emitRemark() << "value " << it.index() << ": " 132 << it.value().getDefiningOp(); 133 } 134 135 struct TestReturnTypeDriver 136 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 137 void runOnFunction() override { 138 if (getFunction().getName() == "testCreateFunctions") { 139 std::vector<Operation *> ops; 140 // Collect ops to avoid triggering on inserted ops. 141 for (auto &op : getFunction().getBody().front()) 142 ops.push_back(&op); 143 // Generate test patterns for each, but skip terminator. 144 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 145 // Test create method of each of the Op classes below. The resultant 146 // output would be in reverse order underneath `op` from which 147 // the attributes and regions are used. 148 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 149 invokeCreateWithInferredReturnType< 150 OpWithShapedTypeInferTypeInterfaceOp>(op); 151 }; 152 return; 153 } 154 if (getFunction().getName() == "testReifyFunctions") { 155 std::vector<Operation *> ops; 156 // Collect ops to avoid triggering on inserted ops. 157 for (auto &op : getFunction().getBody().front()) 158 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 159 ops.push_back(&op); 160 // Generate test patterns for each, but skip terminator. 161 for (auto *op : ops) 162 reifyReturnShape(op); 163 } 164 } 165 }; 166 } // end anonymous namespace 167 168 namespace { 169 struct TestDerivedAttributeDriver 170 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 171 void runOnFunction() override; 172 }; 173 } // end anonymous namespace 174 175 void TestDerivedAttributeDriver::runOnFunction() { 176 getFunction().walk([](DerivedAttributeOpInterface dOp) { 177 auto dAttr = dOp.materializeDerivedAttributes(); 178 if (!dAttr) 179 return; 180 for (auto d : dAttr) 181 dOp.emitRemark() << d.first << " = " << d.second; 182 }); 183 } 184 185 //===----------------------------------------------------------------------===// 186 // Legalization Driver. 187 //===----------------------------------------------------------------------===// 188 189 namespace { 190 //===----------------------------------------------------------------------===// 191 // Region-Block Rewrite Testing 192 193 /// This pattern is a simple pattern that inlines the first region of a given 194 /// operation into the parent region. 195 struct TestRegionRewriteBlockMovement : public ConversionPattern { 196 TestRegionRewriteBlockMovement(MLIRContext *ctx) 197 : ConversionPattern("test.region", 1, ctx) {} 198 199 LogicalResult 200 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 201 ConversionPatternRewriter &rewriter) const final { 202 // Inline this region into the parent region. 203 auto &parentRegion = *op->getParentRegion(); 204 if (op->getAttr("legalizer.should_clone")) 205 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 206 parentRegion.end()); 207 else 208 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 209 parentRegion.end()); 210 211 // Drop this operation. 212 rewriter.eraseOp(op); 213 return success(); 214 } 215 }; 216 /// This pattern is a simple pattern that generates a region containing an 217 /// illegal operation. 218 struct TestRegionRewriteUndo : public RewritePattern { 219 TestRegionRewriteUndo(MLIRContext *ctx) 220 : RewritePattern("test.region_builder", 1, ctx) {} 221 222 LogicalResult matchAndRewrite(Operation *op, 223 PatternRewriter &rewriter) const final { 224 // Create the region operation with an entry block containing arguments. 225 OperationState newRegion(op->getLoc(), "test.region"); 226 newRegion.addRegion(); 227 auto *regionOp = rewriter.createOperation(newRegion); 228 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 229 entryBlock->addArgument(rewriter.getIntegerType(64)); 230 231 // Add an explicitly illegal operation to ensure the conversion fails. 232 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 233 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 234 235 // Drop this operation. 236 rewriter.eraseOp(op); 237 return success(); 238 } 239 }; 240 /// A simple pattern that creates a block at the end of the parent region of the 241 /// matched operation. 242 struct TestCreateBlock : public RewritePattern { 243 TestCreateBlock(MLIRContext *ctx) 244 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 245 246 LogicalResult matchAndRewrite(Operation *op, 247 PatternRewriter &rewriter) const final { 248 Region ®ion = *op->getParentRegion(); 249 Type i32Type = rewriter.getIntegerType(32); 250 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 251 rewriter.create<TerminatorOp>(op->getLoc()); 252 rewriter.replaceOp(op, {}); 253 return success(); 254 } 255 }; 256 257 /// A simple pattern that creates a block containing an invalid operaiton in 258 /// order to trigger the block creation undo mechanism. 259 struct TestCreateIllegalBlock : public RewritePattern { 260 TestCreateIllegalBlock(MLIRContext *ctx) 261 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 262 263 LogicalResult matchAndRewrite(Operation *op, 264 PatternRewriter &rewriter) const final { 265 Region ®ion = *op->getParentRegion(); 266 Type i32Type = rewriter.getIntegerType(32); 267 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 268 // Create an illegal op to ensure the conversion fails. 269 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 270 rewriter.create<TerminatorOp>(op->getLoc()); 271 rewriter.replaceOp(op, {}); 272 return success(); 273 } 274 }; 275 276 /// A simple pattern that tests the undo mechanism when replacing the uses of a 277 /// block argument. 278 struct TestUndoBlockArgReplace : public ConversionPattern { 279 TestUndoBlockArgReplace(MLIRContext *ctx) 280 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 281 282 LogicalResult 283 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 284 ConversionPatternRewriter &rewriter) const final { 285 auto illegalOp = 286 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 287 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 288 illegalOp); 289 rewriter.updateRootInPlace(op, [] {}); 290 return success(); 291 } 292 }; 293 294 /// A rewrite pattern that tests the undo mechanism when erasing a block. 295 struct TestUndoBlockErase : public ConversionPattern { 296 TestUndoBlockErase(MLIRContext *ctx) 297 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 298 299 LogicalResult 300 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 301 ConversionPatternRewriter &rewriter) const final { 302 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 303 rewriter.setInsertionPointToStart(secondBlock); 304 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 305 rewriter.eraseBlock(secondBlock); 306 rewriter.updateRootInPlace(op, [] {}); 307 return success(); 308 } 309 }; 310 311 //===----------------------------------------------------------------------===// 312 // Type-Conversion Rewrite Testing 313 314 /// This patterns erases a region operation that has had a type conversion. 315 struct TestDropOpSignatureConversion : public ConversionPattern { 316 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 317 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} 318 LogicalResult 319 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 320 ConversionPatternRewriter &rewriter) const override { 321 Region ®ion = op->getRegion(0); 322 Block *entry = ®ion.front(); 323 324 // Convert the original entry arguments. 325 TypeConverter &converter = *getTypeConverter(); 326 TypeConverter::SignatureConversion result(entry->getNumArguments()); 327 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 328 result)) || 329 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 330 return failure(); 331 332 // Convert the region signature and just drop the operation. 333 rewriter.eraseOp(op); 334 return success(); 335 } 336 }; 337 /// This pattern simply updates the operands of the given operation. 338 struct TestPassthroughInvalidOp : public ConversionPattern { 339 TestPassthroughInvalidOp(MLIRContext *ctx) 340 : ConversionPattern("test.invalid", 1, ctx) {} 341 LogicalResult 342 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 343 ConversionPatternRewriter &rewriter) const final { 344 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 345 llvm::None); 346 return success(); 347 } 348 }; 349 /// This pattern handles the case of a split return value. 350 struct TestSplitReturnType : public ConversionPattern { 351 TestSplitReturnType(MLIRContext *ctx) 352 : ConversionPattern("test.return", 1, ctx) {} 353 LogicalResult 354 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 355 ConversionPatternRewriter &rewriter) const final { 356 // Check for a return of F32. 357 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 358 return failure(); 359 360 // Check if the first operation is a cast operation, if it is we use the 361 // results directly. 362 auto *defOp = operands[0].getDefiningOp(); 363 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 364 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 365 return success(); 366 } 367 368 // Otherwise, fail to match. 369 return failure(); 370 } 371 }; 372 373 //===----------------------------------------------------------------------===// 374 // Multi-Level Type-Conversion Rewrite Testing 375 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 376 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 377 : ConversionPattern("test.type_producer", 1, ctx) {} 378 LogicalResult 379 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 380 ConversionPatternRewriter &rewriter) const final { 381 // If the type is I32, change the type to F32. 382 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 383 return failure(); 384 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 385 return success(); 386 } 387 }; 388 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 389 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 390 : ConversionPattern("test.type_producer", 1, ctx) {} 391 LogicalResult 392 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 393 ConversionPatternRewriter &rewriter) const final { 394 // If the type is F32, change the type to F64. 395 if (!Type(*op->result_type_begin()).isF32()) 396 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 397 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 398 return success(); 399 } 400 }; 401 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 402 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 403 : ConversionPattern("test.type_producer", 10, ctx) {} 404 LogicalResult 405 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 406 ConversionPatternRewriter &rewriter) const final { 407 // Always convert to B16, even though it is not a legal type. This tests 408 // that values are unmapped correctly. 409 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 410 return success(); 411 } 412 }; 413 struct TestUpdateConsumerType : public ConversionPattern { 414 TestUpdateConsumerType(MLIRContext *ctx) 415 : ConversionPattern("test.type_consumer", 1, ctx) {} 416 LogicalResult 417 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 418 ConversionPatternRewriter &rewriter) const final { 419 // Verify that the incoming operand has been successfully remapped to F64. 420 if (!operands[0].getType().isF64()) 421 return failure(); 422 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 423 return success(); 424 } 425 }; 426 427 //===----------------------------------------------------------------------===// 428 // Non-Root Replacement Rewrite Testing 429 /// This pattern generates an invalid operation, but replaces it before the 430 /// pattern is finished. This checks that we don't need to legalize the 431 /// temporary op. 432 struct TestNonRootReplacement : public RewritePattern { 433 TestNonRootReplacement(MLIRContext *ctx) 434 : RewritePattern("test.replace_non_root", 1, ctx) {} 435 436 LogicalResult matchAndRewrite(Operation *op, 437 PatternRewriter &rewriter) const final { 438 auto resultType = *op->result_type_begin(); 439 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 440 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 441 442 rewriter.replaceOp(illegalOp, {legalOp}); 443 rewriter.replaceOp(op, {illegalOp}); 444 return success(); 445 } 446 }; 447 448 //===----------------------------------------------------------------------===// 449 // Recursive Rewrite Testing 450 /// This pattern is applied to the same operation multiple times, but has a 451 /// bounded recursion. 452 struct TestBoundedRecursiveRewrite 453 : public OpRewritePattern<TestRecursiveRewriteOp> { 454 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 455 456 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 457 PatternRewriter &rewriter) const final { 458 // Decrement the depth of the op in-place. 459 rewriter.updateRootInPlace(op, [&] { 460 op.setAttr("depth", 461 rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1)); 462 }); 463 return success(); 464 } 465 466 /// The conversion target handles bounding the recursion of this pattern. 467 bool hasBoundedRewriteRecursion() const final { return true; } 468 }; 469 470 struct TestNestedOpCreationUndoRewrite 471 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 472 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 473 474 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 475 PatternRewriter &rewriter) const final { 476 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 477 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 478 return success(); 479 }; 480 }; 481 } // namespace 482 483 namespace { 484 struct TestTypeConverter : public TypeConverter { 485 using TypeConverter::TypeConverter; 486 TestTypeConverter() { 487 addConversion(convertType); 488 addArgumentMaterialization(materializeCast); 489 addArgumentMaterialization(materializeOneToOneCast); 490 addSourceMaterialization(materializeCast); 491 } 492 493 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 494 // Drop I16 types. 495 if (t.isSignlessInteger(16)) 496 return success(); 497 498 // Convert I64 to F64. 499 if (t.isSignlessInteger(64)) { 500 results.push_back(FloatType::getF64(t.getContext())); 501 return success(); 502 } 503 504 // Convert I42 to I43. 505 if (t.isInteger(42)) { 506 results.push_back(IntegerType::get(43, t.getContext())); 507 return success(); 508 } 509 510 // Split F32 into F16,F16. 511 if (t.isF32()) { 512 results.assign(2, FloatType::getF16(t.getContext())); 513 return success(); 514 } 515 516 // Otherwise, convert the type directly. 517 results.push_back(t); 518 return success(); 519 } 520 521 /// Hook for materializing a conversion. This is necessary because we generate 522 /// 1->N type mappings. 523 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 524 ValueRange inputs, Location loc) { 525 if (inputs.size() == 1) 526 return inputs[0]; 527 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 528 } 529 530 /// Materialize the cast for one-to-one conversion from i64 to f64. 531 static Optional<Value> materializeOneToOneCast(OpBuilder &builder, 532 IntegerType resultType, 533 ValueRange inputs, 534 Location loc) { 535 if (resultType.getWidth() == 42 && inputs.size() == 1) 536 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 537 return llvm::None; 538 } 539 }; 540 541 struct TestLegalizePatternDriver 542 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 543 /// The mode of conversion to use with the driver. 544 enum class ConversionMode { Analysis, Full, Partial }; 545 546 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 547 548 void runOnOperation() override { 549 TestTypeConverter converter; 550 mlir::OwningRewritePatternList patterns; 551 populateWithGenerated(&getContext(), &patterns); 552 patterns.insert< 553 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 554 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, 555 TestPassthroughInvalidOp, TestSplitReturnType, 556 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 557 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 558 TestNonRootReplacement, TestBoundedRecursiveRewrite, 559 TestNestedOpCreationUndoRewrite>(&getContext()); 560 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 561 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 562 converter); 563 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 564 converter); 565 566 // Define the conversion target used for the test. 567 ConversionTarget target(getContext()); 568 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 569 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 570 TerminatorOp>(); 571 target 572 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 573 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 574 // Don't allow F32 operands. 575 return llvm::none_of(op.getOperandTypes(), 576 [](Type type) { return type.isF32(); }); 577 }); 578 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 579 return converter.isSignatureLegal(op.getType()) && 580 converter.isLegal(&op.getBody()); 581 }); 582 583 // Expect the type_producer/type_consumer operations to only operate on f64. 584 target.addDynamicallyLegalOp<TestTypeProducerOp>( 585 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 586 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 587 return op.getOperand().getType().isF64(); 588 }); 589 590 // Check support for marking certain operations as recursively legal. 591 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 592 return static_cast<bool>( 593 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 594 }); 595 596 // Mark the bound recursion operation as dynamically legal. 597 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 598 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 599 600 // Handle a partial conversion. 601 if (mode == ConversionMode::Partial) { 602 DenseSet<Operation *> unlegalizedOps; 603 (void)applyPartialConversion(getOperation(), target, patterns, 604 &unlegalizedOps); 605 // Emit remarks for each legalizable operation. 606 for (auto *op : unlegalizedOps) 607 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 608 return; 609 } 610 611 // Handle a full conversion. 612 if (mode == ConversionMode::Full) { 613 // Check support for marking unknown operations as dynamically legal. 614 target.markUnknownOpDynamicallyLegal([](Operation *op) { 615 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 616 }); 617 618 (void)applyFullConversion(getOperation(), target, patterns); 619 return; 620 } 621 622 // Otherwise, handle an analysis conversion. 623 assert(mode == ConversionMode::Analysis); 624 625 // Analyze the convertible operations. 626 DenseSet<Operation *> legalizedOps; 627 if (failed(applyAnalysisConversion(getOperation(), target, patterns, 628 legalizedOps))) 629 return signalPassFailure(); 630 631 // Emit remarks for each legalizable operation. 632 for (auto *op : legalizedOps) 633 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 634 } 635 636 /// The mode of conversion to use. 637 ConversionMode mode; 638 }; 639 } // end anonymous namespace 640 641 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 642 legalizerConversionMode( 643 "test-legalize-mode", 644 llvm::cl::desc("The legalization mode to use with the test driver"), 645 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 646 llvm::cl::values( 647 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 648 "analysis", "Perform an analysis conversion"), 649 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 650 "Perform a full conversion"), 651 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 652 "partial", "Perform a partial conversion"))); 653 654 //===----------------------------------------------------------------------===// 655 // ConversionPatternRewriter::getRemappedValue testing. This method is used 656 // to get the remapped value of an original value that was replaced using 657 // ConversionPatternRewriter. 658 namespace { 659 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 660 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 661 /// operand twice. 662 /// 663 /// Example: 664 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 665 /// is replaced with: 666 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 667 struct OneVResOneVOperandOp1Converter 668 : public OpConversionPattern<OneVResOneVOperandOp1> { 669 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 670 671 LogicalResult 672 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 673 ConversionPatternRewriter &rewriter) const override { 674 auto origOps = op.getOperands(); 675 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 676 "One operand expected"); 677 Value origOp = *origOps.begin(); 678 SmallVector<Value, 2> remappedOperands; 679 // Replicate the remapped original operand twice. Note that we don't used 680 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 681 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 682 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 683 684 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 685 remappedOperands); 686 return success(); 687 } 688 }; 689 690 struct TestRemappedValue 691 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 692 void runOnFunction() override { 693 mlir::OwningRewritePatternList patterns; 694 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 695 696 mlir::ConversionTarget target(getContext()); 697 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 698 // We make OneVResOneVOperandOp1 legal only when it has more that one 699 // operand. This will trigger the conversion that will replace one-operand 700 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 701 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 702 [](Operation *op) -> bool { 703 return std::distance(op->operand_begin(), op->operand_end()) > 1; 704 }); 705 706 if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { 707 signalPassFailure(); 708 } 709 } 710 }; 711 } // end anonymous namespace 712 713 //===----------------------------------------------------------------------===// 714 // Test patterns without a specific root operation kind 715 //===----------------------------------------------------------------------===// 716 717 namespace { 718 /// This pattern matches and removes any operation in the test dialect. 719 struct RemoveTestDialectOps : public RewritePattern { 720 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 721 722 LogicalResult matchAndRewrite(Operation *op, 723 PatternRewriter &rewriter) const override { 724 if (!isa<TestDialect>(op->getDialect())) 725 return failure(); 726 rewriter.eraseOp(op); 727 return success(); 728 } 729 }; 730 731 struct TestUnknownRootOpDriver 732 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 733 void runOnFunction() override { 734 mlir::OwningRewritePatternList patterns; 735 patterns.insert<RemoveTestDialectOps>(); 736 737 mlir::ConversionTarget target(getContext()); 738 target.addIllegalDialect<TestDialect>(); 739 if (failed(applyPartialConversion(getFunction(), target, patterns))) 740 signalPassFailure(); 741 } 742 }; 743 } // end anonymous namespace 744 745 //===----------------------------------------------------------------------===// 746 // Test type conversions 747 //===----------------------------------------------------------------------===// 748 749 namespace { 750 struct TestTypeConversionProducer 751 : public OpConversionPattern<TestTypeProducerOp> { 752 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 753 LogicalResult 754 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 755 ConversionPatternRewriter &rewriter) const final { 756 Type resultType = op.getType(); 757 if (resultType.isa<FloatType>()) 758 resultType = rewriter.getF64Type(); 759 else if (resultType.isInteger(16)) 760 resultType = rewriter.getIntegerType(64); 761 else 762 return failure(); 763 764 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 765 return success(); 766 } 767 }; 768 769 struct TestTypeConversionDriver 770 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 771 void getDependentDialects(DialectRegistry ®istry) const override { 772 registry.insert<TestDialect>(); 773 } 774 775 void runOnOperation() override { 776 // Initialize the type converter. 777 TypeConverter converter; 778 779 /// Add the legal set of type conversions. 780 converter.addConversion([](Type type) -> Type { 781 // Treat F64 as legal. 782 if (type.isF64()) 783 return type; 784 // Allow converting BF16/F16/F32 to F64. 785 if (type.isBF16() || type.isF16() || type.isF32()) 786 return FloatType::getF64(type.getContext()); 787 // Otherwise, the type is illegal. 788 return nullptr; 789 }); 790 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 791 // Drop all integer types. 792 return success(); 793 }); 794 795 /// Add the legal set of type materializations. 796 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 797 ValueRange inputs, 798 Location loc) -> Value { 799 // Allow casting from F64 back to F32. 800 if (!resultType.isF16() && inputs.size() == 1 && 801 inputs[0].getType().isF64()) 802 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 803 // Allow producing an i32 or i64 from nothing. 804 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 805 inputs.empty()) 806 return builder.create<TestTypeProducerOp>(loc, resultType); 807 // Allow producing an i64 from an integer. 808 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 809 inputs[0].getType().isa<IntegerType>()) 810 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 811 // Otherwise, fail. 812 return nullptr; 813 }); 814 815 // Initialize the conversion target. 816 mlir::ConversionTarget target(getContext()); 817 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 818 return op.getType().isF64() || op.getType().isInteger(64); 819 }); 820 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 821 return converter.isSignatureLegal(op.getType()) && 822 converter.isLegal(&op.getBody()); 823 }); 824 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 825 // Allow casts from F64 to F32. 826 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 827 }); 828 829 // Initialize the set of rewrite patterns. 830 OwningRewritePatternList patterns; 831 patterns.insert<TestTypeConversionProducer>(converter, &getContext()); 832 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 833 converter); 834 835 if (failed(applyPartialConversion(getOperation(), target, patterns))) 836 signalPassFailure(); 837 } 838 }; 839 } // end anonymous namespace 840 841 namespace { 842 /// A rewriter pattern that tests that blocks can be merged. 843 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 844 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 845 846 LogicalResult 847 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 848 ConversionPatternRewriter &rewriter) const final { 849 Block &firstBlock = op.body().front(); 850 Operation *branchOp = firstBlock.getTerminator(); 851 Block *secondBlock = &*(std::next(op.body().begin())); 852 auto succOperands = branchOp->getOperands(); 853 SmallVector<Value, 2> replacements(succOperands); 854 rewriter.eraseOp(branchOp); 855 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 856 rewriter.updateRootInPlace(op, [] {}); 857 return success(); 858 } 859 }; 860 861 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 862 struct TestUndoBlocksMerge : public ConversionPattern { 863 TestUndoBlocksMerge(MLIRContext *ctx) 864 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 865 LogicalResult 866 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 867 ConversionPatternRewriter &rewriter) const final { 868 Block &firstBlock = op->getRegion(0).front(); 869 Operation *branchOp = firstBlock.getTerminator(); 870 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 871 rewriter.setInsertionPointToStart(secondBlock); 872 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 873 auto succOperands = branchOp->getOperands(); 874 SmallVector<Value, 2> replacements(succOperands); 875 rewriter.eraseOp(branchOp); 876 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 877 rewriter.updateRootInPlace(op, [] {}); 878 return success(); 879 } 880 }; 881 882 /// A rewrite mechanism to inline the body of the op into its parent, when both 883 /// ops can have a single block. 884 struct TestMergeSingleBlockOps 885 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 886 using OpConversionPattern< 887 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 888 889 LogicalResult 890 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 891 ConversionPatternRewriter &rewriter) const final { 892 SingleBlockImplicitTerminatorOp parentOp = 893 op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 894 if (!parentOp) 895 return failure(); 896 Block &parentBlock = parentOp.region().front(); 897 Block &innerBlock = op.region().front(); 898 TerminatorOp innerTerminator = 899 cast<TerminatorOp>(innerBlock.getTerminator()); 900 Block *parentPrologue = 901 rewriter.splitBlock(&parentBlock, Block::iterator(op)); 902 rewriter.eraseOp(innerTerminator); 903 rewriter.mergeBlocks(&innerBlock, &parentBlock, {}); 904 rewriter.eraseOp(op); 905 rewriter.mergeBlocks(parentPrologue, &parentBlock, {}); 906 rewriter.updateRootInPlace(op, [] {}); 907 return success(); 908 } 909 }; 910 911 struct TestMergeBlocksPatternDriver 912 : public PassWrapper<TestMergeBlocksPatternDriver, 913 OperationPass<ModuleOp>> { 914 void runOnOperation() override { 915 mlir::OwningRewritePatternList patterns; 916 MLIRContext *context = &getContext(); 917 patterns 918 .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 919 context); 920 ConversionTarget target(*context); 921 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp, 922 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp, 923 TestReturnOp>(); 924 target.addIllegalOp<ILLegalOpF>(); 925 926 /// Expect the op to have a single block after legalization. 927 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 928 [&](TestMergeBlocksOp op) -> bool { 929 return llvm::hasSingleElement(op.body()); 930 }); 931 932 /// Only allow `test.br` within test.merge_blocks op. 933 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 934 return op.getParentOfType<TestMergeBlocksOp>(); 935 }); 936 937 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 938 /// inlined. 939 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 940 [&](SingleBlockImplicitTerminatorOp op) -> bool { 941 return !op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 942 }); 943 944 DenseSet<Operation *> unlegalizedOps; 945 (void)applyPartialConversion(getOperation(), target, patterns, 946 &unlegalizedOps); 947 for (auto *op : unlegalizedOps) 948 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 949 } 950 }; 951 } // namespace 952 953 //===----------------------------------------------------------------------===// 954 // PassRegistration 955 //===----------------------------------------------------------------------===// 956 957 namespace mlir { 958 void registerPatternsTestPass() { 959 PassRegistration<TestReturnTypeDriver>("test-return-type", 960 "Run return type functions"); 961 962 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 963 "Run test derived attributes"); 964 965 PassRegistration<TestPatternDriver>("test-patterns", 966 "Run test dialect patterns"); 967 968 PassRegistration<TestLegalizePatternDriver>( 969 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 970 return std::make_unique<TestLegalizePatternDriver>( 971 legalizerConversionMode); 972 }); 973 974 PassRegistration<TestRemappedValue>( 975 "test-remapped-value", 976 "Test public remapped value mechanism in ConversionPatternRewriter"); 977 978 PassRegistration<TestUnknownRootOpDriver>( 979 "test-legalize-unknown-root-patterns", 980 "Test public remapped value mechanism in ConversionPatternRewriter"); 981 982 PassRegistration<TestTypeConversionDriver>( 983 "test-legalize-type-conversion", 984 "Test various type conversion functionalities in DialectConversion"); 985 986 PassRegistration<TestMergeBlocksPatternDriver>{ 987 "test-merge-blocks", 988 "Test Merging operation in ConversionPatternRewriter"}; 989 } 990 } // namespace mlir 991