1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the GPU kernel-related dialect and its operations. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/GPU/GPUDialect.h" 23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24 #include "mlir/Dialect/StandardOps/Ops.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Function.h" 27 #include "mlir/IR/Module.h" 28 #include "mlir/IR/OpImplementation.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/IR/StandardTypes.h" 31 32 using namespace mlir; 33 using namespace mlir::gpu; 34 35 //===----------------------------------------------------------------------===// 36 // GPUDialect 37 //===----------------------------------------------------------------------===// 38 39 StringRef GPUDialect::getDialectName() { return "gpu"; } 40 41 bool GPUDialect::isKernel(Operation *op) { 42 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 43 return static_cast<bool>(isKernelAttr); 44 } 45 46 GPUDialect::GPUDialect(MLIRContext *context) 47 : Dialect(getDialectName(), context) { 48 addOperations<LaunchOp, LaunchFuncOp, 49 #define GET_OP_LIST 50 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 51 >(); 52 } 53 54 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 55 NamedAttribute attr) { 56 if (!attr.second.isa<UnitAttr>() || 57 !attr.first.is(getContainerModuleAttrName())) 58 return success(); 59 60 auto module = dyn_cast<ModuleOp>(op); 61 if (!module) 62 return op->emitError("expected '") 63 << getContainerModuleAttrName() << "' attribute to be attached to '" 64 << ModuleOp::getOperationName() << '\''; 65 66 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 67 // Ignore launches that are nested more or less deep than functions in the 68 // module we are currently checking. 69 if (!launchOp.getParentOp() || 70 launchOp.getParentOp()->getParentOp() != module) 71 return success(); 72 73 // Ignore launch ops with missing attributes here. The errors will be 74 // reported by the verifiers of those ops. 75 if (!launchOp.getAttrOfType<StringAttr>( 76 LaunchFuncOp::getKernelAttrName()) || 77 !launchOp.getAttrOfType<SymbolRefAttr>( 78 LaunchFuncOp::getKernelModuleAttrName())) 79 return success(); 80 81 // Check that `launch_func` refers to a well-formed GPU kernel module. 82 StringRef kernelModuleName = launchOp.getKernelModuleName(); 83 auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName); 84 if (!kernelModule) 85 return launchOp.emitOpError() 86 << "kernel module '" << kernelModuleName << "' is undefined"; 87 if (!kernelModule.getAttrOfType<UnitAttr>( 88 GPUDialect::getKernelModuleAttrName())) 89 return launchOp.emitOpError("module '") 90 << kernelModuleName << "' is missing the '" 91 << GPUDialect::getKernelModuleAttrName() << "' attribute"; 92 93 // Check that `launch_func` refers to a well-formed kernel function. 94 StringRef kernelName = launchOp.kernel(); 95 Operation *kernelFunc = kernelModule.lookupSymbol(kernelName); 96 auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc); 97 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 98 if (!kernelStdFunction && !kernelLLVMFunction) 99 return launchOp.emitOpError("kernel function '") 100 << kernelName << "' is undefined"; 101 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 102 GPUDialect::getKernelFuncAttrName())) 103 return launchOp.emitOpError("kernel function is missing the '") 104 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 105 106 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 107 unsigned expectedNumArguments = kernelLLVMFunction 108 ? kernelLLVMFunction.getNumArguments() 109 : kernelStdFunction.getNumArguments(); 110 if (expectedNumArguments != actualNumArguments) 111 return launchOp.emitOpError("got ") 112 << actualNumArguments << " kernel operands but expected " 113 << expectedNumArguments; 114 115 // Due to the ordering of the current impl of lowering and LLVMLowering, 116 // type checks need to be temporarily disabled. 117 // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc 118 // to encode target module" has landed. 119 // auto functionType = kernelFunc.getType(); 120 // for (unsigned i = 0; i < numKernelFuncArgs; ++i) { 121 // if (getKernelOperand(i)->getType() != functionType.getInput(i)) { 122 // return emitOpError("type of function argument ") 123 // << i << " does not match"; 124 // } 125 // } 126 127 return success(); 128 }); 129 130 return walkResult.wasInterrupted() ? failure() : success(); 131 } 132 133 template <typename T> static LogicalResult verifyIndexOp(T op) { 134 auto dimension = op.dimension(); 135 if (dimension != "x" && dimension != "y" && dimension != "z") 136 return op.emitError("dimension \"") << dimension << "\" is invalid"; 137 return success(); 138 } 139 140 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 141 if (allReduce.body().empty() != allReduce.op().hasValue()) 142 return allReduce.emitError( 143 "expected either an op attribute or a non-empty body"); 144 if (!allReduce.body().empty()) { 145 if (allReduce.body().front().getNumArguments() != 2) 146 return allReduce.emitError("expected two region arguments"); 147 for (auto *argument : allReduce.body().front().getArguments()) { 148 if (argument->getType() != allReduce.getType()) 149 return allReduce.emitError("incorrect region argument type"); 150 } 151 unsigned yieldCount = 0; 152 for (Block &block : allReduce.body()) { 153 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 154 if (yield.getNumOperands() != 1) 155 return allReduce.emitError("expected one gpu.yield operand"); 156 if (yield.getOperand(0)->getType() != allReduce.getType()) 157 return allReduce.emitError("incorrect gpu.yield type"); 158 ++yieldCount; 159 } 160 } 161 if (yieldCount == 0) 162 return allReduce.emitError("expected gpu.yield op in region"); 163 } 164 return success(); 165 } 166 167 // Namespace avoids ambiguous ReturnOpOperandAdaptor. 168 namespace mlir { 169 namespace gpu { 170 #define GET_OP_CLASSES 171 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 172 } // namespace gpu 173 } // namespace mlir 174 175 //===----------------------------------------------------------------------===// 176 // LaunchOp 177 //===----------------------------------------------------------------------===// 178 179 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) { 180 SmallVector<Type, 4> types; 181 types.reserve(values.size()); 182 for (Value *v : values) 183 types.push_back(v->getType()); 184 return types; 185 } 186 187 void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX, 188 Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, 189 Value *blockSizeY, Value *blockSizeZ, 190 ArrayRef<Value *> operands) { 191 // Add grid and block sizes as op operands, followed by the data operands. 192 result.addOperands( 193 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 194 result.addOperands(operands); 195 196 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 197 // where the first kNumConfigRegionAttributes arguments have `index` type and 198 // the rest have the same types as the data operands. 199 Region *kernelRegion = result.addRegion(); 200 Block *body = new Block(); 201 body->addArguments( 202 std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); 203 body->addArguments(getValueTypes(operands)); 204 kernelRegion->push_back(body); 205 } 206 207 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); } 208 209 KernelDim3 LaunchOp::getBlockIds() { 210 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 211 auto args = getBody().getBlocks().front().getArguments(); 212 return KernelDim3{args[0], args[1], args[2]}; 213 } 214 215 KernelDim3 LaunchOp::getThreadIds() { 216 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 217 auto args = getBody().getBlocks().front().getArguments(); 218 return KernelDim3{args[3], args[4], args[5]}; 219 } 220 221 KernelDim3 LaunchOp::getGridSize() { 222 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 223 auto args = getBody().getBlocks().front().getArguments(); 224 return KernelDim3{args[6], args[7], args[8]}; 225 } 226 227 KernelDim3 LaunchOp::getBlockSize() { 228 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 229 auto args = getBody().getBlocks().front().getArguments(); 230 return KernelDim3{args[9], args[10], args[11]}; 231 } 232 233 LaunchOp::operand_range LaunchOp::getKernelOperandValues() { 234 return llvm::drop_begin(getOperands(), kNumConfigOperands); 235 } 236 237 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { 238 return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); 239 } 240 241 KernelDim3 LaunchOp::getGridSizeOperandValues() { 242 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 243 } 244 245 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 246 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 247 } 248 249 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { 250 auto args = getBody().getBlocks().front().getArguments(); 251 return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); 252 } 253 254 LogicalResult LaunchOp::verify() { 255 // Kernel launch takes kNumConfigOperands leading operands for grid/block 256 // sizes and transforms them into kNumConfigRegionAttributes region arguments 257 // for block/thread identifiers and grid/block sizes. 258 if (!getBody().empty()) { 259 Block &entryBlock = getBody().front(); 260 if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands()) 261 return emitOpError("unexpected number of region arguments"); 262 } 263 264 // Block terminators without successors are expected to exit the kernel region 265 // and must be `gpu.launch`. 266 for (Block &block : getBody()) { 267 if (block.empty()) 268 continue; 269 if (block.back().getNumSuccessors() != 0) 270 continue; 271 if (!isa<gpu::ReturnOp>(&block.back())) { 272 return block.back() 273 .emitError("expected 'gpu.terminator' or a terminator with " 274 "successors") 275 .attachNote(getLoc()) 276 << "in '" << getOperationName() << "' body region"; 277 } 278 } 279 280 return success(); 281 } 282 283 // Pretty-print the kernel grid/block size assignment as 284 // (%iter-x, %iter-y, %iter-z) in 285 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 286 // where %size-* and %iter-* will correspond to the body region arguments. 287 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 288 ArrayRef<Value *> operands, KernelDim3 ids) { 289 p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in ("; 290 p << *size.x << " = " << *operands[0] << ", "; 291 p << *size.y << " = " << *operands[1] << ", "; 292 p << *size.z << " = " << *operands[2] << ')'; 293 } 294 295 void LaunchOp::print(OpAsmPrinter &p) { 296 SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end()); 297 ArrayRef<Value *> operands(operandContainer); 298 299 // Print the launch configuration. 300 p << getOperationName() << ' ' << getBlocksKeyword(); 301 printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds()); 302 p << ' ' << getThreadsKeyword(); 303 printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds()); 304 305 // From now on, the first kNumConfigOperands operands corresponding to grid 306 // and block sizes are irrelevant, so we can drop them. 307 operands = operands.drop_front(kNumConfigOperands); 308 309 // Print the data argument remapping. 310 if (!getBody().empty() && !operands.empty()) { 311 p << ' ' << getArgsKeyword() << '('; 312 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 313 if (i != 0) 314 p << ", "; 315 p << *getBody().front().getArgument(kNumConfigRegionAttributes + i) 316 << " = " << *operands[i]; 317 } 318 p << ") "; 319 } 320 321 // Print the types of data arguments. 322 if (!operands.empty()) { 323 p << ": "; 324 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 325 if (i != 0) 326 p << ", "; 327 p << operands[i]->getType(); 328 } 329 } 330 331 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 332 p.printOptionalAttrDict(getAttrs()); 333 } 334 335 // Parse the size assignment blocks for blocks and threads. These have the form 336 // (%region_arg, %region_arg, %region_arg) in 337 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 338 // where %region_arg are percent-identifiers for the region arguments to be 339 // introduced further (SSA defs), and %operand are percent-identifiers for the 340 // SSA value uses. 341 static ParseResult 342 parseSizeAssignment(OpAsmParser &parser, 343 MutableArrayRef<OpAsmParser::OperandType> sizes, 344 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 345 MutableArrayRef<OpAsmParser::OperandType> indices) { 346 assert(indices.size() == 3 && "space for three indices expected"); 347 SmallVector<OpAsmParser::OperandType, 3> args; 348 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 349 OpAsmParser::Delimiter::Paren) || 350 parser.parseKeyword("in") || parser.parseLParen()) 351 return failure(); 352 std::move(args.begin(), args.end(), indices.begin()); 353 354 for (int i = 0; i < 3; ++i) { 355 if (i != 0 && parser.parseComma()) 356 return failure(); 357 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 358 parser.parseOperand(sizes[i])) 359 return failure(); 360 } 361 362 return parser.parseRParen(); 363 } 364 365 // Parses a Launch operation. 366 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 367 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 368 // (`args` ssa-reassignment `:` type-list)? 369 // region attr-dict? 370 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 371 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { 372 // Sizes of the grid and block. 373 SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes( 374 kNumConfigOperands); 375 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 376 377 // Actual (data) operands passed to the kernel. 378 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 379 380 // Region arguments to be created. 381 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 382 kNumConfigRegionAttributes); 383 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 384 385 // Parse the size assignment segments: the first segment assigns grid sizes 386 // and defines values for block identifiers; the second segment assigns block 387 // sizes and defines values for thread identifiers. In the region argument 388 // list, identifiers precede sizes, and block-related values precede 389 // thread-related values. 390 if (parser.parseKeyword(getBlocksKeyword().data()) || 391 parseSizeAssignment(parser, sizesRef.take_front(3), 392 regionArgsRef.slice(6, 3), 393 regionArgsRef.slice(0, 3)) || 394 parser.parseKeyword(getThreadsKeyword().data()) || 395 parseSizeAssignment(parser, sizesRef.drop_front(3), 396 regionArgsRef.slice(9, 3), 397 regionArgsRef.slice(3, 3)) || 398 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 399 result.operands)) 400 return failure(); 401 402 // If kernel argument renaming segment is present, parse it. When present, 403 // the segment should have at least one element. If this segment is present, 404 // so is the trailing type list. Parse it as well and use the parsed types 405 // to resolve the operands passed to the kernel arguments. 406 SmallVector<Type, 4> dataTypes; 407 if (!parser.parseOptionalKeyword(getArgsKeyword())) { 408 llvm::SMLoc argsLoc = parser.getCurrentLocation(); 409 410 regionArgs.push_back({}); 411 dataOperands.push_back({}); 412 if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) || 413 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 414 return failure(); 415 416 while (!parser.parseOptionalComma()) { 417 regionArgs.push_back({}); 418 dataOperands.push_back({}); 419 if (parser.parseRegionArgument(regionArgs.back()) || 420 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 421 return failure(); 422 } 423 424 if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) || 425 parser.resolveOperands(dataOperands, dataTypes, argsLoc, 426 result.operands)) 427 return failure(); 428 } 429 430 // Introduce the body region and parse it. The region has 431 // kNumConfigRegionAttributes leading arguments that correspond to 432 // block/thread identifiers and grid/block sizes, all of the `index` type. 433 // Follow the actual kernel arguments. 434 Type index = parser.getBuilder().getIndexType(); 435 dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); 436 Region *body = result.addRegion(); 437 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 438 parser.parseOptionalAttributeDict(result.attributes)); 439 } 440 441 void LaunchOp::eraseKernelArgument(unsigned index) { 442 Block &entryBlock = getBody().front(); 443 assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes && 444 "kernel argument index overflow"); 445 entryBlock.eraseArgument(kNumConfigRegionAttributes + index); 446 getOperation()->eraseOperand(kNumConfigOperands + index); 447 } 448 449 namespace { 450 // Clone any known constants passed as operands to the kernel into its body. 451 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { 452 using OpRewritePattern<LaunchOp>::OpRewritePattern; 453 454 PatternMatchResult matchAndRewrite(LaunchOp launchOp, 455 PatternRewriter &rewriter) const override { 456 auto origInsertionPoint = rewriter.saveInsertionPoint(); 457 rewriter.setInsertionPointToStart(&launchOp.getBody().front()); 458 459 // Traverse operands passed to kernel and check if some of them are known 460 // constants. If so, clone the constant operation inside the kernel region 461 // and use it instead of passing the value from the parent region. Perform 462 // the traversal in the inverse order to simplify index arithmetics when 463 // dropping arguments. 464 SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(), 465 launchOp.getKernelOperandValues().end()); 466 SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(), 467 launchOp.getKernelArguments().end()); 468 bool found = false; 469 for (unsigned i = operands.size(); i > 0; --i) { 470 unsigned index = i - 1; 471 Value *operand = operands[index]; 472 if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { 473 continue; 474 } 475 476 found = true; 477 Value *internalConstant = 478 rewriter.clone(*operand->getDefiningOp())->getResult(0); 479 Value *kernelArg = kernelArgs[index]; 480 kernelArg->replaceAllUsesWith(internalConstant); 481 launchOp.eraseKernelArgument(index); 482 } 483 rewriter.restoreInsertionPoint(origInsertionPoint); 484 485 if (!found) 486 return matchFailure(); 487 488 rewriter.updatedRootInPlace(launchOp); 489 return matchSuccess(); 490 } 491 }; 492 } // end namespace 493 494 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 495 MLIRContext *context) { 496 results.insert<PropagateConstantBounds>(context); 497 } 498 499 //===----------------------------------------------------------------------===// 500 // LaunchFuncOp 501 //===----------------------------------------------------------------------===// 502 503 void LaunchFuncOp::build(Builder *builder, OperationState &result, 504 FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, 505 Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, 506 Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { 507 // Add grid and block sizes as op operands, followed by the data operands. 508 result.addOperands( 509 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 510 result.addOperands(kernelOperands); 511 result.addAttribute(getKernelAttrName(), 512 builder->getStringAttr(kernelFunc.getName())); 513 auto kernelModule = kernelFunc.getParentOfType<ModuleOp>(); 514 if (Optional<StringRef> kernelModuleName = kernelModule.getName()) 515 result.addAttribute(getKernelModuleAttrName(), 516 builder->getSymbolRefAttr(*kernelModuleName)); 517 } 518 519 void LaunchFuncOp::build(Builder *builder, OperationState &result, 520 FuncOp kernelFunc, KernelDim3 gridSize, 521 KernelDim3 blockSize, 522 ArrayRef<Value *> kernelOperands) { 523 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 524 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 525 } 526 527 StringRef LaunchFuncOp::kernel() { 528 return getAttrOfType<StringAttr>(getKernelAttrName()).getValue(); 529 } 530 531 unsigned LaunchFuncOp::getNumKernelOperands() { 532 return getNumOperands() - kNumConfigOperands; 533 } 534 535 StringRef LaunchFuncOp::getKernelModuleName() { 536 return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue(); 537 } 538 539 Value *LaunchFuncOp::getKernelOperand(unsigned i) { 540 return getOperation()->getOperand(i + kNumConfigOperands); 541 } 542 543 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 544 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 545 } 546 547 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 548 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 549 } 550 551 LogicalResult LaunchFuncOp::verify() { 552 auto module = getParentOfType<ModuleOp>(); 553 if (!module) 554 return emitOpError("expected to belong to a module"); 555 556 if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName())) 557 return emitOpError("expected the closest surrounding module to have the '" + 558 GPUDialect::getContainerModuleAttrName() + 559 "' attribute"); 560 561 auto kernelAttr = getAttrOfType<StringAttr>(getKernelAttrName()); 562 if (!kernelAttr) 563 return emitOpError("string attribute '" + getKernelAttrName() + 564 "' must be specified"); 565 566 auto kernelModuleAttr = 567 getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()); 568 if (!kernelModuleAttr) 569 return emitOpError("symbol reference attribute '" + 570 getKernelModuleAttrName() + "' must be specified"); 571 572 return success(); 573 } 574