1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Function.h" 12 #include "mlir/IR/Module.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/TypeUtilities.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 #include "mlir/Transforms/InliningUtils.h" 17 #include "llvm/ADT/StringSwitch.h" 18 19 using namespace mlir; 20 21 //===----------------------------------------------------------------------===// 22 // TestDialect Interfaces 23 //===----------------------------------------------------------------------===// 24 25 namespace { 26 27 // Test support for interacting with the AsmPrinter. 28 struct TestOpAsmInterface : public OpAsmDialectInterface { 29 using OpAsmDialectInterface::OpAsmDialectInterface; 30 31 void getAsmResultNames(Operation *op, 32 OpAsmSetValueNameFn setNameFn) const final { 33 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 34 setNameFn(asmOp, "result"); 35 } 36 37 void getAsmBlockArgumentNames(Block *block, 38 OpAsmSetValueNameFn setNameFn) const final { 39 auto op = block->getParentOp(); 40 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 41 if (!arrayAttr) 42 return; 43 auto args = block->getArguments(); 44 auto e = std::min(arrayAttr.size(), args.size()); 45 for (unsigned i = 0; i < e; ++i) { 46 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 47 setNameFn(args[i], strAttr.getValue()); 48 } 49 } 50 }; 51 52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface { 53 using OpFolderDialectInterface::OpFolderDialectInterface; 54 55 /// Registered hook to check if the given region, which is attached to an 56 /// operation that is *not* isolated from above, should be used when 57 /// materializing constants. 58 bool shouldMaterializeInto(Region *region) const final { 59 // If this is a one region operation, then insert into it. 60 return isa<OneRegionOp>(region->getParentOp()); 61 } 62 }; 63 64 /// This class defines the interface for handling inlining with standard 65 /// operations. 66 struct TestInlinerInterface : public DialectInlinerInterface { 67 using DialectInlinerInterface::DialectInlinerInterface; 68 69 //===--------------------------------------------------------------------===// 70 // Analysis Hooks 71 //===--------------------------------------------------------------------===// 72 73 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 74 // Inlining into test dialect regions is legal. 75 return true; 76 } 77 bool isLegalToInline(Operation *, Region *, 78 BlockAndValueMapping &) const final { 79 return true; 80 } 81 82 bool shouldAnalyzeRecursively(Operation *op) const final { 83 // Analyze recursively if this is not a functional region operation, it 84 // froms a separate functional scope. 85 return !isa<FunctionalRegionOp>(op); 86 } 87 88 //===--------------------------------------------------------------------===// 89 // Transformation Hooks 90 //===--------------------------------------------------------------------===// 91 92 /// Handle the given inlined terminator by replacing it with a new operation 93 /// as necessary. 94 void handleTerminator(Operation *op, 95 ArrayRef<Value> valuesToRepl) const final { 96 // Only handle "test.return" here. 97 auto returnOp = dyn_cast<TestReturnOp>(op); 98 if (!returnOp) 99 return; 100 101 // Replace the values directly with the return operands. 102 assert(returnOp.getNumOperands() == valuesToRepl.size()); 103 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 104 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 105 } 106 107 /// Attempt to materialize a conversion for a type mismatch between a call 108 /// from this dialect, and a callable region. This method should generate an 109 /// operation that takes 'input' as the only operand, and produces a single 110 /// result of 'resultType'. If a conversion can not be generated, nullptr 111 /// should be returned. 112 Operation *materializeCallConversion(OpBuilder &builder, Value input, 113 Type resultType, 114 Location conversionLoc) const final { 115 // Only allow conversion for i16/i32 types. 116 if (!(resultType.isSignlessInteger(16) || 117 resultType.isSignlessInteger(32)) || 118 !(input.getType().isSignlessInteger(16) || 119 input.getType().isSignlessInteger(32))) 120 return nullptr; 121 return builder.create<TestCastOp>(conversionLoc, resultType, input); 122 } 123 }; 124 } // end anonymous namespace 125 126 //===----------------------------------------------------------------------===// 127 // TestDialect 128 //===----------------------------------------------------------------------===// 129 130 TestDialect::TestDialect(MLIRContext *context) 131 : Dialect(getDialectNamespace(), context) { 132 addOperations< 133 #define GET_OP_LIST 134 #include "TestOps.cpp.inc" 135 >(); 136 addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, 137 TestInlinerInterface>(); 138 allowUnknownOperations(); 139 } 140 141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 142 NamedAttribute namedAttr) { 143 if (namedAttr.first == "test.invalid_attr") 144 return op->emitError() << "invalid to use 'test.invalid_attr'"; 145 return success(); 146 } 147 148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 149 unsigned regionIndex, 150 unsigned argIndex, 151 NamedAttribute namedAttr) { 152 if (namedAttr.first == "test.invalid_attr") 153 return op->emitError() << "invalid to use 'test.invalid_attr'"; 154 return success(); 155 } 156 157 LogicalResult 158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 159 unsigned resultIndex, 160 NamedAttribute namedAttr) { 161 if (namedAttr.first == "test.invalid_attr") 162 return op->emitError() << "invalid to use 'test.invalid_attr'"; 163 return success(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // TestBranchOp 168 //===----------------------------------------------------------------------===// 169 170 Optional<MutableOperandRange> 171 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 172 assert(index == 0 && "invalid successor index"); 173 return targetOperandsMutable(); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // Test IsolatedRegionOp - parse passthrough region arguments. 178 //===----------------------------------------------------------------------===// 179 180 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 181 OperationState &result) { 182 OpAsmParser::OperandType argInfo; 183 Type argType = parser.getBuilder().getIndexType(); 184 185 // Parse the input operand. 186 if (parser.parseOperand(argInfo) || 187 parser.resolveOperand(argInfo, argType, result.operands)) 188 return failure(); 189 190 // Parse the body region, and reuse the operand info as the argument info. 191 Region *body = result.addRegion(); 192 return parser.parseRegion(*body, argInfo, argType, 193 /*enableNameShadowing=*/true); 194 } 195 196 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 197 p << "test.isolated_region "; 198 p.printOperand(op.getOperand()); 199 p.shadowRegionArgs(op.region(), op.getOperand()); 200 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // Test AffineScopeOp 205 //===----------------------------------------------------------------------===// 206 207 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 208 OperationState &result) { 209 // Parse the body region, and reuse the operand info as the argument info. 210 Region *body = result.addRegion(); 211 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 212 } 213 214 static void print(OpAsmPrinter &p, AffineScopeOp op) { 215 p << "test.affine_scope "; 216 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // Test parser. 221 //===----------------------------------------------------------------------===// 222 223 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 224 OperationState &result) { 225 StringRef keyword; 226 if (parser.parseKeyword(&keyword)) 227 return failure(); 228 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 229 return success(); 230 } 231 232 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 233 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 234 } 235 236 //===----------------------------------------------------------------------===// 237 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 238 239 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 240 OperationState &result) { 241 if (parser.parseKeyword("wraps")) 242 return failure(); 243 244 // Parse the wrapped op in a region 245 Region &body = *result.addRegion(); 246 body.push_back(new Block); 247 Block &block = body.back(); 248 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 249 if (!wrapped_op) 250 return failure(); 251 252 // Create a return terminator in the inner region, pass as operand to the 253 // terminator the returned values from the wrapped operation. 254 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 255 OpBuilder builder(parser.getBuilder().getContext()); 256 builder.setInsertionPointToEnd(&block); 257 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 258 259 // Get the results type for the wrapping op from the terminator operands. 260 Operation &return_op = body.back().back(); 261 result.types.append(return_op.operand_type_begin(), 262 return_op.operand_type_end()); 263 264 // Use the location of the wrapped op for the "test.wrapping_region" op. 265 result.location = wrapped_op->getLoc(); 266 267 return success(); 268 } 269 270 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 271 p << op.getOperationName() << " wraps "; 272 p.printGenericOp(&op.region().front().front()); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // Test PolyForOp - parse list of region arguments. 277 //===----------------------------------------------------------------------===// 278 279 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 280 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 281 // Parse list of region arguments without a delimiter. 282 if (parser.parseRegionArgumentList(ivsInfo)) 283 return failure(); 284 285 // Parse the body region. 286 Region *body = result.addRegion(); 287 auto &builder = parser.getBuilder(); 288 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 289 return parser.parseRegion(*body, ivsInfo, argTypes); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // Test removing op with inner ops. 294 //===----------------------------------------------------------------------===// 295 296 namespace { 297 struct TestRemoveOpWithInnerOps 298 : public OpRewritePattern<TestOpWithRegionPattern> { 299 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 300 301 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 302 PatternRewriter &rewriter) const override { 303 rewriter.eraseOp(op); 304 return success(); 305 } 306 }; 307 } // end anonymous namespace 308 309 void TestOpWithRegionPattern::getCanonicalizationPatterns( 310 OwningRewritePatternList &results, MLIRContext *context) { 311 results.insert<TestRemoveOpWithInnerOps>(context); 312 } 313 314 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 315 return operand(); 316 } 317 318 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 319 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 320 for (Value input : this->operands()) { 321 results.push_back(input); 322 } 323 return success(); 324 } 325 326 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 327 assert(operands.size() == 1); 328 if (operands.front()) { 329 setAttr("attr", operands.front()); 330 return getResult(); 331 } 332 return {}; 333 } 334 335 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( 336 MLIRContext *, Optional<Location> location, ValueRange operands, 337 DictionaryAttr attributes, RegionRange regions, 338 SmallVectorImpl<Type> &inferredReturnTypes) { 339 if (operands[0].getType() != operands[1].getType()) { 340 return emitOptionalError(location, "operand type mismatch ", 341 operands[0].getType(), " vs ", 342 operands[1].getType()); 343 } 344 inferredReturnTypes.assign({operands[0].getType()}); 345 return success(); 346 } 347 348 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 349 MLIRContext *context, Optional<Location> location, ValueRange operands, 350 DictionaryAttr attributes, RegionRange regions, 351 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 352 // Create return type consisting of the last element of the first operand. 353 auto operandType = *operands.getTypes().begin(); 354 auto sval = operandType.dyn_cast<ShapedType>(); 355 if (!sval) { 356 return emitOptionalError(location, "only shaped type operands allowed"); 357 } 358 int64_t dim = 359 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 360 auto type = IntegerType::get(17, context); 361 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 362 return success(); 363 } 364 365 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 366 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 367 shapes = SmallVector<Value, 1>{ 368 builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)}; 369 return success(); 370 } 371 372 //===----------------------------------------------------------------------===// 373 // Test SideEffect interfaces 374 //===----------------------------------------------------------------------===// 375 376 namespace { 377 /// A test resource for side effects. 378 struct TestResource : public SideEffects::Resource::Base<TestResource> { 379 StringRef getName() final { return "<Test>"; } 380 }; 381 } // end anonymous namespace 382 383 void SideEffectOp::getEffects( 384 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 385 // Check for an effects attribute on the op instance. 386 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 387 if (!effectsAttr) 388 return; 389 390 // If there is one, it is an array of dictionary attributes that hold 391 // information on the effects of this operation. 392 for (Attribute element : effectsAttr) { 393 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 394 395 // Get the specific memory effect. 396 MemoryEffects::Effect *effect = 397 llvm::StringSwitch<MemoryEffects::Effect *>( 398 effectElement.get("effect").cast<StringAttr>().getValue()) 399 .Case("allocate", MemoryEffects::Allocate::get()) 400 .Case("free", MemoryEffects::Free::get()) 401 .Case("read", MemoryEffects::Read::get()) 402 .Case("write", MemoryEffects::Write::get()); 403 404 // Check for a result to affect. 405 Value value; 406 if (effectElement.get("on_result")) 407 value = getResult(); 408 409 // Check for a non-default resource to use. 410 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 411 if (effectElement.get("test_resource")) 412 resource = TestResource::get(); 413 414 effects.emplace_back(effect, value, resource); 415 } 416 } 417 418 //===----------------------------------------------------------------------===// 419 // StringAttrPrettyNameOp 420 //===----------------------------------------------------------------------===// 421 422 // This op has fancy handling of its SSA result name. 423 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 424 OperationState &result) { 425 // Add the result types. 426 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 427 result.addTypes(parser.getBuilder().getIntegerType(32)); 428 429 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 430 return failure(); 431 432 // If the attribute dictionary contains no 'names' attribute, infer it from 433 // the SSA name (if specified). 434 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 435 return attr.first == "names"; 436 }); 437 438 // If there was no name specified, check to see if there was a useful name 439 // specified in the asm file. 440 if (hadNames || parser.getNumResults() == 0) 441 return success(); 442 443 SmallVector<StringRef, 4> names; 444 auto *context = result.getContext(); 445 446 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 447 auto resultName = parser.getResultName(i); 448 StringRef nameStr; 449 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 450 nameStr = resultName.first; 451 452 names.push_back(nameStr); 453 } 454 455 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 456 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 457 return success(); 458 } 459 460 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 461 p << "test.string_attr_pretty_name"; 462 463 // Note that we only need to print the "name" attribute if the asmprinter 464 // result name disagrees with it. This can happen in strange cases, e.g. 465 // when there are conflicts. 466 bool namesDisagree = op.names().size() != op.getNumResults(); 467 468 SmallString<32> resultNameStr; 469 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 470 resultNameStr.clear(); 471 llvm::raw_svector_ostream tmpStream(resultNameStr); 472 p.printOperand(op.getResult(i), tmpStream); 473 474 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 475 if (!expectedName || 476 tmpStream.str().drop_front() != expectedName.getValue()) { 477 namesDisagree = true; 478 } 479 } 480 481 if (namesDisagree) 482 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 483 else 484 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 485 } 486 487 // We set the SSA name in the asm syntax to the contents of the name 488 // attribute. 489 void StringAttrPrettyNameOp::getAsmResultNames( 490 function_ref<void(Value, StringRef)> setNameFn) { 491 492 auto value = names(); 493 for (size_t i = 0, e = value.size(); i != e; ++i) 494 if (auto str = value[i].dyn_cast<StringAttr>()) 495 if (!str.getValue().empty()) 496 setNameFn(getResult(i), str.getValue()); 497 } 498 499 //===----------------------------------------------------------------------===// 500 // Dialect Registration 501 //===----------------------------------------------------------------------===// 502 503 // Static initialization for Test dialect registration. 504 static mlir::DialectRegistration<mlir::TestDialect> testDialect; 505 506 #include "TestOpEnums.cpp.inc" 507 #include "TestOpStructs.cpp.inc" 508 509 #define GET_OP_CLASSES 510 #include "TestOps.cpp.inc" 511