1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 using namespace mlir; 28 using namespace mlir::gpu; 29 30 //===----------------------------------------------------------------------===// 31 // GPUDialect 32 //===----------------------------------------------------------------------===// 33 34 bool GPUDialect::isKernel(Operation *op) { 35 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 36 return static_cast<bool>(isKernelAttr); 37 } 38 39 void GPUDialect::initialize() { 40 addTypes<AsyncTokenType>(); 41 addOperations< 42 #define GET_OP_LIST 43 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 44 >(); 45 } 46 47 Type GPUDialect::parseType(DialectAsmParser &parser) const { 48 // Parse the main keyword for the type. 49 StringRef keyword; 50 if (parser.parseKeyword(&keyword)) 51 return Type(); 52 MLIRContext *context = getContext(); 53 54 // Handle 'async token' types. 55 if (keyword == "async.token") 56 return AsyncTokenType::get(context); 57 58 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 59 return Type(); 60 } 61 62 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 63 TypeSwitch<Type>(type) 64 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 65 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 66 } 67 68 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 69 NamedAttribute attr) { 70 if (!attr.second.isa<UnitAttr>() || 71 attr.first != getContainerModuleAttrName()) 72 return success(); 73 74 auto module = dyn_cast<ModuleOp>(op); 75 if (!module) 76 return op->emitError("expected '") 77 << getContainerModuleAttrName() << "' attribute to be attached to '" 78 << ModuleOp::getOperationName() << '\''; 79 80 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 81 // Ignore launches that are nested more or less deep than functions in the 82 // module we are currently checking. 83 if (!launchOp.getParentOp() || 84 launchOp.getParentOp()->getParentOp() != module) 85 return success(); 86 87 // Ignore launch ops with missing attributes here. The errors will be 88 // reported by the verifiers of those ops. 89 if (!launchOp.getAttrOfType<SymbolRefAttr>( 90 LaunchFuncOp::getKernelAttrName())) 91 return success(); 92 93 // Check that `launch_func` refers to a well-formed GPU kernel module. 94 StringRef kernelModuleName = launchOp.getKernelModuleName(); 95 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 96 if (!kernelModule) 97 return launchOp.emitOpError() 98 << "kernel module '" << kernelModuleName << "' is undefined"; 99 100 // Check that `launch_func` refers to a well-formed kernel function. 101 Operation *kernelFunc = module.lookupSymbol(launchOp.kernel()); 102 auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc); 103 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 104 if (!kernelGPUFunction && !kernelLLVMFunction) 105 return launchOp.emitOpError("kernel function '") 106 << launchOp.kernel() << "' is undefined"; 107 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 108 GPUDialect::getKernelFuncAttrName())) 109 return launchOp.emitOpError("kernel function is missing the '") 110 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 111 112 // TODO: if the kernel function has been converted to 113 // the LLVM dialect but the caller hasn't (which happens during the 114 // separate compilation), do not check type correspondence as it would 115 // require the verifier to be aware of the LLVM type conversion. 116 if (kernelLLVMFunction) 117 return success(); 118 119 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 120 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 121 if (expectedNumArguments != actualNumArguments) 122 return launchOp.emitOpError("got ") 123 << actualNumArguments << " kernel operands but expected " 124 << expectedNumArguments; 125 126 auto functionType = kernelGPUFunction.getType(); 127 for (unsigned i = 0; i < expectedNumArguments; ++i) { 128 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 129 return launchOp.emitOpError("type of function argument ") 130 << i << " does not match"; 131 } 132 } 133 134 return success(); 135 }); 136 137 return walkResult.wasInterrupted() ? failure() : success(); 138 } 139 140 template <typename T> static LogicalResult verifyIndexOp(T op) { 141 auto dimension = op.dimension(); 142 if (dimension != "x" && dimension != "y" && dimension != "z") 143 return op.emitError("dimension \"") << dimension << "\" is invalid"; 144 return success(); 145 } 146 147 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 148 if (allReduce.body().empty() != allReduce.op().hasValue()) 149 return allReduce.emitError( 150 "expected either an op attribute or a non-empty body"); 151 if (!allReduce.body().empty()) { 152 if (allReduce.body().getNumArguments() != 2) 153 return allReduce.emitError("expected two region arguments"); 154 for (auto argument : allReduce.body().getArguments()) { 155 if (argument.getType() != allReduce.getType()) 156 return allReduce.emitError("incorrect region argument type"); 157 } 158 unsigned yieldCount = 0; 159 for (Block &block : allReduce.body()) { 160 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 161 if (yield.getNumOperands() != 1) 162 return allReduce.emitError("expected one gpu.yield operand"); 163 if (yield.getOperand(0).getType() != allReduce.getType()) 164 return allReduce.emitError("incorrect gpu.yield type"); 165 ++yieldCount; 166 } 167 } 168 if (yieldCount == 0) 169 return allReduce.emitError("expected gpu.yield op in region"); 170 } else { 171 StringRef opName = *allReduce.op(); 172 if ((opName == "and" || opName == "or" || opName == "xor") && 173 !allReduce.getType().isa<IntegerType>()) { 174 return allReduce.emitError() 175 << '`' << opName << '`' 176 << " accumulator is only compatible with Integer type"; 177 } 178 } 179 return success(); 180 } 181 182 static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { 183 auto type = shuffleOp.value().getType(); 184 if (shuffleOp.result().getType() != type) { 185 return shuffleOp.emitOpError() 186 << "requires the same type for value operand and result"; 187 } 188 if (!type.isSignlessIntOrFloat() || type.getIntOrFloatBitWidth() != 32) { 189 return shuffleOp.emitOpError() 190 << "requires value operand type to be f32 or i32"; 191 } 192 return success(); 193 } 194 195 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { 196 p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' 197 << op.mode() << " : " << op.value().getType(); 198 } 199 200 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { 201 SmallVector<OpAsmParser::OperandType, 3> operandInfo; 202 if (parser.parseOperandList(operandInfo, 3)) 203 return failure(); 204 205 StringRef mode; 206 if (parser.parseKeyword(&mode)) 207 return failure(); 208 state.addAttribute("mode", parser.getBuilder().getStringAttr(mode)); 209 210 Type valueType; 211 Type int32Type = parser.getBuilder().getIntegerType(32); 212 Type int1Type = parser.getBuilder().getI1Type(); 213 if (parser.parseColonType(valueType) || 214 parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type}, 215 parser.getCurrentLocation(), state.operands) || 216 parser.addTypesToList({valueType, int1Type}, state.types)) 217 return failure(); 218 return success(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // AsyncOpInterface 223 //===----------------------------------------------------------------------===// 224 225 void gpu::addAsyncDependency(Operation *op, Value token) { 226 op->insertOperands(0, {token}); 227 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 228 return; 229 auto attrName = 230 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 231 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName); 232 if (!sizeAttr) 233 return; // Async dependencies is the only variadic operand. 234 SmallVector<int32_t, 8> sizes; 235 for (auto size : sizeAttr.getIntValues()) 236 sizes.push_back(size.getSExtValue()); 237 ++sizes.front(); 238 op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // LaunchOp 243 //===----------------------------------------------------------------------===// 244 245 void LaunchOp::build(OpBuilder &builder, OperationState &result, 246 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 247 Value blockSizeX, Value blockSizeY, Value blockSizeZ) { 248 // Add grid and block sizes as op operands, followed by the data operands. 249 result.addOperands( 250 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 251 252 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 253 // where the first kNumConfigRegionAttributes arguments have `index` type and 254 // the rest have the same types as the data operands. 255 Region *kernelRegion = result.addRegion(); 256 Block *body = new Block(); 257 body->addArguments( 258 std::vector<Type>(kNumConfigRegionAttributes, builder.getIndexType())); 259 kernelRegion->push_back(body); 260 } 261 262 KernelDim3 LaunchOp::getBlockIds() { 263 assert(!body().empty() && "LaunchOp body must not be empty."); 264 auto args = body().getArguments(); 265 return KernelDim3{args[0], args[1], args[2]}; 266 } 267 268 KernelDim3 LaunchOp::getThreadIds() { 269 assert(!body().empty() && "LaunchOp body must not be empty."); 270 auto args = body().getArguments(); 271 return KernelDim3{args[3], args[4], args[5]}; 272 } 273 274 KernelDim3 LaunchOp::getGridSize() { 275 assert(!body().empty() && "LaunchOp body must not be empty."); 276 auto args = body().getArguments(); 277 return KernelDim3{args[6], args[7], args[8]}; 278 } 279 280 KernelDim3 LaunchOp::getBlockSize() { 281 assert(!body().empty() && "LaunchOp body must not be empty."); 282 auto args = body().getArguments(); 283 return KernelDim3{args[9], args[10], args[11]}; 284 } 285 286 KernelDim3 LaunchOp::getGridSizeOperandValues() { 287 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 288 } 289 290 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 291 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 292 } 293 294 static LogicalResult verify(LaunchOp op) { 295 // Kernel launch takes kNumConfigOperands leading operands for grid/block 296 // sizes and transforms them into kNumConfigRegionAttributes region arguments 297 // for block/thread identifiers and grid/block sizes. 298 if (!op.body().empty()) { 299 if (op.body().getNumArguments() != 300 LaunchOp::kNumConfigOperands + op.getNumOperands()) 301 return op.emitOpError("unexpected number of region arguments"); 302 } 303 304 // Block terminators without successors are expected to exit the kernel region 305 // and must be `gpu.terminator`. 306 for (Block &block : op.body()) { 307 if (block.empty()) 308 continue; 309 if (block.back().getNumSuccessors() != 0) 310 continue; 311 if (!isa<gpu::TerminatorOp>(&block.back())) { 312 return block.back() 313 .emitError() 314 .append("expected '", gpu::TerminatorOp::getOperationName(), 315 "' or a terminator with successors") 316 .attachNote(op.getLoc()) 317 .append("in '", LaunchOp::getOperationName(), "' body region"); 318 } 319 } 320 321 return success(); 322 } 323 324 // Pretty-print the kernel grid/block size assignment as 325 // (%iter-x, %iter-y, %iter-z) in 326 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 327 // where %size-* and %iter-* will correspond to the body region arguments. 328 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 329 KernelDim3 operands, KernelDim3 ids) { 330 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 331 p << size.x << " = " << operands.x << ", "; 332 p << size.y << " = " << operands.y << ", "; 333 p << size.z << " = " << operands.z << ')'; 334 } 335 336 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { 337 // Print the launch configuration. 338 p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); 339 printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(), 340 op.getBlockIds()); 341 p << ' ' << op.getThreadsKeyword(); 342 printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(), 343 op.getThreadIds()); 344 345 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 346 p.printOptionalAttrDict(op.getAttrs()); 347 } 348 349 // Parse the size assignment blocks for blocks and threads. These have the form 350 // (%region_arg, %region_arg, %region_arg) in 351 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 352 // where %region_arg are percent-identifiers for the region arguments to be 353 // introduced further (SSA defs), and %operand are percent-identifiers for the 354 // SSA value uses. 355 static ParseResult 356 parseSizeAssignment(OpAsmParser &parser, 357 MutableArrayRef<OpAsmParser::OperandType> sizes, 358 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 359 MutableArrayRef<OpAsmParser::OperandType> indices) { 360 assert(indices.size() == 3 && "space for three indices expected"); 361 SmallVector<OpAsmParser::OperandType, 3> args; 362 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 363 OpAsmParser::Delimiter::Paren) || 364 parser.parseKeyword("in") || parser.parseLParen()) 365 return failure(); 366 std::move(args.begin(), args.end(), indices.begin()); 367 368 for (int i = 0; i < 3; ++i) { 369 if (i != 0 && parser.parseComma()) 370 return failure(); 371 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 372 parser.parseOperand(sizes[i])) 373 return failure(); 374 } 375 376 return parser.parseRParen(); 377 } 378 379 // Parses a Launch operation. 380 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 381 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 382 // region attr-dict? 383 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 384 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { 385 // Sizes of the grid and block. 386 SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes( 387 LaunchOp::kNumConfigOperands); 388 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 389 390 // Actual (data) operands passed to the kernel. 391 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 392 393 // Region arguments to be created. 394 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 395 LaunchOp::kNumConfigRegionAttributes); 396 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 397 398 // Parse the size assignment segments: the first segment assigns grid sizes 399 // and defines values for block identifiers; the second segment assigns block 400 // sizes and defines values for thread identifiers. In the region argument 401 // list, identifiers precede sizes, and block-related values precede 402 // thread-related values. 403 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 404 parseSizeAssignment(parser, sizesRef.take_front(3), 405 regionArgsRef.slice(6, 3), 406 regionArgsRef.slice(0, 3)) || 407 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 408 parseSizeAssignment(parser, sizesRef.drop_front(3), 409 regionArgsRef.slice(9, 3), 410 regionArgsRef.slice(3, 3)) || 411 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 412 result.operands)) 413 return failure(); 414 415 // Introduce the body region and parse it. The region has 416 // kNumConfigRegionAttributes arguments that correspond to 417 // block/thread identifiers and grid/block sizes, all of the `index` type. 418 Type index = parser.getBuilder().getIndexType(); 419 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 420 LaunchOp::kNumConfigRegionAttributes, index); 421 Region *body = result.addRegion(); 422 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 423 parser.parseOptionalAttrDict(result.attributes)); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // LaunchFuncOp 428 //===----------------------------------------------------------------------===// 429 430 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 431 GPUFuncOp kernelFunc, Value gridSizeX, Value gridSizeY, 432 Value gridSizeZ, Value blockSizeX, Value blockSizeY, 433 Value blockSizeZ, ValueRange kernelOperands) { 434 // Add grid and block sizes as op operands, followed by the data operands. 435 result.addOperands( 436 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 437 result.addOperands(kernelOperands); 438 auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>(); 439 auto kernelSymbol = builder.getSymbolRefAttr( 440 kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); 441 result.addAttribute(getKernelAttrName(), kernelSymbol); 442 } 443 444 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 445 GPUFuncOp kernelFunc, KernelDim3 gridSize, 446 KernelDim3 blockSize, ValueRange kernelOperands) { 447 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 448 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 449 } 450 451 SymbolRefAttr LaunchFuncOp::kernel() { 452 return getAttrOfType<SymbolRefAttr>(getKernelAttrName()); 453 } 454 455 unsigned LaunchFuncOp::getNumKernelOperands() { 456 return getNumOperands() - kNumConfigOperands; 457 } 458 459 StringRef LaunchFuncOp::getKernelModuleName() { 460 return kernel().getRootReference(); 461 } 462 463 StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 464 465 Value LaunchFuncOp::getKernelOperand(unsigned i) { 466 return getOperation()->getOperand(i + kNumConfigOperands); 467 } 468 469 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 470 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 471 } 472 473 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 474 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 475 } 476 477 static LogicalResult verify(LaunchFuncOp op) { 478 auto module = op.getParentOfType<ModuleOp>(); 479 if (!module) 480 return op.emitOpError("expected to belong to a module"); 481 482 if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName())) 483 return op.emitOpError( 484 "expected the closest surrounding module to have the '" + 485 GPUDialect::getContainerModuleAttrName() + "' attribute"); 486 487 auto kernelAttr = op.getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); 488 if (!kernelAttr) 489 return op.emitOpError("symbol reference attribute '" + 490 op.getKernelAttrName() + "' must be specified"); 491 492 return success(); 493 } 494 495 //===----------------------------------------------------------------------===// 496 // GPUFuncOp 497 //===----------------------------------------------------------------------===// 498 499 /// Adds a new block argument that corresponds to buffers located in 500 /// workgroup memory. 501 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { 502 auto attrName = getNumWorkgroupAttributionsAttrName(); 503 auto attr = getAttrOfType<IntegerAttr>(attrName); 504 setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 505 return getBody().insertArgument(getType().getNumInputs() + attr.getInt(), 506 type); 507 } 508 509 /// Adds a new block argument that corresponds to buffers located in 510 /// private memory. 511 BlockArgument GPUFuncOp::addPrivateAttribution(Type type) { 512 // Buffers on the private memory always come after buffers on the workgroup 513 // memory. 514 return getBody().addArgument(type); 515 } 516 517 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 518 StringRef name, FunctionType type, 519 TypeRange workgroupAttributions, 520 TypeRange privateAttributions, 521 ArrayRef<NamedAttribute> attrs) { 522 result.addAttribute(SymbolTable::getSymbolAttrName(), 523 builder.getStringAttr(name)); 524 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 525 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 526 builder.getI64IntegerAttr(workgroupAttributions.size())); 527 result.addAttributes(attrs); 528 Region *body = result.addRegion(); 529 Block *entryBlock = new Block; 530 entryBlock->addArguments(type.getInputs()); 531 entryBlock->addArguments(workgroupAttributions); 532 entryBlock->addArguments(privateAttributions); 533 534 body->getBlocks().push_back(entryBlock); 535 } 536 537 /// Parses a GPU function memory attribution. 538 /// 539 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 540 /// (`private` `(` ssa-id-and-type-list `)`)? 541 /// 542 /// Note that this function parses only one of the two similar parts, with the 543 /// keyword provided as argument. 544 static ParseResult 545 parseAttributions(OpAsmParser &parser, StringRef keyword, 546 SmallVectorImpl<OpAsmParser::OperandType> &args, 547 SmallVectorImpl<Type> &argTypes) { 548 // If we could not parse the keyword, just assume empty list and succeed. 549 if (failed(parser.parseOptionalKeyword(keyword))) 550 return success(); 551 552 if (failed(parser.parseLParen())) 553 return failure(); 554 555 // Early exit for an empty list. 556 if (succeeded(parser.parseOptionalRParen())) 557 return success(); 558 559 do { 560 OpAsmParser::OperandType arg; 561 Type type; 562 563 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 564 return failure(); 565 566 args.push_back(arg); 567 argTypes.push_back(type); 568 } while (succeeded(parser.parseOptionalComma())); 569 570 return parser.parseRParen(); 571 } 572 573 /// Parses a GPU function. 574 /// 575 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 576 /// (`->` function-result-list)? memory-attribution `kernel`? 577 /// function-attributes? region 578 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { 579 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 580 SmallVector<NamedAttrList, 1> argAttrs; 581 SmallVector<NamedAttrList, 1> resultAttrs; 582 SmallVector<Type, 8> argTypes; 583 SmallVector<Type, 4> resultTypes; 584 bool isVariadic; 585 586 // Parse the function name. 587 StringAttr nameAttr; 588 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 589 result.attributes)) 590 return failure(); 591 592 auto signatureLocation = parser.getCurrentLocation(); 593 if (failed(impl::parseFunctionSignature( 594 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 595 isVariadic, resultTypes, resultAttrs))) 596 return failure(); 597 598 if (entryArgs.empty() && !argTypes.empty()) 599 return parser.emitError(signatureLocation) 600 << "gpu.func requires named arguments"; 601 602 // Construct the function type. More types will be added to the region, but 603 // not to the function type. 604 Builder &builder = parser.getBuilder(); 605 auto type = builder.getFunctionType(argTypes, resultTypes); 606 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 607 608 // Parse workgroup memory attributions. 609 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 610 entryArgs, argTypes))) 611 return failure(); 612 613 // Store the number of operands we just parsed as the number of workgroup 614 // memory attributions. 615 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 616 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 617 builder.getI64IntegerAttr(numWorkgroupAttrs)); 618 619 // Parse private memory attributions. 620 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 621 entryArgs, argTypes))) 622 return failure(); 623 624 // Parse the kernel attribute if present. 625 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 626 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 627 builder.getUnitAttr()); 628 629 // Parse attributes. 630 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 631 return failure(); 632 mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 633 634 // Parse the region. If no argument names were provided, take all names 635 // (including those of attributions) from the entry block. 636 auto *body = result.addRegion(); 637 return parser.parseRegion(*body, entryArgs, argTypes); 638 } 639 640 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 641 ArrayRef<BlockArgument> values) { 642 if (values.empty()) 643 return; 644 645 p << ' ' << keyword << '('; 646 llvm::interleaveComma( 647 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 648 p << ')'; 649 } 650 651 /// Prints a GPU Func op. 652 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { 653 p << GPUFuncOp::getOperationName() << ' '; 654 p.printSymbolName(op.getName()); 655 656 FunctionType type = op.getType(); 657 impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), 658 /*isVariadic=*/false, type.getResults()); 659 660 printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); 661 printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); 662 if (op.isKernel()) 663 p << ' ' << op.getKernelKeyword(); 664 665 impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), 666 type.getNumResults(), 667 {op.getNumWorkgroupAttributionsAttrName(), 668 GPUDialect::getKernelFuncAttrName()}); 669 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); 670 } 671 672 void GPUFuncOp::setType(FunctionType newType) { 673 auto oldType = getType(); 674 assert(newType.getNumResults() == oldType.getNumResults() && 675 "unimplemented: changes to the number of results"); 676 677 SmallVector<char, 16> nameBuf; 678 for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) 679 removeAttr(getArgAttrName(i, nameBuf)); 680 681 setAttr(getTypeAttrName(), TypeAttr::get(newType)); 682 } 683 684 /// Hook for FunctionLike verifier. 685 LogicalResult GPUFuncOp::verifyType() { 686 Type type = getTypeAttr().getValue(); 687 if (!type.isa<FunctionType>()) 688 return emitOpError("requires '" + getTypeAttrName() + 689 "' attribute of function type"); 690 691 if (isKernel() && getType().getNumResults() != 0) 692 return emitOpError() << "expected void return type for kernel function"; 693 694 return success(); 695 } 696 697 static LogicalResult verifyAttributions(Operation *op, 698 ArrayRef<BlockArgument> attributions, 699 unsigned memorySpace) { 700 for (Value v : attributions) { 701 auto type = v.getType().dyn_cast<MemRefType>(); 702 if (!type) 703 return op->emitOpError() << "expected memref type in attribution"; 704 705 if (type.getMemorySpace() != memorySpace) { 706 return op->emitOpError() 707 << "expected memory space " << memorySpace << " in attribution"; 708 } 709 } 710 return success(); 711 } 712 713 /// Verifies the body of the function. 714 LogicalResult GPUFuncOp::verifyBody() { 715 unsigned numFuncArguments = getNumArguments(); 716 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 717 unsigned numBlockArguments = front().getNumArguments(); 718 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 719 return emitOpError() << "expected at least " 720 << numFuncArguments + numWorkgroupAttributions 721 << " arguments to body region"; 722 723 ArrayRef<Type> funcArgTypes = getType().getInputs(); 724 for (unsigned i = 0; i < numFuncArguments; ++i) { 725 Type blockArgType = front().getArgument(i).getType(); 726 if (funcArgTypes[i] != blockArgType) 727 return emitOpError() << "expected body region argument #" << i 728 << " to be of type " << funcArgTypes[i] << ", got " 729 << blockArgType; 730 } 731 732 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 733 GPUDialect::getWorkgroupAddressSpace())) || 734 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 735 GPUDialect::getPrivateAddressSpace()))) 736 return failure(); 737 738 return success(); 739 } 740 741 //===----------------------------------------------------------------------===// 742 // ReturnOp 743 //===----------------------------------------------------------------------===// 744 745 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 746 llvm::SmallVector<OpAsmParser::OperandType, 4> operands; 747 llvm::SmallVector<Type, 4> types; 748 if (parser.parseOperandList(operands) || 749 parser.parseOptionalColonTypeList(types) || 750 parser.resolveOperands(operands, types, parser.getCurrentLocation(), 751 result.operands)) 752 return failure(); 753 754 return success(); 755 } 756 757 static LogicalResult verify(gpu::ReturnOp returnOp) { 758 GPUFuncOp function = returnOp.getParentOfType<GPUFuncOp>(); 759 760 FunctionType funType = function.getType(); 761 762 if (funType.getNumResults() != returnOp.operands().size()) 763 return returnOp.emitOpError() 764 .append("expected ", funType.getNumResults(), " result operands") 765 .attachNote(function.getLoc()) 766 .append("return type declared here"); 767 768 for (auto pair : llvm::enumerate( 769 llvm::zip(function.getType().getResults(), returnOp.operands()))) { 770 Type type; 771 Value operand; 772 std::tie(type, operand) = pair.value(); 773 if (type != operand.getType()) 774 return returnOp.emitOpError() << "unexpected type `" << operand.getType() 775 << "' for operand #" << pair.index(); 776 } 777 return success(); 778 } 779 780 //===----------------------------------------------------------------------===// 781 // GPUModuleOp 782 //===----------------------------------------------------------------------===// 783 784 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 785 StringRef name) { 786 ensureTerminator(*result.addRegion(), builder, result.location); 787 result.attributes.push_back(builder.getNamedAttr( 788 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 789 } 790 791 static ParseResult parseGPUModuleOp(OpAsmParser &parser, 792 OperationState &result) { 793 StringAttr nameAttr; 794 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 795 result.attributes)) 796 return failure(); 797 798 // If module attributes are present, parse them. 799 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 800 return failure(); 801 802 // Parse the module body. 803 auto *body = result.addRegion(); 804 if (parser.parseRegion(*body, None, None)) 805 return failure(); 806 807 // Ensure that this module has a valid terminator. 808 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 809 return success(); 810 } 811 812 static void print(OpAsmPrinter &p, GPUModuleOp op) { 813 p << op.getOperationName() << ' '; 814 p.printSymbolName(op.getName()); 815 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 816 {SymbolTable::getSymbolAttrName()}); 817 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 818 /*printBlockTerminators=*/false); 819 } 820 821 static ParseResult parseAsyncDependencies( 822 OpAsmParser &parser, Type &asyncTokenType, 823 SmallVectorImpl<OpAsmParser::OperandType> &asyncDependencies) { 824 auto loc = parser.getCurrentLocation(); 825 if (succeeded(parser.parseOptionalKeyword("async"))) { 826 if (parser.getNumResults() == 0) 827 return parser.emitError(loc, "needs to be named when marked 'async'"); 828 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 829 } 830 return parser.parseOperandList(asyncDependencies, 831 OpAsmParser::Delimiter::OptionalSquare); 832 } 833 834 static void printAsyncDependencies(OpAsmPrinter &printer, Type asyncTokenType, 835 OperandRange asyncDependencies) { 836 if (asyncTokenType) 837 printer << "async "; 838 if (asyncDependencies.empty()) 839 return; 840 printer << "["; 841 llvm::interleaveComma(asyncDependencies, printer); 842 printer << "]"; 843 } 844 845 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" 846 847 #define GET_OP_CLASSES 848 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 849