1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/Function.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 26 using namespace mlir; 27 using namespace mlir::gpu; 28 29 //===----------------------------------------------------------------------===// 30 // GPUDialect 31 //===----------------------------------------------------------------------===// 32 33 bool GPUDialect::isKernel(Operation *op) { 34 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 35 return static_cast<bool>(isKernelAttr); 36 } 37 38 void GPUDialect::initialize() { 39 addOperations< 40 #define GET_OP_LIST 41 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 42 >(); 43 } 44 45 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 46 NamedAttribute attr) { 47 if (!attr.second.isa<UnitAttr>() || 48 attr.first != getContainerModuleAttrName()) 49 return success(); 50 51 auto module = dyn_cast<ModuleOp>(op); 52 if (!module) 53 return op->emitError("expected '") 54 << getContainerModuleAttrName() << "' attribute to be attached to '" 55 << ModuleOp::getOperationName() << '\''; 56 57 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 58 // Ignore launches that are nested more or less deep than functions in the 59 // module we are currently checking. 60 if (!launchOp.getParentOp() || 61 launchOp.getParentOp()->getParentOp() != module) 62 return success(); 63 64 // Ignore launch ops with missing attributes here. The errors will be 65 // reported by the verifiers of those ops. 66 if (!launchOp.getAttrOfType<SymbolRefAttr>( 67 LaunchFuncOp::getKernelAttrName())) 68 return success(); 69 70 // Check that `launch_func` refers to a well-formed GPU kernel module. 71 StringRef kernelModuleName = launchOp.getKernelModuleName(); 72 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 73 if (!kernelModule) 74 return launchOp.emitOpError() 75 << "kernel module '" << kernelModuleName << "' is undefined"; 76 77 // Check that `launch_func` refers to a well-formed kernel function. 78 Operation *kernelFunc = module.lookupSymbol(launchOp.kernel()); 79 auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc); 80 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 81 if (!kernelGPUFunction && !kernelLLVMFunction) 82 return launchOp.emitOpError("kernel function '") 83 << launchOp.kernel() << "' is undefined"; 84 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 85 GPUDialect::getKernelFuncAttrName())) 86 return launchOp.emitOpError("kernel function is missing the '") 87 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 88 89 // TODO: if the kernel function has been converted to 90 // the LLVM dialect but the caller hasn't (which happens during the 91 // separate compilation), do not check type correspondence as it would 92 // require the verifier to be aware of the LLVM type conversion. 93 if (kernelLLVMFunction) 94 return success(); 95 96 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 97 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 98 if (expectedNumArguments != actualNumArguments) 99 return launchOp.emitOpError("got ") 100 << actualNumArguments << " kernel operands but expected " 101 << expectedNumArguments; 102 103 auto functionType = kernelGPUFunction.getType(); 104 for (unsigned i = 0; i < expectedNumArguments; ++i) { 105 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 106 return launchOp.emitOpError("type of function argument ") 107 << i << " does not match"; 108 } 109 } 110 111 return success(); 112 }); 113 114 return walkResult.wasInterrupted() ? failure() : success(); 115 } 116 117 template <typename T> static LogicalResult verifyIndexOp(T op) { 118 auto dimension = op.dimension(); 119 if (dimension != "x" && dimension != "y" && dimension != "z") 120 return op.emitError("dimension \"") << dimension << "\" is invalid"; 121 return success(); 122 } 123 124 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 125 if (allReduce.body().empty() != allReduce.op().hasValue()) 126 return allReduce.emitError( 127 "expected either an op attribute or a non-empty body"); 128 if (!allReduce.body().empty()) { 129 if (allReduce.body().getNumArguments() != 2) 130 return allReduce.emitError("expected two region arguments"); 131 for (auto argument : allReduce.body().getArguments()) { 132 if (argument.getType() != allReduce.getType()) 133 return allReduce.emitError("incorrect region argument type"); 134 } 135 unsigned yieldCount = 0; 136 for (Block &block : allReduce.body()) { 137 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 138 if (yield.getNumOperands() != 1) 139 return allReduce.emitError("expected one gpu.yield operand"); 140 if (yield.getOperand(0).getType() != allReduce.getType()) 141 return allReduce.emitError("incorrect gpu.yield type"); 142 ++yieldCount; 143 } 144 } 145 if (yieldCount == 0) 146 return allReduce.emitError("expected gpu.yield op in region"); 147 } else { 148 StringRef opName = *allReduce.op(); 149 if ((opName == "and" || opName == "or" || opName == "xor") && 150 !allReduce.getType().isa<IntegerType>()) { 151 return allReduce.emitError() 152 << '`' << opName << '`' 153 << " accumulator is only compatible with Integer type"; 154 } 155 } 156 return success(); 157 } 158 159 static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { 160 auto type = shuffleOp.value().getType(); 161 if (shuffleOp.result().getType() != type) { 162 return shuffleOp.emitOpError() 163 << "requires the same type for value operand and result"; 164 } 165 if (!type.isSignlessIntOrFloat() || type.getIntOrFloatBitWidth() != 32) { 166 return shuffleOp.emitOpError() 167 << "requires value operand type to be f32 or i32"; 168 } 169 return success(); 170 } 171 172 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { 173 p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' 174 << op.mode() << " : " << op.value().getType(); 175 } 176 177 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { 178 SmallVector<OpAsmParser::OperandType, 3> operandInfo; 179 if (parser.parseOperandList(operandInfo, 3)) 180 return failure(); 181 182 StringRef mode; 183 if (parser.parseKeyword(&mode)) 184 return failure(); 185 state.addAttribute("mode", parser.getBuilder().getStringAttr(mode)); 186 187 Type valueType; 188 Type int32Type = parser.getBuilder().getIntegerType(32); 189 Type int1Type = parser.getBuilder().getI1Type(); 190 if (parser.parseColonType(valueType) || 191 parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type}, 192 parser.getCurrentLocation(), state.operands) || 193 parser.addTypesToList({valueType, int1Type}, state.types)) 194 return failure(); 195 return success(); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // LaunchOp 200 //===----------------------------------------------------------------------===// 201 202 void LaunchOp::build(OpBuilder &builder, OperationState &result, 203 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 204 Value blockSizeX, Value blockSizeY, Value blockSizeZ) { 205 // Add grid and block sizes as op operands, followed by the data operands. 206 result.addOperands( 207 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 208 209 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 210 // where the first kNumConfigRegionAttributes arguments have `index` type and 211 // the rest have the same types as the data operands. 212 Region *kernelRegion = result.addRegion(); 213 Block *body = new Block(); 214 body->addArguments( 215 std::vector<Type>(kNumConfigRegionAttributes, builder.getIndexType())); 216 kernelRegion->push_back(body); 217 } 218 219 KernelDim3 LaunchOp::getBlockIds() { 220 assert(!body().empty() && "LaunchOp body must not be empty."); 221 auto args = body().getArguments(); 222 return KernelDim3{args[0], args[1], args[2]}; 223 } 224 225 KernelDim3 LaunchOp::getThreadIds() { 226 assert(!body().empty() && "LaunchOp body must not be empty."); 227 auto args = body().getArguments(); 228 return KernelDim3{args[3], args[4], args[5]}; 229 } 230 231 KernelDim3 LaunchOp::getGridSize() { 232 assert(!body().empty() && "LaunchOp body must not be empty."); 233 auto args = body().getArguments(); 234 return KernelDim3{args[6], args[7], args[8]}; 235 } 236 237 KernelDim3 LaunchOp::getBlockSize() { 238 assert(!body().empty() && "LaunchOp body must not be empty."); 239 auto args = body().getArguments(); 240 return KernelDim3{args[9], args[10], args[11]}; 241 } 242 243 KernelDim3 LaunchOp::getGridSizeOperandValues() { 244 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 245 } 246 247 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 248 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 249 } 250 251 static LogicalResult verify(LaunchOp op) { 252 // Kernel launch takes kNumConfigOperands leading operands for grid/block 253 // sizes and transforms them into kNumConfigRegionAttributes region arguments 254 // for block/thread identifiers and grid/block sizes. 255 if (!op.body().empty()) { 256 if (op.body().getNumArguments() != 257 LaunchOp::kNumConfigOperands + op.getNumOperands()) 258 return op.emitOpError("unexpected number of region arguments"); 259 } 260 261 // Block terminators without successors are expected to exit the kernel region 262 // and must be `gpu.terminator`. 263 for (Block &block : op.body()) { 264 if (block.empty()) 265 continue; 266 if (block.back().getNumSuccessors() != 0) 267 continue; 268 if (!isa<gpu::TerminatorOp>(&block.back())) { 269 return block.back() 270 .emitError() 271 .append("expected '", gpu::TerminatorOp::getOperationName(), 272 "' or a terminator with successors") 273 .attachNote(op.getLoc()) 274 .append("in '", LaunchOp::getOperationName(), "' body region"); 275 } 276 } 277 278 return success(); 279 } 280 281 // Pretty-print the kernel grid/block size assignment as 282 // (%iter-x, %iter-y, %iter-z) in 283 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 284 // where %size-* and %iter-* will correspond to the body region arguments. 285 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 286 ValueRange operands, KernelDim3 ids) { 287 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 288 p << size.x << " = " << operands[0] << ", "; 289 p << size.y << " = " << operands[1] << ", "; 290 p << size.z << " = " << operands[2] << ')'; 291 } 292 293 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { 294 ValueRange operands = op.getOperands(); 295 296 // Print the launch configuration. 297 p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); 298 printSizeAssignment(p, op.getGridSize(), operands.take_front(3), 299 op.getBlockIds()); 300 p << ' ' << op.getThreadsKeyword(); 301 printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3), 302 op.getThreadIds()); 303 304 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 305 p.printOptionalAttrDict(op.getAttrs()); 306 } 307 308 // Parse the size assignment blocks for blocks and threads. These have the form 309 // (%region_arg, %region_arg, %region_arg) in 310 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 311 // where %region_arg are percent-identifiers for the region arguments to be 312 // introduced further (SSA defs), and %operand are percent-identifiers for the 313 // SSA value uses. 314 static ParseResult 315 parseSizeAssignment(OpAsmParser &parser, 316 MutableArrayRef<OpAsmParser::OperandType> sizes, 317 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 318 MutableArrayRef<OpAsmParser::OperandType> indices) { 319 assert(indices.size() == 3 && "space for three indices expected"); 320 SmallVector<OpAsmParser::OperandType, 3> args; 321 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 322 OpAsmParser::Delimiter::Paren) || 323 parser.parseKeyword("in") || parser.parseLParen()) 324 return failure(); 325 std::move(args.begin(), args.end(), indices.begin()); 326 327 for (int i = 0; i < 3; ++i) { 328 if (i != 0 && parser.parseComma()) 329 return failure(); 330 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 331 parser.parseOperand(sizes[i])) 332 return failure(); 333 } 334 335 return parser.parseRParen(); 336 } 337 338 // Parses a Launch operation. 339 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 340 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 341 // region attr-dict? 342 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 343 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { 344 // Sizes of the grid and block. 345 SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes( 346 LaunchOp::kNumConfigOperands); 347 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 348 349 // Actual (data) operands passed to the kernel. 350 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 351 352 // Region arguments to be created. 353 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 354 LaunchOp::kNumConfigRegionAttributes); 355 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 356 357 // Parse the size assignment segments: the first segment assigns grid sizes 358 // and defines values for block identifiers; the second segment assigns block 359 // sizes and defines values for thread identifiers. In the region argument 360 // list, identifiers precede sizes, and block-related values precede 361 // thread-related values. 362 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 363 parseSizeAssignment(parser, sizesRef.take_front(3), 364 regionArgsRef.slice(6, 3), 365 regionArgsRef.slice(0, 3)) || 366 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 367 parseSizeAssignment(parser, sizesRef.drop_front(3), 368 regionArgsRef.slice(9, 3), 369 regionArgsRef.slice(3, 3)) || 370 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 371 result.operands)) 372 return failure(); 373 374 // Introduce the body region and parse it. The region has 375 // kNumConfigRegionAttributes arguments that correspond to 376 // block/thread identifiers and grid/block sizes, all of the `index` type. 377 Type index = parser.getBuilder().getIndexType(); 378 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 379 LaunchOp::kNumConfigRegionAttributes, index); 380 Region *body = result.addRegion(); 381 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 382 parser.parseOptionalAttrDict(result.attributes)); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // LaunchFuncOp 387 //===----------------------------------------------------------------------===// 388 389 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 390 GPUFuncOp kernelFunc, Value gridSizeX, Value gridSizeY, 391 Value gridSizeZ, Value blockSizeX, Value blockSizeY, 392 Value blockSizeZ, ValueRange kernelOperands) { 393 // Add grid and block sizes as op operands, followed by the data operands. 394 result.addOperands( 395 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 396 result.addOperands(kernelOperands); 397 auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>(); 398 auto kernelSymbol = builder.getSymbolRefAttr( 399 kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); 400 result.addAttribute(getKernelAttrName(), kernelSymbol); 401 } 402 403 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 404 GPUFuncOp kernelFunc, KernelDim3 gridSize, 405 KernelDim3 blockSize, ValueRange kernelOperands) { 406 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 407 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 408 } 409 410 SymbolRefAttr LaunchFuncOp::kernel() { 411 return getAttrOfType<SymbolRefAttr>(getKernelAttrName()); 412 } 413 414 unsigned LaunchFuncOp::getNumKernelOperands() { 415 return getNumOperands() - kNumConfigOperands; 416 } 417 418 StringRef LaunchFuncOp::getKernelModuleName() { 419 return kernel().getRootReference(); 420 } 421 422 StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 423 424 Value LaunchFuncOp::getKernelOperand(unsigned i) { 425 return getOperation()->getOperand(i + kNumConfigOperands); 426 } 427 428 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 429 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 430 } 431 432 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 433 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 434 } 435 436 static LogicalResult verify(LaunchFuncOp op) { 437 auto module = op.getParentOfType<ModuleOp>(); 438 if (!module) 439 return op.emitOpError("expected to belong to a module"); 440 441 if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName())) 442 return op.emitOpError( 443 "expected the closest surrounding module to have the '" + 444 GPUDialect::getContainerModuleAttrName() + "' attribute"); 445 446 auto kernelAttr = op.getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); 447 if (!kernelAttr) 448 return op.emitOpError("symbol reference attribute '" + 449 op.getKernelAttrName() + "' must be specified"); 450 451 return success(); 452 } 453 454 //===----------------------------------------------------------------------===// 455 // GPUFuncOp 456 //===----------------------------------------------------------------------===// 457 458 /// Adds a new block argument that corresponds to buffers located in 459 /// workgroup memory. 460 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { 461 auto attrName = getNumWorkgroupAttributionsAttrName(); 462 auto attr = getAttrOfType<IntegerAttr>(attrName); 463 setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 464 return getBody().insertArgument(getType().getNumInputs() + attr.getInt(), 465 type); 466 } 467 468 /// Adds a new block argument that corresponds to buffers located in 469 /// private memory. 470 BlockArgument GPUFuncOp::addPrivateAttribution(Type type) { 471 // Buffers on the private memory always come after buffers on the workgroup 472 // memory. 473 return getBody().addArgument(type); 474 } 475 476 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 477 StringRef name, FunctionType type, 478 ArrayRef<Type> workgroupAttributions, 479 ArrayRef<Type> privateAttributions, 480 ArrayRef<NamedAttribute> attrs) { 481 result.addAttribute(SymbolTable::getSymbolAttrName(), 482 builder.getStringAttr(name)); 483 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 484 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 485 builder.getI64IntegerAttr(workgroupAttributions.size())); 486 result.addAttributes(attrs); 487 Region *body = result.addRegion(); 488 Block *entryBlock = new Block; 489 entryBlock->addArguments(type.getInputs()); 490 entryBlock->addArguments(workgroupAttributions); 491 entryBlock->addArguments(privateAttributions); 492 493 body->getBlocks().push_back(entryBlock); 494 } 495 496 /// Parses a GPU function memory attribution. 497 /// 498 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 499 /// (`private` `(` ssa-id-and-type-list `)`)? 500 /// 501 /// Note that this function parses only one of the two similar parts, with the 502 /// keyword provided as argument. 503 static ParseResult 504 parseAttributions(OpAsmParser &parser, StringRef keyword, 505 SmallVectorImpl<OpAsmParser::OperandType> &args, 506 SmallVectorImpl<Type> &argTypes) { 507 // If we could not parse the keyword, just assume empty list and succeed. 508 if (failed(parser.parseOptionalKeyword(keyword))) 509 return success(); 510 511 if (failed(parser.parseLParen())) 512 return failure(); 513 514 // Early exit for an empty list. 515 if (succeeded(parser.parseOptionalRParen())) 516 return success(); 517 518 do { 519 OpAsmParser::OperandType arg; 520 Type type; 521 522 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 523 return failure(); 524 525 args.push_back(arg); 526 argTypes.push_back(type); 527 } while (succeeded(parser.parseOptionalComma())); 528 529 return parser.parseRParen(); 530 } 531 532 /// Parses a GPU function. 533 /// 534 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 535 /// (`->` function-result-list)? memory-attribution `kernel`? 536 /// function-attributes? region 537 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { 538 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 539 SmallVector<NamedAttrList, 1> argAttrs; 540 SmallVector<NamedAttrList, 1> resultAttrs; 541 SmallVector<Type, 8> argTypes; 542 SmallVector<Type, 4> resultTypes; 543 bool isVariadic; 544 545 // Parse the function name. 546 StringAttr nameAttr; 547 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 548 result.attributes)) 549 return failure(); 550 551 auto signatureLocation = parser.getCurrentLocation(); 552 if (failed(impl::parseFunctionSignature( 553 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 554 isVariadic, resultTypes, resultAttrs))) 555 return failure(); 556 557 if (entryArgs.empty() && !argTypes.empty()) 558 return parser.emitError(signatureLocation) 559 << "gpu.func requires named arguments"; 560 561 // Construct the function type. More types will be added to the region, but 562 // not to the function type. 563 Builder &builder = parser.getBuilder(); 564 auto type = builder.getFunctionType(argTypes, resultTypes); 565 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 566 567 // Parse workgroup memory attributions. 568 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 569 entryArgs, argTypes))) 570 return failure(); 571 572 // Store the number of operands we just parsed as the number of workgroup 573 // memory attributions. 574 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 575 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 576 builder.getI64IntegerAttr(numWorkgroupAttrs)); 577 578 // Parse private memory attributions. 579 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 580 entryArgs, argTypes))) 581 return failure(); 582 583 // Parse the kernel attribute if present. 584 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 585 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 586 builder.getUnitAttr()); 587 588 // Parse attributes. 589 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 590 return failure(); 591 mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 592 593 // Parse the region. If no argument names were provided, take all names 594 // (including those of attributions) from the entry block. 595 auto *body = result.addRegion(); 596 return parser.parseRegion(*body, entryArgs, argTypes); 597 } 598 599 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 600 ArrayRef<BlockArgument> values) { 601 if (values.empty()) 602 return; 603 604 p << ' ' << keyword << '('; 605 llvm::interleaveComma( 606 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 607 p << ')'; 608 } 609 610 /// Prints a GPU Func op. 611 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { 612 p << GPUFuncOp::getOperationName() << ' '; 613 p.printSymbolName(op.getName()); 614 615 FunctionType type = op.getType(); 616 impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), 617 /*isVariadic=*/false, type.getResults()); 618 619 printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); 620 printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); 621 if (op.isKernel()) 622 p << ' ' << op.getKernelKeyword(); 623 624 impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), 625 type.getNumResults(), 626 {op.getNumWorkgroupAttributionsAttrName(), 627 GPUDialect::getKernelFuncAttrName()}); 628 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); 629 } 630 631 void GPUFuncOp::setType(FunctionType newType) { 632 auto oldType = getType(); 633 assert(newType.getNumResults() == oldType.getNumResults() && 634 "unimplemented: changes to the number of results"); 635 636 SmallVector<char, 16> nameBuf; 637 for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) 638 removeAttr(getArgAttrName(i, nameBuf)); 639 640 setAttr(getTypeAttrName(), TypeAttr::get(newType)); 641 } 642 643 /// Hook for FunctionLike verifier. 644 LogicalResult GPUFuncOp::verifyType() { 645 Type type = getTypeAttr().getValue(); 646 if (!type.isa<FunctionType>()) 647 return emitOpError("requires '" + getTypeAttrName() + 648 "' attribute of function type"); 649 650 if (isKernel() && getType().getNumResults() != 0) 651 return emitOpError() << "expected void return type for kernel function"; 652 653 return success(); 654 } 655 656 static LogicalResult verifyAttributions(Operation *op, 657 ArrayRef<BlockArgument> attributions, 658 unsigned memorySpace) { 659 for (Value v : attributions) { 660 auto type = v.getType().dyn_cast<MemRefType>(); 661 if (!type) 662 return op->emitOpError() << "expected memref type in attribution"; 663 664 if (type.getMemorySpace() != memorySpace) { 665 return op->emitOpError() 666 << "expected memory space " << memorySpace << " in attribution"; 667 } 668 } 669 return success(); 670 } 671 672 /// Verifies the body of the function. 673 LogicalResult GPUFuncOp::verifyBody() { 674 unsigned numFuncArguments = getNumArguments(); 675 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 676 unsigned numBlockArguments = front().getNumArguments(); 677 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 678 return emitOpError() << "expected at least " 679 << numFuncArguments + numWorkgroupAttributions 680 << " arguments to body region"; 681 682 ArrayRef<Type> funcArgTypes = getType().getInputs(); 683 for (unsigned i = 0; i < numFuncArguments; ++i) { 684 Type blockArgType = front().getArgument(i).getType(); 685 if (funcArgTypes[i] != blockArgType) 686 return emitOpError() << "expected body region argument #" << i 687 << " to be of type " << funcArgTypes[i] << ", got " 688 << blockArgType; 689 } 690 691 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 692 GPUDialect::getWorkgroupAddressSpace())) || 693 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 694 GPUDialect::getPrivateAddressSpace()))) 695 return failure(); 696 697 return success(); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // ReturnOp 702 //===----------------------------------------------------------------------===// 703 704 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 705 llvm::SmallVector<OpAsmParser::OperandType, 4> operands; 706 llvm::SmallVector<Type, 4> types; 707 if (parser.parseOperandList(operands) || 708 parser.parseOptionalColonTypeList(types) || 709 parser.resolveOperands(operands, types, parser.getCurrentLocation(), 710 result.operands)) 711 return failure(); 712 713 return success(); 714 } 715 716 static LogicalResult verify(gpu::ReturnOp returnOp) { 717 GPUFuncOp function = returnOp.getParentOfType<GPUFuncOp>(); 718 719 FunctionType funType = function.getType(); 720 721 if (funType.getNumResults() != returnOp.operands().size()) 722 return returnOp.emitOpError() 723 .append("expected ", funType.getNumResults(), " result operands") 724 .attachNote(function.getLoc()) 725 .append("return type declared here"); 726 727 for (auto pair : llvm::enumerate( 728 llvm::zip(function.getType().getResults(), returnOp.operands()))) { 729 Type type; 730 Value operand; 731 std::tie(type, operand) = pair.value(); 732 if (type != operand.getType()) 733 return returnOp.emitOpError() << "unexpected type `" << operand.getType() 734 << "' for operand #" << pair.index(); 735 } 736 return success(); 737 } 738 739 //===----------------------------------------------------------------------===// 740 // GPUModuleOp 741 //===----------------------------------------------------------------------===// 742 743 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 744 StringRef name) { 745 ensureTerminator(*result.addRegion(), builder, result.location); 746 result.attributes.push_back(builder.getNamedAttr( 747 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 748 } 749 750 static ParseResult parseGPUModuleOp(OpAsmParser &parser, 751 OperationState &result) { 752 StringAttr nameAttr; 753 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 754 result.attributes)) 755 return failure(); 756 757 // If module attributes are present, parse them. 758 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 759 return failure(); 760 761 // Parse the module body. 762 auto *body = result.addRegion(); 763 if (parser.parseRegion(*body, None, None)) 764 return failure(); 765 766 // Ensure that this module has a valid terminator. 767 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 768 return success(); 769 } 770 771 static void print(OpAsmPrinter &p, GPUModuleOp op) { 772 p << op.getOperationName() << ' '; 773 p.printSymbolName(op.getName()); 774 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 775 {SymbolTable::getSymbolAttrName()}); 776 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 777 /*printBlockTerminators=*/false); 778 } 779 780 #define GET_OP_CLASSES 781 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 782