1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/Ops.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/FunctionImplementation.h" 19 #include "mlir/IR/Module.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/StandardTypes.h" 23 24 using namespace mlir; 25 using namespace mlir::gpu; 26 27 //===----------------------------------------------------------------------===// 28 // GPUDialect 29 //===----------------------------------------------------------------------===// 30 31 StringRef GPUDialect::getDialectName() { return "gpu"; } 32 33 bool GPUDialect::isKernel(Operation *op) { 34 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 35 return static_cast<bool>(isKernelAttr); 36 } 37 38 GPUDialect::GPUDialect(MLIRContext *context) 39 : Dialect(getDialectName(), context) { 40 addOperations< 41 #define GET_OP_LIST 42 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 43 >(); 44 } 45 46 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 47 NamedAttribute attr) { 48 if (!attr.second.isa<UnitAttr>() || 49 !attr.first.is(getContainerModuleAttrName())) 50 return success(); 51 52 auto module = dyn_cast<ModuleOp>(op); 53 if (!module) 54 return op->emitError("expected '") 55 << getContainerModuleAttrName() << "' attribute to be attached to '" 56 << ModuleOp::getOperationName() << '\''; 57 58 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 59 // Ignore launches that are nested more or less deep than functions in the 60 // module we are currently checking. 61 if (!launchOp.getParentOp() || 62 launchOp.getParentOp()->getParentOp() != module) 63 return success(); 64 65 // Ignore launch ops with missing attributes here. The errors will be 66 // reported by the verifiers of those ops. 67 if (!launchOp.getAttrOfType<StringAttr>( 68 LaunchFuncOp::getKernelAttrName()) || 69 !launchOp.getAttrOfType<SymbolRefAttr>( 70 LaunchFuncOp::getKernelModuleAttrName())) 71 return success(); 72 73 // Check that `launch_func` refers to a well-formed GPU kernel module. 74 StringRef kernelModuleName = launchOp.getKernelModuleName(); 75 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 76 if (!kernelModule) 77 return launchOp.emitOpError() 78 << "kernel module '" << kernelModuleName << "' is undefined"; 79 80 // Check that `launch_func` refers to a well-formed kernel function. 81 StringRef kernelName = launchOp.kernel(); 82 Operation *kernelFunc = kernelModule.lookupSymbol(kernelName); 83 auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc); 84 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 85 if (!kernelGPUFunction && !kernelLLVMFunction) 86 return launchOp.emitOpError("kernel function '") 87 << kernelName << "' is undefined"; 88 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 89 GPUDialect::getKernelFuncAttrName())) 90 return launchOp.emitOpError("kernel function is missing the '") 91 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 92 93 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 94 unsigned expectedNumArguments = kernelLLVMFunction 95 ? kernelLLVMFunction.getNumArguments() 96 : kernelGPUFunction.getNumArguments(); 97 if (expectedNumArguments != actualNumArguments) 98 return launchOp.emitOpError("got ") 99 << actualNumArguments << " kernel operands but expected " 100 << expectedNumArguments; 101 102 // Due to the ordering of the current impl of lowering and LLVMLowering, 103 // type checks need to be temporarily disabled. 104 // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc 105 // to encode target module" has landed. 106 // auto functionType = kernelFunc.getType(); 107 // for (unsigned i = 0; i < numKernelFuncArgs; ++i) { 108 // if (getKernelOperand(i).getType() != functionType.getInput(i)) { 109 // return emitOpError("type of function argument ") 110 // << i << " does not match"; 111 // } 112 // } 113 114 return success(); 115 }); 116 117 return walkResult.wasInterrupted() ? failure() : success(); 118 } 119 120 template <typename T> static LogicalResult verifyIndexOp(T op) { 121 auto dimension = op.dimension(); 122 if (dimension != "x" && dimension != "y" && dimension != "z") 123 return op.emitError("dimension \"") << dimension << "\" is invalid"; 124 return success(); 125 } 126 127 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 128 if (allReduce.body().empty() != allReduce.op().hasValue()) 129 return allReduce.emitError( 130 "expected either an op attribute or a non-empty body"); 131 if (!allReduce.body().empty()) { 132 if (allReduce.body().front().getNumArguments() != 2) 133 return allReduce.emitError("expected two region arguments"); 134 for (auto argument : allReduce.body().front().getArguments()) { 135 if (argument.getType() != allReduce.getType()) 136 return allReduce.emitError("incorrect region argument type"); 137 } 138 unsigned yieldCount = 0; 139 for (Block &block : allReduce.body()) { 140 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 141 if (yield.getNumOperands() != 1) 142 return allReduce.emitError("expected one gpu.yield operand"); 143 if (yield.getOperand(0).getType() != allReduce.getType()) 144 return allReduce.emitError("incorrect gpu.yield type"); 145 ++yieldCount; 146 } 147 } 148 if (yieldCount == 0) 149 return allReduce.emitError("expected gpu.yield op in region"); 150 } 151 return success(); 152 } 153 154 static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { 155 auto type = shuffleOp.value().getType(); 156 if (shuffleOp.result().getType() != type) { 157 return shuffleOp.emitOpError() 158 << "requires the same type for value operand and result"; 159 } 160 if (!type.isIntOrFloat() || type.getIntOrFloatBitWidth() != 32) { 161 return shuffleOp.emitOpError() 162 << "requires value operand type to be f32 or i32"; 163 } 164 return success(); 165 } 166 167 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { 168 p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' 169 << op.mode() << " : " << op.value().getType(); 170 } 171 172 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { 173 SmallVector<OpAsmParser::OperandType, 3> operandInfo; 174 if (parser.parseOperandList(operandInfo, 3)) 175 return failure(); 176 177 StringRef mode; 178 if (parser.parseKeyword(&mode)) 179 return failure(); 180 state.addAttribute("mode", parser.getBuilder().getStringAttr(mode)); 181 182 Type valueType; 183 Type int32Type = parser.getBuilder().getIntegerType(32); 184 Type int1Type = parser.getBuilder().getI1Type(); 185 if (parser.parseColonType(valueType) || 186 parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type}, 187 parser.getCurrentLocation(), state.operands) || 188 parser.addTypesToList({valueType, int1Type}, state.types)) 189 return failure(); 190 return success(); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // LaunchOp 195 //===----------------------------------------------------------------------===// 196 197 void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX, 198 Value gridSizeY, Value gridSizeZ, Value blockSizeX, 199 Value blockSizeY, Value blockSizeZ, ValueRange operands) { 200 // Add grid and block sizes as op operands, followed by the data operands. 201 result.addOperands( 202 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 203 result.addOperands(operands); 204 205 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 206 // where the first kNumConfigRegionAttributes arguments have `index` type and 207 // the rest have the same types as the data operands. 208 Region *kernelRegion = result.addRegion(); 209 Block *body = new Block(); 210 body->addArguments( 211 std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); 212 body->addArguments(llvm::to_vector<4>(operands.getTypes())); 213 kernelRegion->push_back(body); 214 } 215 216 KernelDim3 LaunchOp::getBlockIds() { 217 assert(!body().getBlocks().empty() && "FuncOp body must not be empty."); 218 auto args = body().getBlocks().front().getArguments(); 219 return KernelDim3{args[0], args[1], args[2]}; 220 } 221 222 KernelDim3 LaunchOp::getThreadIds() { 223 assert(!body().getBlocks().empty() && "FuncOp body must not be empty."); 224 auto args = body().getBlocks().front().getArguments(); 225 return KernelDim3{args[3], args[4], args[5]}; 226 } 227 228 KernelDim3 LaunchOp::getGridSize() { 229 assert(!body().getBlocks().empty() && "FuncOp body must not be empty."); 230 auto args = body().getBlocks().front().getArguments(); 231 return KernelDim3{args[6], args[7], args[8]}; 232 } 233 234 KernelDim3 LaunchOp::getBlockSize() { 235 assert(!body().getBlocks().empty() && "FuncOp body must not be empty."); 236 auto args = body().getBlocks().front().getArguments(); 237 return KernelDim3{args[9], args[10], args[11]}; 238 } 239 240 LaunchOp::operand_range LaunchOp::getKernelOperandValues() { 241 return llvm::drop_begin(getOperands(), kNumConfigOperands); 242 } 243 244 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { 245 return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); 246 } 247 248 KernelDim3 LaunchOp::getGridSizeOperandValues() { 249 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 250 } 251 252 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 253 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 254 } 255 256 iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { 257 auto args = body().getBlocks().front().getArguments(); 258 return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); 259 } 260 261 static LogicalResult verify(LaunchOp op) { 262 // Kernel launch takes kNumConfigOperands leading operands for grid/block 263 // sizes and transforms them into kNumConfigRegionAttributes region arguments 264 // for block/thread identifiers and grid/block sizes. 265 if (!op.body().empty()) { 266 Block &entryBlock = op.body().front(); 267 if (entryBlock.getNumArguments() != 268 LaunchOp::kNumConfigOperands + op.getNumOperands()) 269 return op.emitOpError("unexpected number of region arguments"); 270 } 271 272 // Block terminators without successors are expected to exit the kernel region 273 // and must be `gpu.launch`. 274 for (Block &block : op.body()) { 275 if (block.empty()) 276 continue; 277 if (block.back().getNumSuccessors() != 0) 278 continue; 279 if (!isa<gpu::ReturnOp>(&block.back())) { 280 return block.back() 281 .emitError("expected 'gpu.terminator' or a terminator with " 282 "successors") 283 .attachNote(op.getLoc()) 284 << "in '" << LaunchOp::getOperationName() << "' body region"; 285 } 286 } 287 288 return success(); 289 } 290 291 // Pretty-print the kernel grid/block size assignment as 292 // (%iter-x, %iter-y, %iter-z) in 293 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 294 // where %size-* and %iter-* will correspond to the body region arguments. 295 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 296 ValueRange operands, KernelDim3 ids) { 297 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 298 p << size.x << " = " << operands[0] << ", "; 299 p << size.y << " = " << operands[1] << ", "; 300 p << size.z << " = " << operands[2] << ')'; 301 } 302 303 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { 304 ValueRange operands = op.getOperands(); 305 306 // Print the launch configuration. 307 p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); 308 printSizeAssignment(p, op.getGridSize(), operands.take_front(3), 309 op.getBlockIds()); 310 p << ' ' << op.getThreadsKeyword(); 311 printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3), 312 op.getThreadIds()); 313 314 // From now on, the first kNumConfigOperands operands corresponding to grid 315 // and block sizes are irrelevant, so we can drop them. 316 operands = operands.drop_front(LaunchOp::kNumConfigOperands); 317 318 // Print the data argument remapping. 319 if (!op.body().empty() && !operands.empty()) { 320 p << ' ' << op.getArgsKeyword() << '('; 321 Block *entryBlock = &op.body().front(); 322 interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) { 323 p << entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i) 324 << " = " << operands[i]; 325 }); 326 p << ") "; 327 } 328 329 // Print the types of data arguments. 330 if (!operands.empty()) 331 p << ": " << operands.getTypes(); 332 333 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 334 p.printOptionalAttrDict(op.getAttrs()); 335 } 336 337 // Parse the size assignment blocks for blocks and threads. These have the form 338 // (%region_arg, %region_arg, %region_arg) in 339 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 340 // where %region_arg are percent-identifiers for the region arguments to be 341 // introduced further (SSA defs), and %operand are percent-identifiers for the 342 // SSA value uses. 343 static ParseResult 344 parseSizeAssignment(OpAsmParser &parser, 345 MutableArrayRef<OpAsmParser::OperandType> sizes, 346 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 347 MutableArrayRef<OpAsmParser::OperandType> indices) { 348 assert(indices.size() == 3 && "space for three indices expected"); 349 SmallVector<OpAsmParser::OperandType, 3> args; 350 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 351 OpAsmParser::Delimiter::Paren) || 352 parser.parseKeyword("in") || parser.parseLParen()) 353 return failure(); 354 std::move(args.begin(), args.end(), indices.begin()); 355 356 for (int i = 0; i < 3; ++i) { 357 if (i != 0 && parser.parseComma()) 358 return failure(); 359 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 360 parser.parseOperand(sizes[i])) 361 return failure(); 362 } 363 364 return parser.parseRParen(); 365 } 366 367 // Parses a Launch operation. 368 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 369 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 370 // (`args` ssa-reassignment `:` type-list)? 371 // region attr-dict? 372 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 373 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { 374 // Sizes of the grid and block. 375 SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes( 376 LaunchOp::kNumConfigOperands); 377 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 378 379 // Actual (data) operands passed to the kernel. 380 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 381 382 // Region arguments to be created. 383 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 384 LaunchOp::kNumConfigRegionAttributes); 385 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 386 387 // Parse the size assignment segments: the first segment assigns grid sizes 388 // and defines values for block identifiers; the second segment assigns block 389 // sizes and defines values for thread identifiers. In the region argument 390 // list, identifiers precede sizes, and block-related values precede 391 // thread-related values. 392 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 393 parseSizeAssignment(parser, sizesRef.take_front(3), 394 regionArgsRef.slice(6, 3), 395 regionArgsRef.slice(0, 3)) || 396 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 397 parseSizeAssignment(parser, sizesRef.drop_front(3), 398 regionArgsRef.slice(9, 3), 399 regionArgsRef.slice(3, 3)) || 400 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 401 result.operands)) 402 return failure(); 403 404 // If kernel argument renaming segment is present, parse it. When present, 405 // the segment should have at least one element. If this segment is present, 406 // so is the trailing type list. Parse it as well and use the parsed types 407 // to resolve the operands passed to the kernel arguments. 408 SmallVector<Type, 4> dataTypes; 409 if (!parser.parseOptionalKeyword(LaunchOp::getArgsKeyword())) { 410 llvm::SMLoc argsLoc = parser.getCurrentLocation(); 411 412 regionArgs.push_back({}); 413 dataOperands.push_back({}); 414 if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) || 415 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 416 return failure(); 417 418 while (!parser.parseOptionalComma()) { 419 regionArgs.push_back({}); 420 dataOperands.push_back({}); 421 if (parser.parseRegionArgument(regionArgs.back()) || 422 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 423 return failure(); 424 } 425 426 if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) || 427 parser.resolveOperands(dataOperands, dataTypes, argsLoc, 428 result.operands)) 429 return failure(); 430 } 431 432 // Introduce the body region and parse it. The region has 433 // kNumConfigRegionAttributes leading arguments that correspond to 434 // block/thread identifiers and grid/block sizes, all of the `index` type. 435 // Follow the actual kernel arguments. 436 Type index = parser.getBuilder().getIndexType(); 437 dataTypes.insert(dataTypes.begin(), LaunchOp::kNumConfigRegionAttributes, 438 index); 439 Region *body = result.addRegion(); 440 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 441 parser.parseOptionalAttrDict(result.attributes)); 442 } 443 444 void LaunchOp::eraseKernelArgument(unsigned index) { 445 Block &entryBlock = body().front(); 446 assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes && 447 "kernel argument index overflow"); 448 entryBlock.eraseArgument(kNumConfigRegionAttributes + index); 449 getOperation()->eraseOperand(kNumConfigOperands + index); 450 } 451 452 namespace { 453 // Clone any known constants passed as operands to the kernel into its body. 454 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { 455 using OpRewritePattern<LaunchOp>::OpRewritePattern; 456 457 PatternMatchResult matchAndRewrite(LaunchOp launchOp, 458 PatternRewriter &rewriter) const override { 459 rewriter.startRootUpdate(launchOp); 460 PatternRewriter::InsertionGuard guard(rewriter); 461 rewriter.setInsertionPointToStart(&launchOp.body().front()); 462 463 // Traverse operands passed to kernel and check if some of them are known 464 // constants. If so, clone the constant operation inside the kernel region 465 // and use it instead of passing the value from the parent region. Perform 466 // the traversal in the inverse order to simplify index arithmetics when 467 // dropping arguments. 468 auto operands = launchOp.getKernelOperandValues(); 469 auto kernelArgs = launchOp.getKernelArguments(); 470 bool found = false; 471 for (unsigned i = operands.size(); i > 0; --i) { 472 unsigned index = i - 1; 473 Value operand = operands[index]; 474 if (!isa_and_nonnull<ConstantOp>(operand.getDefiningOp())) 475 continue; 476 477 found = true; 478 Value internalConstant = 479 rewriter.clone(*operand.getDefiningOp())->getResult(0); 480 Value kernelArg = *std::next(kernelArgs.begin(), index); 481 kernelArg.replaceAllUsesWith(internalConstant); 482 launchOp.eraseKernelArgument(index); 483 } 484 485 if (!found) { 486 rewriter.cancelRootUpdate(launchOp); 487 return matchFailure(); 488 } 489 490 rewriter.finalizeRootUpdate(launchOp); 491 return matchSuccess(); 492 } 493 }; 494 } // end namespace 495 496 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 497 MLIRContext *context) { 498 results.insert<PropagateConstantBounds>(context); 499 } 500 501 //===----------------------------------------------------------------------===// 502 // LaunchFuncOp 503 //===----------------------------------------------------------------------===// 504 505 void LaunchFuncOp::build(Builder *builder, OperationState &result, 506 GPUFuncOp kernelFunc, Value gridSizeX, Value gridSizeY, 507 Value gridSizeZ, Value blockSizeX, Value blockSizeY, 508 Value blockSizeZ, ValueRange kernelOperands) { 509 // Add grid and block sizes as op operands, followed by the data operands. 510 result.addOperands( 511 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 512 result.addOperands(kernelOperands); 513 result.addAttribute(getKernelAttrName(), 514 builder->getStringAttr(kernelFunc.getName())); 515 auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>(); 516 result.addAttribute(getKernelModuleAttrName(), 517 builder->getSymbolRefAttr(kernelModule.getName())); 518 } 519 520 void LaunchFuncOp::build(Builder *builder, OperationState &result, 521 GPUFuncOp kernelFunc, KernelDim3 gridSize, 522 KernelDim3 blockSize, ValueRange kernelOperands) { 523 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 524 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 525 } 526 527 StringRef LaunchFuncOp::kernel() { 528 return getAttrOfType<StringAttr>(getKernelAttrName()).getValue(); 529 } 530 531 unsigned LaunchFuncOp::getNumKernelOperands() { 532 return getNumOperands() - kNumConfigOperands; 533 } 534 535 StringRef LaunchFuncOp::getKernelModuleName() { 536 return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()) 537 .getRootReference(); 538 } 539 540 Value LaunchFuncOp::getKernelOperand(unsigned i) { 541 return getOperation()->getOperand(i + kNumConfigOperands); 542 } 543 544 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 545 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 546 } 547 548 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 549 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 550 } 551 552 static LogicalResult verify(LaunchFuncOp op) { 553 auto module = op.getParentOfType<ModuleOp>(); 554 if (!module) 555 return op.emitOpError("expected to belong to a module"); 556 557 if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName())) 558 return op.emitOpError( 559 "expected the closest surrounding module to have the '" + 560 GPUDialect::getContainerModuleAttrName() + "' attribute"); 561 562 auto kernelAttr = op.getAttrOfType<StringAttr>(op.getKernelAttrName()); 563 if (!kernelAttr) 564 return op.emitOpError("string attribute '" + op.getKernelAttrName() + 565 "' must be specified"); 566 567 auto kernelModuleAttr = 568 op.getAttrOfType<SymbolRefAttr>(op.getKernelModuleAttrName()); 569 if (!kernelModuleAttr) 570 return op.emitOpError("symbol reference attribute '" + 571 op.getKernelModuleAttrName() + "' must be specified"); 572 573 return success(); 574 } 575 576 //===----------------------------------------------------------------------===// 577 // GPUFuncOp 578 //===----------------------------------------------------------------------===// 579 580 /// Adds a workgroup attribution to "op" of the MemRef type with the given shape 581 /// and element type. 582 Value GPUFuncOp::addWorkgroupAttribution(ArrayRef<int64_t> shape, 583 Type elementType) { 584 unsigned pos = getNumFuncArguments() + getNumWorkgroupAttributions(); 585 Block &bodyBlock = body().front(); 586 Value attribution = bodyBlock.insertArgument( 587 std::next(bodyBlock.args_begin(), pos), 588 MemRefType::get(shape, elementType, /*affineMapComposition=*/{}, 589 GPUDialect::getWorkgroupAddressSpace())); 590 auto numWorkgroupBuffersAttr = 591 getAttrOfType<IntegerAttr>(getNumWorkgroupAttributionsAttrName()); 592 setAttr(getNumWorkgroupAttributionsAttrName(), 593 IntegerAttr::get(numWorkgroupBuffersAttr.getType(), 594 numWorkgroupBuffersAttr.getValue() + 1)); 595 return attribution; 596 } 597 598 void GPUFuncOp::build(Builder *builder, OperationState &result, StringRef name, 599 FunctionType type, ArrayRef<Type> workgroupAttributions, 600 ArrayRef<Type> privateAttributions, 601 ArrayRef<NamedAttribute> attrs) { 602 result.addAttribute(SymbolTable::getSymbolAttrName(), 603 builder->getStringAttr(name)); 604 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 605 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 606 builder->getI64IntegerAttr(workgroupAttributions.size())); 607 result.addAttributes(attrs); 608 Region *body = result.addRegion(); 609 Block *entryBlock = new Block; 610 entryBlock->addArguments(type.getInputs()); 611 entryBlock->addArguments(workgroupAttributions); 612 entryBlock->addArguments(privateAttributions); 613 614 body->getBlocks().push_back(entryBlock); 615 } 616 617 /// Parses a GPU function memory attribution. 618 /// 619 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 620 /// (`private` `(` ssa-id-and-type-list `)`)? 621 /// 622 /// Note that this function parses only one of the two similar parts, with the 623 /// keyword provided as argument. 624 static ParseResult 625 parseAttributions(OpAsmParser &parser, StringRef keyword, 626 SmallVectorImpl<OpAsmParser::OperandType> &args, 627 SmallVectorImpl<Type> &argTypes) { 628 // If we could not parse the keyword, just assume empty list and succeed. 629 if (failed(parser.parseOptionalKeyword(keyword))) 630 return success(); 631 632 if (failed(parser.parseLParen())) 633 return failure(); 634 635 // Early exit for an empty list. 636 if (succeeded(parser.parseOptionalRParen())) 637 return success(); 638 639 do { 640 OpAsmParser::OperandType arg; 641 Type type; 642 643 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 644 return failure(); 645 646 args.push_back(arg); 647 argTypes.push_back(type); 648 } while (succeeded(parser.parseOptionalComma())); 649 650 return parser.parseRParen(); 651 } 652 653 /// Parses a GPU function. 654 /// 655 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 656 /// (`->` function-result-list)? memory-attribution `kernel`? 657 /// function-attributes? region 658 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { 659 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 660 SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs; 661 SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs; 662 SmallVector<Type, 8> argTypes; 663 SmallVector<Type, 4> resultTypes; 664 bool isVariadic; 665 666 // Parse the function name. 667 StringAttr nameAttr; 668 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 669 result.attributes)) 670 return failure(); 671 672 auto signatureLocation = parser.getCurrentLocation(); 673 if (failed(impl::parseFunctionSignature( 674 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 675 isVariadic, resultTypes, resultAttrs))) 676 return failure(); 677 678 if (entryArgs.empty() && !argTypes.empty()) 679 return parser.emitError(signatureLocation) 680 << "gpu.func requires named arguments"; 681 682 // Construct the function type. More types will be added to the region, but 683 // not to the functiont type. 684 Builder &builder = parser.getBuilder(); 685 auto type = builder.getFunctionType(argTypes, resultTypes); 686 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 687 688 // Parse workgroup memory attributions. 689 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 690 entryArgs, argTypes))) 691 return failure(); 692 693 // Store the number of operands we just parsed as the number of workgroup 694 // memory attributions. 695 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 696 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 697 builder.getI64IntegerAttr(numWorkgroupAttrs)); 698 699 // Parse private memory attributions. 700 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 701 entryArgs, argTypes))) 702 return failure(); 703 704 // Parse the kernel attribute if present. 705 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 706 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 707 builder.getUnitAttr()); 708 709 // Parse attributes. 710 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 711 return failure(); 712 mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 713 714 // Parse the region. If no argument names were provided, take all names 715 // (including those of attributions) from the entry block. 716 auto *body = result.addRegion(); 717 return parser.parseRegion(*body, entryArgs, argTypes); 718 } 719 720 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 721 ArrayRef<BlockArgument> values) { 722 if (values.empty()) 723 return; 724 725 p << ' ' << keyword << '('; 726 interleaveComma(values, p, 727 [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 728 p << ')'; 729 } 730 731 /// Prints a GPU Func op. 732 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { 733 p << GPUFuncOp::getOperationName() << ' '; 734 p.printSymbolName(op.getName()); 735 736 FunctionType type = op.getType(); 737 impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), 738 /*isVariadic=*/false, type.getResults()); 739 740 printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); 741 printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); 742 if (op.isKernel()) 743 p << ' ' << op.getKernelKeyword(); 744 745 impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), 746 type.getNumResults(), 747 {op.getNumWorkgroupAttributionsAttrName(), 748 GPUDialect::getKernelFuncAttrName()}); 749 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); 750 } 751 752 void GPUFuncOp::setType(FunctionType newType) { 753 auto oldType = getType(); 754 assert(newType.getNumResults() == oldType.getNumResults() && 755 "unimplemented: changes to the number of results"); 756 757 SmallVector<char, 16> nameBuf; 758 for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) 759 removeAttr(getArgAttrName(i, nameBuf)); 760 761 setAttr(getTypeAttrName(), TypeAttr::get(newType)); 762 } 763 764 /// Hook for FunctionLike verifier. 765 LogicalResult GPUFuncOp::verifyType() { 766 Type type = getTypeAttr().getValue(); 767 if (!type.isa<FunctionType>()) 768 return emitOpError("requires '" + getTypeAttrName() + 769 "' attribute of function type"); 770 return success(); 771 } 772 773 static LogicalResult verifyAttributions(Operation *op, 774 ArrayRef<BlockArgument> attributions, 775 unsigned memorySpace) { 776 for (Value v : attributions) { 777 auto type = v.getType().dyn_cast<MemRefType>(); 778 if (!type) 779 return op->emitOpError() << "expected memref type in attribution"; 780 781 if (type.getMemorySpace() != memorySpace) { 782 return op->emitOpError() 783 << "expected memory space " << memorySpace << " in attribution"; 784 } 785 } 786 return success(); 787 } 788 789 /// Verifies the body of the function. 790 LogicalResult GPUFuncOp::verifyBody() { 791 unsigned numFuncArguments = getNumArguments(); 792 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 793 unsigned numBlockArguments = front().getNumArguments(); 794 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 795 return emitOpError() << "expected at least " 796 << numFuncArguments + numWorkgroupAttributions 797 << " arguments to body region"; 798 799 ArrayRef<Type> funcArgTypes = getType().getInputs(); 800 for (unsigned i = 0; i < numFuncArguments; ++i) { 801 Type blockArgType = front().getArgument(i).getType(); 802 if (funcArgTypes[i] != blockArgType) 803 return emitOpError() << "expected body region argument #" << i 804 << " to be of type " << funcArgTypes[i] << ", got " 805 << blockArgType; 806 } 807 808 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 809 GPUDialect::getWorkgroupAddressSpace())) || 810 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 811 GPUDialect::getPrivateAddressSpace()))) 812 return failure(); 813 814 return success(); 815 } 816 817 //===----------------------------------------------------------------------===// 818 // GPUModuleOp 819 //===----------------------------------------------------------------------===// 820 821 void GPUModuleOp::build(Builder *builder, OperationState &result, 822 StringRef name) { 823 ensureTerminator(*result.addRegion(), *builder, result.location); 824 result.attributes.push_back(builder->getNamedAttr( 825 ::mlir::SymbolTable::getSymbolAttrName(), builder->getStringAttr(name))); 826 } 827 828 static ParseResult parseGPUModuleOp(OpAsmParser &parser, 829 OperationState &result) { 830 StringAttr nameAttr; 831 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 832 result.attributes)) 833 return failure(); 834 835 // If module attributes are present, parse them. 836 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 837 return failure(); 838 839 // Parse the module body. 840 auto *body = result.addRegion(); 841 if (parser.parseRegion(*body, None, None)) 842 return failure(); 843 844 // Ensure that this module has a valid terminator. 845 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 846 return success(); 847 } 848 849 static void print(OpAsmPrinter &p, GPUModuleOp op) { 850 p << op.getOperationName() << ' '; 851 p.printSymbolName(op.getName()); 852 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 853 {SymbolTable::getSymbolAttrName()}); 854 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 855 /*printBlockTerminators=*/false); 856 } 857 858 // Namespace avoids ambiguous ReturnOpOperandAdaptor. 859 namespace mlir { 860 namespace gpu { 861 #define GET_OP_CLASSES 862 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 863 } // namespace gpu 864 } // namespace mlir 865