1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Function.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/StringSwitch.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // TestDialect Interfaces 25 //===----------------------------------------------------------------------===// 26 27 namespace { 28 29 // Test support for interacting with the AsmPrinter. 30 struct TestOpAsmInterface : public OpAsmDialectInterface { 31 using OpAsmDialectInterface::OpAsmDialectInterface; 32 33 void getAsmResultNames(Operation *op, 34 OpAsmSetValueNameFn setNameFn) const final { 35 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 36 setNameFn(asmOp, "result"); 37 } 38 39 void getAsmBlockArgumentNames(Block *block, 40 OpAsmSetValueNameFn setNameFn) const final { 41 auto op = block->getParentOp(); 42 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 43 if (!arrayAttr) 44 return; 45 auto args = block->getArguments(); 46 auto e = std::min(arrayAttr.size(), args.size()); 47 for (unsigned i = 0; i < e; ++i) { 48 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 49 setNameFn(args[i], strAttr.getValue()); 50 } 51 } 52 }; 53 54 struct TestOpFolderDialectInterface : public OpFolderDialectInterface { 55 using OpFolderDialectInterface::OpFolderDialectInterface; 56 57 /// Registered hook to check if the given region, which is attached to an 58 /// operation that is *not* isolated from above, should be used when 59 /// materializing constants. 60 bool shouldMaterializeInto(Region *region) const final { 61 // If this is a one region operation, then insert into it. 62 return isa<OneRegionOp>(region->getParentOp()); 63 } 64 }; 65 66 /// This class defines the interface for handling inlining with standard 67 /// operations. 68 struct TestInlinerInterface : public DialectInlinerInterface { 69 using DialectInlinerInterface::DialectInlinerInterface; 70 71 //===--------------------------------------------------------------------===// 72 // Analysis Hooks 73 //===--------------------------------------------------------------------===// 74 75 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 76 // Inlining into test dialect regions is legal. 77 return true; 78 } 79 bool isLegalToInline(Operation *, Region *, 80 BlockAndValueMapping &) const final { 81 return true; 82 } 83 84 bool shouldAnalyzeRecursively(Operation *op) const final { 85 // Analyze recursively if this is not a functional region operation, it 86 // froms a separate functional scope. 87 return !isa<FunctionalRegionOp>(op); 88 } 89 90 //===--------------------------------------------------------------------===// 91 // Transformation Hooks 92 //===--------------------------------------------------------------------===// 93 94 /// Handle the given inlined terminator by replacing it with a new operation 95 /// as necessary. 96 void handleTerminator(Operation *op, 97 ArrayRef<Value> valuesToRepl) const final { 98 // Only handle "test.return" here. 99 auto returnOp = dyn_cast<TestReturnOp>(op); 100 if (!returnOp) 101 return; 102 103 // Replace the values directly with the return operands. 104 assert(returnOp.getNumOperands() == valuesToRepl.size()); 105 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 106 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 107 } 108 109 /// Attempt to materialize a conversion for a type mismatch between a call 110 /// from this dialect, and a callable region. This method should generate an 111 /// operation that takes 'input' as the only operand, and produces a single 112 /// result of 'resultType'. If a conversion can not be generated, nullptr 113 /// should be returned. 114 Operation *materializeCallConversion(OpBuilder &builder, Value input, 115 Type resultType, 116 Location conversionLoc) const final { 117 // Only allow conversion for i16/i32 types. 118 if (!(resultType.isSignlessInteger(16) || 119 resultType.isSignlessInteger(32)) || 120 !(input.getType().isSignlessInteger(16) || 121 input.getType().isSignlessInteger(32))) 122 return nullptr; 123 return builder.create<TestCastOp>(conversionLoc, resultType, input); 124 } 125 }; 126 } // end anonymous namespace 127 128 //===----------------------------------------------------------------------===// 129 // TestDialect 130 //===----------------------------------------------------------------------===// 131 132 TestDialect::TestDialect(MLIRContext *context) 133 : Dialect(getDialectNamespace(), context) { 134 addOperations< 135 #define GET_OP_LIST 136 #include "TestOps.cpp.inc" 137 >(); 138 addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, 139 TestInlinerInterface>(); 140 addTypes<TestType>(); 141 allowUnknownOperations(); 142 } 143 144 Type TestDialect::parseType(DialectAsmParser &parser) const { 145 if (failed(parser.parseKeyword("test_type"))) 146 return Type(); 147 return TestType::get(getContext()); 148 } 149 150 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 151 assert(type.isa<TestType>() && "unexpected type"); 152 printer << "test_type"; 153 } 154 155 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 156 NamedAttribute namedAttr) { 157 if (namedAttr.first == "test.invalid_attr") 158 return op->emitError() << "invalid to use 'test.invalid_attr'"; 159 return success(); 160 } 161 162 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 163 unsigned regionIndex, 164 unsigned argIndex, 165 NamedAttribute namedAttr) { 166 if (namedAttr.first == "test.invalid_attr") 167 return op->emitError() << "invalid to use 'test.invalid_attr'"; 168 return success(); 169 } 170 171 LogicalResult 172 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 173 unsigned resultIndex, 174 NamedAttribute namedAttr) { 175 if (namedAttr.first == "test.invalid_attr") 176 return op->emitError() << "invalid to use 'test.invalid_attr'"; 177 return success(); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // TestBranchOp 182 //===----------------------------------------------------------------------===// 183 184 Optional<MutableOperandRange> 185 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 186 assert(index == 0 && "invalid successor index"); 187 return targetOperandsMutable(); 188 } 189 190 //===----------------------------------------------------------------------===// 191 // TestFoldToCallOp 192 //===----------------------------------------------------------------------===// 193 194 namespace { 195 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 196 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 197 198 LogicalResult matchAndRewrite(FoldToCallOp op, 199 PatternRewriter &rewriter) const override { 200 rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(), 201 ValueRange()); 202 return success(); 203 } 204 }; 205 } // end anonymous namespace 206 207 void FoldToCallOp::getCanonicalizationPatterns( 208 OwningRewritePatternList &results, MLIRContext *context) { 209 results.insert<FoldToCallOpPattern>(context); 210 } 211 212 //===----------------------------------------------------------------------===// 213 // Test IsolatedRegionOp - parse passthrough region arguments. 214 //===----------------------------------------------------------------------===// 215 216 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 217 OperationState &result) { 218 OpAsmParser::OperandType argInfo; 219 Type argType = parser.getBuilder().getIndexType(); 220 221 // Parse the input operand. 222 if (parser.parseOperand(argInfo) || 223 parser.resolveOperand(argInfo, argType, result.operands)) 224 return failure(); 225 226 // Parse the body region, and reuse the operand info as the argument info. 227 Region *body = result.addRegion(); 228 return parser.parseRegion(*body, argInfo, argType, 229 /*enableNameShadowing=*/true); 230 } 231 232 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 233 p << "test.isolated_region "; 234 p.printOperand(op.getOperand()); 235 p.shadowRegionArgs(op.region(), op.getOperand()); 236 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // Test AffineScopeOp 241 //===----------------------------------------------------------------------===// 242 243 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 244 OperationState &result) { 245 // Parse the body region, and reuse the operand info as the argument info. 246 Region *body = result.addRegion(); 247 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 248 } 249 250 static void print(OpAsmPrinter &p, AffineScopeOp op) { 251 p << "test.affine_scope "; 252 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // Test parser. 257 //===----------------------------------------------------------------------===// 258 259 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 260 OperationState &result) { 261 StringRef keyword; 262 if (parser.parseKeyword(&keyword)) 263 return failure(); 264 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 265 return success(); 266 } 267 268 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 269 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 270 } 271 272 //===----------------------------------------------------------------------===// 273 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 274 275 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 276 OperationState &result) { 277 if (parser.parseKeyword("wraps")) 278 return failure(); 279 280 // Parse the wrapped op in a region 281 Region &body = *result.addRegion(); 282 body.push_back(new Block); 283 Block &block = body.back(); 284 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 285 if (!wrapped_op) 286 return failure(); 287 288 // Create a return terminator in the inner region, pass as operand to the 289 // terminator the returned values from the wrapped operation. 290 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 291 OpBuilder builder(parser.getBuilder().getContext()); 292 builder.setInsertionPointToEnd(&block); 293 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 294 295 // Get the results type for the wrapping op from the terminator operands. 296 Operation &return_op = body.back().back(); 297 result.types.append(return_op.operand_type_begin(), 298 return_op.operand_type_end()); 299 300 // Use the location of the wrapped op for the "test.wrapping_region" op. 301 result.location = wrapped_op->getLoc(); 302 303 return success(); 304 } 305 306 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 307 p << op.getOperationName() << " wraps "; 308 p.printGenericOp(&op.region().front().front()); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // Test PolyForOp - parse list of region arguments. 313 //===----------------------------------------------------------------------===// 314 315 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 316 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 317 // Parse list of region arguments without a delimiter. 318 if (parser.parseRegionArgumentList(ivsInfo)) 319 return failure(); 320 321 // Parse the body region. 322 Region *body = result.addRegion(); 323 auto &builder = parser.getBuilder(); 324 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 325 return parser.parseRegion(*body, ivsInfo, argTypes); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // Test removing op with inner ops. 330 //===----------------------------------------------------------------------===// 331 332 namespace { 333 struct TestRemoveOpWithInnerOps 334 : public OpRewritePattern<TestOpWithRegionPattern> { 335 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 336 337 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 338 PatternRewriter &rewriter) const override { 339 rewriter.eraseOp(op); 340 return success(); 341 } 342 }; 343 } // end anonymous namespace 344 345 void TestOpWithRegionPattern::getCanonicalizationPatterns( 346 OwningRewritePatternList &results, MLIRContext *context) { 347 results.insert<TestRemoveOpWithInnerOps>(context); 348 } 349 350 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 351 return operand(); 352 } 353 354 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 355 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 356 for (Value input : this->operands()) { 357 results.push_back(input); 358 } 359 return success(); 360 } 361 362 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 363 assert(operands.size() == 1); 364 if (operands.front()) { 365 setAttr("attr", operands.front()); 366 return getResult(); 367 } 368 return {}; 369 } 370 371 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( 372 MLIRContext *, Optional<Location> location, ValueRange operands, 373 DictionaryAttr attributes, RegionRange regions, 374 SmallVectorImpl<Type> &inferredReturnTypes) { 375 if (operands[0].getType() != operands[1].getType()) { 376 return emitOptionalError(location, "operand type mismatch ", 377 operands[0].getType(), " vs ", 378 operands[1].getType()); 379 } 380 inferredReturnTypes.assign({operands[0].getType()}); 381 return success(); 382 } 383 384 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 385 MLIRContext *context, Optional<Location> location, ValueRange operands, 386 DictionaryAttr attributes, RegionRange regions, 387 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 388 // Create return type consisting of the last element of the first operand. 389 auto operandType = *operands.getTypes().begin(); 390 auto sval = operandType.dyn_cast<ShapedType>(); 391 if (!sval) { 392 return emitOptionalError(location, "only shaped type operands allowed"); 393 } 394 int64_t dim = 395 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 396 auto type = IntegerType::get(17, context); 397 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 398 return success(); 399 } 400 401 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 402 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 403 shapes = SmallVector<Value, 1>{ 404 builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)}; 405 return success(); 406 } 407 408 //===----------------------------------------------------------------------===// 409 // Test SideEffect interfaces 410 //===----------------------------------------------------------------------===// 411 412 namespace { 413 /// A test resource for side effects. 414 struct TestResource : public SideEffects::Resource::Base<TestResource> { 415 StringRef getName() final { return "<Test>"; } 416 }; 417 } // end anonymous namespace 418 419 void SideEffectOp::getEffects( 420 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 421 // Check for an effects attribute on the op instance. 422 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 423 if (!effectsAttr) 424 return; 425 426 // If there is one, it is an array of dictionary attributes that hold 427 // information on the effects of this operation. 428 for (Attribute element : effectsAttr) { 429 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 430 431 // Get the specific memory effect. 432 MemoryEffects::Effect *effect = 433 llvm::StringSwitch<MemoryEffects::Effect *>( 434 effectElement.get("effect").cast<StringAttr>().getValue()) 435 .Case("allocate", MemoryEffects::Allocate::get()) 436 .Case("free", MemoryEffects::Free::get()) 437 .Case("read", MemoryEffects::Read::get()) 438 .Case("write", MemoryEffects::Write::get()); 439 440 // Check for a result to affect. 441 Value value; 442 if (effectElement.get("on_result")) 443 value = getResult(); 444 445 // Check for a non-default resource to use. 446 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 447 if (effectElement.get("test_resource")) 448 resource = TestResource::get(); 449 450 effects.emplace_back(effect, value, resource); 451 } 452 } 453 454 //===----------------------------------------------------------------------===// 455 // StringAttrPrettyNameOp 456 //===----------------------------------------------------------------------===// 457 458 // This op has fancy handling of its SSA result name. 459 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 460 OperationState &result) { 461 // Add the result types. 462 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 463 result.addTypes(parser.getBuilder().getIntegerType(32)); 464 465 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 466 return failure(); 467 468 // If the attribute dictionary contains no 'names' attribute, infer it from 469 // the SSA name (if specified). 470 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 471 return attr.first == "names"; 472 }); 473 474 // If there was no name specified, check to see if there was a useful name 475 // specified in the asm file. 476 if (hadNames || parser.getNumResults() == 0) 477 return success(); 478 479 SmallVector<StringRef, 4> names; 480 auto *context = result.getContext(); 481 482 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 483 auto resultName = parser.getResultName(i); 484 StringRef nameStr; 485 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 486 nameStr = resultName.first; 487 488 names.push_back(nameStr); 489 } 490 491 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 492 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 493 return success(); 494 } 495 496 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 497 p << "test.string_attr_pretty_name"; 498 499 // Note that we only need to print the "name" attribute if the asmprinter 500 // result name disagrees with it. This can happen in strange cases, e.g. 501 // when there are conflicts. 502 bool namesDisagree = op.names().size() != op.getNumResults(); 503 504 SmallString<32> resultNameStr; 505 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 506 resultNameStr.clear(); 507 llvm::raw_svector_ostream tmpStream(resultNameStr); 508 p.printOperand(op.getResult(i), tmpStream); 509 510 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 511 if (!expectedName || 512 tmpStream.str().drop_front() != expectedName.getValue()) { 513 namesDisagree = true; 514 } 515 } 516 517 if (namesDisagree) 518 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 519 else 520 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 521 } 522 523 // We set the SSA name in the asm syntax to the contents of the name 524 // attribute. 525 void StringAttrPrettyNameOp::getAsmResultNames( 526 function_ref<void(Value, StringRef)> setNameFn) { 527 528 auto value = names(); 529 for (size_t i = 0, e = value.size(); i != e; ++i) 530 if (auto str = value[i].dyn_cast<StringAttr>()) 531 if (!str.getValue().empty()) 532 setNameFn(getResult(i), str.getValue()); 533 } 534 535 //===----------------------------------------------------------------------===// 536 // RegionIfOp 537 //===----------------------------------------------------------------------===// 538 539 static void print(OpAsmPrinter &p, RegionIfOp op) { 540 p << RegionIfOp::getOperationName() << " "; 541 p.printOperands(op.getOperands()); 542 p << ": " << op.getOperandTypes(); 543 p.printArrowTypeList(op.getResultTypes()); 544 p << " then"; 545 p.printRegion(op.thenRegion(), 546 /*printEntryBlockArgs=*/true, 547 /*printBlockTerminators=*/true); 548 p << " else"; 549 p.printRegion(op.elseRegion(), 550 /*printEntryBlockArgs=*/true, 551 /*printBlockTerminators=*/true); 552 p << " join"; 553 p.printRegion(op.joinRegion(), 554 /*printEntryBlockArgs=*/true, 555 /*printBlockTerminators=*/true); 556 } 557 558 static ParseResult parseRegionIfOp(OpAsmParser &parser, 559 OperationState &result) { 560 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 561 SmallVector<Type, 2> operandTypes; 562 563 result.regions.reserve(3); 564 Region *thenRegion = result.addRegion(); 565 Region *elseRegion = result.addRegion(); 566 Region *joinRegion = result.addRegion(); 567 568 // Parse operand, type and arrow type lists. 569 if (parser.parseOperandList(operandInfos) || 570 parser.parseColonTypeList(operandTypes) || 571 parser.parseArrowTypeList(result.types)) 572 return failure(); 573 574 // Parse all attached regions. 575 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 576 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 577 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 578 return failure(); 579 580 return parser.resolveOperands(operandInfos, operandTypes, 581 parser.getCurrentLocation(), result.operands); 582 } 583 584 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 585 assert(index < 2 && "invalid region index"); 586 return getOperands(); 587 } 588 589 void RegionIfOp::getSuccessorRegions( 590 Optional<unsigned> index, ArrayRef<Attribute> operands, 591 SmallVectorImpl<RegionSuccessor> ®ions) { 592 // We always branch to the join region. 593 if (index.hasValue()) { 594 if (index.getValue() < 2) 595 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 596 else 597 regions.push_back(RegionSuccessor(getResults())); 598 return; 599 } 600 601 // The then and else regions are the entry regions of this op. 602 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 603 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 604 } 605 606 //===----------------------------------------------------------------------===// 607 // Dialect Registration 608 //===----------------------------------------------------------------------===// 609 610 // Static initialization for Test dialect registration. 611 static mlir::DialectRegistration<mlir::TestDialect> testDialect; 612 613 #include "TestOpEnums.cpp.inc" 614 #include "TestOpStructs.cpp.inc" 615 #include "TestTypeInterfaces.cpp.inc" 616 617 #define GET_OP_CLASSES 618 #include "TestOps.cpp.inc" 619