1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/IR/PatternMatch.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 17 using namespace mlir; 18 19 // Native function for testing NativeCodeCall 20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 21 return choice.getValue() ? input1 : input2; 22 } 23 24 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 25 rewriter.create<OpI>(loc, input); 26 } 27 28 static void handleNoResultOp(PatternRewriter &rewriter, 29 OpSymbolBindingNoResult op) { 30 // Turn the no result op to a one-result op. 31 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 32 op.operand()); 33 } 34 35 // Test that natives calls are only called once during rewrites. 36 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 37 // This let us check the number of times OpM_Test was called by inspecting 38 // the returned value in the MLIR output. 39 static int64_t opMIncreasingValue = 314159265; 40 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 41 int64_t i = opMIncreasingValue++; 42 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 43 } 44 45 namespace { 46 #include "TestPatterns.inc" 47 } // end anonymous namespace 48 49 //===----------------------------------------------------------------------===// 50 // Canonicalizer Driver. 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 struct FoldingPattern : public RewritePattern { 55 public: 56 FoldingPattern(MLIRContext *context) 57 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 58 /*benefit=*/1, context) {} 59 60 LogicalResult matchAndRewrite(Operation *op, 61 PatternRewriter &rewriter) const override { 62 // Exercice OperationFolder API for a single-result operation that is folded 63 // upon construction. The operation being created through the folder has an 64 // in-place folder, and it should be still present in the output. 65 // Furthermore, the folder should not crash when attempting to recover the 66 // (unchanged) opeation result. 67 OperationFolder folder(op->getContext()); 68 Value result = folder.create<TestOpInPlaceFold>( 69 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 70 rewriter.getI32IntegerAttr(0)); 71 assert(result); 72 rewriter.replaceOp(op, result); 73 return success(); 74 } 75 }; 76 77 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 78 void runOnFunction() override { 79 mlir::OwningRewritePatternList patterns; 80 populateWithGenerated(&getContext(), &patterns); 81 82 // Verify named pattern is generated with expected name. 83 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext()); 84 85 applyPatternsAndFoldGreedily(getFunction(), patterns); 86 } 87 }; 88 } // end anonymous namespace 89 90 //===----------------------------------------------------------------------===// 91 // ReturnType Driver. 92 //===----------------------------------------------------------------------===// 93 94 namespace { 95 // Generate ops for each instance where the type can be successfully inferred. 96 template <typename OpTy> 97 static void invokeCreateWithInferredReturnType(Operation *op) { 98 auto *context = op->getContext(); 99 auto fop = op->getParentOfType<FuncOp>(); 100 auto location = UnknownLoc::get(context); 101 OpBuilder b(op); 102 b.setInsertionPointAfter(op); 103 104 // Use permutations of 2 args as operands. 105 assert(fop.getNumArguments() >= 2); 106 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 107 for (int j = 0; j < e; ++j) { 108 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 109 SmallVector<Type, 2> inferredReturnTypes; 110 if (succeeded(OpTy::inferReturnTypes( 111 context, llvm::None, values, op->getAttrDictionary(), 112 op->getRegions(), inferredReturnTypes))) { 113 OperationState state(location, OpTy::getOperationName()); 114 // TODO: Expand to regions. 115 OpTy::build(b, state, values, op->getAttrs()); 116 (void)b.createOperation(state); 117 } 118 } 119 } 120 } 121 122 static void reifyReturnShape(Operation *op) { 123 OpBuilder b(op); 124 125 // Use permutations of 2 args as operands. 126 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 127 SmallVector<Value, 2> shapes; 128 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 129 return; 130 for (auto it : llvm::enumerate(shapes)) 131 op->emitRemark() << "value " << it.index() << ": " 132 << it.value().getDefiningOp(); 133 } 134 135 struct TestReturnTypeDriver 136 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 137 void runOnFunction() override { 138 if (getFunction().getName() == "testCreateFunctions") { 139 std::vector<Operation *> ops; 140 // Collect ops to avoid triggering on inserted ops. 141 for (auto &op : getFunction().getBody().front()) 142 ops.push_back(&op); 143 // Generate test patterns for each, but skip terminator. 144 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 145 // Test create method of each of the Op classes below. The resultant 146 // output would be in reverse order underneath `op` from which 147 // the attributes and regions are used. 148 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 149 invokeCreateWithInferredReturnType< 150 OpWithShapedTypeInferTypeInterfaceOp>(op); 151 }; 152 return; 153 } 154 if (getFunction().getName() == "testReifyFunctions") { 155 std::vector<Operation *> ops; 156 // Collect ops to avoid triggering on inserted ops. 157 for (auto &op : getFunction().getBody().front()) 158 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 159 ops.push_back(&op); 160 // Generate test patterns for each, but skip terminator. 161 for (auto *op : ops) 162 reifyReturnShape(op); 163 } 164 } 165 }; 166 } // end anonymous namespace 167 168 namespace { 169 struct TestDerivedAttributeDriver 170 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 171 void runOnFunction() override; 172 }; 173 } // end anonymous namespace 174 175 void TestDerivedAttributeDriver::runOnFunction() { 176 getFunction().walk([](DerivedAttributeOpInterface dOp) { 177 auto dAttr = dOp.materializeDerivedAttributes(); 178 if (!dAttr) 179 return; 180 for (auto d : dAttr) 181 dOp.emitRemark() << d.first << " = " << d.second; 182 }); 183 } 184 185 //===----------------------------------------------------------------------===// 186 // Legalization Driver. 187 //===----------------------------------------------------------------------===// 188 189 namespace { 190 //===----------------------------------------------------------------------===// 191 // Region-Block Rewrite Testing 192 193 /// This pattern is a simple pattern that inlines the first region of a given 194 /// operation into the parent region. 195 struct TestRegionRewriteBlockMovement : public ConversionPattern { 196 TestRegionRewriteBlockMovement(MLIRContext *ctx) 197 : ConversionPattern("test.region", 1, ctx) {} 198 199 LogicalResult 200 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 201 ConversionPatternRewriter &rewriter) const final { 202 // Inline this region into the parent region. 203 auto &parentRegion = *op->getParentRegion(); 204 if (op->getAttr("legalizer.should_clone")) 205 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 206 parentRegion.end()); 207 else 208 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 209 parentRegion.end()); 210 211 // Drop this operation. 212 rewriter.eraseOp(op); 213 return success(); 214 } 215 }; 216 /// This pattern is a simple pattern that generates a region containing an 217 /// illegal operation. 218 struct TestRegionRewriteUndo : public RewritePattern { 219 TestRegionRewriteUndo(MLIRContext *ctx) 220 : RewritePattern("test.region_builder", 1, ctx) {} 221 222 LogicalResult matchAndRewrite(Operation *op, 223 PatternRewriter &rewriter) const final { 224 // Create the region operation with an entry block containing arguments. 225 OperationState newRegion(op->getLoc(), "test.region"); 226 newRegion.addRegion(); 227 auto *regionOp = rewriter.createOperation(newRegion); 228 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 229 entryBlock->addArgument(rewriter.getIntegerType(64)); 230 231 // Add an explicitly illegal operation to ensure the conversion fails. 232 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 233 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 234 235 // Drop this operation. 236 rewriter.eraseOp(op); 237 return success(); 238 } 239 }; 240 /// A simple pattern that creates a block at the end of the parent region of the 241 /// matched operation. 242 struct TestCreateBlock : public RewritePattern { 243 TestCreateBlock(MLIRContext *ctx) 244 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 245 246 LogicalResult matchAndRewrite(Operation *op, 247 PatternRewriter &rewriter) const final { 248 Region ®ion = *op->getParentRegion(); 249 Type i32Type = rewriter.getIntegerType(32); 250 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 251 rewriter.create<TerminatorOp>(op->getLoc()); 252 rewriter.replaceOp(op, {}); 253 return success(); 254 } 255 }; 256 257 /// A simple pattern that creates a block containing an invalid operaiton in 258 /// order to trigger the block creation undo mechanism. 259 struct TestCreateIllegalBlock : public RewritePattern { 260 TestCreateIllegalBlock(MLIRContext *ctx) 261 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 262 263 LogicalResult matchAndRewrite(Operation *op, 264 PatternRewriter &rewriter) const final { 265 Region ®ion = *op->getParentRegion(); 266 Type i32Type = rewriter.getIntegerType(32); 267 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 268 // Create an illegal op to ensure the conversion fails. 269 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 270 rewriter.create<TerminatorOp>(op->getLoc()); 271 rewriter.replaceOp(op, {}); 272 return success(); 273 } 274 }; 275 276 /// A simple pattern that tests the undo mechanism when replacing the uses of a 277 /// block argument. 278 struct TestUndoBlockArgReplace : public ConversionPattern { 279 TestUndoBlockArgReplace(MLIRContext *ctx) 280 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 281 282 LogicalResult 283 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 284 ConversionPatternRewriter &rewriter) const final { 285 auto illegalOp = 286 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 287 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 288 illegalOp); 289 rewriter.updateRootInPlace(op, [] {}); 290 return success(); 291 } 292 }; 293 294 /// A rewrite pattern that tests the undo mechanism when erasing a block. 295 struct TestUndoBlockErase : public ConversionPattern { 296 TestUndoBlockErase(MLIRContext *ctx) 297 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 298 299 LogicalResult 300 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 301 ConversionPatternRewriter &rewriter) const final { 302 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 303 rewriter.setInsertionPointToStart(secondBlock); 304 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 305 rewriter.eraseBlock(secondBlock); 306 rewriter.updateRootInPlace(op, [] {}); 307 return success(); 308 } 309 }; 310 311 //===----------------------------------------------------------------------===// 312 // Type-Conversion Rewrite Testing 313 314 /// This patterns erases a region operation that has had a type conversion. 315 struct TestDropOpSignatureConversion : public ConversionPattern { 316 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 317 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} 318 LogicalResult 319 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 320 ConversionPatternRewriter &rewriter) const override { 321 Region ®ion = op->getRegion(0); 322 Block *entry = ®ion.front(); 323 324 // Convert the original entry arguments. 325 TypeConverter &converter = *getTypeConverter(); 326 TypeConverter::SignatureConversion result(entry->getNumArguments()); 327 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 328 result)) || 329 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 330 return failure(); 331 332 // Convert the region signature and just drop the operation. 333 rewriter.eraseOp(op); 334 return success(); 335 } 336 }; 337 /// This pattern simply updates the operands of the given operation. 338 struct TestPassthroughInvalidOp : public ConversionPattern { 339 TestPassthroughInvalidOp(MLIRContext *ctx) 340 : ConversionPattern("test.invalid", 1, ctx) {} 341 LogicalResult 342 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 343 ConversionPatternRewriter &rewriter) const final { 344 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 345 llvm::None); 346 return success(); 347 } 348 }; 349 /// This pattern handles the case of a split return value. 350 struct TestSplitReturnType : public ConversionPattern { 351 TestSplitReturnType(MLIRContext *ctx) 352 : ConversionPattern("test.return", 1, ctx) {} 353 LogicalResult 354 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 355 ConversionPatternRewriter &rewriter) const final { 356 // Check for a return of F32. 357 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 358 return failure(); 359 360 // Check if the first operation is a cast operation, if it is we use the 361 // results directly. 362 auto *defOp = operands[0].getDefiningOp(); 363 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 364 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 365 return success(); 366 } 367 368 // Otherwise, fail to match. 369 return failure(); 370 } 371 }; 372 373 //===----------------------------------------------------------------------===// 374 // Multi-Level Type-Conversion Rewrite Testing 375 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 376 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 377 : ConversionPattern("test.type_producer", 1, ctx) {} 378 LogicalResult 379 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 380 ConversionPatternRewriter &rewriter) const final { 381 // If the type is I32, change the type to F32. 382 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 383 return failure(); 384 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 385 return success(); 386 } 387 }; 388 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 389 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 390 : ConversionPattern("test.type_producer", 1, ctx) {} 391 LogicalResult 392 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 393 ConversionPatternRewriter &rewriter) const final { 394 // If the type is F32, change the type to F64. 395 if (!Type(*op->result_type_begin()).isF32()) 396 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 397 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 398 return success(); 399 } 400 }; 401 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 402 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 403 : ConversionPattern("test.type_producer", 10, ctx) {} 404 LogicalResult 405 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 406 ConversionPatternRewriter &rewriter) const final { 407 // Always convert to B16, even though it is not a legal type. This tests 408 // that values are unmapped correctly. 409 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 410 return success(); 411 } 412 }; 413 struct TestUpdateConsumerType : public ConversionPattern { 414 TestUpdateConsumerType(MLIRContext *ctx) 415 : ConversionPattern("test.type_consumer", 1, ctx) {} 416 LogicalResult 417 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 418 ConversionPatternRewriter &rewriter) const final { 419 // Verify that the incoming operand has been successfully remapped to F64. 420 if (!operands[0].getType().isF64()) 421 return failure(); 422 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 423 return success(); 424 } 425 }; 426 427 //===----------------------------------------------------------------------===// 428 // Non-Root Replacement Rewrite Testing 429 /// This pattern generates an invalid operation, but replaces it before the 430 /// pattern is finished. This checks that we don't need to legalize the 431 /// temporary op. 432 struct TestNonRootReplacement : public RewritePattern { 433 TestNonRootReplacement(MLIRContext *ctx) 434 : RewritePattern("test.replace_non_root", 1, ctx) {} 435 436 LogicalResult matchAndRewrite(Operation *op, 437 PatternRewriter &rewriter) const final { 438 auto resultType = *op->result_type_begin(); 439 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 440 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 441 442 rewriter.replaceOp(illegalOp, {legalOp}); 443 rewriter.replaceOp(op, {illegalOp}); 444 return success(); 445 } 446 }; 447 448 //===----------------------------------------------------------------------===// 449 // Recursive Rewrite Testing 450 /// This pattern is applied to the same operation multiple times, but has a 451 /// bounded recursion. 452 struct TestBoundedRecursiveRewrite 453 : public OpRewritePattern<TestRecursiveRewriteOp> { 454 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 455 456 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 457 PatternRewriter &rewriter) const final { 458 // Decrement the depth of the op in-place. 459 rewriter.updateRootInPlace(op, [&] { 460 op.setAttr("depth", 461 rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1)); 462 }); 463 return success(); 464 } 465 466 /// The conversion target handles bounding the recursion of this pattern. 467 bool hasBoundedRewriteRecursion() const final { return true; } 468 }; 469 470 struct TestNestedOpCreationUndoRewrite 471 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 472 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 473 474 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 475 PatternRewriter &rewriter) const final { 476 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 477 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 478 return success(); 479 }; 480 }; 481 } // namespace 482 483 namespace { 484 struct TestTypeConverter : public TypeConverter { 485 using TypeConverter::TypeConverter; 486 TestTypeConverter() { 487 addConversion(convertType); 488 addArgumentMaterialization(materializeCast); 489 addArgumentMaterialization(materializeOneToOneCast); 490 addSourceMaterialization(materializeCast); 491 } 492 493 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 494 // Drop I16 types. 495 if (t.isSignlessInteger(16)) 496 return success(); 497 498 // Convert I64 to F64. 499 if (t.isSignlessInteger(64)) { 500 results.push_back(FloatType::getF64(t.getContext())); 501 return success(); 502 } 503 504 // Convert I42 to I43. 505 if (t.isInteger(42)) { 506 results.push_back(IntegerType::get(43, t.getContext())); 507 return success(); 508 } 509 510 // Split F32 into F16,F16. 511 if (t.isF32()) { 512 results.assign(2, FloatType::getF16(t.getContext())); 513 return success(); 514 } 515 516 // Otherwise, convert the type directly. 517 results.push_back(t); 518 return success(); 519 } 520 521 /// Hook for materializing a conversion. This is necessary because we generate 522 /// 1->N type mappings. 523 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 524 ValueRange inputs, Location loc) { 525 if (inputs.size() == 1) 526 return inputs[0]; 527 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 528 } 529 530 /// Materialize the cast for one-to-one conversion from i64 to f64. 531 static Optional<Value> materializeOneToOneCast(OpBuilder &builder, 532 IntegerType resultType, 533 ValueRange inputs, 534 Location loc) { 535 if (resultType.getWidth() == 42 && inputs.size() == 1) 536 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 537 return llvm::None; 538 } 539 }; 540 541 struct TestLegalizePatternDriver 542 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 543 /// The mode of conversion to use with the driver. 544 enum class ConversionMode { Analysis, Full, Partial }; 545 546 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 547 548 void runOnOperation() override { 549 TestTypeConverter converter; 550 mlir::OwningRewritePatternList patterns; 551 populateWithGenerated(&getContext(), &patterns); 552 patterns.insert< 553 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 554 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, 555 TestPassthroughInvalidOp, TestSplitReturnType, 556 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 557 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 558 TestNonRootReplacement, TestBoundedRecursiveRewrite, 559 TestNestedOpCreationUndoRewrite>(&getContext()); 560 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 561 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 562 converter); 563 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 564 converter); 565 566 // Define the conversion target used for the test. 567 ConversionTarget target(getContext()); 568 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 569 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 570 TerminatorOp>(); 571 target 572 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 573 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 574 // Don't allow F32 operands. 575 return llvm::none_of(op.getOperandTypes(), 576 [](Type type) { return type.isF32(); }); 577 }); 578 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 579 return converter.isSignatureLegal(op.getType()) && 580 converter.isLegal(&op.getBody()); 581 }); 582 583 // Expect the type_producer/type_consumer operations to only operate on f64. 584 target.addDynamicallyLegalOp<TestTypeProducerOp>( 585 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 586 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 587 return op.getOperand().getType().isF64(); 588 }); 589 590 // Check support for marking certain operations as recursively legal. 591 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 592 return static_cast<bool>( 593 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 594 }); 595 596 // Mark the bound recursion operation as dynamically legal. 597 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 598 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 599 600 // Handle a partial conversion. 601 if (mode == ConversionMode::Partial) { 602 DenseSet<Operation *> unlegalizedOps; 603 (void)applyPartialConversion(getOperation(), target, patterns, 604 &unlegalizedOps); 605 // Emit remarks for each legalizable operation. 606 for (auto *op : unlegalizedOps) 607 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 608 return; 609 } 610 611 // Handle a full conversion. 612 if (mode == ConversionMode::Full) { 613 // Check support for marking unknown operations as dynamically legal. 614 target.markUnknownOpDynamicallyLegal([](Operation *op) { 615 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 616 }); 617 618 (void)applyFullConversion(getOperation(), target, patterns); 619 return; 620 } 621 622 // Otherwise, handle an analysis conversion. 623 assert(mode == ConversionMode::Analysis); 624 625 // Analyze the convertible operations. 626 DenseSet<Operation *> legalizedOps; 627 if (failed(applyAnalysisConversion(getOperation(), target, patterns, 628 legalizedOps))) 629 return signalPassFailure(); 630 631 // Emit remarks for each legalizable operation. 632 for (auto *op : legalizedOps) 633 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 634 } 635 636 /// The mode of conversion to use. 637 ConversionMode mode; 638 }; 639 } // end anonymous namespace 640 641 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 642 legalizerConversionMode( 643 "test-legalize-mode", 644 llvm::cl::desc("The legalization mode to use with the test driver"), 645 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 646 llvm::cl::values( 647 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 648 "analysis", "Perform an analysis conversion"), 649 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 650 "Perform a full conversion"), 651 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 652 "partial", "Perform a partial conversion"))); 653 654 //===----------------------------------------------------------------------===// 655 // ConversionPatternRewriter::getRemappedValue testing. This method is used 656 // to get the remapped value of an original value that was replaced using 657 // ConversionPatternRewriter. 658 namespace { 659 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 660 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 661 /// operand twice. 662 /// 663 /// Example: 664 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 665 /// is replaced with: 666 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 667 struct OneVResOneVOperandOp1Converter 668 : public OpConversionPattern<OneVResOneVOperandOp1> { 669 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 670 671 LogicalResult 672 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 673 ConversionPatternRewriter &rewriter) const override { 674 auto origOps = op.getOperands(); 675 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 676 "One operand expected"); 677 Value origOp = *origOps.begin(); 678 SmallVector<Value, 2> remappedOperands; 679 // Replicate the remapped original operand twice. Note that we don't used 680 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 681 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 682 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 683 684 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 685 remappedOperands); 686 return success(); 687 } 688 }; 689 690 struct TestRemappedValue 691 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 692 void runOnFunction() override { 693 mlir::OwningRewritePatternList patterns; 694 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 695 696 mlir::ConversionTarget target(getContext()); 697 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 698 // We make OneVResOneVOperandOp1 legal only when it has more that one 699 // operand. This will trigger the conversion that will replace one-operand 700 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 701 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 702 [](Operation *op) -> bool { 703 return std::distance(op->operand_begin(), op->operand_end()) > 1; 704 }); 705 706 if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { 707 signalPassFailure(); 708 } 709 } 710 }; 711 } // end anonymous namespace 712 713 //===----------------------------------------------------------------------===// 714 // Test patterns without a specific root operation kind 715 //===----------------------------------------------------------------------===// 716 717 namespace { 718 /// This pattern matches and removes any operation in the test dialect. 719 struct RemoveTestDialectOps : public RewritePattern { 720 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 721 722 LogicalResult matchAndRewrite(Operation *op, 723 PatternRewriter &rewriter) const override { 724 if (!isa<TestDialect>(op->getDialect())) 725 return failure(); 726 rewriter.eraseOp(op); 727 return success(); 728 } 729 }; 730 731 struct TestUnknownRootOpDriver 732 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 733 void runOnFunction() override { 734 mlir::OwningRewritePatternList patterns; 735 patterns.insert<RemoveTestDialectOps>(); 736 737 mlir::ConversionTarget target(getContext()); 738 target.addIllegalDialect<TestDialect>(); 739 if (failed(applyPartialConversion(getFunction(), target, patterns))) 740 signalPassFailure(); 741 } 742 }; 743 } // end anonymous namespace 744 745 //===----------------------------------------------------------------------===// 746 // Test type conversions 747 //===----------------------------------------------------------------------===// 748 749 namespace { 750 struct TestTypeConversionProducer 751 : public OpConversionPattern<TestTypeProducerOp> { 752 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 753 LogicalResult 754 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 755 ConversionPatternRewriter &rewriter) const final { 756 Type resultType = op.getType(); 757 if (resultType.isa<FloatType>()) 758 resultType = rewriter.getF64Type(); 759 else if (resultType.isInteger(16)) 760 resultType = rewriter.getIntegerType(64); 761 else 762 return failure(); 763 764 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 765 return success(); 766 } 767 }; 768 769 struct TestTypeConversionDriver 770 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 771 void runOnOperation() override { 772 // Initialize the type converter. 773 TypeConverter converter; 774 775 /// Add the legal set of type conversions. 776 converter.addConversion([](Type type) -> Type { 777 // Treat F64 as legal. 778 if (type.isF64()) 779 return type; 780 // Allow converting BF16/F16/F32 to F64. 781 if (type.isBF16() || type.isF16() || type.isF32()) 782 return FloatType::getF64(type.getContext()); 783 // Otherwise, the type is illegal. 784 return nullptr; 785 }); 786 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 787 // Drop all integer types. 788 return success(); 789 }); 790 791 /// Add the legal set of type materializations. 792 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 793 ValueRange inputs, 794 Location loc) -> Value { 795 // Allow casting from F64 back to F32. 796 if (!resultType.isF16() && inputs.size() == 1 && 797 inputs[0].getType().isF64()) 798 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 799 // Allow producing an i32 or i64 from nothing. 800 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 801 inputs.empty()) 802 return builder.create<TestTypeProducerOp>(loc, resultType); 803 // Allow producing an i64 from an integer. 804 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 805 inputs[0].getType().isa<IntegerType>()) 806 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 807 // Otherwise, fail. 808 return nullptr; 809 }); 810 811 // Initialize the conversion target. 812 mlir::ConversionTarget target(getContext()); 813 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 814 return op.getType().isF64() || op.getType().isInteger(64); 815 }); 816 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 817 return converter.isSignatureLegal(op.getType()) && 818 converter.isLegal(&op.getBody()); 819 }); 820 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 821 // Allow casts from F64 to F32. 822 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 823 }); 824 825 // Initialize the set of rewrite patterns. 826 OwningRewritePatternList patterns; 827 patterns.insert<TestTypeConversionProducer>(converter, &getContext()); 828 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 829 converter); 830 831 if (failed(applyPartialConversion(getOperation(), target, patterns))) 832 signalPassFailure(); 833 } 834 }; 835 } // end anonymous namespace 836 837 namespace { 838 /// A rewriter pattern that tests that blocks can be merged. 839 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 840 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 841 842 LogicalResult 843 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 844 ConversionPatternRewriter &rewriter) const final { 845 Block &firstBlock = op.body().front(); 846 Operation *branchOp = firstBlock.getTerminator(); 847 Block *secondBlock = &*(std::next(op.body().begin())); 848 auto succOperands = branchOp->getOperands(); 849 SmallVector<Value, 2> replacements(succOperands); 850 rewriter.eraseOp(branchOp); 851 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 852 rewriter.updateRootInPlace(op, [] {}); 853 return success(); 854 } 855 }; 856 857 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 858 struct TestUndoBlocksMerge : public ConversionPattern { 859 TestUndoBlocksMerge(MLIRContext *ctx) 860 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 861 LogicalResult 862 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 863 ConversionPatternRewriter &rewriter) const final { 864 Block &firstBlock = op->getRegion(0).front(); 865 Operation *branchOp = firstBlock.getTerminator(); 866 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 867 rewriter.setInsertionPointToStart(secondBlock); 868 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 869 auto succOperands = branchOp->getOperands(); 870 SmallVector<Value, 2> replacements(succOperands); 871 rewriter.eraseOp(branchOp); 872 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 873 rewriter.updateRootInPlace(op, [] {}); 874 return success(); 875 } 876 }; 877 878 /// A rewrite mechanism to inline the body of the op into its parent, when both 879 /// ops can have a single block. 880 struct TestMergeSingleBlockOps 881 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 882 using OpConversionPattern< 883 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 884 885 LogicalResult 886 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 887 ConversionPatternRewriter &rewriter) const final { 888 SingleBlockImplicitTerminatorOp parentOp = 889 op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 890 if (!parentOp) 891 return failure(); 892 Block &parentBlock = parentOp.region().front(); 893 Block &innerBlock = op.region().front(); 894 TerminatorOp innerTerminator = 895 cast<TerminatorOp>(innerBlock.getTerminator()); 896 Block *parentPrologue = 897 rewriter.splitBlock(&parentBlock, Block::iterator(op)); 898 rewriter.eraseOp(innerTerminator); 899 rewriter.mergeBlocks(&innerBlock, &parentBlock, {}); 900 rewriter.eraseOp(op); 901 rewriter.mergeBlocks(parentPrologue, &parentBlock, {}); 902 rewriter.updateRootInPlace(op, [] {}); 903 return success(); 904 } 905 }; 906 907 struct TestMergeBlocksPatternDriver 908 : public PassWrapper<TestMergeBlocksPatternDriver, 909 OperationPass<ModuleOp>> { 910 void runOnOperation() override { 911 mlir::OwningRewritePatternList patterns; 912 MLIRContext *context = &getContext(); 913 patterns 914 .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 915 context); 916 ConversionTarget target(*context); 917 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp, 918 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp, 919 TestReturnOp>(); 920 target.addIllegalOp<ILLegalOpF>(); 921 922 /// Expect the op to have a single block after legalization. 923 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 924 [&](TestMergeBlocksOp op) -> bool { 925 return llvm::hasSingleElement(op.body()); 926 }); 927 928 /// Only allow `test.br` within test.merge_blocks op. 929 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 930 return op.getParentOfType<TestMergeBlocksOp>(); 931 }); 932 933 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 934 /// inlined. 935 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 936 [&](SingleBlockImplicitTerminatorOp op) -> bool { 937 return !op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 938 }); 939 940 DenseSet<Operation *> unlegalizedOps; 941 (void)applyPartialConversion(getOperation(), target, patterns, 942 &unlegalizedOps); 943 for (auto *op : unlegalizedOps) 944 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 945 } 946 }; 947 } // namespace 948 949 //===----------------------------------------------------------------------===// 950 // PassRegistration 951 //===----------------------------------------------------------------------===// 952 953 namespace mlir { 954 void registerPatternsTestPass() { 955 PassRegistration<TestReturnTypeDriver>("test-return-type", 956 "Run return type functions"); 957 958 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 959 "Run test derived attributes"); 960 961 PassRegistration<TestPatternDriver>("test-patterns", 962 "Run test dialect patterns"); 963 964 PassRegistration<TestLegalizePatternDriver>( 965 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 966 return std::make_unique<TestLegalizePatternDriver>( 967 legalizerConversionMode); 968 }); 969 970 PassRegistration<TestRemappedValue>( 971 "test-remapped-value", 972 "Test public remapped value mechanism in ConversionPatternRewriter"); 973 974 PassRegistration<TestUnknownRootOpDriver>( 975 "test-legalize-unknown-root-patterns", 976 "Test public remapped value mechanism in ConversionPatternRewriter"); 977 978 PassRegistration<TestTypeConversionDriver>( 979 "test-legalize-type-conversion", 980 "Test various type conversion functionalities in DialectConversion"); 981 982 PassRegistration<TestMergeBlocksPatternDriver>{ 983 "test-merge-blocks", 984 "Test Merging operation in ConversionPatternRewriter"}; 985 } 986 } // namespace mlir 987