1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Conversion/StandardToStandard/StandardToStandard.h" 11 #include "mlir/IR/PatternMatch.h" 12 #include "mlir/Pass/Pass.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 using namespace mlir; 15 16 // Native function for testing NativeCodeCall 17 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 18 return choice.getValue() ? input1 : input2; 19 } 20 21 static void createOpI(PatternRewriter &rewriter, Value input) { 22 rewriter.create<OpI>(rewriter.getUnknownLoc(), input); 23 } 24 25 static void handleNoResultOp(PatternRewriter &rewriter, 26 OpSymbolBindingNoResult op) { 27 // Turn the no result op to a one-result op. 28 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 29 op.operand()); 30 } 31 32 namespace { 33 #include "TestPatterns.inc" 34 } // end anonymous namespace 35 36 //===----------------------------------------------------------------------===// 37 // Canonicalizer Driver. 38 //===----------------------------------------------------------------------===// 39 40 namespace { 41 struct TestPatternDriver : public FunctionPass<TestPatternDriver> { 42 void runOnFunction() override { 43 mlir::OwningRewritePatternList patterns; 44 populateWithGenerated(&getContext(), &patterns); 45 46 // Verify named pattern is generated with expected name. 47 patterns.insert<TestNamedPatternRule>(&getContext()); 48 49 applyPatternsGreedily(getFunction(), patterns); 50 } 51 }; 52 } // end anonymous namespace 53 54 //===----------------------------------------------------------------------===// 55 // ReturnType Driver. 56 //===----------------------------------------------------------------------===// 57 58 namespace { 59 // Generate ops for each instance where the type can be successfully inferred. 60 template <typename OpTy> 61 static void invokeCreateWithInferredReturnType(Operation *op) { 62 auto *context = op->getContext(); 63 auto fop = op->getParentOfType<FuncOp>(); 64 auto location = UnknownLoc::get(context); 65 OpBuilder b(op); 66 b.setInsertionPointAfter(op); 67 68 // Use permutations of 2 args as operands. 69 assert(fop.getNumArguments() >= 2); 70 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 71 for (int j = 0; j < e; ++j) { 72 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 73 SmallVector<Type, 2> inferredReturnTypes; 74 if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values, 75 op->getAttrs(), op->getRegions(), 76 inferredReturnTypes))) { 77 OperationState state(location, OpTy::getOperationName()); 78 // TODO(jpienaar): Expand to regions. 79 OpTy::build(&b, state, values, op->getAttrs()); 80 (void)b.createOperation(state); 81 } 82 } 83 } 84 } 85 86 static void reifyReturnShape(Operation *op) { 87 OpBuilder b(op); 88 89 // Use permutations of 2 args as operands. 90 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 91 SmallVector<Value, 2> shapes; 92 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) 93 return; 94 for (auto it : llvm::enumerate(shapes)) 95 op->emitRemark() << "value " << it.index() << ": " 96 << it.value().getDefiningOp(); 97 } 98 99 struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> { 100 void runOnFunction() override { 101 if (getFunction().getName() == "testCreateFunctions") { 102 std::vector<Operation *> ops; 103 // Collect ops to avoid triggering on inserted ops. 104 for (auto &op : getFunction().getBody().front()) 105 ops.push_back(&op); 106 // Generate test patterns for each, but skip terminator. 107 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 108 // Test create method of each of the Op classes below. The resultant 109 // output would be in reverse order underneath `op` from which 110 // the attributes and regions are used. 111 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 112 invokeCreateWithInferredReturnType< 113 OpWithShapedTypeInferTypeInterfaceOp>(op); 114 }; 115 return; 116 } 117 if (getFunction().getName() == "testReifyFunctions") { 118 std::vector<Operation *> ops; 119 // Collect ops to avoid triggering on inserted ops. 120 for (auto &op : getFunction().getBody().front()) 121 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 122 ops.push_back(&op); 123 // Generate test patterns for each, but skip terminator. 124 for (auto *op : ops) 125 reifyReturnShape(op); 126 } 127 } 128 }; 129 } // end anonymous namespace 130 131 //===----------------------------------------------------------------------===// 132 // Legalization Driver. 133 //===----------------------------------------------------------------------===// 134 135 namespace { 136 //===----------------------------------------------------------------------===// 137 // Region-Block Rewrite Testing 138 139 /// This pattern is a simple pattern that inlines the first region of a given 140 /// operation into the parent region. 141 struct TestRegionRewriteBlockMovement : public ConversionPattern { 142 TestRegionRewriteBlockMovement(MLIRContext *ctx) 143 : ConversionPattern("test.region", 1, ctx) {} 144 145 LogicalResult 146 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 147 ConversionPatternRewriter &rewriter) const final { 148 // Inline this region into the parent region. 149 auto &parentRegion = *op->getParentRegion(); 150 if (op->getAttr("legalizer.should_clone")) 151 rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, 152 parentRegion.end()); 153 else 154 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 155 parentRegion.end()); 156 157 // Drop this operation. 158 rewriter.eraseOp(op); 159 return success(); 160 } 161 }; 162 /// This pattern is a simple pattern that generates a region containing an 163 /// illegal operation. 164 struct TestRegionRewriteUndo : public RewritePattern { 165 TestRegionRewriteUndo(MLIRContext *ctx) 166 : RewritePattern("test.region_builder", 1, ctx) {} 167 168 LogicalResult matchAndRewrite(Operation *op, 169 PatternRewriter &rewriter) const final { 170 // Create the region operation with an entry block containing arguments. 171 OperationState newRegion(op->getLoc(), "test.region"); 172 newRegion.addRegion(); 173 auto *regionOp = rewriter.createOperation(newRegion); 174 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 175 entryBlock->addArgument(rewriter.getIntegerType(64)); 176 177 // Add an explicitly illegal operation to ensure the conversion fails. 178 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 179 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 180 181 // Drop this operation. 182 rewriter.eraseOp(op); 183 return success(); 184 } 185 }; 186 187 //===----------------------------------------------------------------------===// 188 // Type-Conversion Rewrite Testing 189 190 /// This patterns erases a region operation that has had a type conversion. 191 struct TestDropOpSignatureConversion : public ConversionPattern { 192 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 193 : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { 194 } 195 LogicalResult 196 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 197 ConversionPatternRewriter &rewriter) const override { 198 Region ®ion = op->getRegion(0); 199 Block *entry = ®ion.front(); 200 201 // Convert the original entry arguments. 202 TypeConverter::SignatureConversion result(entry->getNumArguments()); 203 for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i) 204 if (failed(converter.convertSignatureArg( 205 i, entry->getArgument(i).getType(), result))) 206 return failure(); 207 208 // Convert the region signature and just drop the operation. 209 rewriter.applySignatureConversion(®ion, result); 210 rewriter.eraseOp(op); 211 return success(); 212 } 213 214 /// The type converter to use when rewriting the signature. 215 TypeConverter &converter; 216 }; 217 /// This pattern simply updates the operands of the given operation. 218 struct TestPassthroughInvalidOp : public ConversionPattern { 219 TestPassthroughInvalidOp(MLIRContext *ctx) 220 : ConversionPattern("test.invalid", 1, ctx) {} 221 LogicalResult 222 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 223 ConversionPatternRewriter &rewriter) const final { 224 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 225 llvm::None); 226 return success(); 227 } 228 }; 229 /// This pattern handles the case of a split return value. 230 struct TestSplitReturnType : public ConversionPattern { 231 TestSplitReturnType(MLIRContext *ctx) 232 : ConversionPattern("test.return", 1, ctx) {} 233 LogicalResult 234 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 235 ConversionPatternRewriter &rewriter) const final { 236 // Check for a return of F32. 237 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 238 return failure(); 239 240 // Check if the first operation is a cast operation, if it is we use the 241 // results directly. 242 auto *defOp = operands[0].getDefiningOp(); 243 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 244 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 245 return success(); 246 } 247 248 // Otherwise, fail to match. 249 return failure(); 250 } 251 }; 252 253 //===----------------------------------------------------------------------===// 254 // Multi-Level Type-Conversion Rewrite Testing 255 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 256 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 257 : ConversionPattern("test.type_producer", 1, ctx) {} 258 LogicalResult 259 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 260 ConversionPatternRewriter &rewriter) const final { 261 // If the type is I32, change the type to F32. 262 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 263 return failure(); 264 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 265 return success(); 266 } 267 }; 268 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 269 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 270 : ConversionPattern("test.type_producer", 1, ctx) {} 271 LogicalResult 272 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 273 ConversionPatternRewriter &rewriter) const final { 274 // If the type is F32, change the type to F64. 275 if (!Type(*op->result_type_begin()).isF32()) 276 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 277 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 278 return success(); 279 } 280 }; 281 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 282 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 283 : ConversionPattern("test.type_producer", 10, ctx) {} 284 LogicalResult 285 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 286 ConversionPatternRewriter &rewriter) const final { 287 // Always convert to B16, even though it is not a legal type. This tests 288 // that values are unmapped correctly. 289 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 290 return success(); 291 } 292 }; 293 struct TestUpdateConsumerType : public ConversionPattern { 294 TestUpdateConsumerType(MLIRContext *ctx) 295 : ConversionPattern("test.type_consumer", 1, ctx) {} 296 LogicalResult 297 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 298 ConversionPatternRewriter &rewriter) const final { 299 // Verify that the incoming operand has been successfully remapped to F64. 300 if (!operands[0].getType().isF64()) 301 return failure(); 302 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 303 return success(); 304 } 305 }; 306 307 //===----------------------------------------------------------------------===// 308 // Non-Root Replacement Rewrite Testing 309 /// This pattern generates an invalid operation, but replaces it before the 310 /// pattern is finished. This checks that we don't need to legalize the 311 /// temporary op. 312 struct TestNonRootReplacement : public RewritePattern { 313 TestNonRootReplacement(MLIRContext *ctx) 314 : RewritePattern("test.replace_non_root", 1, ctx) {} 315 316 LogicalResult matchAndRewrite(Operation *op, 317 PatternRewriter &rewriter) const final { 318 auto resultType = *op->result_type_begin(); 319 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 320 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 321 322 rewriter.replaceOp(illegalOp, {legalOp}); 323 rewriter.replaceOp(op, {illegalOp}); 324 return success(); 325 } 326 }; 327 } // namespace 328 329 namespace { 330 struct TestTypeConverter : public TypeConverter { 331 using TypeConverter::TypeConverter; 332 TestTypeConverter() { addConversion(convertType); } 333 334 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 335 // Drop I16 types. 336 if (t.isSignlessInteger(16)) 337 return success(); 338 339 // Convert I64 to F64. 340 if (t.isSignlessInteger(64)) { 341 results.push_back(FloatType::getF64(t.getContext())); 342 return success(); 343 } 344 345 // Split F32 into F16,F16. 346 if (t.isF32()) { 347 results.assign(2, FloatType::getF16(t.getContext())); 348 return success(); 349 } 350 351 // Otherwise, convert the type directly. 352 results.push_back(t); 353 return success(); 354 } 355 356 /// Override the hook to materialize a conversion. This is necessary because 357 /// we generate 1->N type mappings. 358 Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, 359 ArrayRef<Value> inputs, 360 Location loc) override { 361 return rewriter.create<TestCastOp>(loc, resultType, inputs); 362 } 363 }; 364 365 struct TestLegalizePatternDriver 366 : public ModulePass<TestLegalizePatternDriver> { 367 /// The mode of conversion to use with the driver. 368 enum class ConversionMode { Analysis, Full, Partial }; 369 370 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 371 372 void runOnModule() override { 373 TestTypeConverter converter; 374 mlir::OwningRewritePatternList patterns; 375 populateWithGenerated(&getContext(), &patterns); 376 patterns 377 .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 378 TestPassthroughInvalidOp, TestSplitReturnType, 379 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 380 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 381 TestNonRootReplacement>(&getContext()); 382 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); 383 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 384 converter); 385 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), 386 converter); 387 388 // Define the conversion target used for the test. 389 ConversionTarget target(getContext()); 390 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 391 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>(); 392 target 393 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 394 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 395 // Don't allow F32 operands. 396 return llvm::none_of(op.getOperandTypes(), 397 [](Type type) { return type.isF32(); }); 398 }); 399 target.addDynamicallyLegalOp<FuncOp>( 400 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 401 402 // Expect the type_producer/type_consumer operations to only operate on f64. 403 target.addDynamicallyLegalOp<TestTypeProducerOp>( 404 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 405 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 406 return op.getOperand().getType().isF64(); 407 }); 408 409 // Check support for marking certain operations as recursively legal. 410 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 411 return static_cast<bool>( 412 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 413 }); 414 415 // Handle a partial conversion. 416 if (mode == ConversionMode::Partial) { 417 (void)applyPartialConversion(getModule(), target, patterns, &converter); 418 return; 419 } 420 421 // Handle a full conversion. 422 if (mode == ConversionMode::Full) { 423 // Check support for marking unknown operations as dynamically legal. 424 target.markUnknownOpDynamicallyLegal([](Operation *op) { 425 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 426 }); 427 428 (void)applyFullConversion(getModule(), target, patterns, &converter); 429 return; 430 } 431 432 // Otherwise, handle an analysis conversion. 433 assert(mode == ConversionMode::Analysis); 434 435 // Analyze the convertible operations. 436 DenseSet<Operation *> legalizedOps; 437 if (failed(applyAnalysisConversion(getModule(), target, patterns, 438 legalizedOps, &converter))) 439 return signalPassFailure(); 440 441 // Emit remarks for each legalizable operation. 442 for (auto *op : legalizedOps) 443 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 444 } 445 446 /// The mode of conversion to use. 447 ConversionMode mode; 448 }; 449 } // end anonymous namespace 450 451 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 452 legalizerConversionMode( 453 "test-legalize-mode", 454 llvm::cl::desc("The legalization mode to use with the test driver"), 455 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 456 llvm::cl::values( 457 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 458 "analysis", "Perform an analysis conversion"), 459 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 460 "Perform a full conversion"), 461 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 462 "partial", "Perform a partial conversion"))); 463 464 //===----------------------------------------------------------------------===// 465 // ConversionPatternRewriter::getRemappedValue testing. This method is used 466 // to get the remapped value of a original value that was replaced using 467 // ConversionPatternRewriter. 468 namespace { 469 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 470 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 471 /// operand twice. 472 /// 473 /// Example: 474 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 475 /// is replaced with: 476 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 477 struct OneVResOneVOperandOp1Converter 478 : public OpConversionPattern<OneVResOneVOperandOp1> { 479 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 480 481 LogicalResult 482 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 483 ConversionPatternRewriter &rewriter) const override { 484 auto origOps = op.getOperands(); 485 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 486 "One operand expected"); 487 Value origOp = *origOps.begin(); 488 SmallVector<Value, 2> remappedOperands; 489 // Replicate the remapped original operand twice. Note that we don't used 490 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 491 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 492 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 493 494 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 495 remappedOperands); 496 return success(); 497 } 498 }; 499 500 struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> { 501 void runOnFunction() override { 502 mlir::OwningRewritePatternList patterns; 503 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); 504 505 mlir::ConversionTarget target(getContext()); 506 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); 507 // We make OneVResOneVOperandOp1 legal only when it has more that one 508 // operand. This will trigger the conversion that will replace one-operand 509 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 510 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 511 [](Operation *op) -> bool { 512 return std::distance(op->operand_begin(), op->operand_end()) > 1; 513 }); 514 515 if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { 516 signalPassFailure(); 517 } 518 } 519 }; 520 } // end anonymous namespace 521 522 namespace mlir { 523 void registerPatternsTestPass() { 524 mlir::PassRegistration<TestReturnTypeDriver>("test-return-type", 525 "Run return type functions"); 526 527 mlir::PassRegistration<TestPatternDriver>("test-patterns", 528 "Run test dialect patterns"); 529 530 mlir::PassRegistration<TestLegalizePatternDriver>( 531 "test-legalize-patterns", "Run test dialect legalization patterns", [] { 532 return std::make_unique<TestLegalizePatternDriver>( 533 legalizerConversionMode); 534 }); 535 536 PassRegistration<TestRemappedValue>( 537 "test-remapped-value", 538 "Test public remapped value mechanism in ConversionPatternRewriter"); 539 } 540 } // namespace mlir 541