1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/FunctionImplementation.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 using namespace mlir; 28 using namespace mlir::gpu; 29 30 //===----------------------------------------------------------------------===// 31 // GPUDialect 32 //===----------------------------------------------------------------------===// 33 34 bool GPUDialect::isKernel(Operation *op) { 35 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 36 return static_cast<bool>(isKernelAttr); 37 } 38 39 void GPUDialect::initialize() { 40 addTypes<AsyncTokenType>(); 41 addOperations< 42 #define GET_OP_LIST 43 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 44 >(); 45 } 46 47 Type GPUDialect::parseType(DialectAsmParser &parser) const { 48 // Parse the main keyword for the type. 49 StringRef keyword; 50 if (parser.parseKeyword(&keyword)) 51 return Type(); 52 MLIRContext *context = getContext(); 53 54 // Handle 'async token' types. 55 if (keyword == "async.token") 56 return AsyncTokenType::get(context); 57 58 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 59 return Type(); 60 } 61 62 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 63 TypeSwitch<Type>(type) 64 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 65 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 66 } 67 68 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 69 NamedAttribute attr) { 70 if (!attr.second.isa<UnitAttr>() || 71 attr.first != getContainerModuleAttrName()) 72 return success(); 73 74 auto module = dyn_cast<ModuleOp>(op); 75 if (!module) 76 return op->emitError("expected '") 77 << getContainerModuleAttrName() << "' attribute to be attached to '" 78 << ModuleOp::getOperationName() << '\''; 79 80 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 81 // Ignore launches that are nested more or less deep than functions in the 82 // module we are currently checking. 83 if (!launchOp->getParentOp() || 84 launchOp->getParentOp()->getParentOp() != module) 85 return success(); 86 87 // Ignore launch ops with missing attributes here. The errors will be 88 // reported by the verifiers of those ops. 89 if (!launchOp->getAttrOfType<SymbolRefAttr>( 90 LaunchFuncOp::getKernelAttrName())) 91 return success(); 92 93 // Check that `launch_func` refers to a well-formed GPU kernel module. 94 StringRef kernelModuleName = launchOp.getKernelModuleName(); 95 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 96 if (!kernelModule) 97 return launchOp.emitOpError() 98 << "kernel module '" << kernelModuleName << "' is undefined"; 99 100 // Check that `launch_func` refers to a well-formed kernel function. 101 Operation *kernelFunc = module.lookupSymbol(launchOp.kernel()); 102 auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc); 103 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 104 if (!kernelGPUFunction && !kernelLLVMFunction) 105 return launchOp.emitOpError("kernel function '") 106 << launchOp.kernel() << "' is undefined"; 107 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 108 GPUDialect::getKernelFuncAttrName())) 109 return launchOp.emitOpError("kernel function is missing the '") 110 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 111 112 // TODO: if the kernel function has been converted to 113 // the LLVM dialect but the caller hasn't (which happens during the 114 // separate compilation), do not check type correspondence as it would 115 // require the verifier to be aware of the LLVM type conversion. 116 if (kernelLLVMFunction) 117 return success(); 118 119 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 120 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 121 if (expectedNumArguments != actualNumArguments) 122 return launchOp.emitOpError("got ") 123 << actualNumArguments << " kernel operands but expected " 124 << expectedNumArguments; 125 126 auto functionType = kernelGPUFunction.getType(); 127 for (unsigned i = 0; i < expectedNumArguments; ++i) { 128 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 129 return launchOp.emitOpError("type of function argument ") 130 << i << " does not match"; 131 } 132 } 133 134 return success(); 135 }); 136 137 return walkResult.wasInterrupted() ? failure() : success(); 138 } 139 140 template <typename T> static LogicalResult verifyIndexOp(T op) { 141 auto dimension = op.dimension(); 142 if (dimension != "x" && dimension != "y" && dimension != "z") 143 return op.emitError("dimension \"") << dimension << "\" is invalid"; 144 return success(); 145 } 146 147 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 148 if (allReduce.body().empty() != allReduce.op().hasValue()) 149 return allReduce.emitError( 150 "expected either an op attribute or a non-empty body"); 151 if (!allReduce.body().empty()) { 152 if (allReduce.body().getNumArguments() != 2) 153 return allReduce.emitError("expected two region arguments"); 154 for (auto argument : allReduce.body().getArguments()) { 155 if (argument.getType() != allReduce.getType()) 156 return allReduce.emitError("incorrect region argument type"); 157 } 158 unsigned yieldCount = 0; 159 for (Block &block : allReduce.body()) { 160 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 161 if (yield.getNumOperands() != 1) 162 return allReduce.emitError("expected one gpu.yield operand"); 163 if (yield.getOperand(0).getType() != allReduce.getType()) 164 return allReduce.emitError("incorrect gpu.yield type"); 165 ++yieldCount; 166 } 167 } 168 if (yieldCount == 0) 169 return allReduce.emitError("expected gpu.yield op in region"); 170 } else { 171 StringRef opName = *allReduce.op(); 172 if ((opName == "and" || opName == "or" || opName == "xor") && 173 !allReduce.getType().isa<IntegerType>()) { 174 return allReduce.emitError() 175 << '`' << opName << '`' 176 << " accumulator is only compatible with Integer type"; 177 } 178 } 179 return success(); 180 } 181 182 static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { 183 auto type = shuffleOp.value().getType(); 184 if (shuffleOp.result().getType() != type) { 185 return shuffleOp.emitOpError() 186 << "requires the same type for value operand and result"; 187 } 188 if (!type.isSignlessIntOrFloat() || type.getIntOrFloatBitWidth() != 32) { 189 return shuffleOp.emitOpError() 190 << "requires value operand type to be f32 or i32"; 191 } 192 return success(); 193 } 194 195 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { 196 p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' 197 << op.mode() << " : " << op.value().getType(); 198 } 199 200 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { 201 SmallVector<OpAsmParser::OperandType, 3> operandInfo; 202 if (parser.parseOperandList(operandInfo, 3)) 203 return failure(); 204 205 StringRef mode; 206 if (parser.parseKeyword(&mode)) 207 return failure(); 208 state.addAttribute("mode", parser.getBuilder().getStringAttr(mode)); 209 210 Type valueType; 211 Type int32Type = parser.getBuilder().getIntegerType(32); 212 Type int1Type = parser.getBuilder().getI1Type(); 213 if (parser.parseColonType(valueType) || 214 parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type}, 215 parser.getCurrentLocation(), state.operands) || 216 parser.addTypesToList({valueType, int1Type}, state.types)) 217 return failure(); 218 return success(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // AsyncOpInterface 223 //===----------------------------------------------------------------------===// 224 225 void gpu::addAsyncDependency(Operation *op, Value token) { 226 op->insertOperands(0, {token}); 227 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 228 return; 229 auto attrName = 230 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 231 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName); 232 if (!sizeAttr) 233 return; // Async dependencies is the only variadic operand. 234 SmallVector<int32_t, 8> sizes; 235 for (auto size : sizeAttr.getIntValues()) 236 sizes.push_back(size.getSExtValue()); 237 ++sizes.front(); 238 op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // LaunchOp 243 //===----------------------------------------------------------------------===// 244 245 void LaunchOp::build(OpBuilder &builder, OperationState &result, 246 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 247 Value blockSizeX, Value blockSizeY, Value blockSizeZ) { 248 // Add grid and block sizes as op operands, followed by the data operands. 249 result.addOperands( 250 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 251 252 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 253 // where the first kNumConfigRegionAttributes arguments have `index` type and 254 // the rest have the same types as the data operands. 255 Region *kernelRegion = result.addRegion(); 256 Block *body = new Block(); 257 body->addArguments( 258 std::vector<Type>(kNumConfigRegionAttributes, builder.getIndexType())); 259 kernelRegion->push_back(body); 260 } 261 262 KernelDim3 LaunchOp::getBlockIds() { 263 assert(!body().empty() && "LaunchOp body must not be empty."); 264 auto args = body().getArguments(); 265 return KernelDim3{args[0], args[1], args[2]}; 266 } 267 268 KernelDim3 LaunchOp::getThreadIds() { 269 assert(!body().empty() && "LaunchOp body must not be empty."); 270 auto args = body().getArguments(); 271 return KernelDim3{args[3], args[4], args[5]}; 272 } 273 274 KernelDim3 LaunchOp::getGridSize() { 275 assert(!body().empty() && "LaunchOp body must not be empty."); 276 auto args = body().getArguments(); 277 return KernelDim3{args[6], args[7], args[8]}; 278 } 279 280 KernelDim3 LaunchOp::getBlockSize() { 281 assert(!body().empty() && "LaunchOp body must not be empty."); 282 auto args = body().getArguments(); 283 return KernelDim3{args[9], args[10], args[11]}; 284 } 285 286 KernelDim3 LaunchOp::getGridSizeOperandValues() { 287 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 288 } 289 290 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 291 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 292 } 293 294 static LogicalResult verify(LaunchOp op) { 295 // Kernel launch takes kNumConfigOperands leading operands for grid/block 296 // sizes and transforms them into kNumConfigRegionAttributes region arguments 297 // for block/thread identifiers and grid/block sizes. 298 if (!op.body().empty()) { 299 if (op.body().getNumArguments() != 300 LaunchOp::kNumConfigOperands + op.getNumOperands()) 301 return op.emitOpError("unexpected number of region arguments"); 302 } 303 304 // Block terminators without successors are expected to exit the kernel region 305 // and must be `gpu.terminator`. 306 for (Block &block : op.body()) { 307 if (block.empty()) 308 continue; 309 if (block.back().getNumSuccessors() != 0) 310 continue; 311 if (!isa<gpu::TerminatorOp>(&block.back())) { 312 return block.back() 313 .emitError() 314 .append("expected '", gpu::TerminatorOp::getOperationName(), 315 "' or a terminator with successors") 316 .attachNote(op.getLoc()) 317 .append("in '", LaunchOp::getOperationName(), "' body region"); 318 } 319 } 320 321 return success(); 322 } 323 324 // Pretty-print the kernel grid/block size assignment as 325 // (%iter-x, %iter-y, %iter-z) in 326 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 327 // where %size-* and %iter-* will correspond to the body region arguments. 328 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 329 KernelDim3 operands, KernelDim3 ids) { 330 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 331 p << size.x << " = " << operands.x << ", "; 332 p << size.y << " = " << operands.y << ", "; 333 p << size.z << " = " << operands.z << ')'; 334 } 335 336 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { 337 // Print the launch configuration. 338 p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); 339 printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(), 340 op.getBlockIds()); 341 p << ' ' << op.getThreadsKeyword(); 342 printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(), 343 op.getThreadIds()); 344 345 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 346 p.printOptionalAttrDict(op.getAttrs()); 347 } 348 349 // Parse the size assignment blocks for blocks and threads. These have the form 350 // (%region_arg, %region_arg, %region_arg) in 351 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 352 // where %region_arg are percent-identifiers for the region arguments to be 353 // introduced further (SSA defs), and %operand are percent-identifiers for the 354 // SSA value uses. 355 static ParseResult 356 parseSizeAssignment(OpAsmParser &parser, 357 MutableArrayRef<OpAsmParser::OperandType> sizes, 358 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 359 MutableArrayRef<OpAsmParser::OperandType> indices) { 360 assert(indices.size() == 3 && "space for three indices expected"); 361 SmallVector<OpAsmParser::OperandType, 3> args; 362 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 363 OpAsmParser::Delimiter::Paren) || 364 parser.parseKeyword("in") || parser.parseLParen()) 365 return failure(); 366 std::move(args.begin(), args.end(), indices.begin()); 367 368 for (int i = 0; i < 3; ++i) { 369 if (i != 0 && parser.parseComma()) 370 return failure(); 371 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 372 parser.parseOperand(sizes[i])) 373 return failure(); 374 } 375 376 return parser.parseRParen(); 377 } 378 379 // Parses a Launch operation. 380 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 381 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 382 // region attr-dict? 383 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 384 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { 385 // Sizes of the grid and block. 386 SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes( 387 LaunchOp::kNumConfigOperands); 388 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 389 390 // Actual (data) operands passed to the kernel. 391 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 392 393 // Region arguments to be created. 394 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 395 LaunchOp::kNumConfigRegionAttributes); 396 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 397 398 // Parse the size assignment segments: the first segment assigns grid sizes 399 // and defines values for block identifiers; the second segment assigns block 400 // sizes and defines values for thread identifiers. In the region argument 401 // list, identifiers precede sizes, and block-related values precede 402 // thread-related values. 403 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 404 parseSizeAssignment(parser, sizesRef.take_front(3), 405 regionArgsRef.slice(6, 3), 406 regionArgsRef.slice(0, 3)) || 407 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 408 parseSizeAssignment(parser, sizesRef.drop_front(3), 409 regionArgsRef.slice(9, 3), 410 regionArgsRef.slice(3, 3)) || 411 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 412 result.operands)) 413 return failure(); 414 415 // Introduce the body region and parse it. The region has 416 // kNumConfigRegionAttributes arguments that correspond to 417 // block/thread identifiers and grid/block sizes, all of the `index` type. 418 Type index = parser.getBuilder().getIndexType(); 419 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 420 LaunchOp::kNumConfigRegionAttributes, index); 421 Region *body = result.addRegion(); 422 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 423 parser.parseOptionalAttrDict(result.attributes)); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // LaunchFuncOp 428 //===----------------------------------------------------------------------===// 429 430 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 431 GPUFuncOp kernelFunc, KernelDim3 gridSize, 432 KernelDim3 blockSize, ValueRange kernelOperands) { 433 // Add grid and block sizes as op operands, followed by the data operands. 434 result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x, 435 blockSize.y, blockSize.z}); 436 result.addOperands(kernelOperands); 437 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); 438 auto kernelSymbol = builder.getSymbolRefAttr( 439 kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); 440 result.addAttribute(getKernelAttrName(), kernelSymbol); 441 SmallVector<int32_t, 8> segmentSizes(8, 1); 442 segmentSizes.front() = 0; // Initially no async dependencies. 443 segmentSizes.back() = static_cast<int32_t>(kernelOperands.size()); 444 result.addAttribute(getOperandSegmentSizeAttr(), 445 builder.getI32VectorAttr(segmentSizes)); 446 } 447 448 unsigned LaunchFuncOp::getNumKernelOperands() { 449 return getNumOperands() - asyncDependencies().size() - kNumConfigOperands; 450 } 451 452 StringRef LaunchFuncOp::getKernelModuleName() { 453 return kernel().getRootReference(); 454 } 455 456 StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 457 458 Value LaunchFuncOp::getKernelOperand(unsigned i) { 459 return getOperand(asyncDependencies().size() + kNumConfigOperands + i); 460 } 461 462 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 463 auto operands = getOperands().drop_front(asyncDependencies().size()); 464 return KernelDim3{operands[0], operands[1], operands[2]}; 465 } 466 467 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 468 auto operands = getOperands().drop_front(asyncDependencies().size()); 469 return KernelDim3{operands[3], operands[4], operands[5]}; 470 } 471 472 static LogicalResult verify(LaunchFuncOp op) { 473 auto module = op->getParentOfType<ModuleOp>(); 474 if (!module) 475 return op.emitOpError("expected to belong to a module"); 476 477 if (!module->getAttrOfType<UnitAttr>( 478 GPUDialect::getContainerModuleAttrName())) 479 return op.emitOpError( 480 "expected the closest surrounding module to have the '" + 481 GPUDialect::getContainerModuleAttrName() + "' attribute"); 482 483 auto kernelAttr = op->getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); 484 if (!kernelAttr) 485 return op.emitOpError("symbol reference attribute '" + 486 op.getKernelAttrName() + "' must be specified"); 487 488 return success(); 489 } 490 491 static ParseResult 492 parseLaunchFuncOperands(OpAsmParser &parser, 493 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 494 SmallVectorImpl<Type> &argTypes) { 495 if (parser.parseOptionalKeyword("args")) 496 return success(); 497 SmallVector<NamedAttrList, 4> argAttrs; 498 bool isVariadic = false; 499 return impl::parseFunctionArgumentList(parser, /*allowAttributes=*/false, 500 /*allowVariadic=*/false, argNames, 501 argTypes, argAttrs, isVariadic); 502 } 503 504 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, 505 OperandRange operands, TypeRange types) { 506 if (operands.empty()) 507 return; 508 printer << "args("; 509 llvm::interleaveComma(llvm::zip(operands, types), printer, 510 [&](const auto &pair) { 511 printer.printOperand(std::get<0>(pair)); 512 printer << " : "; 513 printer.printType(std::get<1>(pair)); 514 }); 515 printer << ")"; 516 } 517 518 //===----------------------------------------------------------------------===// 519 // GPUFuncOp 520 //===----------------------------------------------------------------------===// 521 522 /// Adds a new block argument that corresponds to buffers located in 523 /// workgroup memory. 524 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { 525 auto attrName = getNumWorkgroupAttributionsAttrName(); 526 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); 527 (*this)->setAttr(attrName, 528 IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 529 return getBody().insertArgument(getType().getNumInputs() + attr.getInt(), 530 type); 531 } 532 533 /// Adds a new block argument that corresponds to buffers located in 534 /// private memory. 535 BlockArgument GPUFuncOp::addPrivateAttribution(Type type) { 536 // Buffers on the private memory always come after buffers on the workgroup 537 // memory. 538 return getBody().addArgument(type); 539 } 540 541 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 542 StringRef name, FunctionType type, 543 TypeRange workgroupAttributions, 544 TypeRange privateAttributions, 545 ArrayRef<NamedAttribute> attrs) { 546 result.addAttribute(SymbolTable::getSymbolAttrName(), 547 builder.getStringAttr(name)); 548 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 549 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 550 builder.getI64IntegerAttr(workgroupAttributions.size())); 551 result.addAttributes(attrs); 552 Region *body = result.addRegion(); 553 Block *entryBlock = new Block; 554 entryBlock->addArguments(type.getInputs()); 555 entryBlock->addArguments(workgroupAttributions); 556 entryBlock->addArguments(privateAttributions); 557 558 body->getBlocks().push_back(entryBlock); 559 } 560 561 /// Parses a GPU function memory attribution. 562 /// 563 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 564 /// (`private` `(` ssa-id-and-type-list `)`)? 565 /// 566 /// Note that this function parses only one of the two similar parts, with the 567 /// keyword provided as argument. 568 static ParseResult 569 parseAttributions(OpAsmParser &parser, StringRef keyword, 570 SmallVectorImpl<OpAsmParser::OperandType> &args, 571 SmallVectorImpl<Type> &argTypes) { 572 // If we could not parse the keyword, just assume empty list and succeed. 573 if (failed(parser.parseOptionalKeyword(keyword))) 574 return success(); 575 576 if (failed(parser.parseLParen())) 577 return failure(); 578 579 // Early exit for an empty list. 580 if (succeeded(parser.parseOptionalRParen())) 581 return success(); 582 583 do { 584 OpAsmParser::OperandType arg; 585 Type type; 586 587 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 588 return failure(); 589 590 args.push_back(arg); 591 argTypes.push_back(type); 592 } while (succeeded(parser.parseOptionalComma())); 593 594 return parser.parseRParen(); 595 } 596 597 /// Parses a GPU function. 598 /// 599 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 600 /// (`->` function-result-list)? memory-attribution `kernel`? 601 /// function-attributes? region 602 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { 603 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 604 SmallVector<NamedAttrList, 1> argAttrs; 605 SmallVector<NamedAttrList, 1> resultAttrs; 606 SmallVector<Type, 8> argTypes; 607 SmallVector<Type, 4> resultTypes; 608 bool isVariadic; 609 610 // Parse the function name. 611 StringAttr nameAttr; 612 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 613 result.attributes)) 614 return failure(); 615 616 auto signatureLocation = parser.getCurrentLocation(); 617 if (failed(impl::parseFunctionSignature( 618 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 619 isVariadic, resultTypes, resultAttrs))) 620 return failure(); 621 622 if (entryArgs.empty() && !argTypes.empty()) 623 return parser.emitError(signatureLocation) 624 << "gpu.func requires named arguments"; 625 626 // Construct the function type. More types will be added to the region, but 627 // not to the function type. 628 Builder &builder = parser.getBuilder(); 629 auto type = builder.getFunctionType(argTypes, resultTypes); 630 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 631 632 // Parse workgroup memory attributions. 633 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 634 entryArgs, argTypes))) 635 return failure(); 636 637 // Store the number of operands we just parsed as the number of workgroup 638 // memory attributions. 639 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 640 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 641 builder.getI64IntegerAttr(numWorkgroupAttrs)); 642 643 // Parse private memory attributions. 644 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 645 entryArgs, argTypes))) 646 return failure(); 647 648 // Parse the kernel attribute if present. 649 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 650 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 651 builder.getUnitAttr()); 652 653 // Parse attributes. 654 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 655 return failure(); 656 mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 657 658 // Parse the region. If no argument names were provided, take all names 659 // (including those of attributions) from the entry block. 660 auto *body = result.addRegion(); 661 return parser.parseRegion(*body, entryArgs, argTypes); 662 } 663 664 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 665 ArrayRef<BlockArgument> values) { 666 if (values.empty()) 667 return; 668 669 p << ' ' << keyword << '('; 670 llvm::interleaveComma( 671 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 672 p << ')'; 673 } 674 675 /// Prints a GPU Func op. 676 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { 677 p << GPUFuncOp::getOperationName() << ' '; 678 p.printSymbolName(op.getName()); 679 680 FunctionType type = op.getType(); 681 impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), 682 /*isVariadic=*/false, type.getResults()); 683 684 printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); 685 printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); 686 if (op.isKernel()) 687 p << ' ' << op.getKernelKeyword(); 688 689 impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), 690 type.getNumResults(), 691 {op.getNumWorkgroupAttributionsAttrName(), 692 GPUDialect::getKernelFuncAttrName()}); 693 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); 694 } 695 696 void GPUFuncOp::setType(FunctionType newType) { 697 auto oldType = getType(); 698 assert(newType.getNumResults() == oldType.getNumResults() && 699 "unimplemented: changes to the number of results"); 700 701 SmallVector<char, 16> nameBuf; 702 for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) 703 removeAttr(getArgAttrName(i, nameBuf)); 704 705 (*this)->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 706 } 707 708 /// Hook for FunctionLike verifier. 709 LogicalResult GPUFuncOp::verifyType() { 710 Type type = getTypeAttr().getValue(); 711 if (!type.isa<FunctionType>()) 712 return emitOpError("requires '" + getTypeAttrName() + 713 "' attribute of function type"); 714 715 if (isKernel() && getType().getNumResults() != 0) 716 return emitOpError() << "expected void return type for kernel function"; 717 718 return success(); 719 } 720 721 static LogicalResult verifyAttributions(Operation *op, 722 ArrayRef<BlockArgument> attributions, 723 unsigned memorySpace) { 724 for (Value v : attributions) { 725 auto type = v.getType().dyn_cast<MemRefType>(); 726 if (!type) 727 return op->emitOpError() << "expected memref type in attribution"; 728 729 if (type.getMemorySpace() != memorySpace) { 730 return op->emitOpError() 731 << "expected memory space " << memorySpace << " in attribution"; 732 } 733 } 734 return success(); 735 } 736 737 /// Verifies the body of the function. 738 LogicalResult GPUFuncOp::verifyBody() { 739 unsigned numFuncArguments = getNumArguments(); 740 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 741 unsigned numBlockArguments = front().getNumArguments(); 742 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 743 return emitOpError() << "expected at least " 744 << numFuncArguments + numWorkgroupAttributions 745 << " arguments to body region"; 746 747 ArrayRef<Type> funcArgTypes = getType().getInputs(); 748 for (unsigned i = 0; i < numFuncArguments; ++i) { 749 Type blockArgType = front().getArgument(i).getType(); 750 if (funcArgTypes[i] != blockArgType) 751 return emitOpError() << "expected body region argument #" << i 752 << " to be of type " << funcArgTypes[i] << ", got " 753 << blockArgType; 754 } 755 756 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 757 GPUDialect::getWorkgroupAddressSpace())) || 758 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 759 GPUDialect::getPrivateAddressSpace()))) 760 return failure(); 761 762 return success(); 763 } 764 765 //===----------------------------------------------------------------------===// 766 // ReturnOp 767 //===----------------------------------------------------------------------===// 768 769 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 770 llvm::SmallVector<OpAsmParser::OperandType, 4> operands; 771 llvm::SmallVector<Type, 4> types; 772 if (parser.parseOperandList(operands) || 773 parser.parseOptionalColonTypeList(types) || 774 parser.resolveOperands(operands, types, parser.getCurrentLocation(), 775 result.operands)) 776 return failure(); 777 778 return success(); 779 } 780 781 static LogicalResult verify(gpu::ReturnOp returnOp) { 782 GPUFuncOp function = returnOp->getParentOfType<GPUFuncOp>(); 783 784 FunctionType funType = function.getType(); 785 786 if (funType.getNumResults() != returnOp.operands().size()) 787 return returnOp.emitOpError() 788 .append("expected ", funType.getNumResults(), " result operands") 789 .attachNote(function.getLoc()) 790 .append("return type declared here"); 791 792 for (auto pair : llvm::enumerate( 793 llvm::zip(function.getType().getResults(), returnOp.operands()))) { 794 Type type; 795 Value operand; 796 std::tie(type, operand) = pair.value(); 797 if (type != operand.getType()) 798 return returnOp.emitOpError() << "unexpected type `" << operand.getType() 799 << "' for operand #" << pair.index(); 800 } 801 return success(); 802 } 803 804 //===----------------------------------------------------------------------===// 805 // GPUModuleOp 806 //===----------------------------------------------------------------------===// 807 808 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 809 StringRef name) { 810 ensureTerminator(*result.addRegion(), builder, result.location); 811 result.attributes.push_back(builder.getNamedAttr( 812 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 813 } 814 815 static ParseResult parseGPUModuleOp(OpAsmParser &parser, 816 OperationState &result) { 817 StringAttr nameAttr; 818 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 819 result.attributes)) 820 return failure(); 821 822 // If module attributes are present, parse them. 823 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 824 return failure(); 825 826 // Parse the module body. 827 auto *body = result.addRegion(); 828 if (parser.parseRegion(*body, None, None)) 829 return failure(); 830 831 // Ensure that this module has a valid terminator. 832 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 833 return success(); 834 } 835 836 static void print(OpAsmPrinter &p, GPUModuleOp op) { 837 p << op.getOperationName() << ' '; 838 p.printSymbolName(op.getName()); 839 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 840 {SymbolTable::getSymbolAttrName()}); 841 p.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, 842 /*printBlockTerminators=*/false); 843 } 844 845 static ParseResult parseAsyncDependencies( 846 OpAsmParser &parser, Type &asyncTokenType, 847 SmallVectorImpl<OpAsmParser::OperandType> &asyncDependencies) { 848 auto loc = parser.getCurrentLocation(); 849 if (succeeded(parser.parseOptionalKeyword("async"))) { 850 if (parser.getNumResults() == 0) 851 return parser.emitError(loc, "needs to be named when marked 'async'"); 852 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 853 } 854 return parser.parseOperandList(asyncDependencies, 855 OpAsmParser::Delimiter::OptionalSquare); 856 } 857 858 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, 859 Type asyncTokenType, 860 OperandRange asyncDependencies) { 861 if (asyncTokenType) 862 printer << "async "; 863 if (asyncDependencies.empty()) 864 return; 865 printer << "["; 866 llvm::interleaveComma(asyncDependencies, printer); 867 printer << "]"; 868 } 869 870 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" 871 872 #define GET_OP_CLASSES 873 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 874