1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/FunctionImplementation.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 using namespace mlir; 28 using namespace mlir::gpu; 29 30 //===----------------------------------------------------------------------===// 31 // GPUDialect 32 //===----------------------------------------------------------------------===// 33 34 bool GPUDialect::isKernel(Operation *op) { 35 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 36 return static_cast<bool>(isKernelAttr); 37 } 38 39 void GPUDialect::initialize() { 40 addTypes<AsyncTokenType>(); 41 addOperations< 42 #define GET_OP_LIST 43 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 44 >(); 45 } 46 47 Type GPUDialect::parseType(DialectAsmParser &parser) const { 48 // Parse the main keyword for the type. 49 StringRef keyword; 50 if (parser.parseKeyword(&keyword)) 51 return Type(); 52 MLIRContext *context = getContext(); 53 54 // Handle 'async token' types. 55 if (keyword == "async.token") 56 return AsyncTokenType::get(context); 57 58 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 59 return Type(); 60 } 61 62 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 63 TypeSwitch<Type>(type) 64 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 65 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 66 } 67 68 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 69 NamedAttribute attr) { 70 if (!attr.second.isa<UnitAttr>() || 71 attr.first != getContainerModuleAttrName()) 72 return success(); 73 74 auto module = dyn_cast<ModuleOp>(op); 75 if (!module) 76 return op->emitError("expected '") 77 << getContainerModuleAttrName() << "' attribute to be attached to '" 78 << ModuleOp::getOperationName() << '\''; 79 80 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 81 // Ignore launches that are nested more or less deep than functions in the 82 // module we are currently checking. 83 if (!launchOp.getParentOp() || 84 launchOp.getParentOp()->getParentOp() != module) 85 return success(); 86 87 // Ignore launch ops with missing attributes here. The errors will be 88 // reported by the verifiers of those ops. 89 if (!launchOp.getAttrOfType<SymbolRefAttr>( 90 LaunchFuncOp::getKernelAttrName())) 91 return success(); 92 93 // Check that `launch_func` refers to a well-formed GPU kernel module. 94 StringRef kernelModuleName = launchOp.getKernelModuleName(); 95 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 96 if (!kernelModule) 97 return launchOp.emitOpError() 98 << "kernel module '" << kernelModuleName << "' is undefined"; 99 100 // Check that `launch_func` refers to a well-formed kernel function. 101 Operation *kernelFunc = module.lookupSymbol(launchOp.kernel()); 102 auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc); 103 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 104 if (!kernelGPUFunction && !kernelLLVMFunction) 105 return launchOp.emitOpError("kernel function '") 106 << launchOp.kernel() << "' is undefined"; 107 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 108 GPUDialect::getKernelFuncAttrName())) 109 return launchOp.emitOpError("kernel function is missing the '") 110 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 111 112 // TODO: if the kernel function has been converted to 113 // the LLVM dialect but the caller hasn't (which happens during the 114 // separate compilation), do not check type correspondence as it would 115 // require the verifier to be aware of the LLVM type conversion. 116 if (kernelLLVMFunction) 117 return success(); 118 119 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 120 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 121 if (expectedNumArguments != actualNumArguments) 122 return launchOp.emitOpError("got ") 123 << actualNumArguments << " kernel operands but expected " 124 << expectedNumArguments; 125 126 auto functionType = kernelGPUFunction.getType(); 127 for (unsigned i = 0; i < expectedNumArguments; ++i) { 128 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 129 return launchOp.emitOpError("type of function argument ") 130 << i << " does not match"; 131 } 132 } 133 134 return success(); 135 }); 136 137 return walkResult.wasInterrupted() ? failure() : success(); 138 } 139 140 template <typename T> static LogicalResult verifyIndexOp(T op) { 141 auto dimension = op.dimension(); 142 if (dimension != "x" && dimension != "y" && dimension != "z") 143 return op.emitError("dimension \"") << dimension << "\" is invalid"; 144 return success(); 145 } 146 147 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 148 if (allReduce.body().empty() != allReduce.op().hasValue()) 149 return allReduce.emitError( 150 "expected either an op attribute or a non-empty body"); 151 if (!allReduce.body().empty()) { 152 if (allReduce.body().getNumArguments() != 2) 153 return allReduce.emitError("expected two region arguments"); 154 for (auto argument : allReduce.body().getArguments()) { 155 if (argument.getType() != allReduce.getType()) 156 return allReduce.emitError("incorrect region argument type"); 157 } 158 unsigned yieldCount = 0; 159 for (Block &block : allReduce.body()) { 160 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 161 if (yield.getNumOperands() != 1) 162 return allReduce.emitError("expected one gpu.yield operand"); 163 if (yield.getOperand(0).getType() != allReduce.getType()) 164 return allReduce.emitError("incorrect gpu.yield type"); 165 ++yieldCount; 166 } 167 } 168 if (yieldCount == 0) 169 return allReduce.emitError("expected gpu.yield op in region"); 170 } else { 171 StringRef opName = *allReduce.op(); 172 if ((opName == "and" || opName == "or" || opName == "xor") && 173 !allReduce.getType().isa<IntegerType>()) { 174 return allReduce.emitError() 175 << '`' << opName << '`' 176 << " accumulator is only compatible with Integer type"; 177 } 178 } 179 return success(); 180 } 181 182 static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { 183 auto type = shuffleOp.value().getType(); 184 if (shuffleOp.result().getType() != type) { 185 return shuffleOp.emitOpError() 186 << "requires the same type for value operand and result"; 187 } 188 if (!type.isSignlessIntOrFloat() || type.getIntOrFloatBitWidth() != 32) { 189 return shuffleOp.emitOpError() 190 << "requires value operand type to be f32 or i32"; 191 } 192 return success(); 193 } 194 195 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { 196 p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' 197 << op.mode() << " : " << op.value().getType(); 198 } 199 200 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { 201 SmallVector<OpAsmParser::OperandType, 3> operandInfo; 202 if (parser.parseOperandList(operandInfo, 3)) 203 return failure(); 204 205 StringRef mode; 206 if (parser.parseKeyword(&mode)) 207 return failure(); 208 state.addAttribute("mode", parser.getBuilder().getStringAttr(mode)); 209 210 Type valueType; 211 Type int32Type = parser.getBuilder().getIntegerType(32); 212 Type int1Type = parser.getBuilder().getI1Type(); 213 if (parser.parseColonType(valueType) || 214 parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type}, 215 parser.getCurrentLocation(), state.operands) || 216 parser.addTypesToList({valueType, int1Type}, state.types)) 217 return failure(); 218 return success(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // AsyncOpInterface 223 //===----------------------------------------------------------------------===// 224 225 void gpu::addAsyncDependency(Operation *op, Value token) { 226 op->insertOperands(0, {token}); 227 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 228 return; 229 auto attrName = 230 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 231 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName); 232 if (!sizeAttr) 233 return; // Async dependencies is the only variadic operand. 234 SmallVector<int32_t, 8> sizes; 235 for (auto size : sizeAttr.getIntValues()) 236 sizes.push_back(size.getSExtValue()); 237 ++sizes.front(); 238 op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // LaunchOp 243 //===----------------------------------------------------------------------===// 244 245 void LaunchOp::build(OpBuilder &builder, OperationState &result, 246 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 247 Value blockSizeX, Value blockSizeY, Value blockSizeZ) { 248 // Add grid and block sizes as op operands, followed by the data operands. 249 result.addOperands( 250 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 251 252 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 253 // where the first kNumConfigRegionAttributes arguments have `index` type and 254 // the rest have the same types as the data operands. 255 Region *kernelRegion = result.addRegion(); 256 Block *body = new Block(); 257 body->addArguments( 258 std::vector<Type>(kNumConfigRegionAttributes, builder.getIndexType())); 259 kernelRegion->push_back(body); 260 } 261 262 KernelDim3 LaunchOp::getBlockIds() { 263 assert(!body().empty() && "LaunchOp body must not be empty."); 264 auto args = body().getArguments(); 265 return KernelDim3{args[0], args[1], args[2]}; 266 } 267 268 KernelDim3 LaunchOp::getThreadIds() { 269 assert(!body().empty() && "LaunchOp body must not be empty."); 270 auto args = body().getArguments(); 271 return KernelDim3{args[3], args[4], args[5]}; 272 } 273 274 KernelDim3 LaunchOp::getGridSize() { 275 assert(!body().empty() && "LaunchOp body must not be empty."); 276 auto args = body().getArguments(); 277 return KernelDim3{args[6], args[7], args[8]}; 278 } 279 280 KernelDim3 LaunchOp::getBlockSize() { 281 assert(!body().empty() && "LaunchOp body must not be empty."); 282 auto args = body().getArguments(); 283 return KernelDim3{args[9], args[10], args[11]}; 284 } 285 286 KernelDim3 LaunchOp::getGridSizeOperandValues() { 287 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 288 } 289 290 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 291 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 292 } 293 294 static LogicalResult verify(LaunchOp op) { 295 // Kernel launch takes kNumConfigOperands leading operands for grid/block 296 // sizes and transforms them into kNumConfigRegionAttributes region arguments 297 // for block/thread identifiers and grid/block sizes. 298 if (!op.body().empty()) { 299 if (op.body().getNumArguments() != 300 LaunchOp::kNumConfigOperands + op.getNumOperands()) 301 return op.emitOpError("unexpected number of region arguments"); 302 } 303 304 // Block terminators without successors are expected to exit the kernel region 305 // and must be `gpu.terminator`. 306 for (Block &block : op.body()) { 307 if (block.empty()) 308 continue; 309 if (block.back().getNumSuccessors() != 0) 310 continue; 311 if (!isa<gpu::TerminatorOp>(&block.back())) { 312 return block.back() 313 .emitError() 314 .append("expected '", gpu::TerminatorOp::getOperationName(), 315 "' or a terminator with successors") 316 .attachNote(op.getLoc()) 317 .append("in '", LaunchOp::getOperationName(), "' body region"); 318 } 319 } 320 321 return success(); 322 } 323 324 // Pretty-print the kernel grid/block size assignment as 325 // (%iter-x, %iter-y, %iter-z) in 326 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 327 // where %size-* and %iter-* will correspond to the body region arguments. 328 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 329 KernelDim3 operands, KernelDim3 ids) { 330 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 331 p << size.x << " = " << operands.x << ", "; 332 p << size.y << " = " << operands.y << ", "; 333 p << size.z << " = " << operands.z << ')'; 334 } 335 336 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { 337 // Print the launch configuration. 338 p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); 339 printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(), 340 op.getBlockIds()); 341 p << ' ' << op.getThreadsKeyword(); 342 printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(), 343 op.getThreadIds()); 344 345 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 346 p.printOptionalAttrDict(op.getAttrs()); 347 } 348 349 // Parse the size assignment blocks for blocks and threads. These have the form 350 // (%region_arg, %region_arg, %region_arg) in 351 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 352 // where %region_arg are percent-identifiers for the region arguments to be 353 // introduced further (SSA defs), and %operand are percent-identifiers for the 354 // SSA value uses. 355 static ParseResult 356 parseSizeAssignment(OpAsmParser &parser, 357 MutableArrayRef<OpAsmParser::OperandType> sizes, 358 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 359 MutableArrayRef<OpAsmParser::OperandType> indices) { 360 assert(indices.size() == 3 && "space for three indices expected"); 361 SmallVector<OpAsmParser::OperandType, 3> args; 362 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 363 OpAsmParser::Delimiter::Paren) || 364 parser.parseKeyword("in") || parser.parseLParen()) 365 return failure(); 366 std::move(args.begin(), args.end(), indices.begin()); 367 368 for (int i = 0; i < 3; ++i) { 369 if (i != 0 && parser.parseComma()) 370 return failure(); 371 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 372 parser.parseOperand(sizes[i])) 373 return failure(); 374 } 375 376 return parser.parseRParen(); 377 } 378 379 // Parses a Launch operation. 380 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 381 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 382 // region attr-dict? 383 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 384 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { 385 // Sizes of the grid and block. 386 SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes( 387 LaunchOp::kNumConfigOperands); 388 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 389 390 // Actual (data) operands passed to the kernel. 391 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 392 393 // Region arguments to be created. 394 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 395 LaunchOp::kNumConfigRegionAttributes); 396 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 397 398 // Parse the size assignment segments: the first segment assigns grid sizes 399 // and defines values for block identifiers; the second segment assigns block 400 // sizes and defines values for thread identifiers. In the region argument 401 // list, identifiers precede sizes, and block-related values precede 402 // thread-related values. 403 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 404 parseSizeAssignment(parser, sizesRef.take_front(3), 405 regionArgsRef.slice(6, 3), 406 regionArgsRef.slice(0, 3)) || 407 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 408 parseSizeAssignment(parser, sizesRef.drop_front(3), 409 regionArgsRef.slice(9, 3), 410 regionArgsRef.slice(3, 3)) || 411 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 412 result.operands)) 413 return failure(); 414 415 // Introduce the body region and parse it. The region has 416 // kNumConfigRegionAttributes arguments that correspond to 417 // block/thread identifiers and grid/block sizes, all of the `index` type. 418 Type index = parser.getBuilder().getIndexType(); 419 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 420 LaunchOp::kNumConfigRegionAttributes, index); 421 Region *body = result.addRegion(); 422 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 423 parser.parseOptionalAttrDict(result.attributes)); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // LaunchFuncOp 428 //===----------------------------------------------------------------------===// 429 430 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 431 GPUFuncOp kernelFunc, KernelDim3 gridSize, 432 KernelDim3 blockSize, ValueRange kernelOperands) { 433 // Add grid and block sizes as op operands, followed by the data operands. 434 result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x, 435 blockSize.y, blockSize.z}); 436 result.addOperands(kernelOperands); 437 auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>(); 438 auto kernelSymbol = builder.getSymbolRefAttr( 439 kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); 440 result.addAttribute(getKernelAttrName(), kernelSymbol); 441 SmallVector<int32_t, 8> segmentSizes(8, 1); 442 segmentSizes.front() = 0; // Initially no async dependencies. 443 segmentSizes.back() = static_cast<int32_t>(kernelOperands.size()); 444 result.addAttribute(getOperandSegmentSizeAttr(), 445 builder.getI32VectorAttr(segmentSizes)); 446 } 447 448 unsigned LaunchFuncOp::getNumKernelOperands() { 449 return getNumOperands() - asyncDependencies().size() - kNumConfigOperands; 450 } 451 452 StringRef LaunchFuncOp::getKernelModuleName() { 453 return kernel().getRootReference(); 454 } 455 456 StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 457 458 Value LaunchFuncOp::getKernelOperand(unsigned i) { 459 return getOperand(asyncDependencies().size() + kNumConfigOperands + i); 460 } 461 462 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 463 auto operands = getOperands().drop_front(asyncDependencies().size()); 464 return KernelDim3{operands[0], operands[1], operands[2]}; 465 } 466 467 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 468 auto operands = getOperands().drop_front(asyncDependencies().size()); 469 return KernelDim3{operands[3], operands[4], operands[5]}; 470 } 471 472 static LogicalResult verify(LaunchFuncOp op) { 473 auto module = op.getParentOfType<ModuleOp>(); 474 if (!module) 475 return op.emitOpError("expected to belong to a module"); 476 477 if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName())) 478 return op.emitOpError( 479 "expected the closest surrounding module to have the '" + 480 GPUDialect::getContainerModuleAttrName() + "' attribute"); 481 482 auto kernelAttr = op.getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); 483 if (!kernelAttr) 484 return op.emitOpError("symbol reference attribute '" + 485 op.getKernelAttrName() + "' must be specified"); 486 487 return success(); 488 } 489 490 static ParseResult 491 parseLaunchFuncOperands(OpAsmParser &parser, 492 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 493 SmallVectorImpl<Type> &argTypes) { 494 if (parser.parseOptionalKeyword("args")) 495 return success(); 496 SmallVector<NamedAttrList, 4> argAttrs; 497 bool isVariadic = false; 498 return impl::parseFunctionArgumentList(parser, /*allowAttributes=*/false, 499 /*allowVariadic=*/false, argNames, 500 argTypes, argAttrs, isVariadic); 501 } 502 503 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, 504 OperandRange operands, TypeRange types) { 505 if (operands.empty()) 506 return; 507 printer << "args("; 508 llvm::interleaveComma(llvm::zip(operands, types), printer, 509 [&](const auto &pair) { 510 printer.printOperand(std::get<0>(pair)); 511 printer << " : "; 512 printer.printType(std::get<1>(pair)); 513 }); 514 printer << ")"; 515 } 516 517 //===----------------------------------------------------------------------===// 518 // GPUFuncOp 519 //===----------------------------------------------------------------------===// 520 521 /// Adds a new block argument that corresponds to buffers located in 522 /// workgroup memory. 523 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { 524 auto attrName = getNumWorkgroupAttributionsAttrName(); 525 auto attr = getAttrOfType<IntegerAttr>(attrName); 526 setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 527 return getBody().insertArgument(getType().getNumInputs() + attr.getInt(), 528 type); 529 } 530 531 /// Adds a new block argument that corresponds to buffers located in 532 /// private memory. 533 BlockArgument GPUFuncOp::addPrivateAttribution(Type type) { 534 // Buffers on the private memory always come after buffers on the workgroup 535 // memory. 536 return getBody().addArgument(type); 537 } 538 539 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 540 StringRef name, FunctionType type, 541 TypeRange workgroupAttributions, 542 TypeRange privateAttributions, 543 ArrayRef<NamedAttribute> attrs) { 544 result.addAttribute(SymbolTable::getSymbolAttrName(), 545 builder.getStringAttr(name)); 546 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 547 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 548 builder.getI64IntegerAttr(workgroupAttributions.size())); 549 result.addAttributes(attrs); 550 Region *body = result.addRegion(); 551 Block *entryBlock = new Block; 552 entryBlock->addArguments(type.getInputs()); 553 entryBlock->addArguments(workgroupAttributions); 554 entryBlock->addArguments(privateAttributions); 555 556 body->getBlocks().push_back(entryBlock); 557 } 558 559 /// Parses a GPU function memory attribution. 560 /// 561 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 562 /// (`private` `(` ssa-id-and-type-list `)`)? 563 /// 564 /// Note that this function parses only one of the two similar parts, with the 565 /// keyword provided as argument. 566 static ParseResult 567 parseAttributions(OpAsmParser &parser, StringRef keyword, 568 SmallVectorImpl<OpAsmParser::OperandType> &args, 569 SmallVectorImpl<Type> &argTypes) { 570 // If we could not parse the keyword, just assume empty list and succeed. 571 if (failed(parser.parseOptionalKeyword(keyword))) 572 return success(); 573 574 if (failed(parser.parseLParen())) 575 return failure(); 576 577 // Early exit for an empty list. 578 if (succeeded(parser.parseOptionalRParen())) 579 return success(); 580 581 do { 582 OpAsmParser::OperandType arg; 583 Type type; 584 585 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 586 return failure(); 587 588 args.push_back(arg); 589 argTypes.push_back(type); 590 } while (succeeded(parser.parseOptionalComma())); 591 592 return parser.parseRParen(); 593 } 594 595 /// Parses a GPU function. 596 /// 597 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 598 /// (`->` function-result-list)? memory-attribution `kernel`? 599 /// function-attributes? region 600 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { 601 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 602 SmallVector<NamedAttrList, 1> argAttrs; 603 SmallVector<NamedAttrList, 1> resultAttrs; 604 SmallVector<Type, 8> argTypes; 605 SmallVector<Type, 4> resultTypes; 606 bool isVariadic; 607 608 // Parse the function name. 609 StringAttr nameAttr; 610 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 611 result.attributes)) 612 return failure(); 613 614 auto signatureLocation = parser.getCurrentLocation(); 615 if (failed(impl::parseFunctionSignature( 616 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 617 isVariadic, resultTypes, resultAttrs))) 618 return failure(); 619 620 if (entryArgs.empty() && !argTypes.empty()) 621 return parser.emitError(signatureLocation) 622 << "gpu.func requires named arguments"; 623 624 // Construct the function type. More types will be added to the region, but 625 // not to the function type. 626 Builder &builder = parser.getBuilder(); 627 auto type = builder.getFunctionType(argTypes, resultTypes); 628 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 629 630 // Parse workgroup memory attributions. 631 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 632 entryArgs, argTypes))) 633 return failure(); 634 635 // Store the number of operands we just parsed as the number of workgroup 636 // memory attributions. 637 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 638 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 639 builder.getI64IntegerAttr(numWorkgroupAttrs)); 640 641 // Parse private memory attributions. 642 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 643 entryArgs, argTypes))) 644 return failure(); 645 646 // Parse the kernel attribute if present. 647 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 648 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 649 builder.getUnitAttr()); 650 651 // Parse attributes. 652 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 653 return failure(); 654 mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 655 656 // Parse the region. If no argument names were provided, take all names 657 // (including those of attributions) from the entry block. 658 auto *body = result.addRegion(); 659 return parser.parseRegion(*body, entryArgs, argTypes); 660 } 661 662 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 663 ArrayRef<BlockArgument> values) { 664 if (values.empty()) 665 return; 666 667 p << ' ' << keyword << '('; 668 llvm::interleaveComma( 669 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 670 p << ')'; 671 } 672 673 /// Prints a GPU Func op. 674 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { 675 p << GPUFuncOp::getOperationName() << ' '; 676 p.printSymbolName(op.getName()); 677 678 FunctionType type = op.getType(); 679 impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), 680 /*isVariadic=*/false, type.getResults()); 681 682 printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); 683 printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); 684 if (op.isKernel()) 685 p << ' ' << op.getKernelKeyword(); 686 687 impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), 688 type.getNumResults(), 689 {op.getNumWorkgroupAttributionsAttrName(), 690 GPUDialect::getKernelFuncAttrName()}); 691 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); 692 } 693 694 void GPUFuncOp::setType(FunctionType newType) { 695 auto oldType = getType(); 696 assert(newType.getNumResults() == oldType.getNumResults() && 697 "unimplemented: changes to the number of results"); 698 699 SmallVector<char, 16> nameBuf; 700 for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) 701 removeAttr(getArgAttrName(i, nameBuf)); 702 703 setAttr(getTypeAttrName(), TypeAttr::get(newType)); 704 } 705 706 /// Hook for FunctionLike verifier. 707 LogicalResult GPUFuncOp::verifyType() { 708 Type type = getTypeAttr().getValue(); 709 if (!type.isa<FunctionType>()) 710 return emitOpError("requires '" + getTypeAttrName() + 711 "' attribute of function type"); 712 713 if (isKernel() && getType().getNumResults() != 0) 714 return emitOpError() << "expected void return type for kernel function"; 715 716 return success(); 717 } 718 719 static LogicalResult verifyAttributions(Operation *op, 720 ArrayRef<BlockArgument> attributions, 721 unsigned memorySpace) { 722 for (Value v : attributions) { 723 auto type = v.getType().dyn_cast<MemRefType>(); 724 if (!type) 725 return op->emitOpError() << "expected memref type in attribution"; 726 727 if (type.getMemorySpace() != memorySpace) { 728 return op->emitOpError() 729 << "expected memory space " << memorySpace << " in attribution"; 730 } 731 } 732 return success(); 733 } 734 735 /// Verifies the body of the function. 736 LogicalResult GPUFuncOp::verifyBody() { 737 unsigned numFuncArguments = getNumArguments(); 738 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 739 unsigned numBlockArguments = front().getNumArguments(); 740 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 741 return emitOpError() << "expected at least " 742 << numFuncArguments + numWorkgroupAttributions 743 << " arguments to body region"; 744 745 ArrayRef<Type> funcArgTypes = getType().getInputs(); 746 for (unsigned i = 0; i < numFuncArguments; ++i) { 747 Type blockArgType = front().getArgument(i).getType(); 748 if (funcArgTypes[i] != blockArgType) 749 return emitOpError() << "expected body region argument #" << i 750 << " to be of type " << funcArgTypes[i] << ", got " 751 << blockArgType; 752 } 753 754 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 755 GPUDialect::getWorkgroupAddressSpace())) || 756 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 757 GPUDialect::getPrivateAddressSpace()))) 758 return failure(); 759 760 return success(); 761 } 762 763 //===----------------------------------------------------------------------===// 764 // ReturnOp 765 //===----------------------------------------------------------------------===// 766 767 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 768 llvm::SmallVector<OpAsmParser::OperandType, 4> operands; 769 llvm::SmallVector<Type, 4> types; 770 if (parser.parseOperandList(operands) || 771 parser.parseOptionalColonTypeList(types) || 772 parser.resolveOperands(operands, types, parser.getCurrentLocation(), 773 result.operands)) 774 return failure(); 775 776 return success(); 777 } 778 779 static LogicalResult verify(gpu::ReturnOp returnOp) { 780 GPUFuncOp function = returnOp.getParentOfType<GPUFuncOp>(); 781 782 FunctionType funType = function.getType(); 783 784 if (funType.getNumResults() != returnOp.operands().size()) 785 return returnOp.emitOpError() 786 .append("expected ", funType.getNumResults(), " result operands") 787 .attachNote(function.getLoc()) 788 .append("return type declared here"); 789 790 for (auto pair : llvm::enumerate( 791 llvm::zip(function.getType().getResults(), returnOp.operands()))) { 792 Type type; 793 Value operand; 794 std::tie(type, operand) = pair.value(); 795 if (type != operand.getType()) 796 return returnOp.emitOpError() << "unexpected type `" << operand.getType() 797 << "' for operand #" << pair.index(); 798 } 799 return success(); 800 } 801 802 //===----------------------------------------------------------------------===// 803 // GPUModuleOp 804 //===----------------------------------------------------------------------===// 805 806 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 807 StringRef name) { 808 ensureTerminator(*result.addRegion(), builder, result.location); 809 result.attributes.push_back(builder.getNamedAttr( 810 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 811 } 812 813 static ParseResult parseGPUModuleOp(OpAsmParser &parser, 814 OperationState &result) { 815 StringAttr nameAttr; 816 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 817 result.attributes)) 818 return failure(); 819 820 // If module attributes are present, parse them. 821 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 822 return failure(); 823 824 // Parse the module body. 825 auto *body = result.addRegion(); 826 if (parser.parseRegion(*body, None, None)) 827 return failure(); 828 829 // Ensure that this module has a valid terminator. 830 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 831 return success(); 832 } 833 834 static void print(OpAsmPrinter &p, GPUModuleOp op) { 835 p << op.getOperationName() << ' '; 836 p.printSymbolName(op.getName()); 837 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 838 {SymbolTable::getSymbolAttrName()}); 839 p.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, 840 /*printBlockTerminators=*/false); 841 } 842 843 static ParseResult parseAsyncDependencies( 844 OpAsmParser &parser, Type &asyncTokenType, 845 SmallVectorImpl<OpAsmParser::OperandType> &asyncDependencies) { 846 auto loc = parser.getCurrentLocation(); 847 if (succeeded(parser.parseOptionalKeyword("async"))) { 848 if (parser.getNumResults() == 0) 849 return parser.emitError(loc, "needs to be named when marked 'async'"); 850 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 851 } 852 return parser.parseOperandList(asyncDependencies, 853 OpAsmParser::Delimiter::OptionalSquare); 854 } 855 856 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, 857 Type asyncTokenType, 858 OperandRange asyncDependencies) { 859 if (asyncTokenType) 860 printer << "async "; 861 if (asyncDependencies.empty()) 862 return; 863 printer << "["; 864 llvm::interleaveComma(asyncDependencies, printer); 865 printer << "]"; 866 } 867 868 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" 869 870 #define GET_OP_CLASSES 871 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 872