1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Function.h" 12 #include "mlir/IR/Module.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/TypeUtilities.h" 15 #include "mlir/Transforms/FoldUtils.h" 16 #include "mlir/Transforms/InliningUtils.h" 17 #include "llvm/ADT/StringSwitch.h" 18 19 using namespace mlir; 20 21 //===----------------------------------------------------------------------===// 22 // TestDialect Interfaces 23 //===----------------------------------------------------------------------===// 24 25 namespace { 26 27 // Test support for interacting with the AsmPrinter. 28 struct TestOpAsmInterface : public OpAsmDialectInterface { 29 using OpAsmDialectInterface::OpAsmDialectInterface; 30 31 void getAsmResultNames(Operation *op, 32 OpAsmSetValueNameFn setNameFn) const final { 33 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 34 setNameFn(asmOp, "result"); 35 } 36 37 void getAsmBlockArgumentNames(Block *block, 38 OpAsmSetValueNameFn setNameFn) const final { 39 auto op = block->getParentOp(); 40 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 41 if (!arrayAttr) 42 return; 43 auto args = block->getArguments(); 44 auto e = std::min(arrayAttr.size(), args.size()); 45 for (unsigned i = 0; i < e; ++i) { 46 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 47 setNameFn(args[i], strAttr.getValue()); 48 } 49 } 50 }; 51 52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface { 53 using OpFolderDialectInterface::OpFolderDialectInterface; 54 55 /// Registered hook to check if the given region, which is attached to an 56 /// operation that is *not* isolated from above, should be used when 57 /// materializing constants. 58 bool shouldMaterializeInto(Region *region) const final { 59 // If this is a one region operation, then insert into it. 60 return isa<OneRegionOp>(region->getParentOp()); 61 } 62 }; 63 64 /// This class defines the interface for handling inlining with standard 65 /// operations. 66 struct TestInlinerInterface : public DialectInlinerInterface { 67 using DialectInlinerInterface::DialectInlinerInterface; 68 69 //===--------------------------------------------------------------------===// 70 // Analysis Hooks 71 //===--------------------------------------------------------------------===// 72 73 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 74 // Inlining into test dialect regions is legal. 75 return true; 76 } 77 bool isLegalToInline(Operation *, Region *, 78 BlockAndValueMapping &) const final { 79 return true; 80 } 81 82 bool shouldAnalyzeRecursively(Operation *op) const final { 83 // Analyze recursively if this is not a functional region operation, it 84 // froms a separate functional scope. 85 return !isa<FunctionalRegionOp>(op); 86 } 87 88 //===--------------------------------------------------------------------===// 89 // Transformation Hooks 90 //===--------------------------------------------------------------------===// 91 92 /// Handle the given inlined terminator by replacing it with a new operation 93 /// as necessary. 94 void handleTerminator(Operation *op, 95 ArrayRef<Value> valuesToRepl) const final { 96 // Only handle "test.return" here. 97 auto returnOp = dyn_cast<TestReturnOp>(op); 98 if (!returnOp) 99 return; 100 101 // Replace the values directly with the return operands. 102 assert(returnOp.getNumOperands() == valuesToRepl.size()); 103 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 104 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 105 } 106 107 /// Attempt to materialize a conversion for a type mismatch between a call 108 /// from this dialect, and a callable region. This method should generate an 109 /// operation that takes 'input' as the only operand, and produces a single 110 /// result of 'resultType'. If a conversion can not be generated, nullptr 111 /// should be returned. 112 Operation *materializeCallConversion(OpBuilder &builder, Value input, 113 Type resultType, 114 Location conversionLoc) const final { 115 // Only allow conversion for i16/i32 types. 116 if (!(resultType.isSignlessInteger(16) || 117 resultType.isSignlessInteger(32)) || 118 !(input.getType().isSignlessInteger(16) || 119 input.getType().isSignlessInteger(32))) 120 return nullptr; 121 return builder.create<TestCastOp>(conversionLoc, resultType, input); 122 } 123 }; 124 } // end anonymous namespace 125 126 //===----------------------------------------------------------------------===// 127 // TestDialect 128 //===----------------------------------------------------------------------===// 129 130 TestDialect::TestDialect(MLIRContext *context) 131 : Dialect(getDialectNamespace(), context) { 132 addOperations< 133 #define GET_OP_LIST 134 #include "TestOps.cpp.inc" 135 >(); 136 addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, 137 TestInlinerInterface>(); 138 allowUnknownOperations(); 139 } 140 141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 142 NamedAttribute namedAttr) { 143 if (namedAttr.first == "test.invalid_attr") 144 return op->emitError() << "invalid to use 'test.invalid_attr'"; 145 return success(); 146 } 147 148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 149 unsigned regionIndex, 150 unsigned argIndex, 151 NamedAttribute namedAttr) { 152 if (namedAttr.first == "test.invalid_attr") 153 return op->emitError() << "invalid to use 'test.invalid_attr'"; 154 return success(); 155 } 156 157 LogicalResult 158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 159 unsigned resultIndex, 160 NamedAttribute namedAttr) { 161 if (namedAttr.first == "test.invalid_attr") 162 return op->emitError() << "invalid to use 'test.invalid_attr'"; 163 return success(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // TestBranchOp 168 //===----------------------------------------------------------------------===// 169 170 Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) { 171 assert(index == 0 && "invalid successor index"); 172 return getOperands(); 173 } 174 175 bool TestBranchOp::canEraseSuccessorOperand() { return true; } 176 177 //===----------------------------------------------------------------------===// 178 // Test IsolatedRegionOp - parse passthrough region arguments. 179 //===----------------------------------------------------------------------===// 180 181 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 182 OperationState &result) { 183 OpAsmParser::OperandType argInfo; 184 Type argType = parser.getBuilder().getIndexType(); 185 186 // Parse the input operand. 187 if (parser.parseOperand(argInfo) || 188 parser.resolveOperand(argInfo, argType, result.operands)) 189 return failure(); 190 191 // Parse the body region, and reuse the operand info as the argument info. 192 Region *body = result.addRegion(); 193 return parser.parseRegion(*body, argInfo, argType, 194 /*enableNameShadowing=*/true); 195 } 196 197 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 198 p << "test.isolated_region "; 199 p.printOperand(op.getOperand()); 200 p.shadowRegionArgs(op.region(), op.getOperand()); 201 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // Test parser. 206 //===----------------------------------------------------------------------===// 207 208 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 209 OperationState &result) { 210 StringRef keyword; 211 if (parser.parseKeyword(&keyword)) 212 return failure(); 213 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 214 return success(); 215 } 216 217 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 218 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 223 224 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 225 OperationState &result) { 226 if (parser.parseKeyword("wraps")) 227 return failure(); 228 229 // Parse the wrapped op in a region 230 Region &body = *result.addRegion(); 231 body.push_back(new Block); 232 Block &block = body.back(); 233 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 234 if (!wrapped_op) 235 return failure(); 236 237 // Create a return terminator in the inner region, pass as operand to the 238 // terminator the returned values from the wrapped operation. 239 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 240 OpBuilder builder(parser.getBuilder().getContext()); 241 builder.setInsertionPointToEnd(&block); 242 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 243 244 // Get the results type for the wrapping op from the terminator operands. 245 Operation &return_op = body.back().back(); 246 result.types.append(return_op.operand_type_begin(), 247 return_op.operand_type_end()); 248 249 // Use the location of the wrapped op for the "test.wrapping_region" op. 250 result.location = wrapped_op->getLoc(); 251 252 return success(); 253 } 254 255 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 256 p << op.getOperationName() << " wraps "; 257 p.printGenericOp(&op.region().front().front()); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // Test PolyForOp - parse list of region arguments. 262 //===----------------------------------------------------------------------===// 263 264 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 265 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 266 // Parse list of region arguments without a delimiter. 267 if (parser.parseRegionArgumentList(ivsInfo)) 268 return failure(); 269 270 // Parse the body region. 271 Region *body = result.addRegion(); 272 auto &builder = parser.getBuilder(); 273 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 274 return parser.parseRegion(*body, ivsInfo, argTypes); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // Test removing op with inner ops. 279 //===----------------------------------------------------------------------===// 280 281 namespace { 282 struct TestRemoveOpWithInnerOps 283 : public OpRewritePattern<TestOpWithRegionPattern> { 284 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 285 286 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 287 PatternRewriter &rewriter) const override { 288 rewriter.eraseOp(op); 289 return success(); 290 } 291 }; 292 } // end anonymous namespace 293 294 void TestOpWithRegionPattern::getCanonicalizationPatterns( 295 OwningRewritePatternList &results, MLIRContext *context) { 296 results.insert<TestRemoveOpWithInnerOps>(context); 297 } 298 299 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 300 return operand(); 301 } 302 303 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 304 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 305 for (Value input : this->operands()) { 306 results.push_back(input); 307 } 308 return success(); 309 } 310 311 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( 312 MLIRContext *, Optional<Location> location, ValueRange operands, 313 ArrayRef<NamedAttribute> attributes, RegionRange regions, 314 SmallVectorImpl<Type> &inferredReturnTypes) { 315 if (operands[0].getType() != operands[1].getType()) { 316 return emitOptionalError(location, "operand type mismatch ", 317 operands[0].getType(), " vs ", 318 operands[1].getType()); 319 } 320 inferredReturnTypes.assign({operands[0].getType()}); 321 return success(); 322 } 323 324 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 325 MLIRContext *context, Optional<Location> location, ValueRange operands, 326 ArrayRef<NamedAttribute> attributes, RegionRange regions, 327 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 328 // Create return type consisting of the last element of the first operand. 329 auto operandType = *operands.getTypes().begin(); 330 auto sval = operandType.dyn_cast<ShapedType>(); 331 if (!sval) { 332 return emitOptionalError(location, "only shaped type operands allowed"); 333 } 334 int64_t dim = 335 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 336 auto type = IntegerType::get(17, context); 337 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 338 return success(); 339 } 340 341 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 342 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 343 shapes = SmallVector<Value, 1>{ 344 builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)}; 345 return success(); 346 } 347 348 //===----------------------------------------------------------------------===// 349 // Test SideEffect interfaces 350 //===----------------------------------------------------------------------===// 351 352 namespace { 353 /// A test resource for side effects. 354 struct TestResource : public SideEffects::Resource::Base<TestResource> { 355 StringRef getName() final { return "<Test>"; } 356 }; 357 } // end anonymous namespace 358 359 void SideEffectOp::getEffects( 360 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 361 // Check for an effects attribute on the op instance. 362 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 363 if (!effectsAttr) 364 return; 365 366 // If there is one, it is an array of dictionary attributes that hold 367 // information on the effects of this operation. 368 for (Attribute element : effectsAttr) { 369 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 370 371 // Get the specific memory effect. 372 MemoryEffects::Effect *effect = 373 llvm::StringSwitch<MemoryEffects::Effect *>( 374 effectElement.get("effect").cast<StringAttr>().getValue()) 375 .Case("allocate", MemoryEffects::Allocate::get()) 376 .Case("free", MemoryEffects::Free::get()) 377 .Case("read", MemoryEffects::Read::get()) 378 .Case("write", MemoryEffects::Write::get()); 379 380 // Check for a result to affect. 381 Value value; 382 if (effectElement.get("on_result")) 383 value = getResult(); 384 385 // Check for a non-default resource to use. 386 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 387 if (effectElement.get("test_resource")) 388 resource = TestResource::get(); 389 390 effects.emplace_back(effect, value, resource); 391 } 392 } 393 394 //===----------------------------------------------------------------------===// 395 // StringAttrPrettyNameOp 396 //===----------------------------------------------------------------------===// 397 398 // This op has fancy handling of its SSA result name. 399 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 400 OperationState &result) { 401 // Add the result types. 402 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 403 result.addTypes(parser.getBuilder().getIntegerType(32)); 404 405 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 406 return failure(); 407 408 // If the attribute dictionary contains no 'names' attribute, infer it from 409 // the SSA name (if specified). 410 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 411 return attr.first == "names"; 412 }); 413 414 // If there was no name specified, check to see if there was a useful name 415 // specified in the asm file. 416 if (hadNames || parser.getNumResults() == 0) 417 return success(); 418 419 SmallVector<StringRef, 4> names; 420 auto *context = result.getContext(); 421 422 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 423 auto resultName = parser.getResultName(i); 424 StringRef nameStr; 425 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 426 nameStr = resultName.first; 427 428 names.push_back(nameStr); 429 } 430 431 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 432 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 433 return success(); 434 } 435 436 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 437 p << "test.string_attr_pretty_name"; 438 439 // Note that we only need to print the "name" attribute if the asmprinter 440 // result name disagrees with it. This can happen in strange cases, e.g. 441 // when there are conflicts. 442 bool namesDisagree = op.names().size() != op.getNumResults(); 443 444 SmallString<32> resultNameStr; 445 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 446 resultNameStr.clear(); 447 llvm::raw_svector_ostream tmpStream(resultNameStr); 448 p.printOperand(op.getResult(i), tmpStream); 449 450 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 451 if (!expectedName || 452 tmpStream.str().drop_front() != expectedName.getValue()) { 453 namesDisagree = true; 454 } 455 } 456 457 if (namesDisagree) 458 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 459 else 460 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 461 } 462 463 // We set the SSA name in the asm syntax to the contents of the name 464 // attribute. 465 void StringAttrPrettyNameOp::getAsmResultNames( 466 function_ref<void(Value, StringRef)> setNameFn) { 467 468 auto value = names(); 469 for (size_t i = 0, e = value.size(); i != e; ++i) 470 if (auto str = value[i].dyn_cast<StringAttr>()) 471 if (!str.getValue().empty()) 472 setNameFn(getResult(i), str.getValue()); 473 } 474 475 //===----------------------------------------------------------------------===// 476 // Dialect Registration 477 //===----------------------------------------------------------------------===// 478 479 // Static initialization for Test dialect registration. 480 static mlir::DialectRegistration<mlir::TestDialect> testDialect; 481 482 #include "TestOpEnums.cpp.inc" 483 484 #define GET_OP_CLASSES 485 #include "TestOps.cpp.inc" 486