1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the GPU kernel-related dialect and its operations. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/GPU/GPUDialect.h" 23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24 #include "mlir/Dialect/StandardOps/Ops.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Function.h" 27 #include "mlir/IR/Module.h" 28 #include "mlir/IR/OpImplementation.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/IR/StandardTypes.h" 31 32 using namespace mlir; 33 using namespace mlir::gpu; 34 35 //===----------------------------------------------------------------------===// 36 // GPUDialect 37 //===----------------------------------------------------------------------===// 38 39 StringRef GPUDialect::getDialectName() { return "gpu"; } 40 41 bool GPUDialect::isKernel(Operation *op) { 42 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 43 return static_cast<bool>(isKernelAttr); 44 } 45 46 GPUDialect::GPUDialect(MLIRContext *context) 47 : Dialect(getDialectName(), context) { 48 addOperations<LaunchOp, LaunchFuncOp, 49 #define GET_OP_LIST 50 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 51 >(); 52 } 53 54 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 55 NamedAttribute attr) { 56 if (!attr.second.isa<UnitAttr>() || 57 !attr.first.is(getContainerModuleAttrName())) 58 return success(); 59 60 auto module = dyn_cast<ModuleOp>(op); 61 if (!module) 62 return op->emitError("expected '") 63 << getContainerModuleAttrName() << "' attribute to be attached to '" 64 << ModuleOp::getOperationName() << '\''; 65 66 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 67 // Ignore launches that are nested more or less deep than functions in the 68 // module we are currently checking. 69 if (!launchOp.getParentOp() || 70 launchOp.getParentOp()->getParentOp() != module) 71 return success(); 72 73 // Ignore launch ops with missing attributes here. The errors will be 74 // reported by the verifiers of those ops. 75 if (!launchOp.getAttrOfType<StringAttr>( 76 LaunchFuncOp::getKernelAttrName()) || 77 !launchOp.getAttrOfType<SymbolRefAttr>( 78 LaunchFuncOp::getKernelModuleAttrName())) 79 return success(); 80 81 // Check that `launch_func` refers to a well-formed GPU kernel module. 82 StringRef kernelModuleName = launchOp.getKernelModuleName(); 83 auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName); 84 if (!kernelModule) 85 return launchOp.emitOpError() 86 << "kernel module '" << kernelModuleName << "' is undefined"; 87 if (!kernelModule.getAttrOfType<UnitAttr>( 88 GPUDialect::getKernelModuleAttrName())) 89 return launchOp.emitOpError("module '") 90 << kernelModuleName << "' is missing the '" 91 << GPUDialect::getKernelModuleAttrName() << "' attribute"; 92 93 // Check that `launch_func` refers to a well-formed kernel function. 94 StringRef kernelName = launchOp.kernel(); 95 Operation *kernelFunc = kernelModule.lookupSymbol(kernelName); 96 auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc); 97 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 98 if (!kernelStdFunction && !kernelLLVMFunction) 99 return launchOp.emitOpError("kernel function '") 100 << kernelName << "' is undefined"; 101 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 102 GPUDialect::getKernelFuncAttrName())) 103 return launchOp.emitOpError("kernel function is missing the '") 104 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 105 106 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 107 unsigned expectedNumArguments = kernelLLVMFunction 108 ? kernelLLVMFunction.getNumArguments() 109 : kernelStdFunction.getNumArguments(); 110 if (expectedNumArguments != actualNumArguments) 111 return launchOp.emitOpError("got ") 112 << actualNumArguments << " kernel operands but expected " 113 << expectedNumArguments; 114 115 // Due to the ordering of the current impl of lowering and LLVMLowering, 116 // type checks need to be temporarily disabled. 117 // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc 118 // to encode target module" has landed. 119 // auto functionType = kernelFunc.getType(); 120 // for (unsigned i = 0; i < numKernelFuncArgs; ++i) { 121 // if (getKernelOperand(i)->getType() != functionType.getInput(i)) { 122 // return emitOpError("type of function argument ") 123 // << i << " does not match"; 124 // } 125 // } 126 127 return success(); 128 }); 129 130 return walkResult.wasInterrupted() ? failure() : success(); 131 } 132 133 template <typename T> static LogicalResult verifyIndexOp(T op) { 134 auto dimension = op.dimension(); 135 if (dimension != "x" && dimension != "y" && dimension != "z") 136 return op.emitError("dimension \"") << dimension << "\" is invalid"; 137 return success(); 138 } 139 140 #define GET_OP_CLASSES 141 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 142 143 //===----------------------------------------------------------------------===// 144 // LaunchOp 145 //===----------------------------------------------------------------------===// 146 147 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) { 148 SmallVector<Type, 4> types; 149 types.reserve(values.size()); 150 for (Value *v : values) 151 types.push_back(v->getType()); 152 return types; 153 } 154 155 void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX, 156 Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, 157 Value *blockSizeY, Value *blockSizeZ, 158 ArrayRef<Value *> operands) { 159 // Add grid and block sizes as op operands, followed by the data operands. 160 result.addOperands( 161 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 162 result.addOperands(operands); 163 164 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 165 // where the first kNumConfigRegionAttributes arguments have `index` type and 166 // the rest have the same types as the data operands. 167 Region *kernelRegion = result.addRegion(); 168 Block *body = new Block(); 169 body->addArguments( 170 std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); 171 body->addArguments(getValueTypes(operands)); 172 kernelRegion->push_back(body); 173 } 174 175 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); } 176 177 KernelDim3 LaunchOp::getBlockIds() { 178 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 179 auto args = getBody().getBlocks().front().getArguments(); 180 return KernelDim3{args[0], args[1], args[2]}; 181 } 182 183 KernelDim3 LaunchOp::getThreadIds() { 184 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 185 auto args = getBody().getBlocks().front().getArguments(); 186 return KernelDim3{args[3], args[4], args[5]}; 187 } 188 189 KernelDim3 LaunchOp::getGridSize() { 190 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 191 auto args = getBody().getBlocks().front().getArguments(); 192 return KernelDim3{args[6], args[7], args[8]}; 193 } 194 195 KernelDim3 LaunchOp::getBlockSize() { 196 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 197 auto args = getBody().getBlocks().front().getArguments(); 198 return KernelDim3{args[9], args[10], args[11]}; 199 } 200 201 LaunchOp::operand_range LaunchOp::getKernelOperandValues() { 202 return llvm::drop_begin(getOperands(), kNumConfigOperands); 203 } 204 205 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { 206 return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); 207 } 208 209 KernelDim3 LaunchOp::getGridSizeOperandValues() { 210 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 211 } 212 213 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 214 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 215 } 216 217 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { 218 auto args = getBody().getBlocks().front().getArguments(); 219 return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); 220 } 221 222 LogicalResult LaunchOp::verify() { 223 // Kernel launch takes kNumConfigOperands leading operands for grid/block 224 // sizes and transforms them into kNumConfigRegionAttributes region arguments 225 // for block/thread identifiers and grid/block sizes. 226 if (!getBody().empty()) { 227 Block &entryBlock = getBody().front(); 228 if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands()) 229 return emitOpError("unexpected number of region arguments"); 230 } 231 232 // Block terminators without successors are expected to exit the kernel region 233 // and must be `gpu.launch`. 234 for (Block &block : getBody()) { 235 if (block.empty()) 236 continue; 237 if (block.back().getNumSuccessors() != 0) 238 continue; 239 if (!isa<gpu::Return>(&block.back())) { 240 return block.back() 241 .emitError("expected 'gpu.terminator' or a terminator with " 242 "successors") 243 .attachNote(getLoc()) 244 << "in '" << getOperationName() << "' body region"; 245 } 246 } 247 248 return success(); 249 } 250 251 // Pretty-print the kernel grid/block size assignment as 252 // (%iter-x, %iter-y, %iter-z) in 253 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 254 // where %size-* and %iter-* will correspond to the body region arguments. 255 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 256 ArrayRef<Value *> operands, KernelDim3 ids) { 257 p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in ("; 258 p << *size.x << " = " << *operands[0] << ", "; 259 p << *size.y << " = " << *operands[1] << ", "; 260 p << *size.z << " = " << *operands[2] << ')'; 261 } 262 263 void LaunchOp::print(OpAsmPrinter &p) { 264 SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end()); 265 ArrayRef<Value *> operands(operandContainer); 266 267 // Print the launch configuration. 268 p << getOperationName() << ' ' << getBlocksKeyword(); 269 printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds()); 270 p << ' ' << getThreadsKeyword(); 271 printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds()); 272 273 // From now on, the first kNumConfigOperands operands corresponding to grid 274 // and block sizes are irrelevant, so we can drop them. 275 operands = operands.drop_front(kNumConfigOperands); 276 277 // Print the data argument remapping. 278 if (!getBody().empty() && !operands.empty()) { 279 p << ' ' << getArgsKeyword() << '('; 280 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 281 if (i != 0) 282 p << ", "; 283 p << *getBody().front().getArgument(kNumConfigRegionAttributes + i) 284 << " = " << *operands[i]; 285 } 286 p << ") "; 287 } 288 289 // Print the types of data arguments. 290 if (!operands.empty()) { 291 p << ": "; 292 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 293 if (i != 0) 294 p << ", "; 295 p << operands[i]->getType(); 296 } 297 } 298 299 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 300 p.printOptionalAttrDict(getAttrs()); 301 } 302 303 // Parse the size assignment blocks for blocks and threads. These have the form 304 // (%region_arg, %region_arg, %region_arg) in 305 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 306 // where %region_arg are percent-identifiers for the region arguments to be 307 // introduced further (SSA defs), and %operand are percent-identifiers for the 308 // SSA value uses. 309 static ParseResult 310 parseSizeAssignment(OpAsmParser &parser, 311 MutableArrayRef<OpAsmParser::OperandType> sizes, 312 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 313 MutableArrayRef<OpAsmParser::OperandType> indices) { 314 assert(indices.size() == 3 && "space for three indices expected"); 315 SmallVector<OpAsmParser::OperandType, 3> args; 316 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 317 OpAsmParser::Delimiter::Paren) || 318 parser.parseKeyword("in") || parser.parseLParen()) 319 return failure(); 320 std::move(args.begin(), args.end(), indices.begin()); 321 322 for (int i = 0; i < 3; ++i) { 323 if (i != 0 && parser.parseComma()) 324 return failure(); 325 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 326 parser.parseOperand(sizes[i])) 327 return failure(); 328 } 329 330 return parser.parseRParen(); 331 } 332 333 // Parses a Launch operation. 334 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 335 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 336 // (`args` ssa-reassignment `:` type-list)? 337 // region attr-dict? 338 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 339 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { 340 // Sizes of the grid and block. 341 SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes( 342 kNumConfigOperands); 343 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 344 345 // Actual (data) operands passed to the kernel. 346 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 347 348 // Region arguments to be created. 349 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 350 kNumConfigRegionAttributes); 351 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 352 353 // Parse the size assignment segments: the first segment assigns grid sizes 354 // and defines values for block identifiers; the second segment assigns block 355 // sizes and defines values for thread identifiers. In the region argument 356 // list, identifiers precede sizes, and block-related values precede 357 // thread-related values. 358 if (parser.parseKeyword(getBlocksKeyword().data()) || 359 parseSizeAssignment(parser, sizesRef.take_front(3), 360 regionArgsRef.slice(6, 3), 361 regionArgsRef.slice(0, 3)) || 362 parser.parseKeyword(getThreadsKeyword().data()) || 363 parseSizeAssignment(parser, sizesRef.drop_front(3), 364 regionArgsRef.slice(9, 3), 365 regionArgsRef.slice(3, 3)) || 366 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 367 result.operands)) 368 return failure(); 369 370 // If kernel argument renaming segment is present, parse it. When present, 371 // the segment should have at least one element. If this segment is present, 372 // so is the trailing type list. Parse it as well and use the parsed types 373 // to resolve the operands passed to the kernel arguments. 374 SmallVector<Type, 4> dataTypes; 375 if (!parser.parseOptionalKeyword(getArgsKeyword())) { 376 llvm::SMLoc argsLoc = parser.getCurrentLocation(); 377 378 regionArgs.push_back({}); 379 dataOperands.push_back({}); 380 if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) || 381 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 382 return failure(); 383 384 while (!parser.parseOptionalComma()) { 385 regionArgs.push_back({}); 386 dataOperands.push_back({}); 387 if (parser.parseRegionArgument(regionArgs.back()) || 388 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 389 return failure(); 390 } 391 392 if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) || 393 parser.resolveOperands(dataOperands, dataTypes, argsLoc, 394 result.operands)) 395 return failure(); 396 } 397 398 // Introduce the body region and parse it. The region has 399 // kNumConfigRegionAttributes leading arguments that correspond to 400 // block/thread identifiers and grid/block sizes, all of the `index` type. 401 // Follow the actual kernel arguments. 402 Type index = parser.getBuilder().getIndexType(); 403 dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); 404 Region *body = result.addRegion(); 405 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 406 parser.parseOptionalAttributeDict(result.attributes)); 407 } 408 409 void LaunchOp::eraseKernelArgument(unsigned index) { 410 Block &entryBlock = getBody().front(); 411 assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes && 412 "kernel argument index overflow"); 413 entryBlock.eraseArgument(kNumConfigRegionAttributes + index); 414 getOperation()->eraseOperand(kNumConfigOperands + index); 415 } 416 417 namespace { 418 // Clone any known constants passed as operands to the kernel into its body. 419 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { 420 using OpRewritePattern<LaunchOp>::OpRewritePattern; 421 422 PatternMatchResult matchAndRewrite(LaunchOp launchOp, 423 PatternRewriter &rewriter) const override { 424 auto origInsertionPoint = rewriter.saveInsertionPoint(); 425 rewriter.setInsertionPointToStart(&launchOp.getBody().front()); 426 427 // Traverse operands passed to kernel and check if some of them are known 428 // constants. If so, clone the constant operation inside the kernel region 429 // and use it instead of passing the value from the parent region. Perform 430 // the traversal in the inverse order to simplify index arithmetics when 431 // dropping arguments. 432 SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(), 433 launchOp.getKernelOperandValues().end()); 434 SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(), 435 launchOp.getKernelArguments().end()); 436 bool found = false; 437 for (unsigned i = operands.size(); i > 0; --i) { 438 unsigned index = i - 1; 439 Value *operand = operands[index]; 440 if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { 441 continue; 442 } 443 444 found = true; 445 Value *internalConstant = 446 rewriter.clone(*operand->getDefiningOp())->getResult(0); 447 Value *kernelArg = kernelArgs[index]; 448 kernelArg->replaceAllUsesWith(internalConstant); 449 launchOp.eraseKernelArgument(index); 450 } 451 rewriter.restoreInsertionPoint(origInsertionPoint); 452 453 if (!found) 454 return matchFailure(); 455 456 rewriter.updatedRootInPlace(launchOp); 457 return matchSuccess(); 458 } 459 }; 460 } // end namespace 461 462 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 463 MLIRContext *context) { 464 results.insert<PropagateConstantBounds>(context); 465 } 466 467 //===----------------------------------------------------------------------===// 468 // LaunchFuncOp 469 //===----------------------------------------------------------------------===// 470 471 void LaunchFuncOp::build(Builder *builder, OperationState &result, 472 FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, 473 Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, 474 Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { 475 // Add grid and block sizes as op operands, followed by the data operands. 476 result.addOperands( 477 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 478 result.addOperands(kernelOperands); 479 result.addAttribute(getKernelAttrName(), 480 builder->getStringAttr(kernelFunc.getName())); 481 auto kernelModule = kernelFunc.getParentOfType<ModuleOp>(); 482 if (Optional<StringRef> kernelModuleName = kernelModule.getName()) 483 result.addAttribute(getKernelModuleAttrName(), 484 builder->getSymbolRefAttr(*kernelModuleName)); 485 } 486 487 void LaunchFuncOp::build(Builder *builder, OperationState &result, 488 FuncOp kernelFunc, KernelDim3 gridSize, 489 KernelDim3 blockSize, 490 ArrayRef<Value *> kernelOperands) { 491 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 492 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 493 } 494 495 StringRef LaunchFuncOp::kernel() { 496 return getAttrOfType<StringAttr>(getKernelAttrName()).getValue(); 497 } 498 499 unsigned LaunchFuncOp::getNumKernelOperands() { 500 return getNumOperands() - kNumConfigOperands; 501 } 502 503 StringRef LaunchFuncOp::getKernelModuleName() { 504 return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue(); 505 } 506 507 Value *LaunchFuncOp::getKernelOperand(unsigned i) { 508 return getOperation()->getOperand(i + kNumConfigOperands); 509 } 510 511 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 512 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 513 } 514 515 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 516 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 517 } 518 519 LogicalResult LaunchFuncOp::verify() { 520 auto module = getParentOfType<ModuleOp>(); 521 if (!module) 522 return emitOpError("expected to belong to a module"); 523 524 if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName())) 525 return emitOpError("expected the closest surrounding module to have the '" + 526 GPUDialect::getContainerModuleAttrName() + 527 "' attribute"); 528 529 auto kernelAttr = getAttrOfType<StringAttr>(getKernelAttrName()); 530 if (!kernelAttr) 531 return emitOpError("string attribute '" + getKernelAttrName() + 532 "' must be specified"); 533 534 auto kernelModuleAttr = 535 getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()); 536 if (!kernelModuleAttr) 537 return emitOpError("symbol reference attribute '" + 538 getKernelModuleAttrName() + "' must be specified"); 539 540 return success(); 541 } 542