1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/GPU/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/Ops.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Function.h"
18 #include "mlir/IR/FunctionImplementation.h"
19 #include "mlir/IR/Module.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/StandardTypes.h"
23 
24 using namespace mlir;
25 using namespace mlir::gpu;
26 
27 //===----------------------------------------------------------------------===//
28 // GPUDialect
29 //===----------------------------------------------------------------------===//
30 
31 StringRef GPUDialect::getDialectName() { return "gpu"; }
32 
33 bool GPUDialect::isKernel(Operation *op) {
34   UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
35   return static_cast<bool>(isKernelAttr);
36 }
37 
38 GPUDialect::GPUDialect(MLIRContext *context)
39     : Dialect(getDialectName(), context) {
40   addOperations<
41 #define GET_OP_LIST
42 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
43       >();
44 }
45 
46 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
47                                                    NamedAttribute attr) {
48   if (!attr.second.isa<UnitAttr>() ||
49       !attr.first.is(getContainerModuleAttrName()))
50     return success();
51 
52   auto module = dyn_cast<ModuleOp>(op);
53   if (!module)
54     return op->emitError("expected '")
55            << getContainerModuleAttrName() << "' attribute to be attached to '"
56            << ModuleOp::getOperationName() << '\'';
57 
58   auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
59     // Ignore launches that are nested more or less deep than functions in the
60     // module we are currently checking.
61     if (!launchOp.getParentOp() ||
62         launchOp.getParentOp()->getParentOp() != module)
63       return success();
64 
65     // Ignore launch ops with missing attributes here. The errors will be
66     // reported by the verifiers of those ops.
67     if (!launchOp.getAttrOfType<StringAttr>(
68             LaunchFuncOp::getKernelAttrName()) ||
69         !launchOp.getAttrOfType<SymbolRefAttr>(
70             LaunchFuncOp::getKernelModuleAttrName()))
71       return success();
72 
73     // Check that `launch_func` refers to a well-formed GPU kernel module.
74     StringRef kernelModuleName = launchOp.getKernelModuleName();
75     auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName);
76     if (!kernelModule)
77       return launchOp.emitOpError()
78              << "kernel module '" << kernelModuleName << "' is undefined";
79     if (!kernelModule.getAttrOfType<UnitAttr>(
80             GPUDialect::getKernelModuleAttrName()))
81       return launchOp.emitOpError("module '")
82              << kernelModuleName << "' is missing the '"
83              << GPUDialect::getKernelModuleAttrName() << "' attribute";
84 
85     // Check that `launch_func` refers to a well-formed kernel function.
86     StringRef kernelName = launchOp.kernel();
87     Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
88     auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc);
89     auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
90     if (!kernelGPUFunction && !kernelLLVMFunction)
91       return launchOp.emitOpError("kernel function '")
92              << kernelName << "' is undefined";
93     if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
94             GPUDialect::getKernelFuncAttrName()))
95       return launchOp.emitOpError("kernel function is missing the '")
96              << GPUDialect::getKernelFuncAttrName() << "' attribute";
97 
98     unsigned actualNumArguments = launchOp.getNumKernelOperands();
99     unsigned expectedNumArguments = kernelLLVMFunction
100                                         ? kernelLLVMFunction.getNumArguments()
101                                         : kernelGPUFunction.getNumArguments();
102     if (expectedNumArguments != actualNumArguments)
103       return launchOp.emitOpError("got ")
104              << actualNumArguments << " kernel operands but expected "
105              << expectedNumArguments;
106 
107     // Due to the ordering of the current impl of lowering and LLVMLowering,
108     // type checks need to be temporarily disabled.
109     // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
110     // to encode target module" has landed.
111     // auto functionType = kernelFunc.getType();
112     // for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
113     //   if (getKernelOperand(i).getType() != functionType.getInput(i)) {
114     //     return emitOpError("type of function argument ")
115     //            << i << " does not match";
116     //   }
117     // }
118 
119     return success();
120   });
121 
122   return walkResult.wasInterrupted() ? failure() : success();
123 }
124 
125 template <typename T> static LogicalResult verifyIndexOp(T op) {
126   auto dimension = op.dimension();
127   if (dimension != "x" && dimension != "y" && dimension != "z")
128     return op.emitError("dimension \"") << dimension << "\" is invalid";
129   return success();
130 }
131 
132 static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
133   if (allReduce.body().empty() != allReduce.op().hasValue())
134     return allReduce.emitError(
135         "expected either an op attribute or a non-empty body");
136   if (!allReduce.body().empty()) {
137     if (allReduce.body().front().getNumArguments() != 2)
138       return allReduce.emitError("expected two region arguments");
139     for (auto argument : allReduce.body().front().getArguments()) {
140       if (argument.getType() != allReduce.getType())
141         return allReduce.emitError("incorrect region argument type");
142     }
143     unsigned yieldCount = 0;
144     for (Block &block : allReduce.body()) {
145       if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
146         if (yield.getNumOperands() != 1)
147           return allReduce.emitError("expected one gpu.yield operand");
148         if (yield.getOperand(0).getType() != allReduce.getType())
149           return allReduce.emitError("incorrect gpu.yield type");
150         ++yieldCount;
151       }
152     }
153     if (yieldCount == 0)
154       return allReduce.emitError("expected gpu.yield op in region");
155   }
156   return success();
157 }
158 
159 static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
160   auto type = shuffleOp.value().getType();
161   if (shuffleOp.result().getType() != type) {
162     return shuffleOp.emitOpError()
163            << "requires the same type for value operand and result";
164   }
165   if (!type.isIntOrFloat() || type.getIntOrFloatBitWidth() != 32) {
166     return shuffleOp.emitOpError()
167            << "requires value operand type to be f32 or i32";
168   }
169   return success();
170 }
171 
172 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
173   p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' '
174     << op.mode() << " : " << op.value().getType();
175 }
176 
177 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
178   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
179   if (parser.parseOperandList(operandInfo, 3))
180     return failure();
181 
182   StringRef mode;
183   if (parser.parseKeyword(&mode))
184     return failure();
185   state.addAttribute("mode", parser.getBuilder().getStringAttr(mode));
186 
187   Type valueType;
188   Type int32Type = parser.getBuilder().getIntegerType(32);
189   Type int1Type = parser.getBuilder().getI1Type();
190   if (parser.parseColonType(valueType) ||
191       parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type},
192                              parser.getCurrentLocation(), state.operands) ||
193       parser.addTypesToList({valueType, int1Type}, state.types))
194     return failure();
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // LaunchOp
200 //===----------------------------------------------------------------------===//
201 
202 void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
203                      Value gridSizeY, Value gridSizeZ, Value blockSizeX,
204                      Value blockSizeY, Value blockSizeZ, ValueRange operands) {
205   // Add grid and block sizes as op operands, followed by the data operands.
206   result.addOperands(
207       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
208   result.addOperands(operands);
209 
210   // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
211   // where the first kNumConfigRegionAttributes arguments have `index` type and
212   // the rest have the same types as the data operands.
213   Region *kernelRegion = result.addRegion();
214   Block *body = new Block();
215   body->addArguments(
216       std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
217   body->addArguments(llvm::to_vector<4>(operands.getTypes()));
218   kernelRegion->push_back(body);
219 }
220 
221 KernelDim3 LaunchOp::getBlockIds() {
222   assert(!body().getBlocks().empty() && "FuncOp body must not be empty.");
223   auto args = body().getBlocks().front().getArguments();
224   return KernelDim3{args[0], args[1], args[2]};
225 }
226 
227 KernelDim3 LaunchOp::getThreadIds() {
228   assert(!body().getBlocks().empty() && "FuncOp body must not be empty.");
229   auto args = body().getBlocks().front().getArguments();
230   return KernelDim3{args[3], args[4], args[5]};
231 }
232 
233 KernelDim3 LaunchOp::getGridSize() {
234   assert(!body().getBlocks().empty() && "FuncOp body must not be empty.");
235   auto args = body().getBlocks().front().getArguments();
236   return KernelDim3{args[6], args[7], args[8]};
237 }
238 
239 KernelDim3 LaunchOp::getBlockSize() {
240   assert(!body().getBlocks().empty() && "FuncOp body must not be empty.");
241   auto args = body().getBlocks().front().getArguments();
242   return KernelDim3{args[9], args[10], args[11]};
243 }
244 
245 LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
246   return llvm::drop_begin(getOperands(), kNumConfigOperands);
247 }
248 
249 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
250   return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
251 }
252 
253 KernelDim3 LaunchOp::getGridSizeOperandValues() {
254   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
255 }
256 
257 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
258   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
259 }
260 
261 iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
262   auto args = body().getBlocks().front().getArguments();
263   return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
264 }
265 
266 static LogicalResult verify(LaunchOp op) {
267   // Kernel launch takes kNumConfigOperands leading operands for grid/block
268   // sizes and transforms them into kNumConfigRegionAttributes region arguments
269   // for block/thread identifiers and grid/block sizes.
270   if (!op.body().empty()) {
271     Block &entryBlock = op.body().front();
272     if (entryBlock.getNumArguments() !=
273         LaunchOp::kNumConfigOperands + op.getNumOperands())
274       return op.emitOpError("unexpected number of region arguments");
275   }
276 
277   // Block terminators without successors are expected to exit the kernel region
278   // and must be `gpu.launch`.
279   for (Block &block : op.body()) {
280     if (block.empty())
281       continue;
282     if (block.back().getNumSuccessors() != 0)
283       continue;
284     if (!isa<gpu::ReturnOp>(&block.back())) {
285       return block.back()
286                  .emitError("expected 'gpu.terminator' or a terminator with "
287                             "successors")
288                  .attachNote(op.getLoc())
289              << "in '" << LaunchOp::getOperationName() << "' body region";
290     }
291   }
292 
293   return success();
294 }
295 
296 // Pretty-print the kernel grid/block size assignment as
297 //   (%iter-x, %iter-y, %iter-z) in
298 //   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
299 // where %size-* and %iter-* will correspond to the body region arguments.
300 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
301                                 ValueRange operands, KernelDim3 ids) {
302   p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
303   p << size.x << " = " << operands[0] << ", ";
304   p << size.y << " = " << operands[1] << ", ";
305   p << size.z << " = " << operands[2] << ')';
306 }
307 
308 static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
309   ValueRange operands = op.getOperands();
310 
311   // Print the launch configuration.
312   p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword();
313   printSizeAssignment(p, op.getGridSize(), operands.take_front(3),
314                       op.getBlockIds());
315   p << ' ' << op.getThreadsKeyword();
316   printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3),
317                       op.getThreadIds());
318 
319   // From now on, the first kNumConfigOperands operands corresponding to grid
320   // and block sizes are irrelevant, so we can drop them.
321   operands = operands.drop_front(LaunchOp::kNumConfigOperands);
322 
323   // Print the data argument remapping.
324   if (!op.body().empty() && !operands.empty()) {
325     p << ' ' << op.getArgsKeyword() << '(';
326     Block *entryBlock = &op.body().front();
327     interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
328       p << entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
329         << " = " << operands[i];
330     });
331     p << ") ";
332   }
333 
334   // Print the types of data arguments.
335   if (!operands.empty())
336     p << ": " << operands.getTypes();
337 
338   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
339   p.printOptionalAttrDict(op.getAttrs());
340 }
341 
342 // Parse the size assignment blocks for blocks and threads.  These have the form
343 //   (%region_arg, %region_arg, %region_arg) in
344 //   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
345 // where %region_arg are percent-identifiers for the region arguments to be
346 // introduced further (SSA defs), and %operand are percent-identifiers for the
347 // SSA value uses.
348 static ParseResult
349 parseSizeAssignment(OpAsmParser &parser,
350                     MutableArrayRef<OpAsmParser::OperandType> sizes,
351                     MutableArrayRef<OpAsmParser::OperandType> regionSizes,
352                     MutableArrayRef<OpAsmParser::OperandType> indices) {
353   assert(indices.size() == 3 && "space for three indices expected");
354   SmallVector<OpAsmParser::OperandType, 3> args;
355   if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3,
356                                      OpAsmParser::Delimiter::Paren) ||
357       parser.parseKeyword("in") || parser.parseLParen())
358     return failure();
359   std::move(args.begin(), args.end(), indices.begin());
360 
361   for (int i = 0; i < 3; ++i) {
362     if (i != 0 && parser.parseComma())
363       return failure();
364     if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() ||
365         parser.parseOperand(sizes[i]))
366       return failure();
367   }
368 
369   return parser.parseRParen();
370 }
371 
372 // Parses a Launch operation.
373 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
374 //                           `threads` `(` ssa-id-list `)` `in` ssa-reassignment
375 //                             (`args` ssa-reassignment `:` type-list)?
376 //                             region attr-dict?
377 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
378 static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
379   // Sizes of the grid and block.
380   SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes(
381       LaunchOp::kNumConfigOperands);
382   MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes);
383 
384   // Actual (data) operands passed to the kernel.
385   SmallVector<OpAsmParser::OperandType, 4> dataOperands;
386 
387   // Region arguments to be created.
388   SmallVector<OpAsmParser::OperandType, 16> regionArgs(
389       LaunchOp::kNumConfigRegionAttributes);
390   MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs);
391 
392   // Parse the size assignment segments: the first segment assigns grid sizes
393   // and defines values for block identifiers; the second segment assigns block
394   // sizes and defines values for thread identifiers.  In the region argument
395   // list, identifiers precede sizes, and block-related values precede
396   // thread-related values.
397   if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
398       parseSizeAssignment(parser, sizesRef.take_front(3),
399                           regionArgsRef.slice(6, 3),
400                           regionArgsRef.slice(0, 3)) ||
401       parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
402       parseSizeAssignment(parser, sizesRef.drop_front(3),
403                           regionArgsRef.slice(9, 3),
404                           regionArgsRef.slice(3, 3)) ||
405       parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
406                              result.operands))
407     return failure();
408 
409   // If kernel argument renaming segment is present, parse it.  When present,
410   // the segment should have at least one element.  If this segment is present,
411   // so is the trailing type list.  Parse it as well and use the parsed types
412   // to resolve the operands passed to the kernel arguments.
413   SmallVector<Type, 4> dataTypes;
414   if (!parser.parseOptionalKeyword(LaunchOp::getArgsKeyword())) {
415     llvm::SMLoc argsLoc = parser.getCurrentLocation();
416 
417     regionArgs.push_back({});
418     dataOperands.push_back({});
419     if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) ||
420         parser.parseEqual() || parser.parseOperand(dataOperands.back()))
421       return failure();
422 
423     while (!parser.parseOptionalComma()) {
424       regionArgs.push_back({});
425       dataOperands.push_back({});
426       if (parser.parseRegionArgument(regionArgs.back()) ||
427           parser.parseEqual() || parser.parseOperand(dataOperands.back()))
428         return failure();
429     }
430 
431     if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) ||
432         parser.resolveOperands(dataOperands, dataTypes, argsLoc,
433                                result.operands))
434       return failure();
435   }
436 
437   // Introduce the body region and parse it.  The region has
438   // kNumConfigRegionAttributes leading arguments that correspond to
439   // block/thread identifiers and grid/block sizes, all of the `index` type.
440   // Follow the actual kernel arguments.
441   Type index = parser.getBuilder().getIndexType();
442   dataTypes.insert(dataTypes.begin(), LaunchOp::kNumConfigRegionAttributes,
443                    index);
444   Region *body = result.addRegion();
445   return failure(parser.parseRegion(*body, regionArgs, dataTypes) ||
446                  parser.parseOptionalAttrDict(result.attributes));
447 }
448 
449 void LaunchOp::eraseKernelArgument(unsigned index) {
450   Block &entryBlock = body().front();
451   assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes &&
452          "kernel argument index overflow");
453   entryBlock.eraseArgument(kNumConfigRegionAttributes + index);
454   getOperation()->eraseOperand(kNumConfigOperands + index);
455 }
456 
457 namespace {
458 // Clone any known constants passed as operands to the kernel into its body.
459 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
460   using OpRewritePattern<LaunchOp>::OpRewritePattern;
461 
462   PatternMatchResult matchAndRewrite(LaunchOp launchOp,
463                                      PatternRewriter &rewriter) const override {
464     rewriter.startRootUpdate(launchOp);
465     PatternRewriter::InsertionGuard guard(rewriter);
466     rewriter.setInsertionPointToStart(&launchOp.body().front());
467 
468     // Traverse operands passed to kernel and check if some of them are known
469     // constants.  If so, clone the constant operation inside the kernel region
470     // and use it instead of passing the value from the parent region.  Perform
471     // the traversal in the inverse order to simplify index arithmetics when
472     // dropping arguments.
473     auto operands = launchOp.getKernelOperandValues();
474     auto kernelArgs = launchOp.getKernelArguments();
475     bool found = false;
476     for (unsigned i = operands.size(); i > 0; --i) {
477       unsigned index = i - 1;
478       Value operand = operands[index];
479       if (!isa_and_nonnull<ConstantOp>(operand.getDefiningOp()))
480         continue;
481 
482       found = true;
483       Value internalConstant =
484           rewriter.clone(*operand.getDefiningOp())->getResult(0);
485       Value kernelArg = *std::next(kernelArgs.begin(), index);
486       kernelArg.replaceAllUsesWith(internalConstant);
487       launchOp.eraseKernelArgument(index);
488     }
489 
490     if (!found) {
491       rewriter.cancelRootUpdate(launchOp);
492       return matchFailure();
493     }
494 
495     rewriter.finalizeRootUpdate(launchOp);
496     return matchSuccess();
497   }
498 };
499 } // end namespace
500 
501 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
502                                            MLIRContext *context) {
503   results.insert<PropagateConstantBounds>(context);
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // LaunchFuncOp
508 //===----------------------------------------------------------------------===//
509 
510 void LaunchFuncOp::build(Builder *builder, OperationState &result,
511                          GPUFuncOp kernelFunc, Value gridSizeX, Value gridSizeY,
512                          Value gridSizeZ, Value blockSizeX, Value blockSizeY,
513                          Value blockSizeZ, ValueRange kernelOperands) {
514   // Add grid and block sizes as op operands, followed by the data operands.
515   result.addOperands(
516       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
517   result.addOperands(kernelOperands);
518   result.addAttribute(getKernelAttrName(),
519                       builder->getStringAttr(kernelFunc.getName()));
520   auto kernelModule = kernelFunc.getParentOfType<ModuleOp>();
521   if (Optional<StringRef> kernelModuleName = kernelModule.getName())
522     result.addAttribute(getKernelModuleAttrName(),
523                         builder->getSymbolRefAttr(*kernelModuleName));
524 }
525 
526 void LaunchFuncOp::build(Builder *builder, OperationState &result,
527                          GPUFuncOp kernelFunc, KernelDim3 gridSize,
528                          KernelDim3 blockSize, ValueRange kernelOperands) {
529   build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
530         blockSize.x, blockSize.y, blockSize.z, kernelOperands);
531 }
532 
533 StringRef LaunchFuncOp::kernel() {
534   return getAttrOfType<StringAttr>(getKernelAttrName()).getValue();
535 }
536 
537 unsigned LaunchFuncOp::getNumKernelOperands() {
538   return getNumOperands() - kNumConfigOperands;
539 }
540 
541 StringRef LaunchFuncOp::getKernelModuleName() {
542   return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName())
543       .getRootReference();
544 }
545 
546 Value LaunchFuncOp::getKernelOperand(unsigned i) {
547   return getOperation()->getOperand(i + kNumConfigOperands);
548 }
549 
550 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
551   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
552 }
553 
554 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
555   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
556 }
557 
558 static LogicalResult verify(LaunchFuncOp op) {
559   auto module = op.getParentOfType<ModuleOp>();
560   if (!module)
561     return op.emitOpError("expected to belong to a module");
562 
563   if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName()))
564     return op.emitOpError(
565         "expected the closest surrounding module to have the '" +
566         GPUDialect::getContainerModuleAttrName() + "' attribute");
567 
568   auto kernelAttr = op.getAttrOfType<StringAttr>(op.getKernelAttrName());
569   if (!kernelAttr)
570     return op.emitOpError("string attribute '" + op.getKernelAttrName() +
571                           "' must be specified");
572 
573   auto kernelModuleAttr =
574       op.getAttrOfType<SymbolRefAttr>(op.getKernelModuleAttrName());
575   if (!kernelModuleAttr)
576     return op.emitOpError("symbol reference attribute '" +
577                           op.getKernelModuleAttrName() + "' must be specified");
578 
579   return success();
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // GPUFuncOp
584 //===----------------------------------------------------------------------===//
585 
586 /// Adds a workgroup attribution to "op" of the MemRef type with the given shape
587 /// and element type.
588 Value GPUFuncOp::addWorkgroupAttribution(ArrayRef<int64_t> shape,
589                                          Type elementType) {
590   unsigned pos = getNumFuncArguments() + getNumWorkgroupAttributions();
591   Block &bodyBlock = body().front();
592   Value attribution = bodyBlock.insertArgument(
593       std::next(bodyBlock.args_begin(), pos),
594       MemRefType::get(shape, elementType, /*affineMapComposition=*/{},
595                       GPUDialect::getWorkgroupAddressSpace()));
596   auto numWorkgroupBuffersAttr =
597       getAttrOfType<IntegerAttr>(getNumWorkgroupAttributionsAttrName());
598   setAttr(getNumWorkgroupAttributionsAttrName(),
599           IntegerAttr::get(numWorkgroupBuffersAttr.getType(),
600                            numWorkgroupBuffersAttr.getValue() + 1));
601   return attribution;
602 }
603 
604 void GPUFuncOp::build(Builder *builder, OperationState &result, StringRef name,
605                       FunctionType type, ArrayRef<Type> workgroupAttributions,
606                       ArrayRef<Type> privateAttributions,
607                       ArrayRef<NamedAttribute> attrs) {
608   result.addAttribute(SymbolTable::getSymbolAttrName(),
609                       builder->getStringAttr(name));
610   result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
611   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
612                       builder->getI64IntegerAttr(workgroupAttributions.size()));
613   result.addAttributes(attrs);
614   Region *body = result.addRegion();
615   Block *entryBlock = new Block;
616   entryBlock->addArguments(type.getInputs());
617   entryBlock->addArguments(workgroupAttributions);
618   entryBlock->addArguments(privateAttributions);
619 
620   body->getBlocks().push_back(entryBlock);
621 }
622 
623 /// Parses a GPU function memory attribution.
624 ///
625 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
626 ///                        (`private` `(` ssa-id-and-type-list `)`)?
627 ///
628 /// Note that this function parses only one of the two similar parts, with the
629 /// keyword provided as argument.
630 static ParseResult
631 parseAttributions(OpAsmParser &parser, StringRef keyword,
632                   SmallVectorImpl<OpAsmParser::OperandType> &args,
633                   SmallVectorImpl<Type> &argTypes) {
634   // If we could not parse the keyword, just assume empty list and succeed.
635   if (failed(parser.parseOptionalKeyword(keyword)))
636     return success();
637 
638   if (failed(parser.parseLParen()))
639     return failure();
640 
641   // Early exit for an empty list.
642   if (succeeded(parser.parseOptionalRParen()))
643     return success();
644 
645   do {
646     OpAsmParser::OperandType arg;
647     Type type;
648 
649     if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
650       return failure();
651 
652     args.push_back(arg);
653     argTypes.push_back(type);
654   } while (succeeded(parser.parseOptionalComma()));
655 
656   return parser.parseRParen();
657 }
658 
659 /// Parses a GPU function.
660 ///
661 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
662 ///                 (`->` function-result-list)? memory-attribution `kernel`?
663 ///                 function-attributes? region
664 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
665   SmallVector<OpAsmParser::OperandType, 8> entryArgs;
666   SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
667   SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
668   SmallVector<Type, 8> argTypes;
669   SmallVector<Type, 4> resultTypes;
670   bool isVariadic;
671 
672   // Parse the function name.
673   StringAttr nameAttr;
674   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
675                              result.attributes))
676     return failure();
677 
678   auto signatureLocation = parser.getCurrentLocation();
679   if (failed(impl::parseFunctionSignature(
680           parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
681           isVariadic, resultTypes, resultAttrs)))
682     return failure();
683 
684   if (entryArgs.empty() && !argTypes.empty())
685     return parser.emitError(signatureLocation)
686            << "gpu.func requires named arguments";
687 
688   // Construct the function type. More types will be added to the region, but
689   // not to the functiont type.
690   Builder &builder = parser.getBuilder();
691   auto type = builder.getFunctionType(argTypes, resultTypes);
692   result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
693 
694   // Parse workgroup memory attributions.
695   if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
696                                entryArgs, argTypes)))
697     return failure();
698 
699   // Store the number of operands we just parsed as the number of workgroup
700   // memory attributions.
701   unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs();
702   result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
703                       builder.getI64IntegerAttr(numWorkgroupAttrs));
704 
705   // Parse private memory attributions.
706   if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
707                                entryArgs, argTypes)))
708     return failure();
709 
710   // Parse the kernel attribute if present.
711   if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
712     result.addAttribute(GPUDialect::getKernelFuncAttrName(),
713                         builder.getUnitAttr());
714 
715   // Parse attributes.
716   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
717     return failure();
718   mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
719 
720   // Parse the region. If no argument names were provided, take all names
721   // (including those of attributions) from the entry block.
722   auto *body = result.addRegion();
723   return parser.parseRegion(*body, entryArgs, argTypes);
724 }
725 
726 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
727                               ArrayRef<BlockArgument> values) {
728   if (values.empty())
729     return;
730 
731   p << ' ' << keyword << '(';
732   interleaveComma(values, p,
733                   [&p](BlockArgument v) { p << v << " : " << v.getType(); });
734   p << ')';
735 }
736 
737 /// Prints a GPU Func op.
738 static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) {
739   p << GPUFuncOp::getOperationName() << ' ';
740   p.printSymbolName(op.getName());
741 
742   FunctionType type = op.getType();
743   impl::printFunctionSignature(p, op.getOperation(), type.getInputs(),
744                                /*isVariadic=*/false, type.getResults());
745 
746   printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions());
747   printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions());
748   if (op.isKernel())
749     p << ' ' << op.getKernelKeyword();
750 
751   impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(),
752                                 type.getNumResults(),
753                                 {op.getNumWorkgroupAttributionsAttrName(),
754                                  GPUDialect::getKernelFuncAttrName()});
755   p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
756 }
757 
758 void GPUFuncOp::setType(FunctionType newType) {
759   auto oldType = getType();
760   assert(newType.getNumResults() == oldType.getNumResults() &&
761          "unimplemented: changes to the number of results");
762 
763   SmallVector<char, 16> nameBuf;
764   for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++)
765     removeAttr(getArgAttrName(i, nameBuf));
766 
767   setAttr(getTypeAttrName(), TypeAttr::get(newType));
768 }
769 
770 /// Hook for FunctionLike verifier.
771 LogicalResult GPUFuncOp::verifyType() {
772   Type type = getTypeAttr().getValue();
773   if (!type.isa<FunctionType>())
774     return emitOpError("requires '" + getTypeAttrName() +
775                        "' attribute of function type");
776   return success();
777 }
778 
779 static LogicalResult verifyAttributions(Operation *op,
780                                         ArrayRef<BlockArgument> attributions,
781                                         unsigned memorySpace) {
782   for (Value v : attributions) {
783     auto type = v.getType().dyn_cast<MemRefType>();
784     if (!type)
785       return op->emitOpError() << "expected memref type in attribution";
786 
787     if (type.getMemorySpace() != memorySpace) {
788       return op->emitOpError()
789              << "expected memory space " << memorySpace << " in attribution";
790     }
791   }
792   return success();
793 }
794 
795 /// Verifies the body of the function.
796 LogicalResult GPUFuncOp::verifyBody() {
797   unsigned numFuncArguments = getNumArguments();
798   unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
799   unsigned numBlockArguments = front().getNumArguments();
800   if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
801     return emitOpError() << "expected at least "
802                          << numFuncArguments + numWorkgroupAttributions
803                          << " arguments to body region";
804 
805   ArrayRef<Type> funcArgTypes = getType().getInputs();
806   for (unsigned i = 0; i < numFuncArguments; ++i) {
807     Type blockArgType = front().getArgument(i).getType();
808     if (funcArgTypes[i] != blockArgType)
809       return emitOpError() << "expected body region argument #" << i
810                            << " to be of type " << funcArgTypes[i] << ", got "
811                            << blockArgType;
812   }
813 
814   if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
815                                 GPUDialect::getWorkgroupAddressSpace())) ||
816       failed(verifyAttributions(getOperation(), getPrivateAttributions(),
817                                 GPUDialect::getPrivateAddressSpace())))
818     return failure();
819 
820   return success();
821 }
822 
823 // Namespace avoids ambiguous ReturnOpOperandAdaptor.
824 namespace mlir {
825 namespace gpu {
826 #define GET_OP_CLASSES
827 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
828 } // namespace gpu
829 } // namespace mlir
830