1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Function.h" 12 #include "mlir/IR/Module.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/TypeUtilities.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 #include "mlir/Transforms/InliningUtils.h" 17 #include "llvm/ADT/StringSwitch.h" 18 19 using namespace mlir; 20 21 //===----------------------------------------------------------------------===// 22 // TestDialect Interfaces 23 //===----------------------------------------------------------------------===// 24 25 namespace { 26 27 // Test support for interacting with the AsmPrinter. 28 struct TestOpAsmInterface : public OpAsmDialectInterface { 29 using OpAsmDialectInterface::OpAsmDialectInterface; 30 31 void getAsmResultNames(Operation *op, 32 OpAsmSetValueNameFn setNameFn) const final { 33 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 34 setNameFn(asmOp, "result"); 35 } 36 37 void getAsmBlockArgumentNames(Block *block, 38 OpAsmSetValueNameFn setNameFn) const final { 39 auto op = block->getParentOp(); 40 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 41 if (!arrayAttr) 42 return; 43 auto args = block->getArguments(); 44 auto e = std::min(arrayAttr.size(), args.size()); 45 for (unsigned i = 0; i < e; ++i) { 46 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 47 setNameFn(args[i], strAttr.getValue()); 48 } 49 } 50 }; 51 52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface { 53 using OpFolderDialectInterface::OpFolderDialectInterface; 54 55 /// Registered hook to check if the given region, which is attached to an 56 /// operation that is *not* isolated from above, should be used when 57 /// materializing constants. 58 bool shouldMaterializeInto(Region *region) const final { 59 // If this is a one region operation, then insert into it. 60 return isa<OneRegionOp>(region->getParentOp()); 61 } 62 }; 63 64 /// This class defines the interface for handling inlining with standard 65 /// operations. 66 struct TestInlinerInterface : public DialectInlinerInterface { 67 using DialectInlinerInterface::DialectInlinerInterface; 68 69 //===--------------------------------------------------------------------===// 70 // Analysis Hooks 71 //===--------------------------------------------------------------------===// 72 73 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 74 // Inlining into test dialect regions is legal. 75 return true; 76 } 77 bool isLegalToInline(Operation *, Region *, 78 BlockAndValueMapping &) const final { 79 return true; 80 } 81 82 bool shouldAnalyzeRecursively(Operation *op) const final { 83 // Analyze recursively if this is not a functional region operation, it 84 // froms a separate functional scope. 85 return !isa<FunctionalRegionOp>(op); 86 } 87 88 //===--------------------------------------------------------------------===// 89 // Transformation Hooks 90 //===--------------------------------------------------------------------===// 91 92 /// Handle the given inlined terminator by replacing it with a new operation 93 /// as necessary. 94 void handleTerminator(Operation *op, 95 ArrayRef<Value> valuesToRepl) const final { 96 // Only handle "test.return" here. 97 auto returnOp = dyn_cast<TestReturnOp>(op); 98 if (!returnOp) 99 return; 100 101 // Replace the values directly with the return operands. 102 assert(returnOp.getNumOperands() == valuesToRepl.size()); 103 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 104 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 105 } 106 107 /// Attempt to materialize a conversion for a type mismatch between a call 108 /// from this dialect, and a callable region. This method should generate an 109 /// operation that takes 'input' as the only operand, and produces a single 110 /// result of 'resultType'. If a conversion can not be generated, nullptr 111 /// should be returned. 112 Operation *materializeCallConversion(OpBuilder &builder, Value input, 113 Type resultType, 114 Location conversionLoc) const final { 115 // Only allow conversion for i16/i32 types. 116 if (!(resultType.isSignlessInteger(16) || 117 resultType.isSignlessInteger(32)) || 118 !(input.getType().isSignlessInteger(16) || 119 input.getType().isSignlessInteger(32))) 120 return nullptr; 121 return builder.create<TestCastOp>(conversionLoc, resultType, input); 122 } 123 }; 124 } // end anonymous namespace 125 126 //===----------------------------------------------------------------------===// 127 // TestDialect 128 //===----------------------------------------------------------------------===// 129 130 TestDialect::TestDialect(MLIRContext *context) 131 : Dialect(getDialectNamespace(), context) { 132 addOperations< 133 #define GET_OP_LIST 134 #include "TestOps.cpp.inc" 135 >(); 136 addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, 137 TestInlinerInterface>(); 138 allowUnknownOperations(); 139 } 140 141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 142 NamedAttribute namedAttr) { 143 if (namedAttr.first == "test.invalid_attr") 144 return op->emitError() << "invalid to use 'test.invalid_attr'"; 145 return success(); 146 } 147 148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 149 unsigned regionIndex, 150 unsigned argIndex, 151 NamedAttribute namedAttr) { 152 if (namedAttr.first == "test.invalid_attr") 153 return op->emitError() << "invalid to use 'test.invalid_attr'"; 154 return success(); 155 } 156 157 LogicalResult 158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 159 unsigned resultIndex, 160 NamedAttribute namedAttr) { 161 if (namedAttr.first == "test.invalid_attr") 162 return op->emitError() << "invalid to use 'test.invalid_attr'"; 163 return success(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // TestBranchOp 168 //===----------------------------------------------------------------------===// 169 170 Optional<MutableOperandRange> 171 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 172 assert(index == 0 && "invalid successor index"); 173 return targetOperandsMutable(); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // TestFoldToCallOp 178 //===----------------------------------------------------------------------===// 179 180 namespace { 181 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 182 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 183 184 LogicalResult matchAndRewrite(FoldToCallOp op, 185 PatternRewriter &rewriter) const override { 186 rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(), 187 ValueRange()); 188 return success(); 189 } 190 }; 191 } // end anonymous namespace 192 193 void FoldToCallOp::getCanonicalizationPatterns( 194 OwningRewritePatternList &results, MLIRContext *context) { 195 results.insert<FoldToCallOpPattern>(context); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // Test IsolatedRegionOp - parse passthrough region arguments. 200 //===----------------------------------------------------------------------===// 201 202 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 203 OperationState &result) { 204 OpAsmParser::OperandType argInfo; 205 Type argType = parser.getBuilder().getIndexType(); 206 207 // Parse the input operand. 208 if (parser.parseOperand(argInfo) || 209 parser.resolveOperand(argInfo, argType, result.operands)) 210 return failure(); 211 212 // Parse the body region, and reuse the operand info as the argument info. 213 Region *body = result.addRegion(); 214 return parser.parseRegion(*body, argInfo, argType, 215 /*enableNameShadowing=*/true); 216 } 217 218 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 219 p << "test.isolated_region "; 220 p.printOperand(op.getOperand()); 221 p.shadowRegionArgs(op.region(), op.getOperand()); 222 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // Test AffineScopeOp 227 //===----------------------------------------------------------------------===// 228 229 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 230 OperationState &result) { 231 // Parse the body region, and reuse the operand info as the argument info. 232 Region *body = result.addRegion(); 233 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 234 } 235 236 static void print(OpAsmPrinter &p, AffineScopeOp op) { 237 p << "test.affine_scope "; 238 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Test parser. 243 //===----------------------------------------------------------------------===// 244 245 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 246 OperationState &result) { 247 StringRef keyword; 248 if (parser.parseKeyword(&keyword)) 249 return failure(); 250 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 251 return success(); 252 } 253 254 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 255 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 260 261 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 262 OperationState &result) { 263 if (parser.parseKeyword("wraps")) 264 return failure(); 265 266 // Parse the wrapped op in a region 267 Region &body = *result.addRegion(); 268 body.push_back(new Block); 269 Block &block = body.back(); 270 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 271 if (!wrapped_op) 272 return failure(); 273 274 // Create a return terminator in the inner region, pass as operand to the 275 // terminator the returned values from the wrapped operation. 276 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 277 OpBuilder builder(parser.getBuilder().getContext()); 278 builder.setInsertionPointToEnd(&block); 279 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 280 281 // Get the results type for the wrapping op from the terminator operands. 282 Operation &return_op = body.back().back(); 283 result.types.append(return_op.operand_type_begin(), 284 return_op.operand_type_end()); 285 286 // Use the location of the wrapped op for the "test.wrapping_region" op. 287 result.location = wrapped_op->getLoc(); 288 289 return success(); 290 } 291 292 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 293 p << op.getOperationName() << " wraps "; 294 p.printGenericOp(&op.region().front().front()); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // Test PolyForOp - parse list of region arguments. 299 //===----------------------------------------------------------------------===// 300 301 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 302 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 303 // Parse list of region arguments without a delimiter. 304 if (parser.parseRegionArgumentList(ivsInfo)) 305 return failure(); 306 307 // Parse the body region. 308 Region *body = result.addRegion(); 309 auto &builder = parser.getBuilder(); 310 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 311 return parser.parseRegion(*body, ivsInfo, argTypes); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Test removing op with inner ops. 316 //===----------------------------------------------------------------------===// 317 318 namespace { 319 struct TestRemoveOpWithInnerOps 320 : public OpRewritePattern<TestOpWithRegionPattern> { 321 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 322 323 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 324 PatternRewriter &rewriter) const override { 325 rewriter.eraseOp(op); 326 return success(); 327 } 328 }; 329 } // end anonymous namespace 330 331 void TestOpWithRegionPattern::getCanonicalizationPatterns( 332 OwningRewritePatternList &results, MLIRContext *context) { 333 results.insert<TestRemoveOpWithInnerOps>(context); 334 } 335 336 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 337 return operand(); 338 } 339 340 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 341 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 342 for (Value input : this->operands()) { 343 results.push_back(input); 344 } 345 return success(); 346 } 347 348 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 349 assert(operands.size() == 1); 350 if (operands.front()) { 351 setAttr("attr", operands.front()); 352 return getResult(); 353 } 354 return {}; 355 } 356 357 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( 358 MLIRContext *, Optional<Location> location, ValueRange operands, 359 DictionaryAttr attributes, RegionRange regions, 360 SmallVectorImpl<Type> &inferredReturnTypes) { 361 if (operands[0].getType() != operands[1].getType()) { 362 return emitOptionalError(location, "operand type mismatch ", 363 operands[0].getType(), " vs ", 364 operands[1].getType()); 365 } 366 inferredReturnTypes.assign({operands[0].getType()}); 367 return success(); 368 } 369 370 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 371 MLIRContext *context, Optional<Location> location, ValueRange operands, 372 DictionaryAttr attributes, RegionRange regions, 373 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 374 // Create return type consisting of the last element of the first operand. 375 auto operandType = *operands.getTypes().begin(); 376 auto sval = operandType.dyn_cast<ShapedType>(); 377 if (!sval) { 378 return emitOptionalError(location, "only shaped type operands allowed"); 379 } 380 int64_t dim = 381 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 382 auto type = IntegerType::get(17, context); 383 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 384 return success(); 385 } 386 387 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 388 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 389 shapes = SmallVector<Value, 1>{ 390 builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)}; 391 return success(); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // Test SideEffect interfaces 396 //===----------------------------------------------------------------------===// 397 398 namespace { 399 /// A test resource for side effects. 400 struct TestResource : public SideEffects::Resource::Base<TestResource> { 401 StringRef getName() final { return "<Test>"; } 402 }; 403 } // end anonymous namespace 404 405 void SideEffectOp::getEffects( 406 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 407 // Check for an effects attribute on the op instance. 408 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 409 if (!effectsAttr) 410 return; 411 412 // If there is one, it is an array of dictionary attributes that hold 413 // information on the effects of this operation. 414 for (Attribute element : effectsAttr) { 415 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 416 417 // Get the specific memory effect. 418 MemoryEffects::Effect *effect = 419 llvm::StringSwitch<MemoryEffects::Effect *>( 420 effectElement.get("effect").cast<StringAttr>().getValue()) 421 .Case("allocate", MemoryEffects::Allocate::get()) 422 .Case("free", MemoryEffects::Free::get()) 423 .Case("read", MemoryEffects::Read::get()) 424 .Case("write", MemoryEffects::Write::get()); 425 426 // Check for a result to affect. 427 Value value; 428 if (effectElement.get("on_result")) 429 value = getResult(); 430 431 // Check for a non-default resource to use. 432 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 433 if (effectElement.get("test_resource")) 434 resource = TestResource::get(); 435 436 effects.emplace_back(effect, value, resource); 437 } 438 } 439 440 //===----------------------------------------------------------------------===// 441 // StringAttrPrettyNameOp 442 //===----------------------------------------------------------------------===// 443 444 // This op has fancy handling of its SSA result name. 445 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 446 OperationState &result) { 447 // Add the result types. 448 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 449 result.addTypes(parser.getBuilder().getIntegerType(32)); 450 451 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 452 return failure(); 453 454 // If the attribute dictionary contains no 'names' attribute, infer it from 455 // the SSA name (if specified). 456 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 457 return attr.first == "names"; 458 }); 459 460 // If there was no name specified, check to see if there was a useful name 461 // specified in the asm file. 462 if (hadNames || parser.getNumResults() == 0) 463 return success(); 464 465 SmallVector<StringRef, 4> names; 466 auto *context = result.getContext(); 467 468 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 469 auto resultName = parser.getResultName(i); 470 StringRef nameStr; 471 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 472 nameStr = resultName.first; 473 474 names.push_back(nameStr); 475 } 476 477 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 478 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 479 return success(); 480 } 481 482 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 483 p << "test.string_attr_pretty_name"; 484 485 // Note that we only need to print the "name" attribute if the asmprinter 486 // result name disagrees with it. This can happen in strange cases, e.g. 487 // when there are conflicts. 488 bool namesDisagree = op.names().size() != op.getNumResults(); 489 490 SmallString<32> resultNameStr; 491 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 492 resultNameStr.clear(); 493 llvm::raw_svector_ostream tmpStream(resultNameStr); 494 p.printOperand(op.getResult(i), tmpStream); 495 496 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 497 if (!expectedName || 498 tmpStream.str().drop_front() != expectedName.getValue()) { 499 namesDisagree = true; 500 } 501 } 502 503 if (namesDisagree) 504 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 505 else 506 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 507 } 508 509 // We set the SSA name in the asm syntax to the contents of the name 510 // attribute. 511 void StringAttrPrettyNameOp::getAsmResultNames( 512 function_ref<void(Value, StringRef)> setNameFn) { 513 514 auto value = names(); 515 for (size_t i = 0, e = value.size(); i != e; ++i) 516 if (auto str = value[i].dyn_cast<StringAttr>()) 517 if (!str.getValue().empty()) 518 setNameFn(getResult(i), str.getValue()); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // Dialect Registration 523 //===----------------------------------------------------------------------===// 524 525 // Static initialization for Test dialect registration. 526 static mlir::DialectRegistration<mlir::TestDialect> testDialect; 527 528 #include "TestOpEnums.cpp.inc" 529 #include "TestOpStructs.cpp.inc" 530 531 #define GET_OP_CLASSES 532 #include "TestOps.cpp.inc" 533