1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/FunctionImplementation.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/OpImplementation.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Transforms/InliningUtils.h" 28 #include "llvm/ADT/TypeSwitch.h" 29 30 using namespace mlir; 31 using namespace mlir::gpu; 32 33 #include "mlir/Dialect/GPU/GPUOpsDialect.cpp.inc" 34 35 //===----------------------------------------------------------------------===// 36 // MMAMatrixType 37 //===----------------------------------------------------------------------===// 38 39 MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType, 40 StringRef operand) { 41 return Base::get(elementType.getContext(), shape, elementType, operand); 42 } 43 44 MMAMatrixType 45 MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, 46 ArrayRef<int64_t> shape, Type elementType, 47 StringRef operand) { 48 return Base::getChecked(emitError, elementType.getContext(), shape, 49 elementType, operand); 50 } 51 52 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } 53 54 ArrayRef<int64_t> MMAMatrixType::getShape() const { 55 return getImpl()->getShape(); 56 } 57 58 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } 59 60 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } 61 62 bool MMAMatrixType::isValidElementType(Type elementType) { 63 return elementType.isF16() || elementType.isF32(); 64 } 65 66 LogicalResult 67 MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError, 68 ArrayRef<int64_t> shape, Type elementType, 69 StringRef operand) { 70 if (!operand.equals("AOp") && !operand.equals("BOp") && 71 !operand.equals("COp")) 72 return emitError() << "operand expected to be one of AOp, BOp or COp"; 73 74 if (shape.size() != 2) 75 return emitError() << "MMAMatrixType must have exactly two dimensions"; 76 77 if (!MMAMatrixType::isValidElementType(elementType)) 78 return emitError() << "MMAMatrixType elements must be F16 or F32"; 79 80 return success(); 81 } 82 83 //===----------------------------------------------------------------------===// 84 // GPUDialect 85 //===----------------------------------------------------------------------===// 86 87 /// GPU memory space identifiers. 88 enum GPUMemorySpace { 89 /// Generic memory space identifier. 90 kGenericMemorySpace = 0, 91 92 /// Global memory space identifier. 93 kGlobalMemorySpace = 1, 94 95 /// Shared memory space identifier. 96 kSharedMemorySpace = 3 97 }; 98 99 bool GPUDialect::isKernel(Operation *op) { 100 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 101 return static_cast<bool>(isKernelAttr); 102 } 103 104 namespace { 105 /// This class defines the interface for handling inlining with gpu 106 /// operations. 107 struct GPUInlinerInterface : public DialectInlinerInterface { 108 using DialectInlinerInterface::DialectInlinerInterface; 109 110 /// All gpu dialect ops can be inlined. 111 bool isLegalToInline(Operation *, Region *, bool, 112 BlockAndValueMapping &) const final { 113 return true; 114 } 115 }; 116 } // namespace 117 118 void GPUDialect::initialize() { 119 addTypes<AsyncTokenType>(); 120 addTypes<DeviceAsyncTokenType>(); 121 addTypes<MMAMatrixType>(); 122 addOperations< 123 #define GET_OP_LIST 124 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 125 >(); 126 addAttributes< 127 #define GET_ATTRDEF_LIST 128 #include "mlir/Dialect/GPU/GPUOpsAttributes.cpp.inc" 129 >(); 130 addInterfaces<GPUInlinerInterface>(); 131 } 132 133 Type GPUDialect::parseType(DialectAsmParser &parser) const { 134 // Parse the main keyword for the type. 135 StringRef keyword; 136 if (parser.parseKeyword(&keyword)) 137 return Type(); 138 MLIRContext *context = getContext(); 139 140 // Handle 'async token' types. 141 if (keyword == "async.token") 142 return AsyncTokenType::get(context); 143 // Handle 'device async token' types. 144 if (keyword == "device.async.token") 145 return DeviceAsyncTokenType::get(context); 146 147 if (keyword == "mma_matrix") { 148 SMLoc beginLoc = parser.getNameLoc(); 149 150 // Parse '<'. 151 if (parser.parseLess()) 152 return nullptr; 153 154 // Parse the size and elementType. 155 SmallVector<int64_t> shape; 156 Type elementType; 157 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) || 158 parser.parseType(elementType)) 159 return nullptr; 160 161 // Parse ',' 162 if (parser.parseComma()) 163 return nullptr; 164 165 // Parse operand. 166 std::string operand; 167 if (failed(parser.parseOptionalString(&operand))) 168 return nullptr; 169 170 // Parse '>'. 171 if (parser.parseGreater()) 172 return nullptr; 173 174 return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( 175 parser.getEncodedSourceLoc(beginLoc)), 176 shape, elementType, operand); 177 } 178 179 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 180 return Type(); 181 } 182 183 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 184 TypeSwitch<Type>(type) 185 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 186 .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; }) 187 .Case<MMAMatrixType>([&](MMAMatrixType fragTy) { 188 os << "mma_matrix<"; 189 auto shape = fragTy.getShape(); 190 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) 191 os << *dim << 'x'; 192 os << shape.back() << 'x' << fragTy.getElementType(); 193 os << ", \"" << fragTy.getOperand() << "\"" << '>'; 194 }) 195 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 196 } 197 198 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 199 NamedAttribute attr) { 200 if (!attr.getValue().isa<UnitAttr>() || 201 attr.getName() != getContainerModuleAttrName()) 202 return success(); 203 204 auto module = dyn_cast<ModuleOp>(op); 205 if (!module) 206 return op->emitError("expected '") 207 << getContainerModuleAttrName() << "' attribute to be attached to '" 208 << ModuleOp::getOperationName() << '\''; 209 210 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 211 // Ignore launches that are nested more or less deep than functions in the 212 // module we are currently checking. 213 if (!launchOp->getParentOp() || 214 launchOp->getParentOp()->getParentOp() != module) 215 return success(); 216 217 // Ignore launch ops with missing attributes here. The errors will be 218 // reported by the verifiers of those ops. 219 if (!launchOp->getAttrOfType<SymbolRefAttr>( 220 LaunchFuncOp::getKernelAttrName())) 221 return success(); 222 223 // Check that `launch_func` refers to a well-formed GPU kernel module. 224 StringAttr kernelModuleName = launchOp.getKernelModuleName(); 225 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 226 if (!kernelModule) 227 return launchOp.emitOpError() 228 << "kernel module '" << kernelModuleName.getValue() 229 << "' is undefined"; 230 231 // Check that `launch_func` refers to a well-formed kernel function. 232 Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr()); 233 if (!kernelFunc) 234 return launchOp.emitOpError("kernel function '") 235 << launchOp.kernel() << "' is undefined"; 236 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc); 237 if (!kernelConvertedFunction) { 238 InFlightDiagnostic diag = launchOp.emitOpError() 239 << "referenced kernel '" << launchOp.kernel() 240 << "' is not a function"; 241 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here"; 242 return diag; 243 } 244 245 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 246 GPUDialect::getKernelFuncAttrName())) 247 return launchOp.emitOpError("kernel function is missing the '") 248 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 249 250 // TODO: If the kernel isn't a GPU function (which happens during separate 251 // compilation), do not check type correspondence as it would require the 252 // verifier to be aware of the type conversion. 253 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc); 254 if (!kernelGPUFunction) 255 return success(); 256 257 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 258 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 259 if (expectedNumArguments != actualNumArguments) 260 return launchOp.emitOpError("got ") 261 << actualNumArguments << " kernel operands but expected " 262 << expectedNumArguments; 263 264 auto functionType = kernelGPUFunction.getFunctionType(); 265 for (unsigned i = 0; i < expectedNumArguments; ++i) { 266 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 267 return launchOp.emitOpError("type of function argument ") 268 << i << " does not match"; 269 } 270 } 271 272 return success(); 273 }); 274 275 return walkResult.wasInterrupted() ? failure() : success(); 276 } 277 278 LogicalResult gpu::AllReduceOp::verifyRegions() { 279 if (body().empty() != op().hasValue()) 280 return emitError("expected either an op attribute or a non-empty body"); 281 if (!body().empty()) { 282 if (body().getNumArguments() != 2) 283 return emitError("expected two region arguments"); 284 for (auto argument : body().getArguments()) { 285 if (argument.getType() != getType()) 286 return emitError("incorrect region argument type"); 287 } 288 unsigned yieldCount = 0; 289 for (Block &block : body()) { 290 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 291 if (yield.getNumOperands() != 1) 292 return emitError("expected one gpu.yield operand"); 293 if (yield.getOperand(0).getType() != getType()) 294 return emitError("incorrect gpu.yield type"); 295 ++yieldCount; 296 } 297 } 298 if (yieldCount == 0) 299 return emitError("expected gpu.yield op in region"); 300 } else { 301 gpu::AllReduceOperation opName = *op(); 302 if ((opName == gpu::AllReduceOperation::AND || 303 opName == gpu::AllReduceOperation::OR || 304 opName == gpu::AllReduceOperation::XOR) && 305 !getType().isa<IntegerType>()) { 306 return emitError() 307 << '`' << gpu::stringifyAllReduceOperation(opName) 308 << "` accumulator is only compatible with Integer type"; 309 } 310 } 311 return success(); 312 } 313 314 // TODO: Support optional custom attributes (without dialect prefix). 315 static ParseResult parseAllReduceOperation(AsmParser &parser, 316 AllReduceOperationAttr &attr) { 317 StringRef enumStr; 318 if (!parser.parseOptionalKeyword(&enumStr)) { 319 Optional<AllReduceOperation> op = gpu::symbolizeAllReduceOperation(enumStr); 320 if (!op) 321 return parser.emitError(parser.getCurrentLocation(), "invalid op kind"); 322 attr = AllReduceOperationAttr::get(parser.getContext(), *op); 323 } 324 return success(); 325 } 326 327 static void printAllReduceOperation(AsmPrinter &printer, Operation *op, 328 AllReduceOperationAttr attr) { 329 if (attr) 330 attr.print(printer); 331 } 332 333 //===----------------------------------------------------------------------===// 334 // AsyncOpInterface 335 //===----------------------------------------------------------------------===// 336 337 void gpu::addAsyncDependency(Operation *op, Value token) { 338 op->insertOperands(0, {token}); 339 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 340 return; 341 auto attrName = 342 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 343 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName); 344 345 // Async dependencies is the only variadic operand. 346 if (!sizeAttr) 347 return; 348 349 SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>()); 350 ++sizes.front(); 351 op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // LaunchOp 356 //===----------------------------------------------------------------------===// 357 358 void LaunchOp::build(OpBuilder &builder, OperationState &result, 359 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 360 Value blockSizeX, Value blockSizeY, Value blockSizeZ, 361 Value dynamicSharedMemorySize) { 362 // Add grid and block sizes as op operands, followed by the data operands. 363 result.addOperands( 364 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 365 if (dynamicSharedMemorySize) 366 result.addOperands(dynamicSharedMemorySize); 367 368 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 369 // where the first kNumConfigRegionAttributes arguments have `index` type and 370 // the rest have the same types as the data operands. 371 Region *kernelRegion = result.addRegion(); 372 Block *body = new Block(); 373 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i) 374 body->addArgument(builder.getIndexType(), result.location); 375 kernelRegion->push_back(body); 376 } 377 378 KernelDim3 LaunchOp::getBlockIds() { 379 assert(!body().empty() && "LaunchOp body must not be empty."); 380 auto args = body().getArguments(); 381 return KernelDim3{args[0], args[1], args[2]}; 382 } 383 384 KernelDim3 LaunchOp::getThreadIds() { 385 assert(!body().empty() && "LaunchOp body must not be empty."); 386 auto args = body().getArguments(); 387 return KernelDim3{args[3], args[4], args[5]}; 388 } 389 390 KernelDim3 LaunchOp::getGridSize() { 391 assert(!body().empty() && "LaunchOp body must not be empty."); 392 auto args = body().getArguments(); 393 return KernelDim3{args[6], args[7], args[8]}; 394 } 395 396 KernelDim3 LaunchOp::getBlockSize() { 397 assert(!body().empty() && "LaunchOp body must not be empty."); 398 auto args = body().getArguments(); 399 return KernelDim3{args[9], args[10], args[11]}; 400 } 401 402 KernelDim3 LaunchOp::getGridSizeOperandValues() { 403 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 404 } 405 406 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 407 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 408 } 409 410 LogicalResult LaunchOp::verifyRegions() { 411 // Kernel launch takes kNumConfigOperands leading operands for grid/block 412 // sizes and transforms them into kNumConfigRegionAttributes region arguments 413 // for block/thread identifiers and grid/block sizes. 414 if (!body().empty()) { 415 if (body().getNumArguments() != LaunchOp::kNumConfigOperands + 416 getNumOperands() - 417 (dynamicSharedMemorySize() ? 1 : 0)) 418 return emitOpError("unexpected number of region arguments"); 419 } 420 421 // Block terminators without successors are expected to exit the kernel region 422 // and must be `gpu.terminator`. 423 for (Block &block : body()) { 424 if (block.empty()) 425 continue; 426 if (block.back().getNumSuccessors() != 0) 427 continue; 428 if (!isa<gpu::TerminatorOp>(&block.back())) { 429 return block.back() 430 .emitError() 431 .append("expected '", gpu::TerminatorOp::getOperationName(), 432 "' or a terminator with successors") 433 .attachNote(getLoc()) 434 .append("in '", LaunchOp::getOperationName(), "' body region"); 435 } 436 } 437 438 return success(); 439 } 440 441 // Pretty-print the kernel grid/block size assignment as 442 // (%iter-x, %iter-y, %iter-z) in 443 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 444 // where %size-* and %iter-* will correspond to the body region arguments. 445 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 446 KernelDim3 operands, KernelDim3 ids) { 447 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 448 p << size.x << " = " << operands.x << ", "; 449 p << size.y << " = " << operands.y << ", "; 450 p << size.z << " = " << operands.z << ')'; 451 } 452 453 void LaunchOp::print(OpAsmPrinter &p) { 454 // Print the launch configuration. 455 p << ' ' << getBlocksKeyword(); 456 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(), 457 getBlockIds()); 458 p << ' ' << getThreadsKeyword(); 459 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(), 460 getThreadIds()); 461 if (dynamicSharedMemorySize()) 462 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' 463 << dynamicSharedMemorySize(); 464 465 p << ' '; 466 p.printRegion(body(), /*printEntryBlockArgs=*/false); 467 p.printOptionalAttrDict((*this)->getAttrs()); 468 } 469 470 // Parse the size assignment blocks for blocks and threads. These have the form 471 // (%region_arg, %region_arg, %region_arg) in 472 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 473 // where %region_arg are percent-identifiers for the region arguments to be 474 // introduced further (SSA defs), and %operand are percent-identifiers for the 475 // SSA value uses. 476 static ParseResult 477 parseSizeAssignment(OpAsmParser &parser, 478 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes, 479 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes, 480 MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) { 481 assert(indices.size() == 3 && "space for three indices expected"); 482 SmallVector<OpAsmParser::UnresolvedOperand, 3> args; 483 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 484 OpAsmParser::Delimiter::Paren) || 485 parser.parseKeyword("in") || parser.parseLParen()) 486 return failure(); 487 std::move(args.begin(), args.end(), indices.begin()); 488 489 for (int i = 0; i < 3; ++i) { 490 if (i != 0 && parser.parseComma()) 491 return failure(); 492 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 493 parser.parseOperand(sizes[i])) 494 return failure(); 495 } 496 497 return parser.parseRParen(); 498 } 499 500 /// Parses a Launch operation. 501 /// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` 502 /// ssa-reassignment 503 /// `threads` `(` ssa-id-list `)` `in` 504 /// ssa-reassignment 505 /// region attr-dict? 506 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 507 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { 508 // Sizes of the grid and block. 509 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands> 510 sizes(LaunchOp::kNumConfigOperands); 511 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes); 512 513 // Actual (data) operands passed to the kernel. 514 SmallVector<OpAsmParser::UnresolvedOperand, 4> dataOperands; 515 516 // Region arguments to be created. 517 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs( 518 LaunchOp::kNumConfigRegionAttributes); 519 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs); 520 521 // Parse the size assignment segments: the first segment assigns grid sizes 522 // and defines values for block identifiers; the second segment assigns block 523 // sizes and defines values for thread identifiers. In the region argument 524 // list, identifiers precede sizes, and block-related values precede 525 // thread-related values. 526 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 527 parseSizeAssignment(parser, sizesRef.take_front(3), 528 regionArgsRef.slice(6, 3), 529 regionArgsRef.slice(0, 3)) || 530 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 531 parseSizeAssignment(parser, sizesRef.drop_front(3), 532 regionArgsRef.slice(9, 3), 533 regionArgsRef.slice(3, 3)) || 534 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 535 result.operands)) 536 return failure(); 537 538 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize; 539 if (!parser.parseOptionalKeyword( 540 LaunchOp::getDynamicSharedMemorySizeKeyword())) 541 if (parser.parseOperand(dynamicSharedMemorySize) || 542 parser.resolveOperand(dynamicSharedMemorySize, 543 parser.getBuilder().getI32Type(), 544 result.operands)) 545 return failure(); 546 547 // Introduce the body region and parse it. The region has 548 // kNumConfigRegionAttributes arguments that correspond to 549 // block/thread identifiers and grid/block sizes, all of the `index` type. 550 Type index = parser.getBuilder().getIndexType(); 551 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 552 LaunchOp::kNumConfigRegionAttributes, index); 553 Region *body = result.addRegion(); 554 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 555 parser.parseOptionalAttrDict(result.attributes)); 556 } 557 558 /// Simplify the gpu.launch when the range of a thread or block ID is 559 /// trivially known to be one. 560 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> { 561 using OpRewritePattern<LaunchOp>::OpRewritePattern; 562 LogicalResult matchAndRewrite(LaunchOp op, 563 PatternRewriter &rewriter) const override { 564 // If the range implies a single value for `id`, replace `id`'s uses by 565 // zero. 566 Value zero; 567 bool simplified = false; 568 auto constPropIdUses = [&](Value id, Value size) { 569 // Check if size is trivially one. 570 if (!matchPattern(size, m_One())) 571 return; 572 if (!simplified) { 573 // Create a zero value the first time. 574 OpBuilder::InsertionGuard guard(rewriter); 575 rewriter.setInsertionPointToStart(&op.body().front()); 576 zero = 577 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0); 578 } 579 id.replaceAllUsesWith(zero); 580 simplified = true; 581 }; 582 constPropIdUses(op.getBlockIds().x, op.gridSizeX()); 583 constPropIdUses(op.getBlockIds().y, op.gridSizeY()); 584 constPropIdUses(op.getBlockIds().z, op.gridSizeZ()); 585 constPropIdUses(op.getThreadIds().x, op.blockSizeX()); 586 constPropIdUses(op.getThreadIds().y, op.blockSizeY()); 587 constPropIdUses(op.getThreadIds().z, op.blockSizeZ()); 588 589 return success(simplified); 590 } 591 }; 592 593 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites, 594 MLIRContext *context) { 595 rewrites.add<FoldLaunchArguments>(context); 596 } 597 598 //===----------------------------------------------------------------------===// 599 // LaunchFuncOp 600 //===----------------------------------------------------------------------===// 601 602 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 603 GPUFuncOp kernelFunc, KernelDim3 gridSize, 604 KernelDim3 blockSize, Value dynamicSharedMemorySize, 605 ValueRange kernelOperands) { 606 // Add grid and block sizes as op operands, followed by the data operands. 607 result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x, 608 blockSize.y, blockSize.z}); 609 if (dynamicSharedMemorySize) 610 result.addOperands(dynamicSharedMemorySize); 611 result.addOperands(kernelOperands); 612 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); 613 auto kernelSymbol = 614 SymbolRefAttr::get(kernelModule.getNameAttr(), 615 {SymbolRefAttr::get(kernelFunc.getNameAttr())}); 616 result.addAttribute(getKernelAttrName(), kernelSymbol); 617 SmallVector<int32_t, 9> segmentSizes(9, 1); 618 segmentSizes.front() = 0; // Initially no async dependencies. 619 segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0; 620 segmentSizes.back() = static_cast<int32_t>(kernelOperands.size()); 621 result.addAttribute(getOperandSegmentSizeAttr(), 622 builder.getI32VectorAttr(segmentSizes)); 623 } 624 625 unsigned LaunchFuncOp::getNumKernelOperands() { 626 return getNumOperands() - asyncDependencies().size() - kNumConfigOperands - 627 (dynamicSharedMemorySize() ? 1 : 0); 628 } 629 630 StringAttr LaunchFuncOp::getKernelModuleName() { 631 return kernel().getRootReference(); 632 } 633 634 StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 635 636 Value LaunchFuncOp::getKernelOperand(unsigned i) { 637 return getOperand(asyncDependencies().size() + kNumConfigOperands + 638 (dynamicSharedMemorySize() ? 1 : 0) + i); 639 } 640 641 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 642 auto operands = getOperands().drop_front(asyncDependencies().size()); 643 return KernelDim3{operands[0], operands[1], operands[2]}; 644 } 645 646 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 647 auto operands = getOperands().drop_front(asyncDependencies().size()); 648 return KernelDim3{operands[3], operands[4], operands[5]}; 649 } 650 651 LogicalResult LaunchFuncOp::verify() { 652 auto module = (*this)->getParentOfType<ModuleOp>(); 653 if (!module) 654 return emitOpError("expected to belong to a module"); 655 656 if (!module->getAttrOfType<UnitAttr>( 657 GPUDialect::getContainerModuleAttrName())) 658 return emitOpError("expected the closest surrounding module to have the '" + 659 GPUDialect::getContainerModuleAttrName() + 660 "' attribute"); 661 662 auto kernelAttr = (*this)->getAttrOfType<SymbolRefAttr>(getKernelAttrName()); 663 if (!kernelAttr) 664 return emitOpError("symbol reference attribute '" + getKernelAttrName() + 665 "' must be specified"); 666 667 return success(); 668 } 669 670 static ParseResult parseLaunchFuncOperands( 671 OpAsmParser &parser, 672 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, 673 SmallVectorImpl<Type> &argTypes) { 674 if (parser.parseOptionalKeyword("args")) 675 return success(); 676 SmallVector<NamedAttrList> argAttrs; 677 SmallVector<Location> argLocations; 678 bool isVariadic = false; 679 return function_interface_impl::parseFunctionArgumentList( 680 parser, /*allowAttributes=*/false, 681 /*allowVariadic=*/false, argNames, argTypes, argAttrs, argLocations, 682 isVariadic); 683 } 684 685 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, 686 OperandRange operands, TypeRange types) { 687 if (operands.empty()) 688 return; 689 printer << "args("; 690 llvm::interleaveComma(llvm::zip(operands, types), printer, 691 [&](const auto &pair) { 692 printer.printOperand(std::get<0>(pair)); 693 printer << " : "; 694 printer.printType(std::get<1>(pair)); 695 }); 696 printer << ")"; 697 } 698 699 // 700 701 //===----------------------------------------------------------------------===// 702 // ShuffleOp 703 //===----------------------------------------------------------------------===// 704 705 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, 706 int32_t offset, int32_t width, ShuffleMode mode) { 707 build(builder, result, value, 708 builder.create<arith::ConstantOp>(result.location, 709 builder.getI32IntegerAttr(offset)), 710 builder.create<arith::ConstantOp>(result.location, 711 builder.getI32IntegerAttr(width)), 712 mode); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // GPUFuncOp 717 //===----------------------------------------------------------------------===// 718 719 /// Adds a new block argument that corresponds to buffers located in 720 /// workgroup memory. 721 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) { 722 auto attrName = getNumWorkgroupAttributionsAttrName(); 723 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); 724 (*this)->setAttr(attrName, 725 IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 726 return getBody().insertArgument( 727 getFunctionType().getNumInputs() + attr.getInt(), type, loc); 728 } 729 730 /// Adds a new block argument that corresponds to buffers located in 731 /// private memory. 732 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) { 733 // Buffers on the private memory always come after buffers on the workgroup 734 // memory. 735 return getBody().addArgument(type, loc); 736 } 737 738 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 739 StringRef name, FunctionType type, 740 TypeRange workgroupAttributions, 741 TypeRange privateAttributions, 742 ArrayRef<NamedAttribute> attrs) { 743 result.addAttribute(SymbolTable::getSymbolAttrName(), 744 builder.getStringAttr(name)); 745 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 746 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 747 builder.getI64IntegerAttr(workgroupAttributions.size())); 748 result.addAttributes(attrs); 749 Region *body = result.addRegion(); 750 Block *entryBlock = new Block; 751 752 // TODO: Allow passing in proper locations here. 753 for (Type argTy : type.getInputs()) 754 entryBlock->addArgument(argTy, result.location); 755 for (Type argTy : workgroupAttributions) 756 entryBlock->addArgument(argTy, result.location); 757 for (Type argTy : privateAttributions) 758 entryBlock->addArgument(argTy, result.location); 759 760 body->getBlocks().push_back(entryBlock); 761 } 762 763 /// Parses a GPU function memory attribution. 764 /// 765 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 766 /// (`private` `(` ssa-id-and-type-list `)`)? 767 /// 768 /// Note that this function parses only one of the two similar parts, with the 769 /// keyword provided as argument. 770 static ParseResult 771 parseAttributions(OpAsmParser &parser, StringRef keyword, 772 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args, 773 SmallVectorImpl<Type> &argTypes) { 774 // If we could not parse the keyword, just assume empty list and succeed. 775 if (failed(parser.parseOptionalKeyword(keyword))) 776 return success(); 777 778 if (failed(parser.parseLParen())) 779 return failure(); 780 781 // Early exit for an empty list. 782 if (succeeded(parser.parseOptionalRParen())) 783 return success(); 784 785 do { 786 OpAsmParser::UnresolvedOperand arg; 787 Type type; 788 789 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 790 return failure(); 791 792 args.push_back(arg); 793 argTypes.push_back(type); 794 } while (succeeded(parser.parseOptionalComma())); 795 796 return parser.parseRParen(); 797 } 798 799 /// Parses a GPU function. 800 /// 801 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 802 /// (`->` function-result-list)? memory-attribution `kernel`? 803 /// function-attributes? region 804 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { 805 SmallVector<OpAsmParser::UnresolvedOperand> entryArgs; 806 SmallVector<NamedAttrList> argAttrs; 807 SmallVector<NamedAttrList> resultAttrs; 808 SmallVector<Type> argTypes; 809 SmallVector<Type> resultTypes; 810 SmallVector<Location> argLocations; 811 bool isVariadic; 812 813 // Parse the function name. 814 StringAttr nameAttr; 815 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 816 result.attributes)) 817 return failure(); 818 819 auto signatureLocation = parser.getCurrentLocation(); 820 if (failed(function_interface_impl::parseFunctionSignature( 821 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 822 argLocations, isVariadic, resultTypes, resultAttrs))) 823 return failure(); 824 825 if (entryArgs.empty() && !argTypes.empty()) 826 return parser.emitError(signatureLocation) 827 << "gpu.func requires named arguments"; 828 829 // Construct the function type. More types will be added to the region, but 830 // not to the function type. 831 Builder &builder = parser.getBuilder(); 832 auto type = builder.getFunctionType(argTypes, resultTypes); 833 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 834 835 // Parse workgroup memory attributions. 836 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 837 entryArgs, argTypes))) 838 return failure(); 839 840 // Store the number of operands we just parsed as the number of workgroup 841 // memory attributions. 842 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 843 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 844 builder.getI64IntegerAttr(numWorkgroupAttrs)); 845 846 // Parse private memory attributions. 847 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 848 entryArgs, argTypes))) 849 return failure(); 850 851 // Parse the kernel attribute if present. 852 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 853 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 854 builder.getUnitAttr()); 855 856 // Parse attributes. 857 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 858 return failure(); 859 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 860 resultAttrs); 861 862 // Parse the region. If no argument names were provided, take all names 863 // (including those of attributions) from the entry block. 864 auto *body = result.addRegion(); 865 return parser.parseRegion(*body, entryArgs, argTypes); 866 } 867 868 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 869 ArrayRef<BlockArgument> values) { 870 if (values.empty()) 871 return; 872 873 p << ' ' << keyword << '('; 874 llvm::interleaveComma( 875 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 876 p << ')'; 877 } 878 879 void GPUFuncOp::print(OpAsmPrinter &p) { 880 p << ' '; 881 p.printSymbolName(getName()); 882 883 FunctionType type = getFunctionType(); 884 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(), 885 /*isVariadic=*/false, 886 type.getResults()); 887 888 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); 889 printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); 890 if (isKernel()) 891 p << ' ' << getKernelKeyword(); 892 893 function_interface_impl::printFunctionAttributes( 894 p, *this, type.getNumInputs(), type.getNumResults(), 895 {getNumWorkgroupAttributionsAttrName(), 896 GPUDialect::getKernelFuncAttrName()}); 897 p << ' '; 898 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 899 } 900 901 LogicalResult GPUFuncOp::verifyType() { 902 Type type = getFunctionTypeAttr().getValue(); 903 if (!type.isa<FunctionType>()) 904 return emitOpError("requires '" + getTypeAttrName() + 905 "' attribute of function type"); 906 907 if (isKernel() && getFunctionType().getNumResults() != 0) 908 return emitOpError() << "expected void return type for kernel function"; 909 910 return success(); 911 } 912 913 static LogicalResult verifyAttributions(Operation *op, 914 ArrayRef<BlockArgument> attributions, 915 unsigned memorySpace) { 916 for (Value v : attributions) { 917 auto type = v.getType().dyn_cast<MemRefType>(); 918 if (!type) 919 return op->emitOpError() << "expected memref type in attribution"; 920 921 if (type.getMemorySpaceAsInt() != memorySpace) { 922 return op->emitOpError() 923 << "expected memory space " << memorySpace << " in attribution"; 924 } 925 } 926 return success(); 927 } 928 929 /// Verifies the body of the function. 930 LogicalResult GPUFuncOp::verifyBody() { 931 unsigned numFuncArguments = getNumArguments(); 932 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 933 unsigned numBlockArguments = front().getNumArguments(); 934 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 935 return emitOpError() << "expected at least " 936 << numFuncArguments + numWorkgroupAttributions 937 << " arguments to body region"; 938 939 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs(); 940 for (unsigned i = 0; i < numFuncArguments; ++i) { 941 Type blockArgType = front().getArgument(i).getType(); 942 if (funcArgTypes[i] != blockArgType) 943 return emitOpError() << "expected body region argument #" << i 944 << " to be of type " << funcArgTypes[i] << ", got " 945 << blockArgType; 946 } 947 948 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 949 GPUDialect::getWorkgroupAddressSpace())) || 950 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 951 GPUDialect::getPrivateAddressSpace()))) 952 return failure(); 953 954 return success(); 955 } 956 957 //===----------------------------------------------------------------------===// 958 // ReturnOp 959 //===----------------------------------------------------------------------===// 960 961 LogicalResult gpu::ReturnOp::verify() { 962 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>(); 963 964 FunctionType funType = function.getFunctionType(); 965 966 if (funType.getNumResults() != operands().size()) 967 return emitOpError() 968 .append("expected ", funType.getNumResults(), " result operands") 969 .attachNote(function.getLoc()) 970 .append("return type declared here"); 971 972 for (const auto &pair : llvm::enumerate( 973 llvm::zip(function.getFunctionType().getResults(), operands()))) { 974 Type type; 975 Value operand; 976 std::tie(type, operand) = pair.value(); 977 if (type != operand.getType()) 978 return emitOpError() << "unexpected type `" << operand.getType() 979 << "' for operand #" << pair.index(); 980 } 981 return success(); 982 } 983 984 //===----------------------------------------------------------------------===// 985 // GPUModuleOp 986 //===----------------------------------------------------------------------===// 987 988 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 989 StringRef name) { 990 ensureTerminator(*result.addRegion(), builder, result.location); 991 result.attributes.push_back(builder.getNamedAttr( 992 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 993 } 994 995 ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) { 996 StringAttr nameAttr; 997 if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), 998 result.attributes)) 999 return failure(); 1000 1001 // If module attributes are present, parse them. 1002 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1003 return failure(); 1004 1005 // Parse the module body. 1006 auto *body = result.addRegion(); 1007 if (parser.parseRegion(*body, None, None)) 1008 return failure(); 1009 1010 // Ensure that this module has a valid terminator. 1011 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 1012 return success(); 1013 } 1014 1015 void GPUModuleOp::print(OpAsmPrinter &p) { 1016 p << ' '; 1017 p.printSymbolName(getName()); 1018 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), 1019 {mlir::SymbolTable::getSymbolAttrName()}); 1020 p << ' '; 1021 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1022 /*printBlockTerminators=*/false); 1023 } 1024 1025 //===----------------------------------------------------------------------===// 1026 // GPUMemcpyOp 1027 //===----------------------------------------------------------------------===// 1028 1029 LogicalResult MemcpyOp::verify() { 1030 auto srcType = src().getType(); 1031 auto dstType = dst().getType(); 1032 1033 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) 1034 return emitOpError("arguments have incompatible element type"); 1035 1036 if (failed(verifyCompatibleShape(srcType, dstType))) 1037 return emitOpError("arguments have incompatible shape"); 1038 1039 return success(); 1040 } 1041 1042 static ParseResult parseAsyncDependencies( 1043 OpAsmParser &parser, Type &asyncTokenType, 1044 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) { 1045 auto loc = parser.getCurrentLocation(); 1046 if (succeeded(parser.parseOptionalKeyword("async"))) { 1047 if (parser.getNumResults() == 0) 1048 return parser.emitError(loc, "needs to be named when marked 'async'"); 1049 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 1050 } 1051 return parser.parseOperandList(asyncDependencies, 1052 OpAsmParser::Delimiter::OptionalSquare); 1053 } 1054 1055 /// Prints optional async dependencies with its leading keyword. 1056 /// (`async`)? (`[` ssa-id-list `]`)? 1057 // Used by the tablegen assembly format for several async ops. 1058 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, 1059 Type asyncTokenType, 1060 OperandRange asyncDependencies) { 1061 if (asyncTokenType) 1062 printer << "async"; 1063 if (asyncDependencies.empty()) 1064 return; 1065 if (asyncTokenType) 1066 printer << ' '; 1067 printer << '['; 1068 llvm::interleaveComma(asyncDependencies, printer); 1069 printer << ']'; 1070 } 1071 1072 namespace { 1073 1074 /// Erases a common case of copy ops where a destination value is used only by 1075 /// the copy op, alloc and dealloc ops. 1076 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> { 1077 using OpRewritePattern<MemcpyOp>::OpRewritePattern; 1078 1079 LogicalResult matchAndRewrite(MemcpyOp op, 1080 PatternRewriter &rewriter) const override { 1081 Value dest = op.dst(); 1082 // If `dest` is a block argument, we cannot remove `op`. 1083 if (dest.isa<BlockArgument>()) 1084 return failure(); 1085 auto isDeallocLikeOpActingOnVal = [](Operation *op, Value val) { 1086 auto memOp = dyn_cast<MemoryEffectOpInterface>(op); 1087 if (!memOp) 1088 return false; 1089 llvm::SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> 1090 memOpEffects; 1091 memOp.getEffects(memOpEffects); 1092 return llvm::none_of(memOpEffects, [val](auto &effect) { 1093 return effect.getValue() == val && 1094 !isa<MemoryEffects::Free>(effect.getEffect()); 1095 }); 1096 }; 1097 // We can erase `op` iff `dest` has no other use apart from its 1098 // use by `op` and dealloc ops. 1099 if (llvm::any_of(dest.getUsers(), [isDeallocLikeOpActingOnVal, op, 1100 dest](Operation *user) { 1101 return user != op && !isDeallocLikeOpActingOnVal(user, dest); 1102 })) 1103 return failure(); 1104 1105 if (op.asyncDependencies().size() > 1 || 1106 ((op.asyncDependencies().empty() && op.asyncToken()) || 1107 (!op.asyncDependencies().empty() && !op.asyncToken()))) 1108 return failure(); 1109 rewriter.replaceOp(op, op.asyncDependencies()); 1110 return success(); 1111 } 1112 }; 1113 1114 } // end anonymous namespace 1115 1116 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results, 1117 MLIRContext *context) { 1118 results.add<EraseTrivialCopyOp>(context); 1119 } 1120 1121 //===----------------------------------------------------------------------===// 1122 // GPU_SubgroupMmaLoadMatrixOp 1123 //===----------------------------------------------------------------------===// 1124 1125 /// Return true if the last dimension of the MemRefType has unit stride. Also 1126 /// return true for memrefs with no strides. 1127 static bool isLastMemrefDimUnitStride(MemRefType type) { 1128 int64_t offset; 1129 SmallVector<int64_t> strides; 1130 if (failed(getStridesAndOffset(type, strides, offset))) { 1131 return false; 1132 } 1133 return strides.back() == 1; 1134 } 1135 1136 LogicalResult SubgroupMmaLoadMatrixOp::verify() { 1137 auto srcType = srcMemref().getType(); 1138 auto resType = res().getType(); 1139 auto resMatrixType = resType.cast<gpu::MMAMatrixType>(); 1140 auto operand = resMatrixType.getOperand(); 1141 auto srcMemrefType = srcType.cast<MemRefType>(); 1142 auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); 1143 1144 if (!isLastMemrefDimUnitStride(srcMemrefType)) 1145 return emitError( 1146 "expected source memref most minor dim must have unit stride"); 1147 1148 if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && 1149 srcMemSpace != kGlobalMemorySpace) 1150 return emitError( 1151 "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " 1152 "kGlobalMemorySpace only allowed"); 1153 1154 if (!operand.equals("AOp") && !operand.equals("BOp") && 1155 !operand.equals("COp")) 1156 return emitError("only AOp, BOp and COp can be loaded"); 1157 1158 return success(); 1159 } 1160 1161 //===----------------------------------------------------------------------===// 1162 // GPU_SubgroupMmaStoreMatrixOp 1163 //===----------------------------------------------------------------------===// 1164 1165 LogicalResult SubgroupMmaStoreMatrixOp::verify() { 1166 auto srcType = src().getType(); 1167 auto dstType = dstMemref().getType(); 1168 auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>(); 1169 auto dstMemrefType = dstType.cast<MemRefType>(); 1170 auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); 1171 1172 if (!isLastMemrefDimUnitStride(dstMemrefType)) 1173 return emitError( 1174 "expected destination memref most minor dim must have unit stride"); 1175 1176 if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && 1177 dstMemSpace != kGlobalMemorySpace) 1178 return emitError("destination memorySpace of kGenericMemorySpace, " 1179 "kGlobalMemorySpace or kSharedMemorySpace only allowed"); 1180 1181 if (!srcMatrixType.getOperand().equals("COp")) 1182 return emitError( 1183 "expected the operand matrix being stored to have 'COp' operand type"); 1184 1185 return success(); 1186 } 1187 1188 //===----------------------------------------------------------------------===// 1189 // GPU_SubgroupMmaComputeOp 1190 //===----------------------------------------------------------------------===// 1191 1192 LogicalResult SubgroupMmaComputeOp::verify() { 1193 enum OperandMap { A, B, C }; 1194 SmallVector<MMAMatrixType, 3> opTypes; 1195 opTypes.push_back(opA().getType().cast<MMAMatrixType>()); 1196 opTypes.push_back(opB().getType().cast<MMAMatrixType>()); 1197 opTypes.push_back(opC().getType().cast<MMAMatrixType>()); 1198 1199 if (!opTypes[A].getOperand().equals("AOp") || 1200 !opTypes[B].getOperand().equals("BOp") || 1201 !opTypes[C].getOperand().equals("COp")) 1202 return emitError("operands must be in the order AOp, BOp, COp"); 1203 1204 ArrayRef<int64_t> aShape, bShape, cShape; 1205 aShape = opTypes[A].getShape(); 1206 bShape = opTypes[B].getShape(); 1207 cShape = opTypes[C].getShape(); 1208 1209 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || 1210 bShape[1] != cShape[1]) 1211 return emitError("operand shapes do not satisfy matmul constraints"); 1212 1213 return success(); 1214 } 1215 1216 /// This is a common class used for patterns of the form 1217 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 1218 /// into the root operation directly. 1219 static LogicalResult foldMemRefCast(Operation *op) { 1220 bool folded = false; 1221 for (OpOperand &operand : op->getOpOperands()) { 1222 auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>(); 1223 if (cast) { 1224 operand.set(cast.getOperand()); 1225 folded = true; 1226 } 1227 } 1228 return success(folded); 1229 } 1230 1231 LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands, 1232 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1233 return foldMemRefCast(*this); 1234 } 1235 1236 LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands, 1237 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1238 return foldMemRefCast(*this); 1239 } 1240 1241 //===----------------------------------------------------------------------===// 1242 // GPU_WaitOp 1243 //===----------------------------------------------------------------------===// 1244 1245 namespace { 1246 1247 /// Remove gpu.wait op use of gpu.wait op def without async dependencies. 1248 /// %t = gpu.wait async [] // No async dependencies. 1249 /// ... gpu.wait ... [%t, ...] // %t can be removed. 1250 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> { 1251 public: 1252 using OpRewritePattern::OpRewritePattern; 1253 1254 LogicalResult matchAndRewrite(WaitOp op, 1255 PatternRewriter &rewriter) const final { 1256 auto predicate = [](Value value) { 1257 auto waitOp = value.getDefiningOp<WaitOp>(); 1258 return waitOp && waitOp->getNumOperands() == 0; 1259 }; 1260 if (llvm::none_of(op.asyncDependencies(), predicate)) 1261 return failure(); 1262 SmallVector<Value> validOperands; 1263 for (Value operand : op->getOperands()) { 1264 if (predicate(operand)) 1265 continue; 1266 validOperands.push_back(operand); 1267 } 1268 op->setOperands(validOperands); 1269 return success(); 1270 } 1271 }; 1272 1273 /// Simplify trivial gpu.wait ops for the following patterns. 1274 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async 1275 /// dependencies). 1276 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with 1277 /// %t0. 1278 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async 1279 /// dependencies nor return any token. 1280 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> { 1281 public: 1282 using OpRewritePattern::OpRewritePattern; 1283 1284 LogicalResult matchAndRewrite(WaitOp op, 1285 PatternRewriter &rewriter) const final { 1286 // Erase gpu.wait ops that neither have any async dependencies nor return 1287 // any async token. 1288 if (op.asyncDependencies().empty() && !op.asyncToken()) { 1289 rewriter.eraseOp(op); 1290 return success(); 1291 } 1292 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op. 1293 if (llvm::hasSingleElement(op.asyncDependencies()) && op.asyncToken()) { 1294 rewriter.replaceOp(op, op.asyncDependencies()); 1295 return success(); 1296 } 1297 // Erase %t = gpu.wait async ... ops, where %t has no uses. 1298 if (op.asyncToken() && op.asyncToken().use_empty()) { 1299 rewriter.eraseOp(op); 1300 return success(); 1301 } 1302 return failure(); 1303 } 1304 }; 1305 1306 } // end anonymous namespace 1307 1308 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results, 1309 MLIRContext *context) { 1310 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context); 1311 } 1312 1313 //===----------------------------------------------------------------------===// 1314 // GPU_AllocOp 1315 //===----------------------------------------------------------------------===// 1316 1317 LogicalResult AllocOp::verify() { 1318 auto memRefType = memref().getType().cast<MemRefType>(); 1319 1320 if (static_cast<int64_t>(dynamicSizes().size()) != 1321 memRefType.getNumDynamicDims()) 1322 return emitOpError("dimension operand count does not equal memref " 1323 "dynamic dimension count"); 1324 1325 unsigned numSymbols = 0; 1326 if (!memRefType.getLayout().isIdentity()) 1327 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 1328 if (symbolOperands().size() != numSymbols) { 1329 return emitOpError( 1330 "symbol operand count does not equal memref symbol count"); 1331 } 1332 1333 return success(); 1334 } 1335 1336 namespace { 1337 1338 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to 1339 /// `memref::AllocOp`. 1340 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> { 1341 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 1342 1343 LogicalResult matchAndRewrite(memref::DimOp dimOp, 1344 PatternRewriter &rewriter) const override { 1345 auto index = dimOp.index().getDefiningOp<arith::ConstantIndexOp>(); 1346 if (!index) 1347 return failure(); 1348 1349 auto memrefType = dimOp.source().getType().dyn_cast<MemRefType>(); 1350 if (!memrefType || !memrefType.isDynamicDim(index.value())) 1351 return failure(); 1352 1353 auto alloc = dimOp.source().getDefiningOp<AllocOp>(); 1354 if (!alloc) 1355 return failure(); 1356 1357 Value substituteOp = *(alloc.dynamicSizes().begin() + 1358 memrefType.getDynamicDimIndex(index.value())); 1359 rewriter.replaceOp(dimOp, substituteOp); 1360 return success(); 1361 } 1362 }; 1363 1364 } // namespace 1365 1366 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 1367 MLIRContext *context) { 1368 results.add<SimplifyDimOfAllocOp>(context); 1369 } 1370 1371 //===----------------------------------------------------------------------===// 1372 // GPU_DeviceAsyncCopyOp 1373 //===----------------------------------------------------------------------===// 1374 1375 LogicalResult DeviceAsyncCopyOp::verify() { 1376 auto srcMemref = src().getType().cast<MemRefType>(); 1377 auto dstMemref = dst().getType().cast<MemRefType>(); 1378 unsigned workgroupAddressSpace = GPUDialect::getWorkgroupAddressSpace(); 1379 if (!isLastMemrefDimUnitStride(srcMemref)) 1380 return emitError("source memref most minor dim must have unit stride"); 1381 if (!isLastMemrefDimUnitStride(dstMemref)) 1382 return emitError("destination memref most minor dim must have unit stride"); 1383 if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace) 1384 return emitError("destination memref must have memory space ") 1385 << workgroupAddressSpace; 1386 if (dstMemref.getElementType() != srcMemref.getElementType()) 1387 return emitError("source and destination must have the same element type"); 1388 if (size_t(srcMemref.getRank()) != srcIndices().size()) 1389 return emitOpError() << "expected " << srcMemref.getRank() 1390 << " source indices, got " << srcIndices().size(); 1391 if (size_t(dstMemref.getRank()) != dstIndices().size()) 1392 return emitOpError() << "expected " << dstMemref.getRank() 1393 << " destination indices, got " << dstIndices().size(); 1394 return success(); 1395 } 1396 1397 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" 1398 #include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc" 1399 1400 #define GET_ATTRDEF_CLASSES 1401 #include "mlir/Dialect/GPU/GPUOpsAttributes.cpp.inc" 1402 1403 #define GET_OP_CLASSES 1404 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 1405