1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the GPU kernel-related dialect and its operations. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/GPU/GPUDialect.h" 23 #include "mlir/Dialect/StandardOps/Ops.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/OpImplementation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 31 using namespace mlir; 32 using namespace mlir::gpu; 33 34 StringRef GPUDialect::getDialectName() { return "gpu"; } 35 36 bool GPUDialect::isKernel(FuncOp function) { 37 UnitAttr isKernelAttr = 38 function.getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 39 return static_cast<bool>(isKernelAttr); 40 } 41 42 GPUDialect::GPUDialect(MLIRContext *context) 43 : Dialect(getDialectName(), context) { 44 addOperations<LaunchOp, LaunchFuncOp, 45 #define GET_OP_LIST 46 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 47 >(); 48 } 49 50 #define GET_OP_CLASSES 51 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 52 53 //===----------------------------------------------------------------------===// 54 // LaunchOp 55 //===----------------------------------------------------------------------===// 56 57 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) { 58 SmallVector<Type, 4> types; 59 types.reserve(values.size()); 60 for (Value *v : values) 61 types.push_back(v->getType()); 62 return types; 63 } 64 65 void LaunchOp::build(Builder *builder, OperationState *result, Value *gridSizeX, 66 Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, 67 Value *blockSizeY, Value *blockSizeZ, 68 ArrayRef<Value *> operands) { 69 // Add grid and block sizes as op operands, followed by the data operands. 70 result->addOperands( 71 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 72 result->addOperands(operands); 73 74 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 75 // where the first kNumConfigRegionAttributes arguments have `index` type and 76 // the rest have the same types as the data operands. 77 Region *kernelRegion = result->addRegion(); 78 Block *body = new Block(); 79 body->addArguments( 80 std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); 81 body->addArguments(getValueTypes(operands)); 82 kernelRegion->push_back(body); 83 } 84 85 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); } 86 87 KernelDim3 LaunchOp::getBlockIds() { 88 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 89 auto args = getBody().getBlocks().front().getArguments(); 90 return KernelDim3{args[0], args[1], args[2]}; 91 } 92 93 KernelDim3 LaunchOp::getThreadIds() { 94 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 95 auto args = getBody().getBlocks().front().getArguments(); 96 return KernelDim3{args[3], args[4], args[5]}; 97 } 98 99 KernelDim3 LaunchOp::getGridSize() { 100 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 101 auto args = getBody().getBlocks().front().getArguments(); 102 return KernelDim3{args[6], args[7], args[8]}; 103 } 104 105 KernelDim3 LaunchOp::getBlockSize() { 106 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 107 auto args = getBody().getBlocks().front().getArguments(); 108 return KernelDim3{args[9], args[10], args[11]}; 109 } 110 111 LaunchOp::operand_range LaunchOp::getKernelOperandValues() { 112 return llvm::drop_begin(getOperands(), kNumConfigOperands); 113 } 114 115 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { 116 return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); 117 } 118 119 KernelDim3 LaunchOp::getGridSizeOperandValues() { 120 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 121 } 122 123 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 124 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 125 } 126 127 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { 128 auto args = getBody().getBlocks().front().getArguments(); 129 return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); 130 } 131 132 LogicalResult LaunchOp::verify() { 133 // Kernel launch takes kNumConfigOperands leading operands for grid/block 134 // sizes and transforms them into kNumConfigRegionAttributes region arguments 135 // for block/thread identifiers and grid/block sizes. 136 if (!getBody().empty()) { 137 Block &entryBlock = getBody().front(); 138 if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands()) 139 return emitError("unexpected number of region arguments"); 140 } 141 142 // Block terminators without successors are expected to exit the kernel region 143 // and must be `gpu.launch`. 144 for (Block &block : getBody()) { 145 if (block.empty()) 146 continue; 147 if (block.back().getNumSuccessors() != 0) 148 continue; 149 if (!isa<gpu::Return>(&block.back())) { 150 return block.back() 151 .emitError("expected 'gpu.terminator' or a terminator with " 152 "successors") 153 .attachNote(getLoc()) 154 << "in '" << getOperationName() << "' body region"; 155 } 156 } 157 158 return success(); 159 } 160 161 // Pretty-print the kernel grid/block size assignment as 162 // (%iter-x, %iter-y, %iter-z) in 163 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 164 // where %size-* and %iter-* will correspond to the body region arguments. 165 static void printSizeAssignment(OpAsmPrinter *p, KernelDim3 size, 166 ArrayRef<Value *> operands, KernelDim3 ids) { 167 *p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in ("; 168 *p << *size.x << " = " << *operands[0] << ", "; 169 *p << *size.y << " = " << *operands[1] << ", "; 170 *p << *size.z << " = " << *operands[2] << ')'; 171 } 172 173 void LaunchOp::print(OpAsmPrinter *p) { 174 SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end()); 175 ArrayRef<Value *> operands(operandContainer); 176 177 // Print the launch configuration. 178 *p << getOperationName() << ' ' << getBlocksKeyword(); 179 printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds()); 180 *p << ' ' << getThreadsKeyword(); 181 printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds()); 182 183 // From now on, the first kNumConfigOperands operands corresponding to grid 184 // and block sizes are irrelevant, so we can drop them. 185 operands = operands.drop_front(kNumConfigOperands); 186 187 // Print the data argument remapping. 188 if (!getBody().empty() && !operands.empty()) { 189 *p << ' ' << getArgsKeyword() << '('; 190 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 191 if (i != 0) 192 *p << ", "; 193 *p << *getBody().front().getArgument(kNumConfigRegionAttributes + i) 194 << " = " << *operands[i]; 195 } 196 *p << ") "; 197 } 198 199 // Print the types of data arguments. 200 if (!operands.empty()) { 201 *p << ": "; 202 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 203 if (i != 0) 204 *p << ", "; 205 *p << operands[i]->getType(); 206 } 207 } 208 209 p->printRegion(getBody(), /*printEntryBlockArgs=*/false); 210 p->printOptionalAttrDict(getAttrs()); 211 } 212 213 // Parse the size assignment blocks for blocks and threads. These have the form 214 // (%region_arg, %region_arg, %region_arg) in 215 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 216 // where %region_arg are percent-identifiers for the region arguments to be 217 // introduced futher (SSA defs), and %operand are percent-identifiers for the 218 // SSA value uses. 219 static ParseResult 220 parseSizeAssignment(OpAsmParser *parser, 221 MutableArrayRef<OpAsmParser::OperandType> sizes, 222 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 223 MutableArrayRef<OpAsmParser::OperandType> indices) { 224 assert(indices.size() == 3 && "space for three indices expected"); 225 SmallVector<OpAsmParser::OperandType, 3> args; 226 if (parser->parseRegionArgumentList(args, /*requiredOperandCount=*/3, 227 OpAsmParser::Delimiter::Paren) || 228 parser->parseKeyword("in") || parser->parseLParen()) 229 return failure(); 230 std::move(args.begin(), args.end(), indices.begin()); 231 232 for (int i = 0; i < 3; ++i) { 233 if (i != 0 && parser->parseComma()) 234 return failure(); 235 if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() || 236 parser->parseOperand(sizes[i])) 237 return failure(); 238 } 239 240 return parser->parseRParen(); 241 } 242 243 // Parses a Launch operation. 244 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 245 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 246 // (`args` ssa-reassignment `:` type-list)? 247 // region attr-dict? 248 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 249 ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) { 250 // Sizes of the grid and block. 251 SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes( 252 kNumConfigOperands); 253 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 254 255 // Actual (data) operands passed to the kernel. 256 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 257 258 // Region arguments to be created. 259 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 260 kNumConfigRegionAttributes); 261 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 262 263 // Parse the size assignment segments: the first segment assigns grid siezs 264 // and defines values for block identifiers; the second segment assigns block 265 // sies and defines values for thread identifiers. In the region argument 266 // list, identifiers preceed sizes, and block-related values preceed 267 // thread-related values. 268 if (parser->parseKeyword(getBlocksKeyword().data()) || 269 parseSizeAssignment(parser, sizesRef.take_front(3), 270 regionArgsRef.slice(6, 3), 271 regionArgsRef.slice(0, 3)) || 272 parser->parseKeyword(getThreadsKeyword().data()) || 273 parseSizeAssignment(parser, sizesRef.drop_front(3), 274 regionArgsRef.slice(9, 3), 275 regionArgsRef.slice(3, 3)) || 276 parser->resolveOperands(sizes, parser->getBuilder().getIndexType(), 277 result->operands)) 278 return failure(); 279 280 // If kernel argument renaming segment is present, parse it. When present, 281 // the segment should have at least one element. If this segment is present, 282 // so is the trailing type list. Parse it as well and use the parsed types 283 // to resolve the operands passed to the kernel arguments. 284 SmallVector<Type, 4> dataTypes; 285 if (!parser->parseOptionalKeyword(getArgsKeyword().data())) { 286 llvm::SMLoc argsLoc = parser->getCurrentLocation(); 287 288 regionArgs.push_back({}); 289 dataOperands.push_back({}); 290 if (parser->parseLParen() || 291 parser->parseRegionArgument(regionArgs.back()) || 292 parser->parseEqual() || parser->parseOperand(dataOperands.back())) 293 return failure(); 294 295 while (!parser->parseOptionalComma()) { 296 regionArgs.push_back({}); 297 dataOperands.push_back({}); 298 if (parser->parseRegionArgument(regionArgs.back()) || 299 parser->parseEqual() || parser->parseOperand(dataOperands.back())) 300 return failure(); 301 } 302 303 if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) || 304 parser->resolveOperands(dataOperands, dataTypes, argsLoc, 305 result->operands)) 306 return failure(); 307 } 308 309 // Introduce the body region and parse it. The region has 310 // kNumConfigRegionAttributes leading arguments that correspond to 311 // block/thread identifiers and grid/block sizes, all of the `index` type. 312 // Follow the actual kernel arguments. 313 Type index = parser->getBuilder().getIndexType(); 314 dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); 315 Region *body = result->addRegion(); 316 return failure(parser->parseRegion(*body, regionArgs, dataTypes) || 317 parser->parseOptionalAttributeDict(result->attributes)); 318 } 319 320 void LaunchOp::eraseKernelArgument(unsigned index) { 321 Block &entryBlock = getBody().front(); 322 assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes && 323 "kernel argument index overflow"); 324 entryBlock.eraseArgument(kNumConfigRegionAttributes + index); 325 getOperation()->eraseOperand(kNumConfigOperands + index); 326 } 327 328 namespace { 329 // Clone any known constants passed as operands to the kernel into its body. 330 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { 331 using OpRewritePattern<LaunchOp>::OpRewritePattern; 332 333 PatternMatchResult matchAndRewrite(LaunchOp launchOp, 334 PatternRewriter &rewriter) const override { 335 auto oringInsertionPoint = rewriter.saveInsertionPoint(); 336 rewriter.setInsertionPointToStart(&launchOp.getBody().front()); 337 338 // Traverse operands passed to kernel and check if some of them are known 339 // constants. If so, clone the constant operation inside the kernel region 340 // and use it instead of passing the value from the parent region. Perform 341 // the traversal in the inverse order to simplify index arithmetics when 342 // dropping arguments. 343 SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(), 344 launchOp.getKernelOperandValues().end()); 345 SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(), 346 launchOp.getKernelArguments().end()); 347 bool found = false; 348 for (unsigned i = operands.size(); i > 0; --i) { 349 unsigned index = i - 1; 350 Value *operand = operands[index]; 351 if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { 352 continue; 353 } 354 355 found = true; 356 Value *internalConstant = 357 rewriter.clone(*operand->getDefiningOp())->getResult(0); 358 Value *kernelArg = kernelArgs[index]; 359 kernelArg->replaceAllUsesWith(internalConstant); 360 launchOp.eraseKernelArgument(index); 361 } 362 rewriter.restoreInsertionPoint(oringInsertionPoint); 363 364 if (!found) 365 return matchFailure(); 366 367 rewriter.updatedRootInPlace(launchOp); 368 return matchSuccess(); 369 } 370 }; 371 } // end namespace 372 373 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 374 MLIRContext *context) { 375 results.insert<PropagateConstantBounds>(context); 376 } 377 378 //===----------------------------------------------------------------------===// 379 // LaunchFuncOp 380 //===----------------------------------------------------------------------===// 381 382 void LaunchFuncOp::build(Builder *builder, OperationState *result, 383 FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, 384 Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, 385 Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { 386 // Add grid and block sizes as op operands, followed by the data operands. 387 result->addOperands( 388 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 389 result->addOperands(kernelOperands); 390 result->addAttribute(getKernelAttrName(), 391 builder->getSymbolRefAttr(kernelFunc)); 392 } 393 394 void LaunchFuncOp::build(Builder *builder, OperationState *result, 395 FuncOp kernelFunc, KernelDim3 gridSize, 396 KernelDim3 blockSize, 397 ArrayRef<Value *> kernelOperands) { 398 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 399 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 400 } 401 402 StringRef LaunchFuncOp::kernel() { 403 return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue(); 404 } 405 406 unsigned LaunchFuncOp::getNumKernelOperands() { 407 return getNumOperands() - kNumConfigOperands; 408 } 409 410 Value *LaunchFuncOp::getKernelOperand(unsigned i) { 411 return getOperation()->getOperand(i + kNumConfigOperands); 412 } 413 414 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 415 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 416 } 417 418 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 419 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 420 } 421 422 LogicalResult LaunchFuncOp::verify() { 423 auto kernelAttr = this->getAttr(getKernelAttrName()); 424 if (!kernelAttr) { 425 return emitOpError("attribute 'kernel' must be specified"); 426 } else if (!kernelAttr.isa<SymbolRefAttr>()) { 427 return emitOpError("attribute 'kernel' must be a function"); 428 } 429 430 auto module = getParentOfType<ModuleOp>(); 431 FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel()); 432 if (!kernelFunc) 433 return emitError() << "kernel function '" << kernelAttr << "' is undefined"; 434 435 if (!kernelFunc.getAttrOfType<mlir::UnitAttr>( 436 GPUDialect::getKernelFuncAttrName())) { 437 return emitError("kernel function is missing the '") 438 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 439 } 440 unsigned numKernelFuncArgs = kernelFunc.getNumArguments(); 441 if (getNumKernelOperands() != numKernelFuncArgs) { 442 return emitOpError("got ") 443 << getNumKernelOperands() << " kernel operands but expected " 444 << numKernelFuncArgs; 445 } 446 auto functionType = kernelFunc.getType(); 447 for (unsigned i = 0; i < numKernelFuncArgs; ++i) { 448 if (getKernelOperand(i)->getType() != functionType.getInput(i)) { 449 return emitOpError("type of function argument ") 450 << i << " does not match"; 451 } 452 } 453 return success(); 454 } 455