1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 14 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/FunctionImplementation.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/OpImplementation.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Interfaces/SideEffectInterfaces.h" 28 #include "mlir/Transforms/InliningUtils.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 31 using namespace mlir; 32 using namespace mlir::gpu; 33 34 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc" 35 36 //===----------------------------------------------------------------------===// 37 // MMAMatrixType 38 //===----------------------------------------------------------------------===// 39 40 MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType, 41 StringRef operand) { 42 return Base::get(elementType.getContext(), shape, elementType, operand); 43 } 44 45 MMAMatrixType 46 MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, 47 ArrayRef<int64_t> shape, Type elementType, 48 StringRef operand) { 49 return Base::getChecked(emitError, elementType.getContext(), shape, 50 elementType, operand); 51 } 52 53 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } 54 55 ArrayRef<int64_t> MMAMatrixType::getShape() const { 56 return getImpl()->getShape(); 57 } 58 59 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } 60 61 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } 62 63 bool MMAMatrixType::isValidElementType(Type elementType) { 64 return elementType.isF16() || elementType.isF32(); 65 } 66 67 LogicalResult 68 MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError, 69 ArrayRef<int64_t> shape, Type elementType, 70 StringRef operand) { 71 if (!operand.equals("AOp") && !operand.equals("BOp") && 72 !operand.equals("COp")) 73 return emitError() << "operand expected to be one of AOp, BOp or COp"; 74 75 if (shape.size() != 2) 76 return emitError() << "MMAMatrixType must have exactly two dimensions"; 77 78 if (!MMAMatrixType::isValidElementType(elementType)) 79 return emitError() << "MMAMatrixType elements must be F16 or F32"; 80 81 return success(); 82 } 83 84 //===----------------------------------------------------------------------===// 85 // GPUDialect 86 //===----------------------------------------------------------------------===// 87 88 /// GPU memory space identifiers. 89 enum GPUMemorySpace { 90 /// Generic memory space identifier. 91 kGenericMemorySpace = 0, 92 93 /// Global memory space identifier. 94 kGlobalMemorySpace = 1, 95 96 /// Shared memory space identifier. 97 kSharedMemorySpace = 3 98 }; 99 100 bool GPUDialect::isKernel(Operation *op) { 101 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 102 return static_cast<bool>(isKernelAttr); 103 } 104 105 namespace { 106 /// This class defines the interface for handling inlining with gpu 107 /// operations. 108 struct GPUInlinerInterface : public DialectInlinerInterface { 109 using DialectInlinerInterface::DialectInlinerInterface; 110 111 /// All gpu dialect ops can be inlined. 112 bool isLegalToInline(Operation *, Region *, bool, 113 BlockAndValueMapping &) const final { 114 return true; 115 } 116 }; 117 } // namespace 118 119 void GPUDialect::initialize() { 120 addTypes<AsyncTokenType>(); 121 addTypes<MMAMatrixType>(); 122 addOperations< 123 #define GET_OP_LIST 124 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" 125 >(); 126 addAttributes< 127 #define GET_ATTRDEF_LIST 128 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" 129 >(); 130 addInterfaces<GPUInlinerInterface>(); 131 } 132 133 Type GPUDialect::parseType(DialectAsmParser &parser) const { 134 // Parse the main keyword for the type. 135 StringRef keyword; 136 if (parser.parseKeyword(&keyword)) 137 return Type(); 138 MLIRContext *context = getContext(); 139 140 // Handle 'async token' types. 141 if (keyword == "async.token") 142 return AsyncTokenType::get(context); 143 144 if (keyword == "mma_matrix") { 145 SMLoc beginLoc = parser.getNameLoc(); 146 147 // Parse '<'. 148 if (parser.parseLess()) 149 return nullptr; 150 151 // Parse the size and elementType. 152 SmallVector<int64_t> shape; 153 Type elementType; 154 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) || 155 parser.parseType(elementType)) 156 return nullptr; 157 158 // Parse ',' 159 if (parser.parseComma()) 160 return nullptr; 161 162 // Parse operand. 163 std::string operand; 164 if (failed(parser.parseOptionalString(&operand))) 165 return nullptr; 166 167 // Parse '>'. 168 if (parser.parseGreater()) 169 return nullptr; 170 171 return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( 172 parser.getEncodedSourceLoc(beginLoc)), 173 shape, elementType, operand); 174 } 175 176 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 177 return Type(); 178 } 179 180 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 181 TypeSwitch<Type>(type) 182 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 183 .Case<MMAMatrixType>([&](MMAMatrixType fragTy) { 184 os << "mma_matrix<"; 185 auto shape = fragTy.getShape(); 186 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) 187 os << *dim << 'x'; 188 os << shape.back() << 'x' << fragTy.getElementType(); 189 os << ", \"" << fragTy.getOperand() << "\"" << '>'; 190 }) 191 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 192 } 193 194 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 195 NamedAttribute attr) { 196 if (!attr.getValue().isa<UnitAttr>() || 197 attr.getName() != getContainerModuleAttrName()) 198 return success(); 199 200 auto module = dyn_cast<ModuleOp>(op); 201 if (!module) 202 return op->emitError("expected '") 203 << getContainerModuleAttrName() << "' attribute to be attached to '" 204 << ModuleOp::getOperationName() << '\''; 205 206 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 207 // Ignore launches that are nested more or less deep than functions in the 208 // module we are currently checking. 209 if (!launchOp->getParentOp() || 210 launchOp->getParentOp()->getParentOp() != module) 211 return success(); 212 213 // Ignore launch ops with missing attributes here. The errors will be 214 // reported by the verifiers of those ops. 215 if (!launchOp->getAttrOfType<SymbolRefAttr>( 216 LaunchFuncOp::getKernelAttrName())) 217 return success(); 218 219 // Check that `launch_func` refers to a well-formed GPU kernel module. 220 StringAttr kernelModuleName = launchOp.getKernelModuleName(); 221 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 222 if (!kernelModule) 223 return launchOp.emitOpError() 224 << "kernel module '" << kernelModuleName.getValue() 225 << "' is undefined"; 226 227 // Check that `launch_func` refers to a well-formed kernel function. 228 Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr()); 229 if (!kernelFunc) 230 return launchOp.emitOpError("kernel function '") 231 << launchOp.kernel() << "' is undefined"; 232 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc); 233 if (!kernelConvertedFunction) { 234 InFlightDiagnostic diag = launchOp.emitOpError() 235 << "referenced kernel '" << launchOp.kernel() 236 << "' is not a function"; 237 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here"; 238 return diag; 239 } 240 241 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 242 GPUDialect::getKernelFuncAttrName())) 243 return launchOp.emitOpError("kernel function is missing the '") 244 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 245 246 // TODO: If the kernel isn't a GPU function (which happens during separate 247 // compilation), do not check type correspondence as it would require the 248 // verifier to be aware of the type conversion. 249 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc); 250 if (!kernelGPUFunction) 251 return success(); 252 253 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 254 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 255 if (expectedNumArguments != actualNumArguments) 256 return launchOp.emitOpError("got ") 257 << actualNumArguments << " kernel operands but expected " 258 << expectedNumArguments; 259 260 auto functionType = kernelGPUFunction.getFunctionType(); 261 for (unsigned i = 0; i < expectedNumArguments; ++i) { 262 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 263 return launchOp.emitOpError("type of function argument ") 264 << i << " does not match"; 265 } 266 } 267 268 return success(); 269 }); 270 271 return walkResult.wasInterrupted() ? failure() : success(); 272 } 273 274 /// Parses an optional list of async operands with an optional leading keyword. 275 /// (`async`)? (`[` ssa-id-list `]`)? 276 /// 277 /// This method is used by the tablegen assembly format for async ops as well. 278 static ParseResult parseAsyncDependencies( 279 OpAsmParser &parser, Type &asyncTokenType, 280 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) { 281 auto loc = parser.getCurrentLocation(); 282 if (succeeded(parser.parseOptionalKeyword("async"))) { 283 if (parser.getNumResults() == 0) 284 return parser.emitError(loc, "needs to be named when marked 'async'"); 285 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 286 } 287 return parser.parseOperandList(asyncDependencies, 288 OpAsmParser::Delimiter::OptionalSquare); 289 } 290 291 /// Prints optional async dependencies with its leading keyword. 292 /// (`async`)? (`[` ssa-id-list `]`)? 293 // Used by the tablegen assembly format for several async ops. 294 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, 295 Type asyncTokenType, 296 OperandRange asyncDependencies) { 297 if (asyncTokenType) 298 printer << "async"; 299 if (asyncDependencies.empty()) 300 return; 301 if (asyncTokenType) 302 printer << ' '; 303 printer << '['; 304 llvm::interleaveComma(asyncDependencies, printer); 305 printer << ']'; 306 } 307 308 //===----------------------------------------------------------------------===// 309 // AllReduceOp 310 //===----------------------------------------------------------------------===// 311 312 LogicalResult gpu::AllReduceOp::verifyRegions() { 313 if (body().empty() != op().has_value()) 314 return emitError("expected either an op attribute or a non-empty body"); 315 if (!body().empty()) { 316 if (body().getNumArguments() != 2) 317 return emitError("expected two region arguments"); 318 for (auto argument : body().getArguments()) { 319 if (argument.getType() != getType()) 320 return emitError("incorrect region argument type"); 321 } 322 unsigned yieldCount = 0; 323 for (Block &block : body()) { 324 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 325 if (yield.getNumOperands() != 1) 326 return emitError("expected one gpu.yield operand"); 327 if (yield.getOperand(0).getType() != getType()) 328 return emitError("incorrect gpu.yield type"); 329 ++yieldCount; 330 } 331 } 332 if (yieldCount == 0) 333 return emitError("expected gpu.yield op in region"); 334 } else { 335 gpu::AllReduceOperation opName = *op(); 336 if ((opName == gpu::AllReduceOperation::AND || 337 opName == gpu::AllReduceOperation::OR || 338 opName == gpu::AllReduceOperation::XOR) && 339 !getType().isa<IntegerType>()) { 340 return emitError() 341 << '`' << gpu::stringifyAllReduceOperation(opName) 342 << "` accumulator is only compatible with Integer type"; 343 } 344 } 345 return success(); 346 } 347 348 // TODO: Support optional custom attributes (without dialect prefix). 349 static ParseResult parseAllReduceOperation(AsmParser &parser, 350 AllReduceOperationAttr &attr) { 351 StringRef enumStr; 352 if (!parser.parseOptionalKeyword(&enumStr)) { 353 Optional<AllReduceOperation> op = gpu::symbolizeAllReduceOperation(enumStr); 354 if (!op) 355 return parser.emitError(parser.getCurrentLocation(), "invalid op kind"); 356 attr = AllReduceOperationAttr::get(parser.getContext(), *op); 357 } 358 return success(); 359 } 360 361 static void printAllReduceOperation(AsmPrinter &printer, Operation *op, 362 AllReduceOperationAttr attr) { 363 if (attr) 364 attr.print(printer); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // AsyncOpInterface 369 //===----------------------------------------------------------------------===// 370 371 void gpu::addAsyncDependency(Operation *op, Value token) { 372 op->insertOperands(0, {token}); 373 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 374 return; 375 auto attrName = 376 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 377 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName); 378 379 // Async dependencies is the only variadic operand. 380 if (!sizeAttr) 381 return; 382 383 SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>()); 384 ++sizes.front(); 385 op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); 386 } 387 388 //===----------------------------------------------------------------------===// 389 // LaunchOp 390 //===----------------------------------------------------------------------===// 391 392 void LaunchOp::build(OpBuilder &builder, OperationState &result, 393 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 394 Value blockSizeX, Value blockSizeY, Value blockSizeZ, 395 Value dynamicSharedMemorySize, Type asyncTokenType, 396 ValueRange asyncDependencies) { 397 result.addOperands(asyncDependencies); 398 if (asyncTokenType) 399 result.types.push_back(builder.getType<AsyncTokenType>()); 400 401 // Add grid and block sizes as op operands, followed by the data operands. 402 result.addOperands( 403 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 404 if (dynamicSharedMemorySize) 405 result.addOperands(dynamicSharedMemorySize); 406 407 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 408 // where the first kNumConfigRegionAttributes arguments have `index` type and 409 // the rest have the same types as the data operands. 410 Region *kernelRegion = result.addRegion(); 411 Block *body = new Block(); 412 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i) 413 body->addArgument(builder.getIndexType(), result.location); 414 kernelRegion->push_back(body); 415 SmallVector<int32_t, 8> segmentSizes(8, 1); 416 segmentSizes.front() = asyncDependencies.size(); 417 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0; 418 result.addAttribute(getOperandSegmentSizeAttr(), 419 builder.getI32VectorAttr(segmentSizes)); 420 } 421 422 KernelDim3 LaunchOp::getBlockIds() { 423 assert(!body().empty() && "LaunchOp body must not be empty."); 424 auto args = body().getArguments(); 425 return KernelDim3{args[0], args[1], args[2]}; 426 } 427 428 KernelDim3 LaunchOp::getThreadIds() { 429 assert(!body().empty() && "LaunchOp body must not be empty."); 430 auto args = body().getArguments(); 431 return KernelDim3{args[3], args[4], args[5]}; 432 } 433 434 KernelDim3 LaunchOp::getGridSize() { 435 assert(!body().empty() && "LaunchOp body must not be empty."); 436 auto args = body().getArguments(); 437 return KernelDim3{args[6], args[7], args[8]}; 438 } 439 440 KernelDim3 LaunchOp::getBlockSize() { 441 assert(!body().empty() && "LaunchOp body must not be empty."); 442 auto args = body().getArguments(); 443 return KernelDim3{args[9], args[10], args[11]}; 444 } 445 446 KernelDim3 LaunchOp::getGridSizeOperandValues() { 447 auto operands = getOperands().drop_front(asyncDependencies().size()); 448 return KernelDim3{operands[0], operands[1], operands[2]}; 449 } 450 451 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 452 auto operands = getOperands().drop_front(asyncDependencies().size()); 453 return KernelDim3{operands[3], operands[4], operands[5]}; 454 } 455 456 LogicalResult LaunchOp::verifyRegions() { 457 // Kernel launch takes kNumConfigOperands leading operands for grid/block 458 // sizes and transforms them into kNumConfigRegionAttributes region arguments 459 // for block/thread identifiers and grid/block sizes. 460 if (!body().empty()) { 461 if (body().getNumArguments() != 462 LaunchOp::kNumConfigOperands + getNumOperands() - 463 (dynamicSharedMemorySize() ? 1 : 0) - asyncDependencies().size()) 464 return emitOpError("unexpected number of region arguments"); 465 } 466 467 // Block terminators without successors are expected to exit the kernel region 468 // and must be `gpu.terminator`. 469 for (Block &block : body()) { 470 if (block.empty()) 471 continue; 472 if (block.back().getNumSuccessors() != 0) 473 continue; 474 if (!isa<gpu::TerminatorOp>(&block.back())) { 475 return block.back() 476 .emitError() 477 .append("expected '", gpu::TerminatorOp::getOperationName(), 478 "' or a terminator with successors") 479 .attachNote(getLoc()) 480 .append("in '", LaunchOp::getOperationName(), "' body region"); 481 } 482 } 483 484 if (getNumResults() == 0 && asyncToken()) 485 return emitOpError("needs to be named when async keyword is specified"); 486 487 return success(); 488 } 489 490 // Pretty-print the kernel grid/block size assignment as 491 // (%iter-x, %iter-y, %iter-z) in 492 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 493 // where %size-* and %iter-* will correspond to the body region arguments. 494 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 495 KernelDim3 operands, KernelDim3 ids) { 496 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 497 p << size.x << " = " << operands.x << ", "; 498 p << size.y << " = " << operands.y << ", "; 499 p << size.z << " = " << operands.z << ')'; 500 } 501 502 void LaunchOp::print(OpAsmPrinter &p) { 503 if (asyncToken()) { 504 p << " async"; 505 if (!asyncDependencies().empty()) 506 p << " [" << asyncDependencies() << ']'; 507 } 508 // Print the launch configuration. 509 p << ' ' << getBlocksKeyword(); 510 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(), 511 getBlockIds()); 512 p << ' ' << getThreadsKeyword(); 513 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(), 514 getThreadIds()); 515 if (dynamicSharedMemorySize()) 516 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' 517 << dynamicSharedMemorySize(); 518 519 p << ' '; 520 p.printRegion(body(), /*printEntryBlockArgs=*/false); 521 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 522 LaunchOp::getOperandSegmentSizeAttr()}); 523 } 524 525 // Parse the size assignment blocks for blocks and threads. These have the form 526 // (%region_arg, %region_arg, %region_arg) in 527 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 528 // where %region_arg are percent-identifiers for the region arguments to be 529 // introduced further (SSA defs), and %operand are percent-identifiers for the 530 // SSA value uses. 531 static ParseResult 532 parseSizeAssignment(OpAsmParser &parser, 533 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes, 534 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes, 535 MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) { 536 assert(indices.size() == 3 && "space for three indices expected"); 537 SmallVector<OpAsmParser::UnresolvedOperand, 3> args; 538 if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren, 539 /*allowResultNumber=*/false) || 540 parser.parseKeyword("in") || parser.parseLParen()) 541 return failure(); 542 std::move(args.begin(), args.end(), indices.begin()); 543 544 for (int i = 0; i < 3; ++i) { 545 if (i != 0 && parser.parseComma()) 546 return failure(); 547 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) || 548 parser.parseEqual() || parser.parseOperand(sizes[i])) 549 return failure(); 550 } 551 552 return parser.parseRParen(); 553 } 554 555 /// Parses a Launch operation. 556 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)? 557 // `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 558 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment 559 /// region attr-dict? 560 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 561 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { 562 // Sizes of the grid and block. 563 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands> 564 sizes(LaunchOp::kNumConfigOperands); 565 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes); 566 567 // Actual (data) operands passed to the kernel. 568 SmallVector<OpAsmParser::UnresolvedOperand, 4> dataOperands; 569 570 // Region arguments to be created. 571 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs( 572 LaunchOp::kNumConfigRegionAttributes); 573 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs); 574 575 // Parse optional async dependencies. 576 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies; 577 Type asyncTokenType; 578 if (failed( 579 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) || 580 parser.resolveOperands(asyncDependencies, asyncTokenType, 581 result.operands)) 582 return failure(); 583 if (parser.getNumResults() > 0) 584 result.types.push_back(asyncTokenType); 585 586 // Parse the size assignment segments: the first segment assigns grid sizes 587 // and defines values for block identifiers; the second segment assigns block 588 // sizes and defines values for thread identifiers. In the region argument 589 // list, identifiers precede sizes, and block-related values precede 590 // thread-related values. 591 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 592 parseSizeAssignment(parser, sizesRef.take_front(3), 593 regionArgsRef.slice(6, 3), 594 regionArgsRef.slice(0, 3)) || 595 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 596 parseSizeAssignment(parser, sizesRef.drop_front(3), 597 regionArgsRef.slice(9, 3), 598 regionArgsRef.slice(3, 3)) || 599 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 600 result.operands)) 601 return failure(); 602 603 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize; 604 bool hasDynamicSharedMemorySize = false; 605 if (!parser.parseOptionalKeyword( 606 LaunchOp::getDynamicSharedMemorySizeKeyword())) { 607 hasDynamicSharedMemorySize = true; 608 if (parser.parseOperand(dynamicSharedMemorySize) || 609 parser.resolveOperand(dynamicSharedMemorySize, 610 parser.getBuilder().getI32Type(), 611 result.operands)) 612 return failure(); 613 } 614 615 // Introduce the body region and parse it. The region has 616 // kNumConfigRegionAttributes arguments that correspond to 617 // block/thread identifiers and grid/block sizes, all of the `index` type. 618 Type index = parser.getBuilder().getIndexType(); 619 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 620 LaunchOp::kNumConfigRegionAttributes, index); 621 622 SmallVector<OpAsmParser::Argument> regionArguments; 623 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) { 624 OpAsmParser::Argument arg; 625 arg.ssaName = std::get<0>(ssaValueAndType); 626 arg.type = std::get<1>(ssaValueAndType); 627 regionArguments.push_back(arg); 628 } 629 630 Region *body = result.addRegion(); 631 if (parser.parseRegion(*body, regionArguments) || 632 parser.parseOptionalAttrDict(result.attributes)) 633 return failure(); 634 635 SmallVector<int32_t, 8> segmentSizes(8, 1); 636 segmentSizes.front() = asyncDependencies.size(); 637 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0; 638 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(), 639 parser.getBuilder().getI32VectorAttr(segmentSizes)); 640 return success(); 641 } 642 643 /// Simplify the gpu.launch when the range of a thread or block ID is 644 /// trivially known to be one. 645 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> { 646 using OpRewritePattern<LaunchOp>::OpRewritePattern; 647 LogicalResult matchAndRewrite(LaunchOp op, 648 PatternRewriter &rewriter) const override { 649 // If the range implies a single value for `id`, replace `id`'s uses by 650 // zero. 651 Value zero; 652 bool simplified = false; 653 auto constPropIdUses = [&](Value id, Value size) { 654 // Check if size is trivially one. 655 if (!matchPattern(size, m_One())) 656 return; 657 if (!simplified) { 658 // Create a zero value the first time. 659 OpBuilder::InsertionGuard guard(rewriter); 660 rewriter.setInsertionPointToStart(&op.body().front()); 661 zero = 662 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0); 663 } 664 id.replaceAllUsesWith(zero); 665 simplified = true; 666 }; 667 constPropIdUses(op.getBlockIds().x, op.gridSizeX()); 668 constPropIdUses(op.getBlockIds().y, op.gridSizeY()); 669 constPropIdUses(op.getBlockIds().z, op.gridSizeZ()); 670 constPropIdUses(op.getThreadIds().x, op.blockSizeX()); 671 constPropIdUses(op.getThreadIds().y, op.blockSizeY()); 672 constPropIdUses(op.getThreadIds().z, op.blockSizeZ()); 673 674 return success(simplified); 675 } 676 }; 677 678 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites, 679 MLIRContext *context) { 680 rewrites.add<FoldLaunchArguments>(context); 681 } 682 683 //===----------------------------------------------------------------------===// 684 // LaunchFuncOp 685 //===----------------------------------------------------------------------===// 686 687 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 688 GPUFuncOp kernelFunc, KernelDim3 gridSize, 689 KernelDim3 blockSize, Value dynamicSharedMemorySize, 690 ValueRange kernelOperands, Type asyncTokenType, 691 ValueRange asyncDependencies) { 692 result.addOperands(asyncDependencies); 693 if (asyncTokenType) 694 result.types.push_back(builder.getType<AsyncTokenType>()); 695 696 // Add grid and block sizes as op operands, followed by the data operands. 697 result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x, 698 blockSize.y, blockSize.z}); 699 if (dynamicSharedMemorySize) 700 result.addOperands(dynamicSharedMemorySize); 701 result.addOperands(kernelOperands); 702 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); 703 auto kernelSymbol = 704 SymbolRefAttr::get(kernelModule.getNameAttr(), 705 {SymbolRefAttr::get(kernelFunc.getNameAttr())}); 706 result.addAttribute(getKernelAttrName(), kernelSymbol); 707 SmallVector<int32_t, 9> segmentSizes(9, 1); 708 segmentSizes.front() = asyncDependencies.size(); 709 segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0; 710 segmentSizes.back() = static_cast<int32_t>(kernelOperands.size()); 711 result.addAttribute(getOperandSegmentSizeAttr(), 712 builder.getI32VectorAttr(segmentSizes)); 713 } 714 715 StringAttr LaunchFuncOp::getKernelModuleName() { 716 return kernel().getRootReference(); 717 } 718 719 StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 720 721 unsigned LaunchFuncOp::getNumKernelOperands() { return operands().size(); } 722 723 Value LaunchFuncOp::getKernelOperand(unsigned i) { return operands()[i]; } 724 725 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 726 auto operands = getOperands().drop_front(asyncDependencies().size()); 727 return KernelDim3{operands[0], operands[1], operands[2]}; 728 } 729 730 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 731 auto operands = getOperands().drop_front(asyncDependencies().size()); 732 return KernelDim3{operands[3], operands[4], operands[5]}; 733 } 734 735 LogicalResult LaunchFuncOp::verify() { 736 auto module = (*this)->getParentOfType<ModuleOp>(); 737 if (!module) 738 return emitOpError("expected to belong to a module"); 739 740 if (!module->getAttrOfType<UnitAttr>( 741 GPUDialect::getContainerModuleAttrName())) 742 return emitOpError("expected the closest surrounding module to have the '" + 743 GPUDialect::getContainerModuleAttrName() + 744 "' attribute"); 745 746 auto kernelAttr = (*this)->getAttrOfType<SymbolRefAttr>(getKernelAttrName()); 747 if (!kernelAttr) 748 return emitOpError("symbol reference attribute '" + getKernelAttrName() + 749 "' must be specified"); 750 751 return success(); 752 } 753 754 static ParseResult parseLaunchFuncOperands( 755 OpAsmParser &parser, 756 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, 757 SmallVectorImpl<Type> &argTypes) { 758 if (parser.parseOptionalKeyword("args")) 759 return success(); 760 761 SmallVector<OpAsmParser::Argument> args; 762 if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, 763 /*allowType=*/true)) 764 return failure(); 765 for (auto &arg : args) { 766 argNames.push_back(arg.ssaName); 767 argTypes.push_back(arg.type); 768 } 769 return success(); 770 } 771 772 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, 773 OperandRange operands, TypeRange types) { 774 if (operands.empty()) 775 return; 776 printer << "args("; 777 llvm::interleaveComma(llvm::zip(operands, types), printer, 778 [&](const auto &pair) { 779 printer.printOperand(std::get<0>(pair)); 780 printer << " : "; 781 printer.printType(std::get<1>(pair)); 782 }); 783 printer << ")"; 784 } 785 786 //===----------------------------------------------------------------------===// 787 // ShuffleOp 788 //===----------------------------------------------------------------------===// 789 790 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, 791 int32_t offset, int32_t width, ShuffleMode mode) { 792 build(builder, result, value, 793 builder.create<arith::ConstantOp>(result.location, 794 builder.getI32IntegerAttr(offset)), 795 builder.create<arith::ConstantOp>(result.location, 796 builder.getI32IntegerAttr(width)), 797 mode); 798 } 799 800 //===----------------------------------------------------------------------===// 801 // GPUFuncOp 802 //===----------------------------------------------------------------------===// 803 804 /// Adds a new block argument that corresponds to buffers located in 805 /// workgroup memory. 806 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) { 807 auto attrName = getNumWorkgroupAttributionsAttrName(); 808 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); 809 (*this)->setAttr(attrName, 810 IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 811 return getBody().insertArgument( 812 getFunctionType().getNumInputs() + attr.getInt(), type, loc); 813 } 814 815 /// Adds a new block argument that corresponds to buffers located in 816 /// private memory. 817 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) { 818 // Buffers on the private memory always come after buffers on the workgroup 819 // memory. 820 return getBody().addArgument(type, loc); 821 } 822 823 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 824 StringRef name, FunctionType type, 825 TypeRange workgroupAttributions, 826 TypeRange privateAttributions, 827 ArrayRef<NamedAttribute> attrs) { 828 result.addAttribute(SymbolTable::getSymbolAttrName(), 829 builder.getStringAttr(name)); 830 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 831 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 832 builder.getI64IntegerAttr(workgroupAttributions.size())); 833 result.addAttributes(attrs); 834 Region *body = result.addRegion(); 835 Block *entryBlock = new Block; 836 837 // TODO: Allow passing in proper locations here. 838 for (Type argTy : type.getInputs()) 839 entryBlock->addArgument(argTy, result.location); 840 for (Type argTy : workgroupAttributions) 841 entryBlock->addArgument(argTy, result.location); 842 for (Type argTy : privateAttributions) 843 entryBlock->addArgument(argTy, result.location); 844 845 body->getBlocks().push_back(entryBlock); 846 } 847 848 /// Parses a GPU function memory attribution. 849 /// 850 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 851 /// (`private` `(` ssa-id-and-type-list `)`)? 852 /// 853 /// Note that this function parses only one of the two similar parts, with the 854 /// keyword provided as argument. 855 static ParseResult 856 parseAttributions(OpAsmParser &parser, StringRef keyword, 857 SmallVectorImpl<OpAsmParser::Argument> &args) { 858 // If we could not parse the keyword, just assume empty list and succeed. 859 if (failed(parser.parseOptionalKeyword(keyword))) 860 return success(); 861 862 return parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, 863 /*allowType=*/true); 864 } 865 866 /// Parses a GPU function. 867 /// 868 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 869 /// (`->` function-result-list)? memory-attribution `kernel`? 870 /// function-attributes? region 871 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { 872 SmallVector<OpAsmParser::Argument> entryArgs; 873 SmallVector<DictionaryAttr> resultAttrs; 874 SmallVector<Type> resultTypes; 875 bool isVariadic; 876 877 // Parse the function name. 878 StringAttr nameAttr; 879 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 880 result.attributes)) 881 return failure(); 882 883 auto signatureLocation = parser.getCurrentLocation(); 884 if (failed(function_interface_impl::parseFunctionSignature( 885 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, 886 resultAttrs))) 887 return failure(); 888 889 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty()) 890 return parser.emitError(signatureLocation) 891 << "gpu.func requires named arguments"; 892 893 // Construct the function type. More types will be added to the region, but 894 // not to the function type. 895 Builder &builder = parser.getBuilder(); 896 897 SmallVector<Type> argTypes; 898 for (auto &arg : entryArgs) 899 argTypes.push_back(arg.type); 900 auto type = builder.getFunctionType(argTypes, resultTypes); 901 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 902 903 function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, 904 resultAttrs); 905 906 // Parse workgroup memory attributions. 907 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 908 entryArgs))) 909 return failure(); 910 911 // Store the number of operands we just parsed as the number of workgroup 912 // memory attributions. 913 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs(); 914 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 915 builder.getI64IntegerAttr(numWorkgroupAttrs)); 916 917 // Parse private memory attributions. 918 if (failed( 919 parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), entryArgs))) 920 return failure(); 921 922 // Parse the kernel attribute if present. 923 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 924 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 925 builder.getUnitAttr()); 926 927 // Parse attributes. 928 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 929 return failure(); 930 931 // Parse the region. If no argument names were provided, take all names 932 // (including those of attributions) from the entry block. 933 auto *body = result.addRegion(); 934 return parser.parseRegion(*body, entryArgs); 935 } 936 937 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 938 ArrayRef<BlockArgument> values) { 939 if (values.empty()) 940 return; 941 942 p << ' ' << keyword << '('; 943 llvm::interleaveComma( 944 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 945 p << ')'; 946 } 947 948 void GPUFuncOp::print(OpAsmPrinter &p) { 949 p << ' '; 950 p.printSymbolName(getName()); 951 952 FunctionType type = getFunctionType(); 953 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(), 954 /*isVariadic=*/false, 955 type.getResults()); 956 957 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); 958 printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); 959 if (isKernel()) 960 p << ' ' << getKernelKeyword(); 961 962 function_interface_impl::printFunctionAttributes( 963 p, *this, type.getNumInputs(), type.getNumResults(), 964 {getNumWorkgroupAttributionsAttrName(), 965 GPUDialect::getKernelFuncAttrName()}); 966 p << ' '; 967 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 968 } 969 970 LogicalResult GPUFuncOp::verifyType() { 971 Type type = getFunctionTypeAttr().getValue(); 972 if (!type.isa<FunctionType>()) 973 return emitOpError("requires '" + getTypeAttrName() + 974 "' attribute of function type"); 975 976 if (isKernel() && getFunctionType().getNumResults() != 0) 977 return emitOpError() << "expected void return type for kernel function"; 978 979 return success(); 980 } 981 982 static LogicalResult verifyAttributions(Operation *op, 983 ArrayRef<BlockArgument> attributions, 984 unsigned memorySpace) { 985 for (Value v : attributions) { 986 auto type = v.getType().dyn_cast<MemRefType>(); 987 if (!type) 988 return op->emitOpError() << "expected memref type in attribution"; 989 990 if (type.getMemorySpaceAsInt() != memorySpace) { 991 return op->emitOpError() 992 << "expected memory space " << memorySpace << " in attribution"; 993 } 994 } 995 return success(); 996 } 997 998 /// Verifies the body of the function. 999 LogicalResult GPUFuncOp::verifyBody() { 1000 unsigned numFuncArguments = getNumArguments(); 1001 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 1002 unsigned numBlockArguments = front().getNumArguments(); 1003 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 1004 return emitOpError() << "expected at least " 1005 << numFuncArguments + numWorkgroupAttributions 1006 << " arguments to body region"; 1007 1008 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs(); 1009 for (unsigned i = 0; i < numFuncArguments; ++i) { 1010 Type blockArgType = front().getArgument(i).getType(); 1011 if (funcArgTypes[i] != blockArgType) 1012 return emitOpError() << "expected body region argument #" << i 1013 << " to be of type " << funcArgTypes[i] << ", got " 1014 << blockArgType; 1015 } 1016 1017 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 1018 GPUDialect::getWorkgroupAddressSpace())) || 1019 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 1020 GPUDialect::getPrivateAddressSpace()))) 1021 return failure(); 1022 1023 return success(); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // ReturnOp 1028 //===----------------------------------------------------------------------===// 1029 1030 LogicalResult gpu::ReturnOp::verify() { 1031 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>(); 1032 1033 FunctionType funType = function.getFunctionType(); 1034 1035 if (funType.getNumResults() != operands().size()) 1036 return emitOpError() 1037 .append("expected ", funType.getNumResults(), " result operands") 1038 .attachNote(function.getLoc()) 1039 .append("return type declared here"); 1040 1041 for (const auto &pair : llvm::enumerate( 1042 llvm::zip(function.getFunctionType().getResults(), operands()))) { 1043 Type type; 1044 Value operand; 1045 std::tie(type, operand) = pair.value(); 1046 if (type != operand.getType()) 1047 return emitOpError() << "unexpected type `" << operand.getType() 1048 << "' for operand #" << pair.index(); 1049 } 1050 return success(); 1051 } 1052 1053 //===----------------------------------------------------------------------===// 1054 // GPUModuleOp 1055 //===----------------------------------------------------------------------===// 1056 1057 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 1058 StringRef name) { 1059 ensureTerminator(*result.addRegion(), builder, result.location); 1060 result.attributes.push_back(builder.getNamedAttr( 1061 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1062 } 1063 1064 ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) { 1065 StringAttr nameAttr; 1066 if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), 1067 result.attributes) || 1068 // If module attributes are present, parse them. 1069 parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1070 return failure(); 1071 1072 // Parse the module body. 1073 auto *body = result.addRegion(); 1074 if (parser.parseRegion(*body, {})) 1075 return failure(); 1076 1077 // Ensure that this module has a valid terminator. 1078 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 1079 return success(); 1080 } 1081 1082 void GPUModuleOp::print(OpAsmPrinter &p) { 1083 p << ' '; 1084 p.printSymbolName(getName()); 1085 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), 1086 {mlir::SymbolTable::getSymbolAttrName()}); 1087 p << ' '; 1088 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1089 /*printBlockTerminators=*/false); 1090 } 1091 1092 //===----------------------------------------------------------------------===// 1093 // GPUMemcpyOp 1094 //===----------------------------------------------------------------------===// 1095 1096 LogicalResult MemcpyOp::verify() { 1097 auto srcType = src().getType(); 1098 auto dstType = dst().getType(); 1099 1100 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) 1101 return emitOpError("arguments have incompatible element type"); 1102 1103 if (failed(verifyCompatibleShape(srcType, dstType))) 1104 return emitOpError("arguments have incompatible shape"); 1105 1106 return success(); 1107 } 1108 1109 namespace { 1110 1111 /// Erases a common case of copy ops where a destination value is used only by 1112 /// the copy op, alloc and dealloc ops. 1113 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> { 1114 using OpRewritePattern<MemcpyOp>::OpRewritePattern; 1115 1116 LogicalResult matchAndRewrite(MemcpyOp op, 1117 PatternRewriter &rewriter) const override { 1118 Value dest = op.dst(); 1119 Operation *destDefOp = dest.getDefiningOp(); 1120 // `dest` must be defined by an op having Allocate memory effect in order to 1121 // perform the folding. 1122 if (!destDefOp || 1123 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest)) 1124 return failure(); 1125 // We can erase `op` iff `dest` has no other use apart from its 1126 // use by `op` and dealloc ops. 1127 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) { 1128 return user != op && 1129 !hasSingleEffect<MemoryEffects::Free>(user, dest); 1130 })) 1131 return failure(); 1132 // We can perform the folding if and only if op has a single async 1133 // dependency and produces an async token as result, or if it does not have 1134 // any async dependency and does not produce any async token result. 1135 if (op.asyncDependencies().size() > 1 || 1136 ((op.asyncDependencies().empty() && op.asyncToken()) || 1137 (!op.asyncDependencies().empty() && !op.asyncToken()))) 1138 return failure(); 1139 rewriter.replaceOp(op, op.asyncDependencies()); 1140 return success(); 1141 } 1142 }; 1143 1144 } // end anonymous namespace 1145 1146 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results, 1147 MLIRContext *context) { 1148 results.add<EraseTrivialCopyOp>(context); 1149 } 1150 1151 //===----------------------------------------------------------------------===// 1152 // GPU_SubgroupMmaLoadMatrixOp 1153 //===----------------------------------------------------------------------===// 1154 1155 /// Return true if the last dimension of the MemRefType has unit stride. Also 1156 /// return true for memrefs with no strides. 1157 static bool isLastMemrefDimUnitStride(MemRefType type) { 1158 int64_t offset; 1159 SmallVector<int64_t> strides; 1160 if (failed(getStridesAndOffset(type, strides, offset))) { 1161 return false; 1162 } 1163 return strides.back() == 1; 1164 } 1165 1166 LogicalResult SubgroupMmaLoadMatrixOp::verify() { 1167 auto srcType = srcMemref().getType(); 1168 auto resType = res().getType(); 1169 auto resMatrixType = resType.cast<gpu::MMAMatrixType>(); 1170 auto operand = resMatrixType.getOperand(); 1171 auto srcMemrefType = srcType.cast<MemRefType>(); 1172 auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); 1173 1174 if (!isLastMemrefDimUnitStride(srcMemrefType)) 1175 return emitError( 1176 "expected source memref most minor dim must have unit stride"); 1177 1178 if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && 1179 srcMemSpace != kGlobalMemorySpace) 1180 return emitError( 1181 "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " 1182 "kGlobalMemorySpace only allowed"); 1183 1184 if (!operand.equals("AOp") && !operand.equals("BOp") && 1185 !operand.equals("COp")) 1186 return emitError("only AOp, BOp and COp can be loaded"); 1187 1188 return success(); 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // GPU_SubgroupMmaStoreMatrixOp 1193 //===----------------------------------------------------------------------===// 1194 1195 LogicalResult SubgroupMmaStoreMatrixOp::verify() { 1196 auto srcType = src().getType(); 1197 auto dstType = dstMemref().getType(); 1198 auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>(); 1199 auto dstMemrefType = dstType.cast<MemRefType>(); 1200 auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); 1201 1202 if (!isLastMemrefDimUnitStride(dstMemrefType)) 1203 return emitError( 1204 "expected destination memref most minor dim must have unit stride"); 1205 1206 if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && 1207 dstMemSpace != kGlobalMemorySpace) 1208 return emitError("destination memorySpace of kGenericMemorySpace, " 1209 "kGlobalMemorySpace or kSharedMemorySpace only allowed"); 1210 1211 if (!srcMatrixType.getOperand().equals("COp")) 1212 return emitError( 1213 "expected the operand matrix being stored to have 'COp' operand type"); 1214 1215 return success(); 1216 } 1217 1218 //===----------------------------------------------------------------------===// 1219 // GPU_SubgroupMmaComputeOp 1220 //===----------------------------------------------------------------------===// 1221 1222 LogicalResult SubgroupMmaComputeOp::verify() { 1223 enum OperandMap { A, B, C }; 1224 SmallVector<MMAMatrixType, 3> opTypes; 1225 opTypes.push_back(opA().getType().cast<MMAMatrixType>()); 1226 opTypes.push_back(opB().getType().cast<MMAMatrixType>()); 1227 opTypes.push_back(opC().getType().cast<MMAMatrixType>()); 1228 1229 if (!opTypes[A].getOperand().equals("AOp") || 1230 !opTypes[B].getOperand().equals("BOp") || 1231 !opTypes[C].getOperand().equals("COp")) 1232 return emitError("operands must be in the order AOp, BOp, COp"); 1233 1234 ArrayRef<int64_t> aShape, bShape, cShape; 1235 aShape = opTypes[A].getShape(); 1236 bShape = opTypes[B].getShape(); 1237 cShape = opTypes[C].getShape(); 1238 1239 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || 1240 bShape[1] != cShape[1]) 1241 return emitError("operand shapes do not satisfy matmul constraints"); 1242 1243 return success(); 1244 } 1245 1246 /// This is a common class used for patterns of the form 1247 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 1248 /// into the root operation directly. 1249 static LogicalResult foldMemRefCast(Operation *op) { 1250 bool folded = false; 1251 for (OpOperand &operand : op->getOpOperands()) { 1252 auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>(); 1253 if (cast) { 1254 operand.set(cast.getOperand()); 1255 folded = true; 1256 } 1257 } 1258 return success(folded); 1259 } 1260 1261 LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands, 1262 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1263 return foldMemRefCast(*this); 1264 } 1265 1266 LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands, 1267 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1268 return foldMemRefCast(*this); 1269 } 1270 1271 //===----------------------------------------------------------------------===// 1272 // GPU_WaitOp 1273 //===----------------------------------------------------------------------===// 1274 1275 namespace { 1276 1277 /// Remove gpu.wait op use of gpu.wait op def without async dependencies. 1278 /// %t = gpu.wait async [] // No async dependencies. 1279 /// ... gpu.wait ... [%t, ...] // %t can be removed. 1280 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> { 1281 public: 1282 using OpRewritePattern::OpRewritePattern; 1283 1284 LogicalResult matchAndRewrite(WaitOp op, 1285 PatternRewriter &rewriter) const final { 1286 auto predicate = [](Value value) { 1287 auto waitOp = value.getDefiningOp<WaitOp>(); 1288 return waitOp && waitOp->getNumOperands() == 0; 1289 }; 1290 if (llvm::none_of(op.asyncDependencies(), predicate)) 1291 return failure(); 1292 SmallVector<Value> validOperands; 1293 for (Value operand : op->getOperands()) { 1294 if (predicate(operand)) 1295 continue; 1296 validOperands.push_back(operand); 1297 } 1298 op->setOperands(validOperands); 1299 return success(); 1300 } 1301 }; 1302 1303 /// Simplify trivial gpu.wait ops for the following patterns. 1304 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async 1305 /// dependencies). 1306 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with 1307 /// %t0. 1308 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async 1309 /// dependencies nor return any token. 1310 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> { 1311 public: 1312 using OpRewritePattern::OpRewritePattern; 1313 1314 LogicalResult matchAndRewrite(WaitOp op, 1315 PatternRewriter &rewriter) const final { 1316 // Erase gpu.wait ops that neither have any async dependencies nor return 1317 // any async token. 1318 if (op.asyncDependencies().empty() && !op.asyncToken()) { 1319 rewriter.eraseOp(op); 1320 return success(); 1321 } 1322 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op. 1323 if (llvm::hasSingleElement(op.asyncDependencies()) && op.asyncToken()) { 1324 rewriter.replaceOp(op, op.asyncDependencies()); 1325 return success(); 1326 } 1327 // Erase %t = gpu.wait async ... ops, where %t has no uses. 1328 if (op.asyncToken() && op.asyncToken().use_empty()) { 1329 rewriter.eraseOp(op); 1330 return success(); 1331 } 1332 return failure(); 1333 } 1334 }; 1335 1336 } // end anonymous namespace 1337 1338 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results, 1339 MLIRContext *context) { 1340 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context); 1341 } 1342 1343 //===----------------------------------------------------------------------===// 1344 // GPU_AllocOp 1345 //===----------------------------------------------------------------------===// 1346 1347 LogicalResult AllocOp::verify() { 1348 auto memRefType = memref().getType().cast<MemRefType>(); 1349 1350 if (static_cast<int64_t>(dynamicSizes().size()) != 1351 memRefType.getNumDynamicDims()) 1352 return emitOpError("dimension operand count does not equal memref " 1353 "dynamic dimension count"); 1354 1355 unsigned numSymbols = 0; 1356 if (!memRefType.getLayout().isIdentity()) 1357 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 1358 if (symbolOperands().size() != numSymbols) { 1359 return emitOpError( 1360 "symbol operand count does not equal memref symbol count"); 1361 } 1362 1363 return success(); 1364 } 1365 1366 namespace { 1367 1368 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to 1369 /// `memref::AllocOp`. 1370 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> { 1371 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 1372 1373 LogicalResult matchAndRewrite(memref::DimOp dimOp, 1374 PatternRewriter &rewriter) const override { 1375 auto index = dimOp.getIndex().getDefiningOp<arith::ConstantIndexOp>(); 1376 if (!index) 1377 return failure(); 1378 1379 auto memrefType = dimOp.getSource().getType().dyn_cast<MemRefType>(); 1380 if (!memrefType || !memrefType.isDynamicDim(index.value())) 1381 return failure(); 1382 1383 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>(); 1384 if (!alloc) 1385 return failure(); 1386 1387 Value substituteOp = *(alloc.dynamicSizes().begin() + 1388 memrefType.getDynamicDimIndex(index.value())); 1389 rewriter.replaceOp(dimOp, substituteOp); 1390 return success(); 1391 } 1392 }; 1393 1394 } // namespace 1395 1396 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 1397 MLIRContext *context) { 1398 results.add<SimplifyDimOfAllocOp>(context); 1399 } 1400 1401 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc" 1402 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc" 1403 1404 #define GET_ATTRDEF_CLASSES 1405 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" 1406 1407 #define GET_OP_CLASSES 1408 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" 1409