1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinOps.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/FunctionImplementation.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/IR/TypeUtilities.h" 28 #include "mlir/Transforms/InliningUtils.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 31 using namespace mlir; 32 using namespace mlir::gpu; 33 34 #include "mlir/Dialect/GPU/GPUOpsDialect.cpp.inc" 35 36 //===----------------------------------------------------------------------===// 37 // MMAMatrixType 38 //===----------------------------------------------------------------------===// 39 40 MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType, 41 StringRef operand) { 42 return Base::get(elementType.getContext(), shape, elementType, operand); 43 } 44 45 MMAMatrixType 46 MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, 47 ArrayRef<int64_t> shape, Type elementType, 48 StringRef operand) { 49 return Base::getChecked(emitError, elementType.getContext(), shape, 50 elementType, operand); 51 } 52 53 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } 54 55 ArrayRef<int64_t> MMAMatrixType::getShape() const { 56 return getImpl()->getShape(); 57 } 58 59 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } 60 61 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } 62 63 bool MMAMatrixType::isValidElementType(Type elementType) { 64 return elementType.isF16() || elementType.isF32(); 65 } 66 67 LogicalResult 68 MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError, 69 ArrayRef<int64_t> shape, Type elementType, 70 StringRef operand) { 71 if (!operand.equals("AOp") && !operand.equals("BOp") && 72 !operand.equals("COp")) 73 return emitError() << "operand expected to be one of AOp, BOp or COp"; 74 75 if (shape.size() != 2) 76 return emitError() << "MMAMatrixType must have exactly two dimensions"; 77 78 if (!MMAMatrixType::isValidElementType(elementType)) 79 return emitError() << "MMAMatrixType elements must be F16 or F32"; 80 81 return success(); 82 } 83 84 //===----------------------------------------------------------------------===// 85 // GPUDialect 86 //===----------------------------------------------------------------------===// 87 88 /// GPU memory space identifiers. 89 enum GPUMemorySpace { 90 /// Generic memory space identifier. 91 kGenericMemorySpace = 0, 92 93 /// Global memory space identifier. 94 kGlobalMemorySpace = 1, 95 96 /// Shared memory space identifier. 97 kSharedMemorySpace = 3 98 }; 99 100 bool GPUDialect::isKernel(Operation *op) { 101 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 102 return static_cast<bool>(isKernelAttr); 103 } 104 105 namespace { 106 /// This class defines the interface for handling inlining with gpu 107 /// operations. 108 struct GPUInlinerInterface : public DialectInlinerInterface { 109 using DialectInlinerInterface::DialectInlinerInterface; 110 111 /// All gpu dialect ops can be inlined. 112 bool isLegalToInline(Operation *, Region *, bool, 113 BlockAndValueMapping &) const final { 114 return true; 115 } 116 }; 117 } // namespace 118 119 void GPUDialect::initialize() { 120 addTypes<AsyncTokenType>(); 121 addTypes<MMAMatrixType>(); 122 addOperations< 123 #define GET_OP_LIST 124 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 125 >(); 126 addAttributes< 127 #define GET_ATTRDEF_LIST 128 #include "mlir/Dialect/GPU/GPUOpsAttributes.cpp.inc" 129 >(); 130 addInterfaces<GPUInlinerInterface>(); 131 } 132 133 Type GPUDialect::parseType(DialectAsmParser &parser) const { 134 // Parse the main keyword for the type. 135 StringRef keyword; 136 if (parser.parseKeyword(&keyword)) 137 return Type(); 138 MLIRContext *context = getContext(); 139 140 // Handle 'async token' types. 141 if (keyword == "async.token") 142 return AsyncTokenType::get(context); 143 144 if (keyword == "mma_matrix") { 145 llvm::SMLoc beginLoc = parser.getNameLoc(); 146 147 // Parse '<'. 148 if (parser.parseLess()) 149 return nullptr; 150 151 // Parse the size and elementType. 152 SmallVector<int64_t> shape; 153 Type elementType; 154 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) || 155 parser.parseType(elementType)) 156 return nullptr; 157 158 // Parse ',' 159 if (parser.parseComma()) 160 return nullptr; 161 162 // Parse operand. 163 std::string operand; 164 if (failed(parser.parseOptionalString(&operand))) 165 return nullptr; 166 167 // Parse '>'. 168 if (parser.parseGreater()) 169 return nullptr; 170 171 return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( 172 parser.getEncodedSourceLoc(beginLoc)), 173 shape, elementType, operand); 174 } 175 176 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 177 return Type(); 178 } 179 180 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 181 TypeSwitch<Type>(type) 182 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 183 .Case<MMAMatrixType>([&](MMAMatrixType fragTy) { 184 os << "mma_matrix<"; 185 auto shape = fragTy.getShape(); 186 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) 187 os << *dim << 'x'; 188 os << shape.back() << 'x' << fragTy.getElementType(); 189 os << ", \"" << fragTy.getOperand() << "\"" << '>'; 190 }) 191 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 192 } 193 194 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 195 NamedAttribute attr) { 196 if (!attr.getValue().isa<UnitAttr>() || 197 attr.getName() != getContainerModuleAttrName()) 198 return success(); 199 200 auto module = dyn_cast<ModuleOp>(op); 201 if (!module) 202 return op->emitError("expected '") 203 << getContainerModuleAttrName() << "' attribute to be attached to '" 204 << ModuleOp::getOperationName() << '\''; 205 206 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 207 // Ignore launches that are nested more or less deep than functions in the 208 // module we are currently checking. 209 if (!launchOp->getParentOp() || 210 launchOp->getParentOp()->getParentOp() != module) 211 return success(); 212 213 // Ignore launch ops with missing attributes here. The errors will be 214 // reported by the verifiers of those ops. 215 if (!launchOp->getAttrOfType<SymbolRefAttr>( 216 LaunchFuncOp::getKernelAttrName())) 217 return success(); 218 219 // Check that `launch_func` refers to a well-formed GPU kernel module. 220 StringAttr kernelModuleName = launchOp.getKernelModuleName(); 221 auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); 222 if (!kernelModule) 223 return launchOp.emitOpError() 224 << "kernel module '" << kernelModuleName.getValue() 225 << "' is undefined"; 226 227 // Check that `launch_func` refers to a well-formed kernel function. 228 Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr()); 229 auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc); 230 auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); 231 if (!kernelGPUFunction && !kernelLLVMFunction) 232 return launchOp.emitOpError("kernel function '") 233 << launchOp.kernel() << "' is undefined"; 234 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 235 GPUDialect::getKernelFuncAttrName())) 236 return launchOp.emitOpError("kernel function is missing the '") 237 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 238 239 // TODO: if the kernel function has been converted to 240 // the LLVM dialect but the caller hasn't (which happens during the 241 // separate compilation), do not check type correspondence as it would 242 // require the verifier to be aware of the LLVM type conversion. 243 if (kernelLLVMFunction) 244 return success(); 245 246 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 247 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 248 if (expectedNumArguments != actualNumArguments) 249 return launchOp.emitOpError("got ") 250 << actualNumArguments << " kernel operands but expected " 251 << expectedNumArguments; 252 253 auto functionType = kernelGPUFunction.getType(); 254 for (unsigned i = 0; i < expectedNumArguments; ++i) { 255 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 256 return launchOp.emitOpError("type of function argument ") 257 << i << " does not match"; 258 } 259 } 260 261 return success(); 262 }); 263 264 return walkResult.wasInterrupted() ? failure() : success(); 265 } 266 267 template <typename T> 268 static LogicalResult verifyIndexOp(T op) { 269 auto dimension = op.dimension(); 270 if (dimension != "x" && dimension != "y" && dimension != "z") 271 return op.emitError("dimension \"") << dimension << "\" is invalid"; 272 return success(); 273 } 274 275 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { 276 if (allReduce.body().empty() != allReduce.op().hasValue()) 277 return allReduce.emitError( 278 "expected either an op attribute or a non-empty body"); 279 if (!allReduce.body().empty()) { 280 if (allReduce.body().getNumArguments() != 2) 281 return allReduce.emitError("expected two region arguments"); 282 for (auto argument : allReduce.body().getArguments()) { 283 if (argument.getType() != allReduce.getType()) 284 return allReduce.emitError("incorrect region argument type"); 285 } 286 unsigned yieldCount = 0; 287 for (Block &block : allReduce.body()) { 288 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 289 if (yield.getNumOperands() != 1) 290 return allReduce.emitError("expected one gpu.yield operand"); 291 if (yield.getOperand(0).getType() != allReduce.getType()) 292 return allReduce.emitError("incorrect gpu.yield type"); 293 ++yieldCount; 294 } 295 } 296 if (yieldCount == 0) 297 return allReduce.emitError("expected gpu.yield op in region"); 298 } else { 299 gpu::AllReduceOperation opName = *allReduce.op(); 300 if ((opName == gpu::AllReduceOperation::AND || 301 opName == gpu::AllReduceOperation::OR || 302 opName == gpu::AllReduceOperation::XOR) && 303 !allReduce.getType().isa<IntegerType>()) { 304 return allReduce.emitError() 305 << '`' << gpu::stringifyAllReduceOperation(opName) << '`' 306 << " accumulator is only compatible with Integer type"; 307 } 308 } 309 return success(); 310 } 311 312 // TODO: Support optional custom attributes (without dialect prefix). 313 static ParseResult parseAllReduceOperation(AsmParser &parser, 314 AllReduceOperationAttr &attr) { 315 StringRef enumStr; 316 if (!parser.parseOptionalKeyword(&enumStr)) { 317 Optional<AllReduceOperation> op = gpu::symbolizeAllReduceOperation(enumStr); 318 if (!op) 319 return parser.emitError(parser.getCurrentLocation(), "invalid op kind"); 320 attr = AllReduceOperationAttr::get(parser.getContext(), *op); 321 } 322 return success(); 323 } 324 325 static void printAllReduceOperation(AsmPrinter &printer, Operation *op, 326 AllReduceOperationAttr attr) { 327 if (attr) 328 attr.print(printer); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // AsyncOpInterface 333 //===----------------------------------------------------------------------===// 334 335 void gpu::addAsyncDependency(Operation *op, Value token) { 336 op->insertOperands(0, {token}); 337 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 338 return; 339 auto attrName = 340 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 341 auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName); 342 343 // Async dependencies is the only variadic operand. 344 if (!sizeAttr) 345 return; 346 347 SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>()); 348 ++sizes.front(); 349 op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); 350 } 351 352 //===----------------------------------------------------------------------===// 353 // LaunchOp 354 //===----------------------------------------------------------------------===// 355 356 void LaunchOp::build(OpBuilder &builder, OperationState &result, 357 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 358 Value blockSizeX, Value blockSizeY, Value blockSizeZ, 359 Value dynamicSharedMemorySize) { 360 // Add grid and block sizes as op operands, followed by the data operands. 361 result.addOperands( 362 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 363 if (dynamicSharedMemorySize) 364 result.addOperands(dynamicSharedMemorySize); 365 366 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 367 // where the first kNumConfigRegionAttributes arguments have `index` type and 368 // the rest have the same types as the data operands. 369 Region *kernelRegion = result.addRegion(); 370 Block *body = new Block(); 371 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i) 372 body->addArgument(builder.getIndexType(), result.location); 373 kernelRegion->push_back(body); 374 } 375 376 KernelDim3 LaunchOp::getBlockIds() { 377 assert(!body().empty() && "LaunchOp body must not be empty."); 378 auto args = body().getArguments(); 379 return KernelDim3{args[0], args[1], args[2]}; 380 } 381 382 KernelDim3 LaunchOp::getThreadIds() { 383 assert(!body().empty() && "LaunchOp body must not be empty."); 384 auto args = body().getArguments(); 385 return KernelDim3{args[3], args[4], args[5]}; 386 } 387 388 KernelDim3 LaunchOp::getGridSize() { 389 assert(!body().empty() && "LaunchOp body must not be empty."); 390 auto args = body().getArguments(); 391 return KernelDim3{args[6], args[7], args[8]}; 392 } 393 394 KernelDim3 LaunchOp::getBlockSize() { 395 assert(!body().empty() && "LaunchOp body must not be empty."); 396 auto args = body().getArguments(); 397 return KernelDim3{args[9], args[10], args[11]}; 398 } 399 400 KernelDim3 LaunchOp::getGridSizeOperandValues() { 401 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 402 } 403 404 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 405 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 406 } 407 408 static LogicalResult verify(LaunchOp op) { 409 // Kernel launch takes kNumConfigOperands leading operands for grid/block 410 // sizes and transforms them into kNumConfigRegionAttributes region arguments 411 // for block/thread identifiers and grid/block sizes. 412 if (!op.body().empty()) { 413 if (op.body().getNumArguments() != 414 LaunchOp::kNumConfigOperands + op.getNumOperands() - 415 (op.dynamicSharedMemorySize() ? 1 : 0)) 416 return op.emitOpError("unexpected number of region arguments"); 417 } 418 419 // Block terminators without successors are expected to exit the kernel region 420 // and must be `gpu.terminator`. 421 for (Block &block : op.body()) { 422 if (block.empty()) 423 continue; 424 if (block.back().getNumSuccessors() != 0) 425 continue; 426 if (!isa<gpu::TerminatorOp>(&block.back())) { 427 return block.back() 428 .emitError() 429 .append("expected '", gpu::TerminatorOp::getOperationName(), 430 "' or a terminator with successors") 431 .attachNote(op.getLoc()) 432 .append("in '", LaunchOp::getOperationName(), "' body region"); 433 } 434 } 435 436 return success(); 437 } 438 439 // Pretty-print the kernel grid/block size assignment as 440 // (%iter-x, %iter-y, %iter-z) in 441 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 442 // where %size-* and %iter-* will correspond to the body region arguments. 443 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 444 KernelDim3 operands, KernelDim3 ids) { 445 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 446 p << size.x << " = " << operands.x << ", "; 447 p << size.y << " = " << operands.y << ", "; 448 p << size.z << " = " << operands.z << ')'; 449 } 450 451 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { 452 // Print the launch configuration. 453 p << ' ' << op.getBlocksKeyword(); 454 printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(), 455 op.getBlockIds()); 456 p << ' ' << op.getThreadsKeyword(); 457 printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(), 458 op.getThreadIds()); 459 if (op.dynamicSharedMemorySize()) 460 p << ' ' << op.getDynamicSharedMemorySizeKeyword() << ' ' 461 << op.dynamicSharedMemorySize(); 462 463 p << ' '; 464 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 465 p.printOptionalAttrDict(op->getAttrs()); 466 } 467 468 // Parse the size assignment blocks for blocks and threads. These have the form 469 // (%region_arg, %region_arg, %region_arg) in 470 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 471 // where %region_arg are percent-identifiers for the region arguments to be 472 // introduced further (SSA defs), and %operand are percent-identifiers for the 473 // SSA value uses. 474 static ParseResult 475 parseSizeAssignment(OpAsmParser &parser, 476 MutableArrayRef<OpAsmParser::OperandType> sizes, 477 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 478 MutableArrayRef<OpAsmParser::OperandType> indices) { 479 assert(indices.size() == 3 && "space for three indices expected"); 480 SmallVector<OpAsmParser::OperandType, 3> args; 481 if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3, 482 OpAsmParser::Delimiter::Paren) || 483 parser.parseKeyword("in") || parser.parseLParen()) 484 return failure(); 485 std::move(args.begin(), args.end(), indices.begin()); 486 487 for (int i = 0; i < 3; ++i) { 488 if (i != 0 && parser.parseComma()) 489 return failure(); 490 if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() || 491 parser.parseOperand(sizes[i])) 492 return failure(); 493 } 494 495 return parser.parseRParen(); 496 } 497 498 // Parses a Launch operation. 499 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 500 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 501 // region attr-dict? 502 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 503 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { 504 // Sizes of the grid and block. 505 SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes( 506 LaunchOp::kNumConfigOperands); 507 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 508 509 // Actual (data) operands passed to the kernel. 510 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 511 512 // Region arguments to be created. 513 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 514 LaunchOp::kNumConfigRegionAttributes); 515 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 516 517 // Parse the size assignment segments: the first segment assigns grid sizes 518 // and defines values for block identifiers; the second segment assigns block 519 // sizes and defines values for thread identifiers. In the region argument 520 // list, identifiers precede sizes, and block-related values precede 521 // thread-related values. 522 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 523 parseSizeAssignment(parser, sizesRef.take_front(3), 524 regionArgsRef.slice(6, 3), 525 regionArgsRef.slice(0, 3)) || 526 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 527 parseSizeAssignment(parser, sizesRef.drop_front(3), 528 regionArgsRef.slice(9, 3), 529 regionArgsRef.slice(3, 3)) || 530 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 531 result.operands)) 532 return failure(); 533 534 OpAsmParser::OperandType dynamicSharedMemorySize; 535 if (!parser.parseOptionalKeyword( 536 LaunchOp::getDynamicSharedMemorySizeKeyword())) 537 if (parser.parseOperand(dynamicSharedMemorySize) || 538 parser.resolveOperand(dynamicSharedMemorySize, 539 parser.getBuilder().getI32Type(), 540 result.operands)) 541 return failure(); 542 543 // Introduce the body region and parse it. The region has 544 // kNumConfigRegionAttributes arguments that correspond to 545 // block/thread identifiers and grid/block sizes, all of the `index` type. 546 Type index = parser.getBuilder().getIndexType(); 547 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 548 LaunchOp::kNumConfigRegionAttributes, index); 549 Region *body = result.addRegion(); 550 return failure(parser.parseRegion(*body, regionArgs, dataTypes) || 551 parser.parseOptionalAttrDict(result.attributes)); 552 } 553 554 /// Simplify the gpu.launch when the range of a thread or block ID is 555 /// trivially known to be one. 556 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> { 557 using OpRewritePattern<LaunchOp>::OpRewritePattern; 558 LogicalResult matchAndRewrite(LaunchOp op, 559 PatternRewriter &rewriter) const override { 560 // If the range implies a single value for `id`, replace `id`'s uses by 561 // zero. 562 Value zero; 563 bool simplified = false; 564 auto constPropIdUses = [&](Value id, Value size) { 565 // Check if size is trivially one. 566 if (!matchPattern(size, m_One())) 567 return; 568 if (!simplified) { 569 // Create a zero value the first time. 570 OpBuilder::InsertionGuard guard(rewriter); 571 rewriter.setInsertionPointToStart(&op.body().front()); 572 zero = 573 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0); 574 } 575 id.replaceAllUsesWith(zero); 576 simplified = true; 577 }; 578 constPropIdUses(op.getBlockIds().x, op.gridSizeX()); 579 constPropIdUses(op.getBlockIds().y, op.gridSizeY()); 580 constPropIdUses(op.getBlockIds().z, op.gridSizeZ()); 581 constPropIdUses(op.getThreadIds().x, op.blockSizeX()); 582 constPropIdUses(op.getThreadIds().y, op.blockSizeY()); 583 constPropIdUses(op.getThreadIds().z, op.blockSizeZ()); 584 585 return success(simplified); 586 } 587 }; 588 589 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites, 590 MLIRContext *context) { 591 rewrites.add<FoldLaunchArguments>(context); 592 } 593 594 //===----------------------------------------------------------------------===// 595 // LaunchFuncOp 596 //===----------------------------------------------------------------------===// 597 598 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 599 GPUFuncOp kernelFunc, KernelDim3 gridSize, 600 KernelDim3 blockSize, Value dynamicSharedMemorySize, 601 ValueRange kernelOperands) { 602 // Add grid and block sizes as op operands, followed by the data operands. 603 result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x, 604 blockSize.y, blockSize.z}); 605 if (dynamicSharedMemorySize) 606 result.addOperands(dynamicSharedMemorySize); 607 result.addOperands(kernelOperands); 608 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); 609 auto kernelSymbol = 610 SymbolRefAttr::get(kernelModule.getNameAttr(), 611 {SymbolRefAttr::get(kernelFunc.getNameAttr())}); 612 result.addAttribute(getKernelAttrName(), kernelSymbol); 613 SmallVector<int32_t, 9> segmentSizes(9, 1); 614 segmentSizes.front() = 0; // Initially no async dependencies. 615 segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0; 616 segmentSizes.back() = static_cast<int32_t>(kernelOperands.size()); 617 result.addAttribute(getOperandSegmentSizeAttr(), 618 builder.getI32VectorAttr(segmentSizes)); 619 } 620 621 unsigned LaunchFuncOp::getNumKernelOperands() { 622 return getNumOperands() - asyncDependencies().size() - kNumConfigOperands - 623 (dynamicSharedMemorySize() ? 1 : 0); 624 } 625 626 StringAttr LaunchFuncOp::getKernelModuleName() { 627 return kernel().getRootReference(); 628 } 629 630 StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } 631 632 Value LaunchFuncOp::getKernelOperand(unsigned i) { 633 return getOperand(asyncDependencies().size() + kNumConfigOperands + 634 (dynamicSharedMemorySize() ? 1 : 0) + i); 635 } 636 637 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 638 auto operands = getOperands().drop_front(asyncDependencies().size()); 639 return KernelDim3{operands[0], operands[1], operands[2]}; 640 } 641 642 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 643 auto operands = getOperands().drop_front(asyncDependencies().size()); 644 return KernelDim3{operands[3], operands[4], operands[5]}; 645 } 646 647 static LogicalResult verify(LaunchFuncOp op) { 648 auto module = op->getParentOfType<ModuleOp>(); 649 if (!module) 650 return op.emitOpError("expected to belong to a module"); 651 652 if (!module->getAttrOfType<UnitAttr>( 653 GPUDialect::getContainerModuleAttrName())) 654 return op.emitOpError( 655 "expected the closest surrounding module to have the '" + 656 GPUDialect::getContainerModuleAttrName() + "' attribute"); 657 658 auto kernelAttr = op->getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); 659 if (!kernelAttr) 660 return op.emitOpError("symbol reference attribute '" + 661 op.getKernelAttrName() + "' must be specified"); 662 663 return success(); 664 } 665 666 static ParseResult 667 parseLaunchFuncOperands(OpAsmParser &parser, 668 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 669 SmallVectorImpl<Type> &argTypes) { 670 if (parser.parseOptionalKeyword("args")) 671 return success(); 672 SmallVector<NamedAttrList> argAttrs; 673 SmallVector<Location> argLocations; 674 bool isVariadic = false; 675 return function_interface_impl::parseFunctionArgumentList( 676 parser, /*allowAttributes=*/false, 677 /*allowVariadic=*/false, argNames, argTypes, argAttrs, argLocations, 678 isVariadic); 679 } 680 681 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, 682 OperandRange operands, TypeRange types) { 683 if (operands.empty()) 684 return; 685 printer << "args("; 686 llvm::interleaveComma(llvm::zip(operands, types), printer, 687 [&](const auto &pair) { 688 printer.printOperand(std::get<0>(pair)); 689 printer << " : "; 690 printer.printType(std::get<1>(pair)); 691 }); 692 printer << ")"; 693 } 694 695 //===----------------------------------------------------------------------===// 696 // GPUFuncOp 697 //===----------------------------------------------------------------------===// 698 699 /// Adds a new block argument that corresponds to buffers located in 700 /// workgroup memory. 701 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) { 702 auto attrName = getNumWorkgroupAttributionsAttrName(); 703 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); 704 (*this)->setAttr(attrName, 705 IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 706 return getBody().insertArgument(getType().getNumInputs() + attr.getInt(), 707 type, loc); 708 } 709 710 /// Adds a new block argument that corresponds to buffers located in 711 /// private memory. 712 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) { 713 // Buffers on the private memory always come after buffers on the workgroup 714 // memory. 715 return getBody().addArgument(type, loc); 716 } 717 718 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 719 StringRef name, FunctionType type, 720 TypeRange workgroupAttributions, 721 TypeRange privateAttributions, 722 ArrayRef<NamedAttribute> attrs) { 723 result.addAttribute(SymbolTable::getSymbolAttrName(), 724 builder.getStringAttr(name)); 725 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 726 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 727 builder.getI64IntegerAttr(workgroupAttributions.size())); 728 result.addAttributes(attrs); 729 Region *body = result.addRegion(); 730 Block *entryBlock = new Block; 731 732 // TODO: Allow passing in proper locations here. 733 for (Type argTy : type.getInputs()) 734 entryBlock->addArgument(argTy, result.location); 735 for (Type argTy : workgroupAttributions) 736 entryBlock->addArgument(argTy, result.location); 737 for (Type argTy : privateAttributions) 738 entryBlock->addArgument(argTy, result.location); 739 740 body->getBlocks().push_back(entryBlock); 741 } 742 743 /// Parses a GPU function memory attribution. 744 /// 745 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 746 /// (`private` `(` ssa-id-and-type-list `)`)? 747 /// 748 /// Note that this function parses only one of the two similar parts, with the 749 /// keyword provided as argument. 750 static ParseResult 751 parseAttributions(OpAsmParser &parser, StringRef keyword, 752 SmallVectorImpl<OpAsmParser::OperandType> &args, 753 SmallVectorImpl<Type> &argTypes) { 754 // If we could not parse the keyword, just assume empty list and succeed. 755 if (failed(parser.parseOptionalKeyword(keyword))) 756 return success(); 757 758 if (failed(parser.parseLParen())) 759 return failure(); 760 761 // Early exit for an empty list. 762 if (succeeded(parser.parseOptionalRParen())) 763 return success(); 764 765 do { 766 OpAsmParser::OperandType arg; 767 Type type; 768 769 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 770 return failure(); 771 772 args.push_back(arg); 773 argTypes.push_back(type); 774 } while (succeeded(parser.parseOptionalComma())); 775 776 return parser.parseRParen(); 777 } 778 779 /// Parses a GPU function. 780 /// 781 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 782 /// (`->` function-result-list)? memory-attribution `kernel`? 783 /// function-attributes? region 784 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { 785 SmallVector<OpAsmParser::OperandType> entryArgs; 786 SmallVector<NamedAttrList> argAttrs; 787 SmallVector<NamedAttrList> resultAttrs; 788 SmallVector<Type> argTypes; 789 SmallVector<Type> resultTypes; 790 SmallVector<Location> argLocations; 791 bool isVariadic; 792 793 // Parse the function name. 794 StringAttr nameAttr; 795 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 796 result.attributes)) 797 return failure(); 798 799 auto signatureLocation = parser.getCurrentLocation(); 800 if (failed(function_interface_impl::parseFunctionSignature( 801 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, 802 argLocations, isVariadic, resultTypes, resultAttrs))) 803 return failure(); 804 805 if (entryArgs.empty() && !argTypes.empty()) 806 return parser.emitError(signatureLocation) 807 << "gpu.func requires named arguments"; 808 809 // Construct the function type. More types will be added to the region, but 810 // not to the function type. 811 Builder &builder = parser.getBuilder(); 812 auto type = builder.getFunctionType(argTypes, resultTypes); 813 result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); 814 815 // Parse workgroup memory attributions. 816 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 817 entryArgs, argTypes))) 818 return failure(); 819 820 // Store the number of operands we just parsed as the number of workgroup 821 // memory attributions. 822 unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); 823 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 824 builder.getI64IntegerAttr(numWorkgroupAttrs)); 825 826 // Parse private memory attributions. 827 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 828 entryArgs, argTypes))) 829 return failure(); 830 831 // Parse the kernel attribute if present. 832 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 833 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 834 builder.getUnitAttr()); 835 836 // Parse attributes. 837 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 838 return failure(); 839 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 840 resultAttrs); 841 842 // Parse the region. If no argument names were provided, take all names 843 // (including those of attributions) from the entry block. 844 auto *body = result.addRegion(); 845 return parser.parseRegion(*body, entryArgs, argTypes); 846 } 847 848 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 849 ArrayRef<BlockArgument> values) { 850 if (values.empty()) 851 return; 852 853 p << ' ' << keyword << '('; 854 llvm::interleaveComma( 855 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 856 p << ')'; 857 } 858 859 /// Prints a GPU Func op. 860 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { 861 p << ' '; 862 p.printSymbolName(op.getName()); 863 864 FunctionType type = op.getType(); 865 function_interface_impl::printFunctionSignature( 866 p, op.getOperation(), type.getInputs(), 867 /*isVariadic=*/false, type.getResults()); 868 869 printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); 870 printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); 871 if (op.isKernel()) 872 p << ' ' << op.getKernelKeyword(); 873 874 function_interface_impl::printFunctionAttributes( 875 p, op.getOperation(), type.getNumInputs(), type.getNumResults(), 876 {op.getNumWorkgroupAttributionsAttrName(), 877 GPUDialect::getKernelFuncAttrName()}); 878 p << ' '; 879 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); 880 } 881 882 LogicalResult GPUFuncOp::verifyType() { 883 Type type = getTypeAttr().getValue(); 884 if (!type.isa<FunctionType>()) 885 return emitOpError("requires '" + getTypeAttrName() + 886 "' attribute of function type"); 887 888 if (isKernel() && getType().getNumResults() != 0) 889 return emitOpError() << "expected void return type for kernel function"; 890 891 return success(); 892 } 893 894 static LogicalResult verifyAttributions(Operation *op, 895 ArrayRef<BlockArgument> attributions, 896 unsigned memorySpace) { 897 for (Value v : attributions) { 898 auto type = v.getType().dyn_cast<MemRefType>(); 899 if (!type) 900 return op->emitOpError() << "expected memref type in attribution"; 901 902 if (type.getMemorySpaceAsInt() != memorySpace) { 903 return op->emitOpError() 904 << "expected memory space " << memorySpace << " in attribution"; 905 } 906 } 907 return success(); 908 } 909 910 /// Verifies the body of the function. 911 LogicalResult GPUFuncOp::verifyBody() { 912 unsigned numFuncArguments = getNumArguments(); 913 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 914 unsigned numBlockArguments = front().getNumArguments(); 915 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 916 return emitOpError() << "expected at least " 917 << numFuncArguments + numWorkgroupAttributions 918 << " arguments to body region"; 919 920 ArrayRef<Type> funcArgTypes = getType().getInputs(); 921 for (unsigned i = 0; i < numFuncArguments; ++i) { 922 Type blockArgType = front().getArgument(i).getType(); 923 if (funcArgTypes[i] != blockArgType) 924 return emitOpError() << "expected body region argument #" << i 925 << " to be of type " << funcArgTypes[i] << ", got " 926 << blockArgType; 927 } 928 929 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 930 GPUDialect::getWorkgroupAddressSpace())) || 931 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 932 GPUDialect::getPrivateAddressSpace()))) 933 return failure(); 934 935 return success(); 936 } 937 938 //===----------------------------------------------------------------------===// 939 // ReturnOp 940 //===----------------------------------------------------------------------===// 941 942 static LogicalResult verify(gpu::ReturnOp returnOp) { 943 GPUFuncOp function = returnOp->getParentOfType<GPUFuncOp>(); 944 945 FunctionType funType = function.getType(); 946 947 if (funType.getNumResults() != returnOp.operands().size()) 948 return returnOp.emitOpError() 949 .append("expected ", funType.getNumResults(), " result operands") 950 .attachNote(function.getLoc()) 951 .append("return type declared here"); 952 953 for (const auto &pair : llvm::enumerate( 954 llvm::zip(function.getType().getResults(), returnOp.operands()))) { 955 Type type; 956 Value operand; 957 std::tie(type, operand) = pair.value(); 958 if (type != operand.getType()) 959 return returnOp.emitOpError() << "unexpected type `" << operand.getType() 960 << "' for operand #" << pair.index(); 961 } 962 return success(); 963 } 964 965 //===----------------------------------------------------------------------===// 966 // GPUModuleOp 967 //===----------------------------------------------------------------------===// 968 969 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 970 StringRef name) { 971 ensureTerminator(*result.addRegion(), builder, result.location); 972 result.attributes.push_back(builder.getNamedAttr( 973 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 974 } 975 976 static ParseResult parseGPUModuleOp(OpAsmParser &parser, 977 OperationState &result) { 978 StringAttr nameAttr; 979 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 980 result.attributes)) 981 return failure(); 982 983 // If module attributes are present, parse them. 984 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 985 return failure(); 986 987 // Parse the module body. 988 auto *body = result.addRegion(); 989 if (parser.parseRegion(*body, None, None)) 990 return failure(); 991 992 // Ensure that this module has a valid terminator. 993 GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); 994 return success(); 995 } 996 997 static void print(OpAsmPrinter &p, GPUModuleOp op) { 998 p << ' '; 999 p.printSymbolName(op.getName()); 1000 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 1001 {SymbolTable::getSymbolAttrName()}); 1002 p << ' '; 1003 p.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, 1004 /*printBlockTerminators=*/false); 1005 } 1006 1007 //===----------------------------------------------------------------------===// 1008 // GPUMemcpyOp 1009 //===----------------------------------------------------------------------===// 1010 1011 static LogicalResult verify(MemcpyOp op) { 1012 auto srcType = op.src().getType(); 1013 auto dstType = op.dst().getType(); 1014 1015 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) 1016 return op.emitOpError("arguments have incompatible element type"); 1017 1018 if (failed(verifyCompatibleShape(srcType, dstType))) 1019 return op.emitOpError("arguments have incompatible shape"); 1020 1021 return success(); 1022 } 1023 1024 static ParseResult parseAsyncDependencies( 1025 OpAsmParser &parser, Type &asyncTokenType, 1026 SmallVectorImpl<OpAsmParser::OperandType> &asyncDependencies) { 1027 auto loc = parser.getCurrentLocation(); 1028 if (succeeded(parser.parseOptionalKeyword("async"))) { 1029 if (parser.getNumResults() == 0) 1030 return parser.emitError(loc, "needs to be named when marked 'async'"); 1031 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 1032 } 1033 return parser.parseOperandList(asyncDependencies, 1034 OpAsmParser::Delimiter::OptionalSquare); 1035 } 1036 1037 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, 1038 Type asyncTokenType, 1039 OperandRange asyncDependencies) { 1040 if (asyncTokenType) 1041 printer << "async "; 1042 if (asyncDependencies.empty()) 1043 return; 1044 printer << "["; 1045 llvm::interleaveComma(asyncDependencies, printer); 1046 printer << "]"; 1047 } 1048 1049 //===----------------------------------------------------------------------===// 1050 // GPU_SubgroupMmaLoadMatrixOp 1051 //===----------------------------------------------------------------------===// 1052 1053 static LogicalResult verify(SubgroupMmaLoadMatrixOp op) { 1054 auto srcType = op.srcMemref().getType(); 1055 auto resType = op.res().getType(); 1056 auto resMatrixType = resType.cast<gpu::MMAMatrixType>(); 1057 auto operand = resMatrixType.getOperand(); 1058 auto srcMemrefType = srcType.cast<MemRefType>(); 1059 auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); 1060 1061 if (!srcMemrefType.getLayout().isIdentity()) 1062 return op.emitError("expected identity layout map for source memref"); 1063 1064 if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && 1065 srcMemSpace != kGlobalMemorySpace) 1066 return op.emitError( 1067 "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " 1068 "kGlobalMemorySpace only allowed"); 1069 1070 if (!operand.equals("AOp") && !operand.equals("BOp") && 1071 !operand.equals("COp")) 1072 return op.emitError("only AOp, BOp and COp can be loaded"); 1073 1074 return success(); 1075 } 1076 1077 //===----------------------------------------------------------------------===// 1078 // GPU_SubgroupMmaStoreMatrixOp 1079 //===----------------------------------------------------------------------===// 1080 1081 static LogicalResult verify(SubgroupMmaStoreMatrixOp op) { 1082 auto srcType = op.src().getType(); 1083 auto dstType = op.dstMemref().getType(); 1084 auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>(); 1085 auto dstMemrefType = dstType.cast<MemRefType>(); 1086 auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); 1087 if (!dstMemrefType.getLayout().isIdentity()) 1088 return op.emitError("expected identity layout map for destination memref"); 1089 1090 if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && 1091 dstMemSpace != kGlobalMemorySpace) 1092 return op.emitError( 1093 "destination memorySpace of kGenericMemorySpace, " 1094 "kGlobalMemorySpace or kSharedMemorySpace only allowed"); 1095 1096 if (!srcMatrixType.getOperand().equals("COp")) 1097 return op.emitError( 1098 "expected the operand matrix being stored to have 'COp' operand type"); 1099 1100 return success(); 1101 } 1102 1103 //===----------------------------------------------------------------------===// 1104 // GPU_SubgroupMmaComputeOp 1105 //===----------------------------------------------------------------------===// 1106 1107 static LogicalResult verify(SubgroupMmaComputeOp op) { 1108 enum OperandMap { A, B, C }; 1109 SmallVector<MMAMatrixType, 3> opTypes; 1110 1111 auto populateOpInfo = [&opTypes, &op]() { 1112 opTypes.push_back(op.opA().getType().cast<MMAMatrixType>()); 1113 opTypes.push_back(op.opB().getType().cast<MMAMatrixType>()); 1114 opTypes.push_back(op.opC().getType().cast<MMAMatrixType>()); 1115 }; 1116 populateOpInfo(); 1117 1118 if (!opTypes[A].getOperand().equals("AOp") || 1119 !opTypes[B].getOperand().equals("BOp") || 1120 !opTypes[C].getOperand().equals("COp")) 1121 return op.emitError("operands must be in the order AOp, BOp, COp"); 1122 1123 ArrayRef<int64_t> aShape, bShape, cShape; 1124 aShape = opTypes[A].getShape(); 1125 bShape = opTypes[B].getShape(); 1126 cShape = opTypes[C].getShape(); 1127 1128 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || 1129 bShape[1] != cShape[1]) 1130 return op.emitError("operand shapes do not satisfy matmul constraints"); 1131 1132 return success(); 1133 } 1134 1135 /// This is a common class used for patterns of the form 1136 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 1137 /// into the root operation directly. 1138 static LogicalResult foldMemRefCast(Operation *op) { 1139 bool folded = false; 1140 for (OpOperand &operand : op->getOpOperands()) { 1141 auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>(); 1142 if (cast) { 1143 operand.set(cast.getOperand()); 1144 folded = true; 1145 } 1146 } 1147 return success(folded); 1148 } 1149 1150 LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands, 1151 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1152 return foldMemRefCast(*this); 1153 } 1154 1155 LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands, 1156 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1157 return foldMemRefCast(*this); 1158 } 1159 1160 //===----------------------------------------------------------------------===// 1161 // GPU_AllocOp 1162 //===----------------------------------------------------------------------===// 1163 namespace { 1164 1165 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to 1166 /// `memref::AllocOp`. 1167 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> { 1168 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 1169 1170 LogicalResult matchAndRewrite(memref::DimOp dimOp, 1171 PatternRewriter &rewriter) const override { 1172 auto index = dimOp.index().getDefiningOp<arith::ConstantIndexOp>(); 1173 if (!index) 1174 return failure(); 1175 1176 auto memrefType = dimOp.source().getType().dyn_cast<MemRefType>(); 1177 if (!memrefType || !memrefType.isDynamicDim(index.value())) 1178 return failure(); 1179 1180 auto alloc = dimOp.source().getDefiningOp<AllocOp>(); 1181 if (!alloc) 1182 return failure(); 1183 1184 Value substituteOp = *(alloc.dynamicSizes().begin() + 1185 memrefType.getDynamicDimIndex(index.value())); 1186 rewriter.replaceOp(dimOp, substituteOp); 1187 return success(); 1188 } 1189 }; 1190 1191 } // namespace 1192 1193 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 1194 MLIRContext *context) { 1195 results.add<SimplifyDimOfAllocOp>(context); 1196 } 1197 1198 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" 1199 #include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc" 1200 1201 #define GET_ATTRDEF_CLASSES 1202 #include "mlir/Dialect/GPU/GPUOpsAttributes.cpp.inc" 1203 1204 #define GET_OP_CLASSES 1205 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 1206