1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/IR/PatternMatch.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 17 using namespace mlir; 18 19 // Native function for testing NativeCodeCall 20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 21 return choice.getValue() ? input1 : input2; 22 } 23 24 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 25 rewriter.create<OpI>(loc, input); 26 } 27 28 static void handleNoResultOp(PatternRewriter &rewriter, 29 OpSymbolBindingNoResult op) { 30 // Turn the no result op to a one-result op. 31 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 32 op.operand()); 33 } 34 35 // Test that natives calls are only called once during rewrites. 36 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 37 // This let us check the number of times OpM_Test was called by inspecting 38 // the returned value in the MLIR output. 39 static int64_t opMIncreasingValue = 314159265; 40 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 41 int64_t i = opMIncreasingValue++; 42 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 43 } 44 45 namespace { 46 #include "TestPatterns.inc" 47 } // end anonymous namespace 48 49 //===----------------------------------------------------------------------===// 50 // Canonicalizer Driver. 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 struct FoldingPattern : public RewritePattern { 55 public: 56 FoldingPattern(MLIRContext *context) 57 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 58 /*benefit=*/1, context) {} 59 60 LogicalResult matchAndRewrite(Operation *op, 61 PatternRewriter &rewriter) const override { 62 // Exercice OperationFolder API for a single-result operation that is folded 63 // upon construction. The operation being created through the folder has an 64 // in-place folder, and it should be still present in the output. 65 // Furthermore, the folder should not crash when attempting to recover the 66 // (unchanged) operation result. 67 OperationFolder folder(op->getContext()); 68 Value result = folder.create<TestOpInPlaceFold>( 69 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 70 rewriter.getI32IntegerAttr(0)); 71 assert(result); 72 rewriter.replaceOp(op, result); 73 return success(); 74 } 75 }; 76 77 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 78 void runOnFunction() override { 79 mlir::OwningRewritePatternList patterns; 80 populateWithGenerated(&getContext(), &patterns); 81 82 // Verify named pattern is generated with expected name. 83 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext()); 84 85 applyPatternsAndFoldGreedily(getFunction(), patterns); 86 } 87 }; 88 } // end anonymous namespace 89 90 //===----------------------------------------------------------------------===// 91 // ReturnType Driver. 92 //===----------------------------------------------------------------------===// 93 94 namespace { 95 // Generate ops for each instance where the type can be successfully inferred. 96 template <typename OpTy> 97 static void invokeCreateWithInferredReturnType(Operation *op) { 98 auto *context = op->getContext(); 99 auto fop = op->getParentOfType<FuncOp>(); 100 auto location = UnknownLoc::get(context); 101 OpBuilder b(op); 102 b.setInsertionPointAfter(op); 103 104 // Use permutations of 2 args as operands. 105 assert(fop.getNumArguments() >= 2); 106 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 107 for (int j = 0; j < e; ++j) { 108 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 109 SmallVector<Type, 2> inferredReturnTypes; 110 if (succeeded(OpTy::inferReturnTypes( 111 context, llvm::None, values, op->getAttrDictionary(), 112 op->getRegions(), inferredReturnTypes))) { 113 OperationState state(location, OpTy::getOperationName()); 114 // TODO: Expand to regions. 115 OpTy::build(b, state, values, op->getAttrs()); 116 (void)b.createOperation(state); 117 } 118 } 119 } 120 } 121 122 static void reifyReturnShape(Operation *op) { 123 OpBuilder b(op); 124 125 // Use permutations of 2 args as operands. 126 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 127 SmallVector<Value, 2> shapes; 128 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 129 return; 130 for (auto it : llvm::enumerate(shapes)) 131 op->emitRemark() << "value " << it.index() << ": " 132 << it.value().getDefiningOp(); 133 } 134 135 struct TestReturnTypeDriver 136 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 137 void runOnFunction() override { 138 if (getFunction().getName() == "testCreateFunctions") { 139 std::vector<Operation *> ops; 140 // Collect ops to avoid triggering on inserted ops. 141 for (auto &op : getFunction().getBody().front()) 142 ops.push_back(&op); 143 // Generate test patterns for each, but skip terminator. 144 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 145 // Test create method of each of the Op classes below. The resultant 146 // output would be in reverse order underneath `op` from which 147 // the attributes and regions are used. 148 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 149 invokeCreateWithInferredReturnType< 150 OpWithShapedTypeInferTypeInterfaceOp>(op); 151 }; 152 return; 153 } 154 if (getFunction().getName() == "testReifyFunctions") { 155 std::vector<Operation *> ops; 156 // Collect ops to avoid triggering on inserted ops. 157 for (auto &op : getFunction().getBody().front()) 158 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 159 ops.push_back(&op); 160 // Generate test patterns for each, but skip terminator. 161 for (auto *op : ops) 162 reifyReturnShape(op); 163 } 164 } 165 }; 166 } // end anonymous namespace 167 168 namespace { 169 struct TestDerivedAttributeDriver 170 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 171 void runOnFunction() override; 172 }; 173 } // end anonymous namespace 174 175 void TestDerivedAttributeDriver::runOnFunction() { 176 getFunction().walk([](DerivedAttributeOpInterface dOp) { 177 auto dAttr = dOp.materializeDerivedAttributes(); 178 if (!dAttr) 179 return; 180 for (auto d : dAttr) 181 dOp.emitRemark() << d.first << " = " << d.second; 182 }); 183 } 184 185 //===----------------------------------------------------------------------===// 186 // Legalization Driver. 187 //===----------------------------------------------------------------------===// 188 189 namespace { 190 //===----------------------------------------------------------------------===// 191 // Region-Block Rewrite Testing 192 193 /// This pattern is a simple pattern that inlines the first region of a given 194 /// operation into the parent region. 195 struct TestRegionRewriteBlockMovement : public ConversionPattern { 196 TestRegionRewriteBlockMovement(MLIRContext *ctx) 197 : ConversionPattern("test.region", 1, ctx) {} 198 199 LogicalResult 200 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 201 ConversionPatternRewriter &rewriter) const final { 202 // Inline this region into the parent region. 203 auto &parentRegion = *op->getParentRegion(); 204 if (op->getAttr("legalizer.should_clone")) 205 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 206 parentRegion.end()); 207 else 208 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 209 parentRegion.end()); 210 211 // Drop this operation. 212 rewriter.eraseOp(op); 213 return success(); 214 } 215 }; 216 /// This pattern is a simple pattern that generates a region containing an 217 /// illegal operation. 218 struct TestRegionRewriteUndo : public RewritePattern { 219 TestRegionRewriteUndo(MLIRContext *ctx) 220 : RewritePattern("test.region_builder", 1, ctx) {} 221 222 LogicalResult matchAndRewrite(Operation *op, 223 PatternRewriter &rewriter) const final { 224 // Create the region operation with an entry block containing arguments. 225 OperationState newRegion(op->getLoc(), "test.region"); 226 newRegion.addRegion(); 227 auto *regionOp = rewriter.createOperation(newRegion); 228 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 229 entryBlock->addArgument(rewriter.getIntegerType(64)); 230 231 // Add an explicitly illegal operation to ensure the conversion fails. 232 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 233 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 234 235 // Drop this operation. 236 rewriter.eraseOp(op); 237 return success(); 238 } 239 }; 240 /// A simple pattern that creates a block at the end of the parent region of the 241 /// matched operation. 242 struct TestCreateBlock : public RewritePattern { 243 TestCreateBlock(MLIRContext *ctx) 244 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 245 246 LogicalResult matchAndRewrite(Operation *op, 247 PatternRewriter &rewriter) const final { 248 Region ®ion = *op->getParentRegion(); 249 Type i32Type = rewriter.getIntegerType(32); 250 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 251 rewriter.create<TerminatorOp>(op->getLoc()); 252 rewriter.replaceOp(op, {}); 253 return success(); 254 } 255 }; 256 257 /// A simple pattern that creates a block containing an invalid operation in 258 /// order to trigger the block creation undo mechanism. 259 struct TestCreateIllegalBlock : public RewritePattern { 260 TestCreateIllegalBlock(MLIRContext *ctx) 261 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 262 263 LogicalResult matchAndRewrite(Operation *op, 264 PatternRewriter &rewriter) const final { 265 Region ®ion = *op->getParentRegion(); 266 Type i32Type = rewriter.getIntegerType(32); 267 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 268 // Create an illegal op to ensure the conversion fails. 269 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 270 rewriter.create<TerminatorOp>(op->getLoc()); 271 rewriter.replaceOp(op, {}); 272 return success(); 273 } 274 }; 275 276 /// A simple pattern that tests the undo mechanism when replacing the uses of a 277 /// block argument. 278 struct TestUndoBlockArgReplace : public ConversionPattern { 279 TestUndoBlockArgReplace(MLIRContext *ctx) 280 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 281 282 LogicalResult 283 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 284 ConversionPatternRewriter &rewriter) const final { 285 auto illegalOp = 286 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 287 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 288 illegalOp); 289 rewriter.updateRootInPlace(op, [] {}); 290 return success(); 291 } 292 }; 293 294 /// A rewrite pattern that tests the undo mechanism when erasing a block. 295 struct TestUndoBlockErase : public ConversionPattern { 296 TestUndoBlockErase(MLIRContext *ctx) 297 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 298 299 LogicalResult 300 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 301 ConversionPatternRewriter &rewriter) const final { 302 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 303 rewriter.setInsertionPointToStart(secondBlock); 304 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 305 rewriter.eraseBlock(secondBlock); 306 rewriter.updateRootInPlace(op, [] {}); 307 return success(); 308 } 309 }; 310 311 //===----------------------------------------------------------------------===// 312 // Type-Conversion Rewrite Testing 313 314 /// This patterns erases a region operation that has had a type conversion. 315 struct TestDropOpSignatureConversion : public ConversionPattern { 316 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 317 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} 318 LogicalResult 319 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 320 ConversionPatternRewriter &rewriter) const override { 321 Region ®ion = op->getRegion(0); 322 Block *entry = ®ion.front(); 323 324 // Convert the original entry arguments. 325 TypeConverter &converter = *getTypeConverter(); 326 TypeConverter::SignatureConversion result(entry->getNumArguments()); 327 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 328 result)) || 329 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 330 return failure(); 331 332 // Convert the region signature and just drop the operation. 333 rewriter.eraseOp(op); 334 return success(); 335 } 336 }; 337 /// This pattern simply updates the operands of the given operation. 338 struct TestPassthroughInvalidOp : public ConversionPattern { 339 TestPassthroughInvalidOp(MLIRContext *ctx) 340 : ConversionPattern("test.invalid", 1, ctx) {} 341 LogicalResult 342 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 343 ConversionPatternRewriter &rewriter) const final { 344 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 345 llvm::None); 346 return success(); 347 } 348 }; 349 /// This pattern handles the case of a split return value. 350 struct TestSplitReturnType : public ConversionPattern { 351 TestSplitReturnType(MLIRContext *ctx) 352 : ConversionPattern("test.return", 1, ctx) {} 353 LogicalResult 354 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 355 ConversionPatternRewriter &rewriter) const final { 356 // Check for a return of F32. 357 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 358 return failure(); 359 360 // Check if the first operation is a cast operation, if it is we use the 361 // results directly. 362 auto *defOp = operands[0].getDefiningOp(); 363 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 364 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 365 return success(); 366 } 367 368 // Otherwise, fail to match. 369 return failure(); 370 } 371 }; 372 373 //===----------------------------------------------------------------------===// 374 // Multi-Level Type-Conversion Rewrite Testing 375 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 376 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 377 : ConversionPattern("test.type_producer", 1, ctx) {} 378 LogicalResult 379 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 380 ConversionPatternRewriter &rewriter) const final { 381 // If the type is I32, change the type to F32. 382 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 383 return failure(); 384 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 385 return success(); 386 } 387 }; 388 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 389 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 390 : ConversionPattern("test.type_producer", 1, ctx) {} 391 LogicalResult 392 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 393 ConversionPatternRewriter &rewriter) const final { 394 // If the type is F32, change the type to F64. 395 if (!Type(*op->result_type_begin()).isF32()) 396 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 397 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 398 return success(); 399 } 400 }; 401 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 402 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 403 : ConversionPattern("test.type_producer", 10, ctx) {} 404 LogicalResult 405 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 406 ConversionPatternRewriter &rewriter) const final { 407 // Always convert to B16, even though it is not a legal type. This tests 408 // that values are unmapped correctly. 409 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 410 return success(); 411 } 412 }; 413 struct TestUpdateConsumerType : public ConversionPattern { 414 TestUpdateConsumerType(MLIRContext *ctx) 415 : ConversionPattern("test.type_consumer", 1, ctx) {} 416 LogicalResult 417 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 418 ConversionPatternRewriter &rewriter) const final { 419 // Verify that the incoming operand has been successfully remapped to F64. 420 if (!operands[0].getType().isF64()) 421 return failure(); 422 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 423 return success(); 424 } 425 }; 426 427 //===----------------------------------------------------------------------===// 428 // Non-Root Replacement Rewrite Testing 429 /// This pattern generates an invalid operation, but replaces it before the 430 /// pattern is finished. This checks that we don't need to legalize the 431 /// temporary op. 432 struct TestNonRootReplacement : public RewritePattern { 433 TestNonRootReplacement(MLIRContext *ctx) 434 : RewritePattern("test.replace_non_root", 1, ctx) {} 435 436 LogicalResult matchAndRewrite(Operation *op, 437 PatternRewriter &rewriter) const final { 438 auto resultType = *op->result_type_begin(); 439 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 440 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 441 442 rewriter.replaceOp(illegalOp, {legalOp}); 443 rewriter.replaceOp(op, {illegalOp}); 444 return success(); 445 } 446 }; 447 448 //===----------------------------------------------------------------------===// 449 // Recursive Rewrite Testing 450 /// This pattern is applied to the same operation multiple times, but has a 451 /// bounded recursion. 452 struct TestBoundedRecursiveRewrite 453 : public OpRewritePattern<TestRecursiveRewriteOp> { 454 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 455 456 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 457 PatternRewriter &rewriter) const final { 458 // Decrement the depth of the op in-place. 459 rewriter.updateRootInPlace(op, [&] { 460 op.setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 461 }); 462 return success(); 463 } 464 465 /// The conversion target handles bounding the recursion of this pattern. 466 bool hasBoundedRewriteRecursion() const final { return true; } 467 }; 468 469 struct TestNestedOpCreationUndoRewrite 470 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 471 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 472 473 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 474 PatternRewriter &rewriter) const final { 475 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 476 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 477 return success(); 478 }; 479 }; 480 } // namespace 481 482 namespace { 483 struct TestTypeConverter : public TypeConverter { 484 using TypeConverter::TypeConverter; 485 TestTypeConverter() { 486 addConversion(convertType); 487 addArgumentMaterialization(materializeCast); 488 addArgumentMaterialization(materializeOneToOneCast); 489 addSourceMaterialization(materializeCast); 490 } 491 492 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 493 // Drop I16 types. 494 if (t.isSignlessInteger(16)) 495 return success(); 496 497 // Convert I64 to F64. 498 if (t.isSignlessInteger(64)) { 499 results.push_back(FloatType::getF64(t.getContext())); 500 return success(); 501 } 502 503 // Convert I42 to I43. 504 if (t.isInteger(42)) { 505 results.push_back(IntegerType::get(43, t.getContext())); 506 return success(); 507 } 508 509 // Split F32 into F16,F16. 510 if (t.isF32()) { 511 results.assign(2, FloatType::getF16(t.getContext())); 512 return success(); 513 } 514 515 // Otherwise, convert the type directly. 516 results.push_back(t); 517 return success(); 518 } 519 520 /// Hook for materializing a conversion. This is necessary because we generate 521 /// 1->N type mappings. 522 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 523 ValueRange inputs, Location loc) { 524 if (inputs.size() == 1) 525 return inputs[0]; 526 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 527 } 528 529 /// Materialize the cast for one-to-one conversion from i64 to f64. 530 static Optional<Value> materializeOneToOneCast(OpBuilder &builder, 531 IntegerType resultType, 532 ValueRange inputs, 533 Location loc) { 534 if (resultType.getWidth() == 42 && inputs.size() == 1) 535 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 536 return llvm::None; 537 } 538 }; 539 540 struct TestLegalizePatternDriver 541 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 542 /// The mode of conversion to use with the driver. 543 enum class ConversionMode { Analysis, Full, Partial }; 544 545 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 546 547 void runOnOperation() override { 548 TestTypeConverter converter; 549 mlir::OwningRewritePatternList patterns; 550 populateWithGenerated(&getContext(), &patterns); 551 patterns.insert< 552 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, 553 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, 554 TestPassthroughInvalidOp, TestSplitReturnType, 555 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 556 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 557 TestNonRootReplacement, TestBoundedRecursiveRewrite, 558 TestNestedOpCreationUndoRewrite>(&getContext()); 559 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 560 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 561 converter); 562 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 563 converter); 564 565 // Define the conversion target used for the test. 566 ConversionTarget target(getContext()); 567 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 568 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 569 TerminatorOp>(); 570 target 571 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 572 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 573 // Don't allow F32 operands. 574 return llvm::none_of(op.getOperandTypes(), 575 [](Type type) { return type.isF32(); }); 576 }); 577 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 578 return converter.isSignatureLegal(op.getType()) && 579 converter.isLegal(&op.getBody()); 580 }); 581 582 // Expect the type_producer/type_consumer operations to only operate on f64. 583 target.addDynamicallyLegalOp<TestTypeProducerOp>( 584 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 585 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 586 return op.getOperand().getType().isF64(); 587 }); 588 589 // Check support for marking certain operations as recursively legal. 590 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 591 return static_cast<bool>( 592 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 593 }); 594 595 // Mark the bound recursion operation as dynamically legal. 596 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 597 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 598 599 // Handle a partial conversion. 600 if (mode == ConversionMode::Partial) { 601 DenseSet<Operation *> unlegalizedOps; 602 (void)applyPartialConversion(getOperation(), target, patterns, 603 &unlegalizedOps); 604 // Emit remarks for each legalizable operation. 605 for (auto *op : unlegalizedOps) 606 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 607 return; 608 } 609 610 // Handle a full conversion. 611 if (mode == ConversionMode::Full) { 612 // Check support for marking unknown operations as dynamically legal. 613 target.markUnknownOpDynamicallyLegal([](Operation *op) { 614 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 615 }); 616 617 (void)applyFullConversion(getOperation(), target, patterns); 618 return; 619 } 620 621 // Otherwise, handle an analysis conversion. 622 assert(mode == ConversionMode::Analysis); 623 624 // Analyze the convertible operations. 625 DenseSet<Operation *> legalizedOps; 626 if (failed(applyAnalysisConversion(getOperation(), target, patterns, 627 legalizedOps))) 628 return signalPassFailure(); 629 630 // Emit remarks for each legalizable operation. 631 for (auto *op : legalizedOps) 632 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 633 } 634 635 /// The mode of conversion to use. 636 ConversionMode mode; 637 }; 638 } // end anonymous namespace 639 640 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 641 legalizerConversionMode( 642 "test-legalize-mode", 643 llvm::cl::desc("The legalization mode to use with the test driver"), 644 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 645 llvm::cl::values( 646 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 647 "analysis", "Perform an analysis conversion"), 648 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 649 "Perform a full conversion"), 650 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 651 "partial", "Perform a partial conversion"))); 652 653 //===----------------------------------------------------------------------===// 654 // ConversionPatternRewriter::getRemappedValue testing. This method is used 655 // to get the remapped value of an original value that was replaced using 656 // ConversionPatternRewriter. 657 namespace { 658 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 659 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 660 /// operand twice. 661 /// 662 /// Example: 663 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 664 /// is replaced with: 665 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 666 struct OneVResOneVOperandOp1Converter 667 : public OpConversionPattern<OneVResOneVOperandOp1> { 668 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 669 670 LogicalResult 671 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 672 ConversionPatternRewriter &rewriter) const override { 673 auto origOps = op.getOperands(); 674 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 675 "One operand expected"); 676 Value origOp = *origOps.begin(); 677 SmallVector<Value, 2> remappedOperands; 678 // Replicate the remapped original operand twice. Note that we don't used 679 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 680 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 681 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 682 683 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 684 remappedOperands); 685 return success(); 686 } 687 }; 688 689 struct TestRemappedValue 690 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 691 void runOnFunction() override { 692 mlir::OwningRewritePatternList patterns; 693 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 694 695 mlir::ConversionTarget target(getContext()); 696 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 697 // We make OneVResOneVOperandOp1 legal only when it has more that one 698 // operand. This will trigger the conversion that will replace one-operand 699 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 700 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 701 [](Operation *op) -> bool { 702 return std::distance(op->operand_begin(), op->operand_end()) > 1; 703 }); 704 705 if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { 706 signalPassFailure(); 707 } 708 } 709 }; 710 } // end anonymous namespace 711 712 //===----------------------------------------------------------------------===// 713 // Test patterns without a specific root operation kind 714 //===----------------------------------------------------------------------===// 715 716 namespace { 717 /// This pattern matches and removes any operation in the test dialect. 718 struct RemoveTestDialectOps : public RewritePattern { 719 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 720 721 LogicalResult matchAndRewrite(Operation *op, 722 PatternRewriter &rewriter) const override { 723 if (!isa<TestDialect>(op->getDialect())) 724 return failure(); 725 rewriter.eraseOp(op); 726 return success(); 727 } 728 }; 729 730 struct TestUnknownRootOpDriver 731 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 732 void runOnFunction() override { 733 mlir::OwningRewritePatternList patterns; 734 patterns.insert<RemoveTestDialectOps>(); 735 736 mlir::ConversionTarget target(getContext()); 737 target.addIllegalDialect<TestDialect>(); 738 if (failed(applyPartialConversion(getFunction(), target, patterns))) 739 signalPassFailure(); 740 } 741 }; 742 } // end anonymous namespace 743 744 //===----------------------------------------------------------------------===// 745 // Test type conversions 746 //===----------------------------------------------------------------------===// 747 748 namespace { 749 struct TestTypeConversionProducer 750 : public OpConversionPattern<TestTypeProducerOp> { 751 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 752 LogicalResult 753 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 754 ConversionPatternRewriter &rewriter) const final { 755 Type resultType = op.getType(); 756 if (resultType.isa<FloatType>()) 757 resultType = rewriter.getF64Type(); 758 else if (resultType.isInteger(16)) 759 resultType = rewriter.getIntegerType(64); 760 else 761 return failure(); 762 763 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 764 return success(); 765 } 766 }; 767 768 struct TestTypeConversionDriver 769 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 770 void getDependentDialects(DialectRegistry ®istry) const override { 771 registry.insert<TestDialect>(); 772 } 773 774 void runOnOperation() override { 775 // Initialize the type converter. 776 TypeConverter converter; 777 778 /// Add the legal set of type conversions. 779 converter.addConversion([](Type type) -> Type { 780 // Treat F64 as legal. 781 if (type.isF64()) 782 return type; 783 // Allow converting BF16/F16/F32 to F64. 784 if (type.isBF16() || type.isF16() || type.isF32()) 785 return FloatType::getF64(type.getContext()); 786 // Otherwise, the type is illegal. 787 return nullptr; 788 }); 789 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 790 // Drop all integer types. 791 return success(); 792 }); 793 794 /// Add the legal set of type materializations. 795 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 796 ValueRange inputs, 797 Location loc) -> Value { 798 // Allow casting from F64 back to F32. 799 if (!resultType.isF16() && inputs.size() == 1 && 800 inputs[0].getType().isF64()) 801 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 802 // Allow producing an i32 or i64 from nothing. 803 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 804 inputs.empty()) 805 return builder.create<TestTypeProducerOp>(loc, resultType); 806 // Allow producing an i64 from an integer. 807 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 808 inputs[0].getType().isa<IntegerType>()) 809 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 810 // Otherwise, fail. 811 return nullptr; 812 }); 813 814 // Initialize the conversion target. 815 mlir::ConversionTarget target(getContext()); 816 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 817 return op.getType().isF64() || op.getType().isInteger(64); 818 }); 819 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 820 return converter.isSignatureLegal(op.getType()) && 821 converter.isLegal(&op.getBody()); 822 }); 823 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 824 // Allow casts from F64 to F32. 825 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 826 }); 827 828 // Initialize the set of rewrite patterns. 829 OwningRewritePatternList patterns; 830 patterns.insert<TestTypeConversionProducer>(converter, &getContext()); 831 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 832 converter); 833 834 if (failed(applyPartialConversion(getOperation(), target, patterns))) 835 signalPassFailure(); 836 } 837 }; 838 } // end anonymous namespace 839 840 namespace { 841 /// A rewriter pattern that tests that blocks can be merged. 842 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 843 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 844 845 LogicalResult 846 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 847 ConversionPatternRewriter &rewriter) const final { 848 Block &firstBlock = op.body().front(); 849 Operation *branchOp = firstBlock.getTerminator(); 850 Block *secondBlock = &*(std::next(op.body().begin())); 851 auto succOperands = branchOp->getOperands(); 852 SmallVector<Value, 2> replacements(succOperands); 853 rewriter.eraseOp(branchOp); 854 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 855 rewriter.updateRootInPlace(op, [] {}); 856 return success(); 857 } 858 }; 859 860 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 861 struct TestUndoBlocksMerge : public ConversionPattern { 862 TestUndoBlocksMerge(MLIRContext *ctx) 863 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 864 LogicalResult 865 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 866 ConversionPatternRewriter &rewriter) const final { 867 Block &firstBlock = op->getRegion(0).front(); 868 Operation *branchOp = firstBlock.getTerminator(); 869 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 870 rewriter.setInsertionPointToStart(secondBlock); 871 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 872 auto succOperands = branchOp->getOperands(); 873 SmallVector<Value, 2> replacements(succOperands); 874 rewriter.eraseOp(branchOp); 875 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 876 rewriter.updateRootInPlace(op, [] {}); 877 return success(); 878 } 879 }; 880 881 /// A rewrite mechanism to inline the body of the op into its parent, when both 882 /// ops can have a single block. 883 struct TestMergeSingleBlockOps 884 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 885 using OpConversionPattern< 886 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 887 888 LogicalResult 889 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 890 ConversionPatternRewriter &rewriter) const final { 891 SingleBlockImplicitTerminatorOp parentOp = 892 op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 893 if (!parentOp) 894 return failure(); 895 Block &innerBlock = op.region().front(); 896 TerminatorOp innerTerminator = 897 cast<TerminatorOp>(innerBlock.getTerminator()); 898 rewriter.mergeBlockBefore(&innerBlock, op); 899 rewriter.eraseOp(innerTerminator); 900 rewriter.eraseOp(op); 901 rewriter.updateRootInPlace(op, [] {}); 902 return success(); 903 } 904 }; 905 906 struct TestMergeBlocksPatternDriver 907 : public PassWrapper<TestMergeBlocksPatternDriver, 908 OperationPass<ModuleOp>> { 909 void runOnOperation() override { 910 mlir::OwningRewritePatternList patterns; 911 MLIRContext *context = &getContext(); 912 patterns 913 .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 914 context); 915 ConversionTarget target(*context); 916 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp, 917 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp, 918 TestReturnOp>(); 919 target.addIllegalOp<ILLegalOpF>(); 920 921 /// Expect the op to have a single block after legalization. 922 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 923 [&](TestMergeBlocksOp op) -> bool { 924 return llvm::hasSingleElement(op.body()); 925 }); 926 927 /// Only allow `test.br` within test.merge_blocks op. 928 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 929 return op.getParentOfType<TestMergeBlocksOp>(); 930 }); 931 932 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 933 /// inlined. 934 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 935 [&](SingleBlockImplicitTerminatorOp op) -> bool { 936 return !op.getParentOfType<SingleBlockImplicitTerminatorOp>(); 937 }); 938 939 DenseSet<Operation *> unlegalizedOps; 940 (void)applyPartialConversion(getOperation(), target, patterns, 941 &unlegalizedOps); 942 for (auto *op : unlegalizedOps) 943 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 944 } 945 }; 946 } // namespace 947 948 //===----------------------------------------------------------------------===// 949 // PassRegistration 950 //===----------------------------------------------------------------------===// 951 952 namespace mlir { 953 void registerPatternsTestPass() { 954 PassRegistration<TestReturnTypeDriver>("test-return-type", 955 "Run return type functions"); 956 957 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr", 958 "Run test derived attributes"); 959 960 PassRegistration<TestPatternDriver>("test-patterns", 961 "Run test dialect patterns"); 962 963 PassRegistration<TestLegalizePatternDriver>( 964 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 965 return std::make_unique<TestLegalizePatternDriver>( 966 legalizerConversionMode); 967 }); 968 969 PassRegistration<TestRemappedValue>( 970 "test-remapped-value", 971 "Test public remapped value mechanism in ConversionPatternRewriter"); 972 973 PassRegistration<TestUnknownRootOpDriver>( 974 "test-legalize-unknown-root-patterns", 975 "Test public remapped value mechanism in ConversionPatternRewriter"); 976 977 PassRegistration<TestTypeConversionDriver>( 978 "test-legalize-type-conversion", 979 "Test various type conversion functionalities in DialectConversion"); 980 981 PassRegistration<TestMergeBlocksPatternDriver>{ 982 "test-merge-blocks", 983 "Test Merging operation in ConversionPatternRewriter"}; 984 } 985 } // namespace mlir 986