1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the GPU kernel-related dialect and its operations. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/GPU/GPUDialect.h" 23 #include "mlir/Dialect/StandardOps/Ops.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/OpImplementation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 31 using namespace mlir; 32 using namespace mlir::gpu; 33 34 //===----------------------------------------------------------------------===// 35 // GPUDialect 36 //===----------------------------------------------------------------------===// 37 38 StringRef GPUDialect::getDialectName() { return "gpu"; } 39 40 bool GPUDialect::isKernel(FuncOp function) { 41 UnitAttr isKernelAttr = 42 function.getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 43 return static_cast<bool>(isKernelAttr); 44 } 45 46 GPUDialect::GPUDialect(MLIRContext *context) 47 : Dialect(getDialectName(), context) { 48 addOperations<LaunchOp, LaunchFuncOp, 49 #define GET_OP_LIST 50 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 51 >(); 52 } 53 54 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 55 NamedAttribute attr) { 56 if (!attr.second.isa<UnitAttr>() || 57 !attr.first.is(getContainerModuleAttrName())) 58 return success(); 59 60 auto module = dyn_cast<ModuleOp>(op); 61 if (!module) 62 return op->emitError("expected '") 63 << getContainerModuleAttrName() << "' attribute to be attached to '" 64 << ModuleOp::getOperationName() << '\''; 65 66 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 67 // Ignore launches that are nested more or less deep than functions in the 68 // module we are currently checking. 69 if (!launchOp.getParentOp() || 70 launchOp.getParentOp()->getParentOp() != module) 71 return success(); 72 73 // Ignore launch ops with missing attributes here. The errors will be 74 // reported by the verifiers of those ops. 75 if (!launchOp.getAttrOfType<StringAttr>( 76 LaunchFuncOp::getKernelAttrName()) || 77 !launchOp.getAttrOfType<SymbolRefAttr>( 78 LaunchFuncOp::getKernelModuleAttrName())) 79 return success(); 80 81 // Check that `launch_func` refers to a well-formed GPU kernel module. 82 StringRef kernelModuleName = launchOp.getKernelModuleName(); 83 auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName); 84 if (!kernelModule) 85 return launchOp.emitOpError() 86 << "kernel module '" << kernelModuleName << "' is undefined"; 87 if (!kernelModule.getAttrOfType<UnitAttr>( 88 GPUDialect::getKernelModuleAttrName())) 89 return launchOp.emitOpError("module '") 90 << kernelModuleName << "' is missing the '" 91 << GPUDialect::getKernelModuleAttrName() << "' attribute"; 92 93 // Check that `launch_func` refers to a well-formed kernel function. 94 StringRef kernelName = launchOp.kernel(); 95 auto kernelFunction = kernelModule.lookupSymbol<FuncOp>(kernelName); 96 if (!kernelFunction) 97 return launchOp.emitOpError("kernel function '") 98 << kernelName << "' is undefined"; 99 if (!kernelFunction.getAttrOfType<mlir::UnitAttr>( 100 GPUDialect::getKernelFuncAttrName())) 101 return launchOp.emitOpError("kernel function is missing the '") 102 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 103 if (launchOp.getNumKernelOperands() != kernelFunction.getNumArguments()) 104 return launchOp.emitOpError("got ") << launchOp.getNumKernelOperands() 105 << " kernel operands but expected " 106 << kernelFunction.getNumArguments(); 107 108 // Due to the ordering of the current impl of lowering and LLVMLowering, 109 // type checks need to be temporarily disabled. 110 // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc 111 // to encode target module" has landed. 112 // auto functionType = kernelFunc.getType(); 113 // for (unsigned i = 0; i < numKernelFuncArgs; ++i) { 114 // if (getKernelOperand(i)->getType() != functionType.getInput(i)) { 115 // return emitOpError("type of function argument ") 116 // << i << " does not match"; 117 // } 118 // } 119 120 return success(); 121 }); 122 123 return walkResult.wasInterrupted() ? failure() : success(); 124 } 125 126 template <typename T> static LogicalResult verifyIndexOp(T op) { 127 auto dimension = op.dimension(); 128 if (dimension != "x" && dimension != "y" && dimension != "z") 129 return op.emitError("dimension \"") << dimension << "\" is invalid"; 130 return success(); 131 } 132 133 #define GET_OP_CLASSES 134 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 135 136 //===----------------------------------------------------------------------===// 137 // LaunchOp 138 //===----------------------------------------------------------------------===// 139 140 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) { 141 SmallVector<Type, 4> types; 142 types.reserve(values.size()); 143 for (Value *v : values) 144 types.push_back(v->getType()); 145 return types; 146 } 147 148 void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX, 149 Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, 150 Value *blockSizeY, Value *blockSizeZ, 151 ArrayRef<Value *> operands) { 152 // Add grid and block sizes as op operands, followed by the data operands. 153 result.addOperands( 154 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 155 result.addOperands(operands); 156 157 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 158 // where the first kNumConfigRegionAttributes arguments have `index` type and 159 // the rest have the same types as the data operands. 160 Region *kernelRegion = result.addRegion(); 161 Block *body = new Block(); 162 body->addArguments( 163 std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); 164 body->addArguments(getValueTypes(operands)); 165 kernelRegion->push_back(body); 166 } 167 168 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); } 169 170 KernelDim3 LaunchOp::getBlockIds() { 171 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 172 auto args = getBody().getBlocks().front().getArguments(); 173 return KernelDim3{args[0], args[1], args[2]}; 174 } 175 176 KernelDim3 LaunchOp::getThreadIds() { 177 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 178 auto args = getBody().getBlocks().front().getArguments(); 179 return KernelDim3{args[3], args[4], args[5]}; 180 } 181 182 KernelDim3 LaunchOp::getGridSize() { 183 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 184 auto args = getBody().getBlocks().front().getArguments(); 185 return KernelDim3{args[6], args[7], args[8]}; 186 } 187 188 KernelDim3 LaunchOp::getBlockSize() { 189 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 190 auto args = getBody().getBlocks().front().getArguments(); 191 return KernelDim3{args[9], args[10], args[11]}; 192 } 193 194 LaunchOp::operand_range LaunchOp::getKernelOperandValues() { 195 return llvm::drop_begin(getOperands(), kNumConfigOperands); 196 } 197 198 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { 199 return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); 200 } 201 202 KernelDim3 LaunchOp::getGridSizeOperandValues() { 203 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 204 } 205 206 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 207 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 208 } 209 210 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { 211 auto args = getBody().getBlocks().front().getArguments(); 212 return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); 213 } 214 215 LogicalResult LaunchOp::verify() { 216 // Kernel launch takes kNumConfigOperands leading operands for grid/block 217 // sizes and transforms them into kNumConfigRegionAttributes region arguments 218 // for block/thread identifiers and grid/block sizes. 219 if (!getBody().empty()) { 220 Block &entryBlock = getBody().front(); 221 if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands()) 222 return emitOpError("unexpected number of region arguments"); 223 } 224 225 // Block terminators without successors are expected to exit the kernel region 226 // and must be `gpu.launch`. 227 for (Block &block : getBody()) { 228 if (block.empty()) 229 continue; 230 if (block.back().getNumSuccessors() != 0) 231 continue; 232 if (!isa<gpu::Return>(&block.back())) { 233 return block.back() 234 .emitError("expected 'gpu.terminator' or a terminator with " 235 "successors") 236 .attachNote(getLoc()) 237 << "in '" << getOperationName() << "' body region"; 238 } 239 } 240 241 return success(); 242 } 243 244 // Pretty-print the kernel grid/block size assignment as 245 // (%iter-x, %iter-y, %iter-z) in 246 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 247 // where %size-* and %iter-* will correspond to the body region arguments. 248 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 249 ArrayRef<Value *> operands, KernelDim3 ids) { 250 p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in ("; 251 p << *size.x << " = " << *operands[0] << ", "; 252 p << *size.y << " = " << *operands[1] << ", "; 253 p << *size.z << " = " << *operands[2] << ')'; 254 } 255 256 void LaunchOp::print(OpAsmPrinter &p) { 257 SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end()); 258 ArrayRef<Value *> operands(operandContainer); 259 260 // Print the launch configuration. 261 p << getOperationName() << ' ' << getBlocksKeyword(); 262 printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds()); 263 p << ' ' << getThreadsKeyword(); 264 printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds()); 265 266 // From now on, the first kNumConfigOperands operands corresponding to grid 267 // and block sizes are irrelevant, so we can drop them. 268 operands = operands.drop_front(kNumConfigOperands); 269 270 // Print the data argument remapping. 271 if (!getBody().empty() && !operands.empty()) { 272 p << ' ' << getArgsKeyword() << '('; 273 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 274 if (i != 0) 275 p << ", "; 276 p << *getBody().front().getArgument(kNumConfigRegionAttributes + i) 277 << " = " << *operands[i]; 278 } 279 p << ") "; 280 } 281 282 // Print the types of data arguments. 283 if (!operands.empty()) { 284 p << ": "; 285 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 286 if (i != 0) 287 p << ", "; 288 p << operands[i]->getType(); 289 } 290 } 291 292 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 293 p.printOptionalAttrDict(getAttrs()); 294 } 295 296 // Parse the size assignment blocks for blocks and threads. These have the form 297 // (%region_arg, %region_arg, %region_arg) in 298 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 299 // where %region_arg are percent-identifiers for the region arguments to be 300 // introduced further (SSA defs), and %operand are percent-identifiers for the 301 // SSA value uses. 302 static ParseResult 303 parseSizeAssignment(OpAsmParser &parser, 304 MutableArrayRef<OpAsmParser::OperandType> sizes, 305 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 306 MutableArrayRef<OpAsmParser::OperandType> indices) { 307 assert(indices.size() == 3 && "space for three indices expected"); 308 SmallVector<OpAsmParser::OperandType, 3> args; 309 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 310 OpAsmParser::Delimiter::Paren) || 311 parser.parseKeyword("in") || parser.parseLParen()) 312 return failure(); 313 std::move(args.begin(), args.end(), indices.begin()); 314 315 for (int i = 0; i < 3; ++i) { 316 if (i != 0 && parser.parseComma()) 317 return failure(); 318 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 319 parser.parseOperand(sizes[i])) 320 return failure(); 321 } 322 323 return parser.parseRParen(); 324 } 325 326 // Parses a Launch operation. 327 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 328 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 329 // (`args` ssa-reassignment `:` type-list)? 330 // region attr-dict? 331 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 332 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { 333 // Sizes of the grid and block. 334 SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes( 335 kNumConfigOperands); 336 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 337 338 // Actual (data) operands passed to the kernel. 339 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 340 341 // Region arguments to be created. 342 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 343 kNumConfigRegionAttributes); 344 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 345 346 // Parse the size assignment segments: the first segment assigns grid sizes 347 // and defines values for block identifiers; the second segment assigns block 348 // sizes and defines values for thread identifiers. In the region argument 349 // list, identifiers precede sizes, and block-related values precede 350 // thread-related values. 351 if (parser.parseKeyword(getBlocksKeyword().data()) || 352 parseSizeAssignment(parser, sizesRef.take_front(3), 353 regionArgsRef.slice(6, 3), 354 regionArgsRef.slice(0, 3)) || 355 parser.parseKeyword(getThreadsKeyword().data()) || 356 parseSizeAssignment(parser, sizesRef.drop_front(3), 357 regionArgsRef.slice(9, 3), 358 regionArgsRef.slice(3, 3)) || 359 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 360 result.operands)) 361 return failure(); 362 363 // If kernel argument renaming segment is present, parse it. When present, 364 // the segment should have at least one element. If this segment is present, 365 // so is the trailing type list. Parse it as well and use the parsed types 366 // to resolve the operands passed to the kernel arguments. 367 SmallVector<Type, 4> dataTypes; 368 if (!parser.parseOptionalKeyword(getArgsKeyword())) { 369 llvm::SMLoc argsLoc = parser.getCurrentLocation(); 370 371 regionArgs.push_back({}); 372 dataOperands.push_back({}); 373 if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) || 374 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 375 return failure(); 376 377 while (!parser.parseOptionalComma()) { 378 regionArgs.push_back({}); 379 dataOperands.push_back({}); 380 if (parser.parseRegionArgument(regionArgs.back()) || 381 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 382 return failure(); 383 } 384 385 if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) || 386 parser.resolveOperands(dataOperands, dataTypes, argsLoc, 387 result.operands)) 388 return failure(); 389 } 390 391 // Introduce the body region and parse it. The region has 392 // kNumConfigRegionAttributes leading arguments that correspond to 393 // block/thread identifiers and grid/block sizes, all of the `index` type. 394 // Follow the actual kernel arguments. 395 Type index = parser.getBuilder().getIndexType(); 396 dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); 397 Region *body = result.addRegion(); 398 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 399 parser.parseOptionalAttributeDict(result.attributes)); 400 } 401 402 void LaunchOp::eraseKernelArgument(unsigned index) { 403 Block &entryBlock = getBody().front(); 404 assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes && 405 "kernel argument index overflow"); 406 entryBlock.eraseArgument(kNumConfigRegionAttributes + index); 407 getOperation()->eraseOperand(kNumConfigOperands + index); 408 } 409 410 namespace { 411 // Clone any known constants passed as operands to the kernel into its body. 412 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { 413 using OpRewritePattern<LaunchOp>::OpRewritePattern; 414 415 PatternMatchResult matchAndRewrite(LaunchOp launchOp, 416 PatternRewriter &rewriter) const override { 417 auto origInsertionPoint = rewriter.saveInsertionPoint(); 418 rewriter.setInsertionPointToStart(&launchOp.getBody().front()); 419 420 // Traverse operands passed to kernel and check if some of them are known 421 // constants. If so, clone the constant operation inside the kernel region 422 // and use it instead of passing the value from the parent region. Perform 423 // the traversal in the inverse order to simplify index arithmetics when 424 // dropping arguments. 425 SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(), 426 launchOp.getKernelOperandValues().end()); 427 SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(), 428 launchOp.getKernelArguments().end()); 429 bool found = false; 430 for (unsigned i = operands.size(); i > 0; --i) { 431 unsigned index = i - 1; 432 Value *operand = operands[index]; 433 if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { 434 continue; 435 } 436 437 found = true; 438 Value *internalConstant = 439 rewriter.clone(*operand->getDefiningOp())->getResult(0); 440 Value *kernelArg = kernelArgs[index]; 441 kernelArg->replaceAllUsesWith(internalConstant); 442 launchOp.eraseKernelArgument(index); 443 } 444 rewriter.restoreInsertionPoint(origInsertionPoint); 445 446 if (!found) 447 return matchFailure(); 448 449 rewriter.updatedRootInPlace(launchOp); 450 return matchSuccess(); 451 } 452 }; 453 } // end namespace 454 455 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 456 MLIRContext *context) { 457 results.insert<PropagateConstantBounds>(context); 458 } 459 460 //===----------------------------------------------------------------------===// 461 // LaunchFuncOp 462 //===----------------------------------------------------------------------===// 463 464 void LaunchFuncOp::build(Builder *builder, OperationState &result, 465 FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, 466 Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, 467 Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { 468 // Add grid and block sizes as op operands, followed by the data operands. 469 result.addOperands( 470 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 471 result.addOperands(kernelOperands); 472 result.addAttribute(getKernelAttrName(), 473 builder->getStringAttr(kernelFunc.getName())); 474 auto kernelModule = kernelFunc.getParentOfType<ModuleOp>(); 475 if (Optional<StringRef> kernelModuleName = kernelModule.getName()) 476 result.addAttribute(getKernelModuleAttrName(), 477 builder->getSymbolRefAttr(*kernelModuleName)); 478 } 479 480 void LaunchFuncOp::build(Builder *builder, OperationState &result, 481 FuncOp kernelFunc, KernelDim3 gridSize, 482 KernelDim3 blockSize, 483 ArrayRef<Value *> kernelOperands) { 484 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 485 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 486 } 487 488 StringRef LaunchFuncOp::kernel() { 489 return getAttrOfType<StringAttr>(getKernelAttrName()).getValue(); 490 } 491 492 unsigned LaunchFuncOp::getNumKernelOperands() { 493 return getNumOperands() - kNumConfigOperands; 494 } 495 496 StringRef LaunchFuncOp::getKernelModuleName() { 497 return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue(); 498 } 499 500 Value *LaunchFuncOp::getKernelOperand(unsigned i) { 501 return getOperation()->getOperand(i + kNumConfigOperands); 502 } 503 504 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 505 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 506 } 507 508 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 509 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 510 } 511 512 LogicalResult LaunchFuncOp::verify() { 513 auto module = getParentOfType<ModuleOp>(); 514 if (!module) 515 return emitOpError("expected to belong to a module"); 516 517 if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName())) 518 return emitOpError("expected the closest surrounding module to have the '" + 519 GPUDialect::getContainerModuleAttrName() + 520 "' attribute"); 521 522 auto kernelAttr = getAttrOfType<StringAttr>(getKernelAttrName()); 523 if (!kernelAttr) 524 return emitOpError("string attribute '" + getKernelAttrName() + 525 "' must be specified"); 526 527 auto kernelModuleAttr = 528 getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()); 529 if (!kernelModuleAttr) 530 return emitOpError("symbol reference attribute '" + 531 getKernelModuleAttrName() + "' must be specified"); 532 533 return success(); 534 } 535