1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/FunctionImplementation.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 using namespace mlir; 28 using namespace mlir::gpu; 29 30 //===----------------------------------------------------------------------===// 31 // GPUDialect 32 //===----------------------------------------------------------------------===// 33 34 bool GPUDialect::isKernel(Operation *op) { 35 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 36 return static_cast<bool>(isKernelAttr); 37 } 38 39 void GPUDialect::initialize() { 40 addTypes<AsyncTokenType>(); 41 addOperations< 42 #define GET_OP_LIST 43 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 44 >(); 45 } 46 47 Type GPUDialect::parseType(DialectAsmParser &parser) const { 48 // Parse the main keyword for the type. 49 StringRef keyword; 50 if (parser.parseKeyword(&keyword)) 51 return Type(); 52 MLIRContext *context = getContext(); 53 54 // Handle 'async token' types. 55 if (keyword == "async.token") 56 return AsyncTokenType::get(context); 57 58 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 59 return Type(); 60 } 61 62 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 63 TypeSwitch<Type>(type) 64 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 65 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 66 } 67 68 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 69 NamedAttribute attr) { 70 if (!attr.second.isa<UnitAttr>() || 71 attr.first != getContainerModuleAttrName()) 72 return success(); 73 74 auto module = dyn_cast<ModuleOp>(op); 75 if (!module) 76 return op->emitError("expected '") 77 << getContainerModuleAttrName() << "' attribute to be attached to '" 78 << ModuleOp::getOperationName() << '\''; 79 80 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 81 // Ignore launches that are nested more or less deep than functions in the 82 // module we are currently checking. 83 if (!launchOp->getParentOp() || 84 launchOp->getParentOp()->getParentOp() != module) 85 return success(); 86 87 // Ignore launch ops with missing attributes here. The errors will be 88 // reported by the verifiers of those ops. 89 if (!launchOp->getAttrOfType<SymbolRefAttr>( 90 LaunchFuncOp::getKernelAttrName())) 91 return success(); 92 93 // Check that `launch_func` refers to a well-formed GPU kernel module. 94 StringRef kernelModuleName = launchOp.getKernelModuleName(); 95 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 96 if (!kernelModule) 97 return launchOp.emitOpError() 98 << "kernel module '" << kernelModuleName << "' is undefined"; 99 100 // Check that `launch_func` refers to a well-formed kernel function. 101 Operation *kernelFunc = module.lookupSymbol(launchOp.kernel()); 102 auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc); 103 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 104 if (!kernelGPUFunction && !kernelLLVMFunction) 105 return launchOp.emitOpError("kernel function '") 106 << launchOp.kernel() << "' is undefined"; 107 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 108 GPUDialect::getKernelFuncAttrName())) 109 return launchOp.emitOpError("kernel function is missing the '") 110 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 111 112 // TODO: if the kernel function has been converted to 113 // the LLVM dialect but the caller hasn't (which happens during the 114 // separate compilation), do not check type correspondence as it would 115 // require the verifier to be aware of the LLVM type conversion. 116 if (kernelLLVMFunction) 117 return success(); 118 119 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 120 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 121 if (expectedNumArguments != actualNumArguments) 122 return launchOp.emitOpError("got ") 123 << actualNumArguments << " kernel operands but expected " 124 << expectedNumArguments; 125 126 auto functionType = kernelGPUFunction.getType(); 127 for (unsigned i = 0; i < expectedNumArguments; ++i) { 128 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 129 return launchOp.emitOpError("type of function argument ") 130 << i << " does not match"; 131 } 132 } 133 134 return success(); 135 }); 136 137 return walkResult.wasInterrupted() ? failure() : success(); 138 } 139 140 template <typename T> static LogicalResult verifyIndexOp(T op) { 141 auto dimension = op.dimension(); 142 if (dimension != "x" && dimension != "y" && dimension != "z") 143 return op.emitError("dimension \"") << dimension << "\" is invalid"; 144 return success(); 145 } 146 147 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 148 if (allReduce.body().empty() != allReduce.op().hasValue()) 149 return allReduce.emitError( 150 "expected either an op attribute or a non-empty body"); 151 if (!allReduce.body().empty()) { 152 if (allReduce.body().getNumArguments() != 2) 153 return allReduce.emitError("expected two region arguments"); 154 for (auto argument : allReduce.body().getArguments()) { 155 if (argument.getType() != allReduce.getType()) 156 return allReduce.emitError("incorrect region argument type"); 157 } 158 unsigned yieldCount = 0; 159 for (Block &block : allReduce.body()) { 160 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 161 if (yield.getNumOperands() != 1) 162 return allReduce.emitError("expected one gpu.yield operand"); 163 if (yield.getOperand(0).getType() != allReduce.getType()) 164 return allReduce.emitError("incorrect gpu.yield type"); 165 ++yieldCount; 166 } 167 } 168 if (yieldCount == 0) 169 return allReduce.emitError("expected gpu.yield op in region"); 170 } else { 171 StringRef opName = *allReduce.op(); 172 if ((opName == "and" || opName == "or" || opName == "xor") && 173 !allReduce.getType().isa<IntegerType>()) { 174 return allReduce.emitError() 175 << '`' << opName << '`' 176 << " accumulator is only compatible with Integer type"; 177 } 178 } 179 return success(); 180 } 181 182 static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { 183 auto type = shuffleOp.value().getType(); 184 if (shuffleOp.result().getType() != type) { 185 return shuffleOp.emitOpError() 186 << "requires the same type for value operand and result"; 187 } 188 if (!type.isSignlessIntOrFloat() || type.getIntOrFloatBitWidth() != 32) { 189 return shuffleOp.emitOpError() 190 << "requires value operand type to be f32 or i32"; 191 } 192 return success(); 193 } 194 195 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { 196 p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' 197 << op.mode() << " : " << op.value().getType(); 198 } 199 200 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { 201 SmallVector<OpAsmParser::OperandType, 3> operandInfo; 202 if (parser.parseOperandList(operandInfo, 3)) 203 return failure(); 204 205 StringRef mode; 206 if (parser.parseKeyword(&mode)) 207 return failure(); 208 state.addAttribute("mode", parser.getBuilder().getStringAttr(mode)); 209 210 Type valueType; 211 Type int32Type = parser.getBuilder().getIntegerType(32); 212 Type int1Type = parser.getBuilder().getI1Type(); 213 if (parser.parseColonType(valueType) || 214 parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type}, 215 parser.getCurrentLocation(), state.operands) || 216 parser.addTypesToList({valueType, int1Type}, state.types)) 217 return failure(); 218 return success(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // AsyncOpInterface 223 //===----------------------------------------------------------------------===// 224 225 void gpu::addAsyncDependency(Operation *op, Value token) { 226 op->insertOperands(0, {token}); 227 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 228 return; 229 auto attrName = 230 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 231 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName); 232 if (!sizeAttr) 233 return; // Async dependencies is the only variadic operand. 234 SmallVector<int32_t, 8> sizes; 235 for (auto size : sizeAttr.getIntValues()) 236 sizes.push_back(size.getSExtValue()); 237 ++sizes.front(); 238 op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // LaunchOp 243 //===----------------------------------------------------------------------===// 244 245 void LaunchOp::build(OpBuilder &builder, OperationState &result, 246 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 247 Value blockSizeX, Value blockSizeY, Value blockSizeZ) { 248 // Add grid and block sizes as op operands, followed by the data operands. 249 result.addOperands( 250 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 251 252 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 253 // where the first kNumConfigRegionAttributes arguments have `index` type and 254 // the rest have the same types as the data operands. 255 Region *kernelRegion = result.addRegion(); 256 Block *body = new Block(); 257 body->addArguments( 258 std::vector<Type>(kNumConfigRegionAttributes, builder.getIndexType())); 259 kernelRegion->push_back(body); 260 } 261 262 KernelDim3 LaunchOp::getBlockIds() { 263 assert(!body().empty() && "LaunchOp body must not be empty."); 264 auto args = body().getArguments(); 265 return KernelDim3{args[0], args[1], args[2]}; 266 } 267 268 KernelDim3 LaunchOp::getThreadIds() { 269 assert(!body().empty() && "LaunchOp body must not be empty."); 270 auto args = body().getArguments(); 271 return KernelDim3{args[3], args[4], args[5]}; 272 } 273 274 KernelDim3 LaunchOp::getGridSize() { 275 assert(!body().empty() && "LaunchOp body must not be empty."); 276 auto args = body().getArguments(); 277 return KernelDim3{args[6], args[7], args[8]}; 278 } 279 280 KernelDim3 LaunchOp::getBlockSize() { 281 assert(!body().empty() && "LaunchOp body must not be empty."); 282 auto args = body().getArguments(); 283 return KernelDim3{args[9], args[10], args[11]}; 284 } 285 286 KernelDim3 LaunchOp::getGridSizeOperandValues() { 287 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 288 } 289 290 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 291 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 292 } 293 294 static LogicalResult verify(LaunchOp op) { 295 // Kernel launch takes kNumConfigOperands leading operands for grid/block 296 // sizes and transforms them into kNumConfigRegionAttributes region arguments 297 // for block/thread identifiers and grid/block sizes. 298 if (!op.body().empty()) { 299 if (op.body().getNumArguments() != 300 LaunchOp::kNumConfigOperands + op.getNumOperands()) 301 return op.emitOpError("unexpected number of region arguments"); 302 } 303 304 // Block terminators without successors are expected to exit the kernel region 305 // and must be `gpu.terminator`. 306 for (Block &block : op.body()) { 307 if (block.empty()) 308 continue; 309 if (block.back().getNumSuccessors() != 0) 310 continue; 311 if (!isa<gpu::TerminatorOp>(&block.back())) { 312 return block.back() 313 .emitError() 314 .append("expected '", gpu::TerminatorOp::getOperationName(), 315 "' or a terminator with successors") 316 .attachNote(op.getLoc()) 317 .append("in '", LaunchOp::getOperationName(), "' body region"); 318 } 319 } 320 321 return success(); 322 } 323 324 // Pretty-print the kernel grid/block size assignment as 325 // (%iter-x, %iter-y, %iter-z) in 326 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 327 // where %size-* and %iter-* will correspond to the body region arguments. 328 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 329 KernelDim3 operands, KernelDim3 ids) { 330 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 331 p << size.x << " = " << operands.x << ", "; 332 p << size.y << " = " << operands.y << ", "; 333 p << size.z << " = " << operands.z << ')'; 334 } 335 336 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { 337 // Print the launch configuration. 338 p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); 339 printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(), 340 op.getBlockIds()); 341 p << ' ' << op.getThreadsKeyword(); 342 printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(), 343 op.getThreadIds()); 344 345 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 346 p.printOptionalAttrDict(op.getAttrs()); 347 } 348 349 // Parse the size assignment blocks for blocks and threads. These have the form 350 // (%region_arg, %region_arg, %region_arg) in 351 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 352 // where %region_arg are percent-identifiers for the region arguments to be 353 // introduced further (SSA defs), and %operand are percent-identifiers for the 354 // SSA value uses. 355 static ParseResult 356 parseSizeAssignment(OpAsmParser &parser, 357 MutableArrayRef<OpAsmParser::OperandType> sizes, 358 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 359 MutableArrayRef<OpAsmParser::OperandType> indices) { 360 assert(indices.size() == 3 && "space for three indices expected"); 361 SmallVector<OpAsmParser::OperandType, 3> args; 362 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 363 OpAsmParser::Delimiter::Paren) || 364 parser.parseKeyword("in") || parser.parseLParen()) 365 return failure(); 366 std::move(args.begin(), args.end(), indices.begin()); 367 368 for (int i = 0; i < 3; ++i) { 369 if (i != 0 && parser.parseComma()) 370 return failure(); 371 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 372 parser.parseOperand(sizes[i])) 373 return failure(); 374 } 375 376 return parser.parseRParen(); 377 } 378 379 // Parses a Launch operation. 380 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 381 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 382 // region attr-dict? 383 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 384 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { 385 // Sizes of the grid and block. 386 SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes( 387 LaunchOp::kNumConfigOperands); 388 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 389 390 // Actual (data) operands passed to the kernel. 391 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 392 393 // Region arguments to be created. 394 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 395 LaunchOp::kNumConfigRegionAttributes); 396 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 397 398 // Parse the size assignment segments: the first segment assigns grid sizes 399 // and defines values for block identifiers; the second segment assigns block 400 // sizes and defines values for thread identifiers. In the region argument 401 // list, identifiers precede sizes, and block-related values precede 402 // thread-related values. 403 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 404 parseSizeAssignment(parser, sizesRef.take_front(3), 405 regionArgsRef.slice(6, 3), 406 regionArgsRef.slice(0, 3)) || 407 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 408 parseSizeAssignment(parser, sizesRef.drop_front(3), 409 regionArgsRef.slice(9, 3), 410 regionArgsRef.slice(3, 3)) || 411 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 412 result.operands)) 413 return failure(); 414 415 // Introduce the body region and parse it. The region has 416 // kNumConfigRegionAttributes arguments that correspond to 417 // block/thread identifiers and grid/block sizes, all of the `index` type. 418 Type index = parser.getBuilder().getIndexType(); 419 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 420 LaunchOp::kNumConfigRegionAttributes, index); 421 Region *body = result.addRegion(); 422 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 423 parser.parseOptionalAttrDict(result.attributes)); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // LaunchFuncOp 428 //===----------------------------------------------------------------------===// 429 430 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 431 GPUFuncOp kernelFunc, KernelDim3 gridSize, 432 KernelDim3 blockSize, ValueRange kernelOperands) { 433 // Add grid and block sizes as op operands, followed by the data operands. 434 result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x, 435 blockSize.y, blockSize.z}); 436 result.addOperands(kernelOperands); 437 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); 438 auto kernelSymbol = builder.getSymbolRefAttr( 439 kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); 440 result.addAttribute(getKernelAttrName(), kernelSymbol); 441 SmallVector<int32_t, 8> segmentSizes(8, 1); 442 segmentSizes.front() = 0; // Initially no async dependencies. 443 segmentSizes.back() = static_cast<int32_t>(kernelOperands.size()); 444 result.addAttribute(getOperandSegmentSizeAttr(), 445 builder.getI32VectorAttr(segmentSizes)); 446 } 447 448 unsigned LaunchFuncOp::getNumKernelOperands() { 449 return getNumOperands() - asyncDependencies().size() - kNumConfigOperands; 450 } 451 452 StringRef LaunchFuncOp::getKernelModuleName() { 453 return kernel().getRootReference(); 454 } 455 456 StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 457 458 Value LaunchFuncOp::getKernelOperand(unsigned i) { 459 return getOperand(asyncDependencies().size() + kNumConfigOperands + i); 460 } 461 462 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 463 auto operands = getOperands().drop_front(asyncDependencies().size()); 464 return KernelDim3{operands[0], operands[1], operands[2]}; 465 } 466 467 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 468 auto operands = getOperands().drop_front(asyncDependencies().size()); 469 return KernelDim3{operands[3], operands[4], operands[5]}; 470 } 471 472 static LogicalResult verify(LaunchFuncOp op) { 473 auto module = op->getParentOfType<ModuleOp>(); 474 if (!module) 475 return op.emitOpError("expected to belong to a module"); 476 477 if (!module->getAttrOfType<UnitAttr>( 478 GPUDialect::getContainerModuleAttrName())) 479 return op.emitOpError( 480 "expected the closest surrounding module to have the '" + 481 GPUDialect::getContainerModuleAttrName() + "' attribute"); 482 483 auto kernelAttr = op->getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); 484 if (!kernelAttr) 485 return op.emitOpError("symbol reference attribute '" + 486 op.getKernelAttrName() + "' must be specified"); 487 488 return success(); 489 } 490 491 static ParseResult 492 parseLaunchFuncOperands(OpAsmParser &parser, 493 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 494 SmallVectorImpl<Type> &argTypes) { 495 if (parser.parseOptionalKeyword("args")) 496 return success(); 497 SmallVector<NamedAttrList, 4> argAttrs; 498 bool isVariadic = false; 499 return impl::parseFunctionArgumentList(parser, /*allowAttributes=*/false, 500 /*allowVariadic=*/false, argNames, 501 argTypes, argAttrs, isVariadic); 502 } 503 504 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, 505 OperandRange operands, TypeRange types) { 506 if (operands.empty()) 507 return; 508 printer << "args("; 509 llvm::interleaveComma(llvm::zip(operands, types), printer, 510 [&](const auto &pair) { 511 printer.printOperand(std::get<0>(pair)); 512 printer << " : "; 513 printer.printType(std::get<1>(pair)); 514 }); 515 printer << ")"; 516 } 517 518 //===----------------------------------------------------------------------===// 519 // GPUFuncOp 520 //===----------------------------------------------------------------------===// 521 522 /// Adds a new block argument that corresponds to buffers located in 523 /// workgroup memory. 524 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { 525 auto attrName = getNumWorkgroupAttributionsAttrName(); 526 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); 527 setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 528 return getBody().insertArgument(getType().getNumInputs() + attr.getInt(), 529 type); 530 } 531 532 /// Adds a new block argument that corresponds to buffers located in 533 /// private memory. 534 BlockArgument GPUFuncOp::addPrivateAttribution(Type type) { 535 // Buffers on the private memory always come after buffers on the workgroup 536 // memory. 537 return getBody().addArgument(type); 538 } 539 540 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 541 StringRef name, FunctionType type, 542 TypeRange workgroupAttributions, 543 TypeRange privateAttributions, 544 ArrayRef<NamedAttribute> attrs) { 545 result.addAttribute(SymbolTable::getSymbolAttrName(), 546 builder.getStringAttr(name)); 547 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 548 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 549 builder.getI64IntegerAttr(workgroupAttributions.size())); 550 result.addAttributes(attrs); 551 Region *body = result.addRegion(); 552 Block *entryBlock = new Block; 553 entryBlock->addArguments(type.getInputs()); 554 entryBlock->addArguments(workgroupAttributions); 555 entryBlock->addArguments(privateAttributions); 556 557 body->getBlocks().push_back(entryBlock); 558 } 559 560 /// Parses a GPU function memory attribution. 561 /// 562 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 563 /// (`private` `(` ssa-id-and-type-list `)`)? 564 /// 565 /// Note that this function parses only one of the two similar parts, with the 566 /// keyword provided as argument. 567 static ParseResult 568 parseAttributions(OpAsmParser &parser, StringRef keyword, 569 SmallVectorImpl<OpAsmParser::OperandType> &args, 570 SmallVectorImpl<Type> &argTypes) { 571 // If we could not parse the keyword, just assume empty list and succeed. 572 if (failed(parser.parseOptionalKeyword(keyword))) 573 return success(); 574 575 if (failed(parser.parseLParen())) 576 return failure(); 577 578 // Early exit for an empty list. 579 if (succeeded(parser.parseOptionalRParen())) 580 return success(); 581 582 do { 583 OpAsmParser::OperandType arg; 584 Type type; 585 586 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 587 return failure(); 588 589 args.push_back(arg); 590 argTypes.push_back(type); 591 } while (succeeded(parser.parseOptionalComma())); 592 593 return parser.parseRParen(); 594 } 595 596 /// Parses a GPU function. 597 /// 598 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 599 /// (`->` function-result-list)? memory-attribution `kernel`? 600 /// function-attributes? region 601 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { 602 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 603 SmallVector<NamedAttrList, 1> argAttrs; 604 SmallVector<NamedAttrList, 1> resultAttrs; 605 SmallVector<Type, 8> argTypes; 606 SmallVector<Type, 4> resultTypes; 607 bool isVariadic; 608 609 // Parse the function name. 610 StringAttr nameAttr; 611 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 612 result.attributes)) 613 return failure(); 614 615 auto signatureLocation = parser.getCurrentLocation(); 616 if (failed(impl::parseFunctionSignature( 617 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 618 isVariadic, resultTypes, resultAttrs))) 619 return failure(); 620 621 if (entryArgs.empty() && !argTypes.empty()) 622 return parser.emitError(signatureLocation) 623 << "gpu.func requires named arguments"; 624 625 // Construct the function type. More types will be added to the region, but 626 // not to the function type. 627 Builder &builder = parser.getBuilder(); 628 auto type = builder.getFunctionType(argTypes, resultTypes); 629 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 630 631 // Parse workgroup memory attributions. 632 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 633 entryArgs, argTypes))) 634 return failure(); 635 636 // Store the number of operands we just parsed as the number of workgroup 637 // memory attributions. 638 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 639 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 640 builder.getI64IntegerAttr(numWorkgroupAttrs)); 641 642 // Parse private memory attributions. 643 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 644 entryArgs, argTypes))) 645 return failure(); 646 647 // Parse the kernel attribute if present. 648 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 649 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 650 builder.getUnitAttr()); 651 652 // Parse attributes. 653 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 654 return failure(); 655 mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 656 657 // Parse the region. If no argument names were provided, take all names 658 // (including those of attributions) from the entry block. 659 auto *body = result.addRegion(); 660 return parser.parseRegion(*body, entryArgs, argTypes); 661 } 662 663 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 664 ArrayRef<BlockArgument> values) { 665 if (values.empty()) 666 return; 667 668 p << ' ' << keyword << '('; 669 llvm::interleaveComma( 670 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 671 p << ')'; 672 } 673 674 /// Prints a GPU Func op. 675 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { 676 p << GPUFuncOp::getOperationName() << ' '; 677 p.printSymbolName(op.getName()); 678 679 FunctionType type = op.getType(); 680 impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), 681 /*isVariadic=*/false, type.getResults()); 682 683 printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); 684 printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); 685 if (op.isKernel()) 686 p << ' ' << op.getKernelKeyword(); 687 688 impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), 689 type.getNumResults(), 690 {op.getNumWorkgroupAttributionsAttrName(), 691 GPUDialect::getKernelFuncAttrName()}); 692 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); 693 } 694 695 void GPUFuncOp::setType(FunctionType newType) { 696 auto oldType = getType(); 697 assert(newType.getNumResults() == oldType.getNumResults() && 698 "unimplemented: changes to the number of results"); 699 700 SmallVector<char, 16> nameBuf; 701 for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) 702 removeAttr(getArgAttrName(i, nameBuf)); 703 704 setAttr(getTypeAttrName(), TypeAttr::get(newType)); 705 } 706 707 /// Hook for FunctionLike verifier. 708 LogicalResult GPUFuncOp::verifyType() { 709 Type type = getTypeAttr().getValue(); 710 if (!type.isa<FunctionType>()) 711 return emitOpError("requires '" + getTypeAttrName() + 712 "' attribute of function type"); 713 714 if (isKernel() && getType().getNumResults() != 0) 715 return emitOpError() << "expected void return type for kernel function"; 716 717 return success(); 718 } 719 720 static LogicalResult verifyAttributions(Operation *op, 721 ArrayRef<BlockArgument> attributions, 722 unsigned memorySpace) { 723 for (Value v : attributions) { 724 auto type = v.getType().dyn_cast<MemRefType>(); 725 if (!type) 726 return op->emitOpError() << "expected memref type in attribution"; 727 728 if (type.getMemorySpace() != memorySpace) { 729 return op->emitOpError() 730 << "expected memory space " << memorySpace << " in attribution"; 731 } 732 } 733 return success(); 734 } 735 736 /// Verifies the body of the function. 737 LogicalResult GPUFuncOp::verifyBody() { 738 unsigned numFuncArguments = getNumArguments(); 739 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 740 unsigned numBlockArguments = front().getNumArguments(); 741 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 742 return emitOpError() << "expected at least " 743 << numFuncArguments + numWorkgroupAttributions 744 << " arguments to body region"; 745 746 ArrayRef<Type> funcArgTypes = getType().getInputs(); 747 for (unsigned i = 0; i < numFuncArguments; ++i) { 748 Type blockArgType = front().getArgument(i).getType(); 749 if (funcArgTypes[i] != blockArgType) 750 return emitOpError() << "expected body region argument #" << i 751 << " to be of type " << funcArgTypes[i] << ", got " 752 << blockArgType; 753 } 754 755 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 756 GPUDialect::getWorkgroupAddressSpace())) || 757 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 758 GPUDialect::getPrivateAddressSpace()))) 759 return failure(); 760 761 return success(); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // ReturnOp 766 //===----------------------------------------------------------------------===// 767 768 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 769 llvm::SmallVector<OpAsmParser::OperandType, 4> operands; 770 llvm::SmallVector<Type, 4> types; 771 if (parser.parseOperandList(operands) || 772 parser.parseOptionalColonTypeList(types) || 773 parser.resolveOperands(operands, types, parser.getCurrentLocation(), 774 result.operands)) 775 return failure(); 776 777 return success(); 778 } 779 780 static LogicalResult verify(gpu::ReturnOp returnOp) { 781 GPUFuncOp function = returnOp->getParentOfType<GPUFuncOp>(); 782 783 FunctionType funType = function.getType(); 784 785 if (funType.getNumResults() != returnOp.operands().size()) 786 return returnOp.emitOpError() 787 .append("expected ", funType.getNumResults(), " result operands") 788 .attachNote(function.getLoc()) 789 .append("return type declared here"); 790 791 for (auto pair : llvm::enumerate( 792 llvm::zip(function.getType().getResults(), returnOp.operands()))) { 793 Type type; 794 Value operand; 795 std::tie(type, operand) = pair.value(); 796 if (type != operand.getType()) 797 return returnOp.emitOpError() << "unexpected type `" << operand.getType() 798 << "' for operand #" << pair.index(); 799 } 800 return success(); 801 } 802 803 //===----------------------------------------------------------------------===// 804 // GPUModuleOp 805 //===----------------------------------------------------------------------===// 806 807 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 808 StringRef name) { 809 ensureTerminator(*result.addRegion(), builder, result.location); 810 result.attributes.push_back(builder.getNamedAttr( 811 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 812 } 813 814 static ParseResult parseGPUModuleOp(OpAsmParser &parser, 815 OperationState &result) { 816 StringAttr nameAttr; 817 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 818 result.attributes)) 819 return failure(); 820 821 // If module attributes are present, parse them. 822 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 823 return failure(); 824 825 // Parse the module body. 826 auto *body = result.addRegion(); 827 if (parser.parseRegion(*body, None, None)) 828 return failure(); 829 830 // Ensure that this module has a valid terminator. 831 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 832 return success(); 833 } 834 835 static void print(OpAsmPrinter &p, GPUModuleOp op) { 836 p << op.getOperationName() << ' '; 837 p.printSymbolName(op.getName()); 838 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 839 {SymbolTable::getSymbolAttrName()}); 840 p.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, 841 /*printBlockTerminators=*/false); 842 } 843 844 static ParseResult parseAsyncDependencies( 845 OpAsmParser &parser, Type &asyncTokenType, 846 SmallVectorImpl<OpAsmParser::OperandType> &asyncDependencies) { 847 auto loc = parser.getCurrentLocation(); 848 if (succeeded(parser.parseOptionalKeyword("async"))) { 849 if (parser.getNumResults() == 0) 850 return parser.emitError(loc, "needs to be named when marked 'async'"); 851 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 852 } 853 return parser.parseOperandList(asyncDependencies, 854 OpAsmParser::Delimiter::OptionalSquare); 855 } 856 857 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, 858 Type asyncTokenType, 859 OperandRange asyncDependencies) { 860 if (asyncTokenType) 861 printer << "async "; 862 if (asyncDependencies.empty()) 863 return; 864 printer << "["; 865 llvm::interleaveComma(asyncDependencies, printer); 866 printer << "]"; 867 } 868 869 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" 870 871 #define GET_OP_CLASSES 872 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 873