1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Function.h" 12 #include "mlir/IR/Module.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/TypeUtilities.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 #include "mlir/Transforms/InliningUtils.h" 17 #include "llvm/ADT/StringSwitch.h" 18 19 using namespace mlir; 20 21 //===----------------------------------------------------------------------===// 22 // TestDialect Interfaces 23 //===----------------------------------------------------------------------===// 24 25 namespace { 26 27 // Test support for interacting with the AsmPrinter. 28 struct TestOpAsmInterface : public OpAsmDialectInterface { 29 using OpAsmDialectInterface::OpAsmDialectInterface; 30 31 void getAsmResultNames(Operation *op, 32 OpAsmSetValueNameFn setNameFn) const final { 33 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 34 setNameFn(asmOp, "result"); 35 } 36 37 void getAsmBlockArgumentNames(Block *block, 38 OpAsmSetValueNameFn setNameFn) const final { 39 auto op = block->getParentOp(); 40 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 41 if (!arrayAttr) 42 return; 43 auto args = block->getArguments(); 44 auto e = std::min(arrayAttr.size(), args.size()); 45 for (unsigned i = 0; i < e; ++i) { 46 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 47 setNameFn(args[i], strAttr.getValue()); 48 } 49 } 50 }; 51 52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface { 53 using OpFolderDialectInterface::OpFolderDialectInterface; 54 55 /// Registered hook to check if the given region, which is attached to an 56 /// operation that is *not* isolated from above, should be used when 57 /// materializing constants. 58 bool shouldMaterializeInto(Region *region) const final { 59 // If this is a one region operation, then insert into it. 60 return isa<OneRegionOp>(region->getParentOp()); 61 } 62 }; 63 64 /// This class defines the interface for handling inlining with standard 65 /// operations. 66 struct TestInlinerInterface : public DialectInlinerInterface { 67 using DialectInlinerInterface::DialectInlinerInterface; 68 69 //===--------------------------------------------------------------------===// 70 // Analysis Hooks 71 //===--------------------------------------------------------------------===// 72 73 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 74 // Inlining into test dialect regions is legal. 75 return true; 76 } 77 bool isLegalToInline(Operation *, Region *, 78 BlockAndValueMapping &) const final { 79 return true; 80 } 81 82 bool shouldAnalyzeRecursively(Operation *op) const final { 83 // Analyze recursively if this is not a functional region operation, it 84 // froms a separate functional scope. 85 return !isa<FunctionalRegionOp>(op); 86 } 87 88 //===--------------------------------------------------------------------===// 89 // Transformation Hooks 90 //===--------------------------------------------------------------------===// 91 92 /// Handle the given inlined terminator by replacing it with a new operation 93 /// as necessary. 94 void handleTerminator(Operation *op, 95 ArrayRef<Value> valuesToRepl) const final { 96 // Only handle "test.return" here. 97 auto returnOp = dyn_cast<TestReturnOp>(op); 98 if (!returnOp) 99 return; 100 101 // Replace the values directly with the return operands. 102 assert(returnOp.getNumOperands() == valuesToRepl.size()); 103 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 104 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 105 } 106 107 /// Attempt to materialize a conversion for a type mismatch between a call 108 /// from this dialect, and a callable region. This method should generate an 109 /// operation that takes 'input' as the only operand, and produces a single 110 /// result of 'resultType'. If a conversion can not be generated, nullptr 111 /// should be returned. 112 Operation *materializeCallConversion(OpBuilder &builder, Value input, 113 Type resultType, 114 Location conversionLoc) const final { 115 // Only allow conversion for i16/i32 types. 116 if (!(resultType.isSignlessInteger(16) || 117 resultType.isSignlessInteger(32)) || 118 !(input.getType().isSignlessInteger(16) || 119 input.getType().isSignlessInteger(32))) 120 return nullptr; 121 return builder.create<TestCastOp>(conversionLoc, resultType, input); 122 } 123 }; 124 } // end anonymous namespace 125 126 //===----------------------------------------------------------------------===// 127 // TestDialect 128 //===----------------------------------------------------------------------===// 129 130 TestDialect::TestDialect(MLIRContext *context) 131 : Dialect(getDialectNamespace(), context) { 132 addOperations< 133 #define GET_OP_LIST 134 #include "TestOps.cpp.inc" 135 >(); 136 addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, 137 TestInlinerInterface>(); 138 allowUnknownOperations(); 139 } 140 141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 142 NamedAttribute namedAttr) { 143 if (namedAttr.first == "test.invalid_attr") 144 return op->emitError() << "invalid to use 'test.invalid_attr'"; 145 return success(); 146 } 147 148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 149 unsigned regionIndex, 150 unsigned argIndex, 151 NamedAttribute namedAttr) { 152 if (namedAttr.first == "test.invalid_attr") 153 return op->emitError() << "invalid to use 'test.invalid_attr'"; 154 return success(); 155 } 156 157 LogicalResult 158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 159 unsigned resultIndex, 160 NamedAttribute namedAttr) { 161 if (namedAttr.first == "test.invalid_attr") 162 return op->emitError() << "invalid to use 'test.invalid_attr'"; 163 return success(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // TestBranchOp 168 //===----------------------------------------------------------------------===// 169 170 Optional<MutableOperandRange> 171 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 172 assert(index == 0 && "invalid successor index"); 173 return targetOperandsMutable(); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // Test IsolatedRegionOp - parse passthrough region arguments. 178 //===----------------------------------------------------------------------===// 179 180 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 181 OperationState &result) { 182 OpAsmParser::OperandType argInfo; 183 Type argType = parser.getBuilder().getIndexType(); 184 185 // Parse the input operand. 186 if (parser.parseOperand(argInfo) || 187 parser.resolveOperand(argInfo, argType, result.operands)) 188 return failure(); 189 190 // Parse the body region, and reuse the operand info as the argument info. 191 Region *body = result.addRegion(); 192 return parser.parseRegion(*body, argInfo, argType, 193 /*enableNameShadowing=*/true); 194 } 195 196 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 197 p << "test.isolated_region "; 198 p.printOperand(op.getOperand()); 199 p.shadowRegionArgs(op.region(), op.getOperand()); 200 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // Test PolyhedralScopeOp 205 //===----------------------------------------------------------------------===// 206 207 static ParseResult parsePolyhedralScopeOp(OpAsmParser &parser, 208 OperationState &result) { 209 // Parse the body region, and reuse the operand info as the argument info. 210 Region *body = result.addRegion(); 211 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 212 } 213 214 static void print(OpAsmPrinter &p, PolyhedralScopeOp op) { 215 p << "test.polyhedral_scope "; 216 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // Test parser. 221 //===----------------------------------------------------------------------===// 222 223 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 224 OperationState &result) { 225 StringRef keyword; 226 if (parser.parseKeyword(&keyword)) 227 return failure(); 228 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 229 return success(); 230 } 231 232 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 233 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 234 } 235 236 //===----------------------------------------------------------------------===// 237 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 238 239 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 240 OperationState &result) { 241 if (parser.parseKeyword("wraps")) 242 return failure(); 243 244 // Parse the wrapped op in a region 245 Region &body = *result.addRegion(); 246 body.push_back(new Block); 247 Block &block = body.back(); 248 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 249 if (!wrapped_op) 250 return failure(); 251 252 // Create a return terminator in the inner region, pass as operand to the 253 // terminator the returned values from the wrapped operation. 254 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 255 OpBuilder builder(parser.getBuilder().getContext()); 256 builder.setInsertionPointToEnd(&block); 257 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 258 259 // Get the results type for the wrapping op from the terminator operands. 260 Operation &return_op = body.back().back(); 261 result.types.append(return_op.operand_type_begin(), 262 return_op.operand_type_end()); 263 264 // Use the location of the wrapped op for the "test.wrapping_region" op. 265 result.location = wrapped_op->getLoc(); 266 267 return success(); 268 } 269 270 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 271 p << op.getOperationName() << " wraps "; 272 p.printGenericOp(&op.region().front().front()); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // Test PolyForOp - parse list of region arguments. 277 //===----------------------------------------------------------------------===// 278 279 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 280 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 281 // Parse list of region arguments without a delimiter. 282 if (parser.parseRegionArgumentList(ivsInfo)) 283 return failure(); 284 285 // Parse the body region. 286 Region *body = result.addRegion(); 287 auto &builder = parser.getBuilder(); 288 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 289 return parser.parseRegion(*body, ivsInfo, argTypes); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // Test removing op with inner ops. 294 //===----------------------------------------------------------------------===// 295 296 namespace { 297 struct TestRemoveOpWithInnerOps 298 : public OpRewritePattern<TestOpWithRegionPattern> { 299 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 300 301 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 302 PatternRewriter &rewriter) const override { 303 rewriter.eraseOp(op); 304 return success(); 305 } 306 }; 307 } // end anonymous namespace 308 309 void TestOpWithRegionPattern::getCanonicalizationPatterns( 310 OwningRewritePatternList &results, MLIRContext *context) { 311 results.insert<TestRemoveOpWithInnerOps>(context); 312 } 313 314 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 315 return operand(); 316 } 317 318 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 319 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 320 for (Value input : this->operands()) { 321 results.push_back(input); 322 } 323 return success(); 324 } 325 326 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( 327 MLIRContext *, Optional<Location> location, ValueRange operands, 328 ArrayRef<NamedAttribute> attributes, RegionRange regions, 329 SmallVectorImpl<Type> &inferredReturnTypes) { 330 if (operands[0].getType() != operands[1].getType()) { 331 return emitOptionalError(location, "operand type mismatch ", 332 operands[0].getType(), " vs ", 333 operands[1].getType()); 334 } 335 inferredReturnTypes.assign({operands[0].getType()}); 336 return success(); 337 } 338 339 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 340 MLIRContext *context, Optional<Location> location, ValueRange operands, 341 ArrayRef<NamedAttribute> attributes, RegionRange regions, 342 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 343 // Create return type consisting of the last element of the first operand. 344 auto operandType = *operands.getTypes().begin(); 345 auto sval = operandType.dyn_cast<ShapedType>(); 346 if (!sval) { 347 return emitOptionalError(location, "only shaped type operands allowed"); 348 } 349 int64_t dim = 350 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 351 auto type = IntegerType::get(17, context); 352 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 353 return success(); 354 } 355 356 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 357 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 358 shapes = SmallVector<Value, 1>{ 359 builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)}; 360 return success(); 361 } 362 363 //===----------------------------------------------------------------------===// 364 // Test SideEffect interfaces 365 //===----------------------------------------------------------------------===// 366 367 namespace { 368 /// A test resource for side effects. 369 struct TestResource : public SideEffects::Resource::Base<TestResource> { 370 StringRef getName() final { return "<Test>"; } 371 }; 372 } // end anonymous namespace 373 374 void SideEffectOp::getEffects( 375 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 376 // Check for an effects attribute on the op instance. 377 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 378 if (!effectsAttr) 379 return; 380 381 // If there is one, it is an array of dictionary attributes that hold 382 // information on the effects of this operation. 383 for (Attribute element : effectsAttr) { 384 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 385 386 // Get the specific memory effect. 387 MemoryEffects::Effect *effect = 388 llvm::StringSwitch<MemoryEffects::Effect *>( 389 effectElement.get("effect").cast<StringAttr>().getValue()) 390 .Case("allocate", MemoryEffects::Allocate::get()) 391 .Case("free", MemoryEffects::Free::get()) 392 .Case("read", MemoryEffects::Read::get()) 393 .Case("write", MemoryEffects::Write::get()); 394 395 // Check for a result to affect. 396 Value value; 397 if (effectElement.get("on_result")) 398 value = getResult(); 399 400 // Check for a non-default resource to use. 401 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 402 if (effectElement.get("test_resource")) 403 resource = TestResource::get(); 404 405 effects.emplace_back(effect, value, resource); 406 } 407 } 408 409 //===----------------------------------------------------------------------===// 410 // StringAttrPrettyNameOp 411 //===----------------------------------------------------------------------===// 412 413 // This op has fancy handling of its SSA result name. 414 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 415 OperationState &result) { 416 // Add the result types. 417 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 418 result.addTypes(parser.getBuilder().getIntegerType(32)); 419 420 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 421 return failure(); 422 423 // If the attribute dictionary contains no 'names' attribute, infer it from 424 // the SSA name (if specified). 425 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 426 return attr.first == "names"; 427 }); 428 429 // If there was no name specified, check to see if there was a useful name 430 // specified in the asm file. 431 if (hadNames || parser.getNumResults() == 0) 432 return success(); 433 434 SmallVector<StringRef, 4> names; 435 auto *context = result.getContext(); 436 437 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 438 auto resultName = parser.getResultName(i); 439 StringRef nameStr; 440 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 441 nameStr = resultName.first; 442 443 names.push_back(nameStr); 444 } 445 446 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 447 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 448 return success(); 449 } 450 451 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 452 p << "test.string_attr_pretty_name"; 453 454 // Note that we only need to print the "name" attribute if the asmprinter 455 // result name disagrees with it. This can happen in strange cases, e.g. 456 // when there are conflicts. 457 bool namesDisagree = op.names().size() != op.getNumResults(); 458 459 SmallString<32> resultNameStr; 460 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 461 resultNameStr.clear(); 462 llvm::raw_svector_ostream tmpStream(resultNameStr); 463 p.printOperand(op.getResult(i), tmpStream); 464 465 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 466 if (!expectedName || 467 tmpStream.str().drop_front() != expectedName.getValue()) { 468 namesDisagree = true; 469 } 470 } 471 472 if (namesDisagree) 473 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 474 else 475 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 476 } 477 478 // We set the SSA name in the asm syntax to the contents of the name 479 // attribute. 480 void StringAttrPrettyNameOp::getAsmResultNames( 481 function_ref<void(Value, StringRef)> setNameFn) { 482 483 auto value = names(); 484 for (size_t i = 0, e = value.size(); i != e; ++i) 485 if (auto str = value[i].dyn_cast<StringAttr>()) 486 if (!str.getValue().empty()) 487 setNameFn(getResult(i), str.getValue()); 488 } 489 490 //===----------------------------------------------------------------------===// 491 // Dialect Registration 492 //===----------------------------------------------------------------------===// 493 494 // Static initialization for Test dialect registration. 495 static mlir::DialectRegistration<mlir::TestDialect> testDialect; 496 497 #include "TestOpEnums.cpp.inc" 498 #include "TestOpStructs.cpp.inc" 499 500 #define GET_OP_CLASSES 501 #include "TestOps.cpp.inc" 502