1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the GPU kernel-related dialect and its operations. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/GPU/GPUDialect.h" 23 #include "mlir/Dialect/StandardOps/Ops.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/OpImplementation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 31 using namespace mlir; 32 using namespace mlir::gpu; 33 34 StringRef GPUDialect::getDialectName() { return "gpu"; } 35 36 bool GPUDialect::isKernel(FuncOp function) { 37 UnitAttr isKernelAttr = 38 function.getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 39 return static_cast<bool>(isKernelAttr); 40 } 41 42 GPUDialect::GPUDialect(MLIRContext *context) 43 : Dialect(getDialectName(), context) { 44 addOperations<LaunchOp, LaunchFuncOp, 45 #define GET_OP_LIST 46 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 47 >(); 48 } 49 50 template <typename T> static LogicalResult verifyIndexOp(T op) { 51 auto dimension = op.dimension(); 52 if (dimension != "x" && dimension != "y" && dimension != "z") 53 return op.emitError("dimension \"") << dimension << "\" is invalid"; 54 return success(); 55 } 56 57 #define GET_OP_CLASSES 58 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 59 60 //===----------------------------------------------------------------------===// 61 // LaunchOp 62 //===----------------------------------------------------------------------===// 63 64 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) { 65 SmallVector<Type, 4> types; 66 types.reserve(values.size()); 67 for (Value *v : values) 68 types.push_back(v->getType()); 69 return types; 70 } 71 72 void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX, 73 Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, 74 Value *blockSizeY, Value *blockSizeZ, 75 ArrayRef<Value *> operands) { 76 // Add grid and block sizes as op operands, followed by the data operands. 77 result.addOperands( 78 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 79 result.addOperands(operands); 80 81 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 82 // where the first kNumConfigRegionAttributes arguments have `index` type and 83 // the rest have the same types as the data operands. 84 Region *kernelRegion = result.addRegion(); 85 Block *body = new Block(); 86 body->addArguments( 87 std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); 88 body->addArguments(getValueTypes(operands)); 89 kernelRegion->push_back(body); 90 } 91 92 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); } 93 94 KernelDim3 LaunchOp::getBlockIds() { 95 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 96 auto args = getBody().getBlocks().front().getArguments(); 97 return KernelDim3{args[0], args[1], args[2]}; 98 } 99 100 KernelDim3 LaunchOp::getThreadIds() { 101 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 102 auto args = getBody().getBlocks().front().getArguments(); 103 return KernelDim3{args[3], args[4], args[5]}; 104 } 105 106 KernelDim3 LaunchOp::getGridSize() { 107 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 108 auto args = getBody().getBlocks().front().getArguments(); 109 return KernelDim3{args[6], args[7], args[8]}; 110 } 111 112 KernelDim3 LaunchOp::getBlockSize() { 113 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 114 auto args = getBody().getBlocks().front().getArguments(); 115 return KernelDim3{args[9], args[10], args[11]}; 116 } 117 118 LaunchOp::operand_range LaunchOp::getKernelOperandValues() { 119 return llvm::drop_begin(getOperands(), kNumConfigOperands); 120 } 121 122 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { 123 return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); 124 } 125 126 KernelDim3 LaunchOp::getGridSizeOperandValues() { 127 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 128 } 129 130 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 131 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 132 } 133 134 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { 135 auto args = getBody().getBlocks().front().getArguments(); 136 return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); 137 } 138 139 LogicalResult LaunchOp::verify() { 140 // Kernel launch takes kNumConfigOperands leading operands for grid/block 141 // sizes and transforms them into kNumConfigRegionAttributes region arguments 142 // for block/thread identifiers and grid/block sizes. 143 if (!getBody().empty()) { 144 Block &entryBlock = getBody().front(); 145 if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands()) 146 return emitOpError("unexpected number of region arguments"); 147 } 148 149 // Block terminators without successors are expected to exit the kernel region 150 // and must be `gpu.launch`. 151 for (Block &block : getBody()) { 152 if (block.empty()) 153 continue; 154 if (block.back().getNumSuccessors() != 0) 155 continue; 156 if (!isa<gpu::Return>(&block.back())) { 157 return block.back() 158 .emitError("expected 'gpu.terminator' or a terminator with " 159 "successors") 160 .attachNote(getLoc()) 161 << "in '" << getOperationName() << "' body region"; 162 } 163 } 164 165 return success(); 166 } 167 168 // Pretty-print the kernel grid/block size assignment as 169 // (%iter-x, %iter-y, %iter-z) in 170 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 171 // where %size-* and %iter-* will correspond to the body region arguments. 172 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 173 ArrayRef<Value *> operands, KernelDim3 ids) { 174 p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in ("; 175 p << *size.x << " = " << *operands[0] << ", "; 176 p << *size.y << " = " << *operands[1] << ", "; 177 p << *size.z << " = " << *operands[2] << ')'; 178 } 179 180 void LaunchOp::print(OpAsmPrinter &p) { 181 SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end()); 182 ArrayRef<Value *> operands(operandContainer); 183 184 // Print the launch configuration. 185 p << getOperationName() << ' ' << getBlocksKeyword(); 186 printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds()); 187 p << ' ' << getThreadsKeyword(); 188 printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds()); 189 190 // From now on, the first kNumConfigOperands operands corresponding to grid 191 // and block sizes are irrelevant, so we can drop them. 192 operands = operands.drop_front(kNumConfigOperands); 193 194 // Print the data argument remapping. 195 if (!getBody().empty() && !operands.empty()) { 196 p << ' ' << getArgsKeyword() << '('; 197 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 198 if (i != 0) 199 p << ", "; 200 p << *getBody().front().getArgument(kNumConfigRegionAttributes + i) 201 << " = " << *operands[i]; 202 } 203 p << ") "; 204 } 205 206 // Print the types of data arguments. 207 if (!operands.empty()) { 208 p << ": "; 209 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 210 if (i != 0) 211 p << ", "; 212 p << operands[i]->getType(); 213 } 214 } 215 216 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 217 p.printOptionalAttrDict(getAttrs()); 218 } 219 220 // Parse the size assignment blocks for blocks and threads. These have the form 221 // (%region_arg, %region_arg, %region_arg) in 222 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 223 // where %region_arg are percent-identifiers for the region arguments to be 224 // introduced futher (SSA defs), and %operand are percent-identifiers for the 225 // SSA value uses. 226 static ParseResult 227 parseSizeAssignment(OpAsmParser &parser, 228 MutableArrayRef<OpAsmParser::OperandType> sizes, 229 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 230 MutableArrayRef<OpAsmParser::OperandType> indices) { 231 assert(indices.size() == 3 && "space for three indices expected"); 232 SmallVector<OpAsmParser::OperandType, 3> args; 233 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 234 OpAsmParser::Delimiter::Paren) || 235 parser.parseKeyword("in") || parser.parseLParen()) 236 return failure(); 237 std::move(args.begin(), args.end(), indices.begin()); 238 239 for (int i = 0; i < 3; ++i) { 240 if (i != 0 && parser.parseComma()) 241 return failure(); 242 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 243 parser.parseOperand(sizes[i])) 244 return failure(); 245 } 246 247 return parser.parseRParen(); 248 } 249 250 // Parses a Launch operation. 251 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 252 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 253 // (`args` ssa-reassignment `:` type-list)? 254 // region attr-dict? 255 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 256 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { 257 // Sizes of the grid and block. 258 SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes( 259 kNumConfigOperands); 260 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 261 262 // Actual (data) operands passed to the kernel. 263 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 264 265 // Region arguments to be created. 266 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 267 kNumConfigRegionAttributes); 268 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 269 270 // Parse the size assignment segments: the first segment assigns grid siezs 271 // and defines values for block identifiers; the second segment assigns block 272 // sies and defines values for thread identifiers. In the region argument 273 // list, identifiers preceed sizes, and block-related values preceed 274 // thread-related values. 275 if (parser.parseKeyword(getBlocksKeyword().data()) || 276 parseSizeAssignment(parser, sizesRef.take_front(3), 277 regionArgsRef.slice(6, 3), 278 regionArgsRef.slice(0, 3)) || 279 parser.parseKeyword(getThreadsKeyword().data()) || 280 parseSizeAssignment(parser, sizesRef.drop_front(3), 281 regionArgsRef.slice(9, 3), 282 regionArgsRef.slice(3, 3)) || 283 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 284 result.operands)) 285 return failure(); 286 287 // If kernel argument renaming segment is present, parse it. When present, 288 // the segment should have at least one element. If this segment is present, 289 // so is the trailing type list. Parse it as well and use the parsed types 290 // to resolve the operands passed to the kernel arguments. 291 SmallVector<Type, 4> dataTypes; 292 if (!parser.parseOptionalKeyword(getArgsKeyword())) { 293 llvm::SMLoc argsLoc = parser.getCurrentLocation(); 294 295 regionArgs.push_back({}); 296 dataOperands.push_back({}); 297 if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) || 298 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 299 return failure(); 300 301 while (!parser.parseOptionalComma()) { 302 regionArgs.push_back({}); 303 dataOperands.push_back({}); 304 if (parser.parseRegionArgument(regionArgs.back()) || 305 parser.parseEqual() || parser.parseOperand(dataOperands.back())) 306 return failure(); 307 } 308 309 if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) || 310 parser.resolveOperands(dataOperands, dataTypes, argsLoc, 311 result.operands)) 312 return failure(); 313 } 314 315 // Introduce the body region and parse it. The region has 316 // kNumConfigRegionAttributes leading arguments that correspond to 317 // block/thread identifiers and grid/block sizes, all of the `index` type. 318 // Follow the actual kernel arguments. 319 Type index = parser.getBuilder().getIndexType(); 320 dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); 321 Region *body = result.addRegion(); 322 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 323 parser.parseOptionalAttributeDict(result.attributes)); 324 } 325 326 void LaunchOp::eraseKernelArgument(unsigned index) { 327 Block &entryBlock = getBody().front(); 328 assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes && 329 "kernel argument index overflow"); 330 entryBlock.eraseArgument(kNumConfigRegionAttributes + index); 331 getOperation()->eraseOperand(kNumConfigOperands + index); 332 } 333 334 namespace { 335 // Clone any known constants passed as operands to the kernel into its body. 336 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { 337 using OpRewritePattern<LaunchOp>::OpRewritePattern; 338 339 PatternMatchResult matchAndRewrite(LaunchOp launchOp, 340 PatternRewriter &rewriter) const override { 341 auto oringInsertionPoint = rewriter.saveInsertionPoint(); 342 rewriter.setInsertionPointToStart(&launchOp.getBody().front()); 343 344 // Traverse operands passed to kernel and check if some of them are known 345 // constants. If so, clone the constant operation inside the kernel region 346 // and use it instead of passing the value from the parent region. Perform 347 // the traversal in the inverse order to simplify index arithmetics when 348 // dropping arguments. 349 SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(), 350 launchOp.getKernelOperandValues().end()); 351 SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(), 352 launchOp.getKernelArguments().end()); 353 bool found = false; 354 for (unsigned i = operands.size(); i > 0; --i) { 355 unsigned index = i - 1; 356 Value *operand = operands[index]; 357 if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { 358 continue; 359 } 360 361 found = true; 362 Value *internalConstant = 363 rewriter.clone(*operand->getDefiningOp())->getResult(0); 364 Value *kernelArg = kernelArgs[index]; 365 kernelArg->replaceAllUsesWith(internalConstant); 366 launchOp.eraseKernelArgument(index); 367 } 368 rewriter.restoreInsertionPoint(oringInsertionPoint); 369 370 if (!found) 371 return matchFailure(); 372 373 rewriter.updatedRootInPlace(launchOp); 374 return matchSuccess(); 375 } 376 }; 377 } // end namespace 378 379 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 380 MLIRContext *context) { 381 results.insert<PropagateConstantBounds>(context); 382 } 383 384 //===----------------------------------------------------------------------===// 385 // LaunchFuncOp 386 //===----------------------------------------------------------------------===// 387 388 void LaunchFuncOp::build(Builder *builder, OperationState &result, 389 FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, 390 Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, 391 Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { 392 // Add grid and block sizes as op operands, followed by the data operands. 393 result.addOperands( 394 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 395 result.addOperands(kernelOperands); 396 result.addAttribute(getKernelAttrName(), 397 builder->getSymbolRefAttr(kernelFunc)); 398 } 399 400 void LaunchFuncOp::build(Builder *builder, OperationState &result, 401 FuncOp kernelFunc, KernelDim3 gridSize, 402 KernelDim3 blockSize, 403 ArrayRef<Value *> kernelOperands) { 404 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 405 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 406 } 407 408 StringRef LaunchFuncOp::kernel() { 409 return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue(); 410 } 411 412 unsigned LaunchFuncOp::getNumKernelOperands() { 413 return getNumOperands() - kNumConfigOperands; 414 } 415 416 Value *LaunchFuncOp::getKernelOperand(unsigned i) { 417 return getOperation()->getOperand(i + kNumConfigOperands); 418 } 419 420 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 421 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 422 } 423 424 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 425 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 426 } 427 428 LogicalResult LaunchFuncOp::verify() { 429 auto kernelAttr = this->getAttr(getKernelAttrName()); 430 if (!kernelAttr) { 431 return emitOpError("attribute 'kernel' must be specified"); 432 } else if (!kernelAttr.isa<SymbolRefAttr>()) { 433 return emitOpError("attribute 'kernel' must be a function"); 434 } 435 436 auto module = getParentOfType<ModuleOp>(); 437 FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel()); 438 if (!kernelFunc) 439 return emitOpError("kernel function '") << kernelAttr << "' is undefined"; 440 441 if (!kernelFunc.getAttrOfType<mlir::UnitAttr>( 442 GPUDialect::getKernelFuncAttrName())) { 443 return emitOpError("kernel function is missing the '") 444 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 445 } 446 unsigned numKernelFuncArgs = kernelFunc.getNumArguments(); 447 if (getNumKernelOperands() != numKernelFuncArgs) { 448 return emitOpError("got ") 449 << getNumKernelOperands() << " kernel operands but expected " 450 << numKernelFuncArgs; 451 } 452 auto functionType = kernelFunc.getType(); 453 for (unsigned i = 0; i < numKernelFuncArgs; ++i) { 454 if (getKernelOperand(i)->getType() != functionType.getInput(i)) { 455 return emitOpError("type of function argument ") 456 << i << " does not match"; 457 } 458 } 459 return success(); 460 } 461