1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the GPU kernel-related dialect and its operations.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Dialect/GPU/GPUDialect.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/Dialect/StandardOps/Ops.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/Function.h"
27 #include "mlir/IR/Module.h"
28 #include "mlir/IR/OpImplementation.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/StandardTypes.h"
31 
32 using namespace mlir;
33 using namespace mlir::gpu;
34 
35 //===----------------------------------------------------------------------===//
36 // GPUDialect
37 //===----------------------------------------------------------------------===//
38 
39 StringRef GPUDialect::getDialectName() { return "gpu"; }
40 
41 bool GPUDialect::isKernel(Operation *op) {
42   UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
43   return static_cast<bool>(isKernelAttr);
44 }
45 
46 GPUDialect::GPUDialect(MLIRContext *context)
47     : Dialect(getDialectName(), context) {
48   addOperations<LaunchOp, LaunchFuncOp,
49 #define GET_OP_LIST
50 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
51                 >();
52 }
53 
54 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
55                                                    NamedAttribute attr) {
56   if (!attr.second.isa<UnitAttr>() ||
57       !attr.first.is(getContainerModuleAttrName()))
58     return success();
59 
60   auto module = dyn_cast<ModuleOp>(op);
61   if (!module)
62     return op->emitError("expected '")
63            << getContainerModuleAttrName() << "' attribute to be attached to '"
64            << ModuleOp::getOperationName() << '\'';
65 
66   auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
67     // Ignore launches that are nested more or less deep than functions in the
68     // module we are currently checking.
69     if (!launchOp.getParentOp() ||
70         launchOp.getParentOp()->getParentOp() != module)
71       return success();
72 
73     // Ignore launch ops with missing attributes here. The errors will be
74     // reported by the verifiers of those ops.
75     if (!launchOp.getAttrOfType<StringAttr>(
76             LaunchFuncOp::getKernelAttrName()) ||
77         !launchOp.getAttrOfType<SymbolRefAttr>(
78             LaunchFuncOp::getKernelModuleAttrName()))
79       return success();
80 
81     // Check that `launch_func` refers to a well-formed GPU kernel module.
82     StringRef kernelModuleName = launchOp.getKernelModuleName();
83     auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName);
84     if (!kernelModule)
85       return launchOp.emitOpError()
86              << "kernel module '" << kernelModuleName << "' is undefined";
87     if (!kernelModule.getAttrOfType<UnitAttr>(
88             GPUDialect::getKernelModuleAttrName()))
89       return launchOp.emitOpError("module '")
90              << kernelModuleName << "' is missing the '"
91              << GPUDialect::getKernelModuleAttrName() << "' attribute";
92 
93     // Check that `launch_func` refers to a well-formed kernel function.
94     StringRef kernelName = launchOp.kernel();
95     Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
96     auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc);
97     auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
98     if (!kernelStdFunction && !kernelLLVMFunction)
99       return launchOp.emitOpError("kernel function '")
100              << kernelName << "' is undefined";
101     if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
102             GPUDialect::getKernelFuncAttrName()))
103       return launchOp.emitOpError("kernel function is missing the '")
104              << GPUDialect::getKernelFuncAttrName() << "' attribute";
105 
106     unsigned actualNumArguments = launchOp.getNumKernelOperands();
107     unsigned expectedNumArguments = kernelLLVMFunction
108                                         ? kernelLLVMFunction.getNumArguments()
109                                         : kernelStdFunction.getNumArguments();
110     if (expectedNumArguments != actualNumArguments)
111       return launchOp.emitOpError("got ")
112              << actualNumArguments << " kernel operands but expected "
113              << expectedNumArguments;
114 
115     // Due to the ordering of the current impl of lowering and LLVMLowering,
116     // type checks need to be temporarily disabled.
117     // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
118     // to encode target module" has landed.
119     // auto functionType = kernelFunc.getType();
120     // for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
121     //   if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
122     //     return emitOpError("type of function argument ")
123     //            << i << " does not match";
124     //   }
125     // }
126 
127     return success();
128   });
129 
130   return walkResult.wasInterrupted() ? failure() : success();
131 }
132 
133 template <typename T> static LogicalResult verifyIndexOp(T op) {
134   auto dimension = op.dimension();
135   if (dimension != "x" && dimension != "y" && dimension != "z")
136     return op.emitError("dimension \"") << dimension << "\" is invalid";
137   return success();
138 }
139 
140 #define GET_OP_CLASSES
141 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
142 
143 //===----------------------------------------------------------------------===//
144 // LaunchOp
145 //===----------------------------------------------------------------------===//
146 
147 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) {
148   SmallVector<Type, 4> types;
149   types.reserve(values.size());
150   for (Value *v : values)
151     types.push_back(v->getType());
152   return types;
153 }
154 
155 void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX,
156                      Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
157                      Value *blockSizeY, Value *blockSizeZ,
158                      ArrayRef<Value *> operands) {
159   // Add grid and block sizes as op operands, followed by the data operands.
160   result.addOperands(
161       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
162   result.addOperands(operands);
163 
164   // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
165   // where the first kNumConfigRegionAttributes arguments have `index` type and
166   // the rest have the same types as the data operands.
167   Region *kernelRegion = result.addRegion();
168   Block *body = new Block();
169   body->addArguments(
170       std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
171   body->addArguments(getValueTypes(operands));
172   kernelRegion->push_back(body);
173 }
174 
175 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); }
176 
177 KernelDim3 LaunchOp::getBlockIds() {
178   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
179   auto args = getBody().getBlocks().front().getArguments();
180   return KernelDim3{args[0], args[1], args[2]};
181 }
182 
183 KernelDim3 LaunchOp::getThreadIds() {
184   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
185   auto args = getBody().getBlocks().front().getArguments();
186   return KernelDim3{args[3], args[4], args[5]};
187 }
188 
189 KernelDim3 LaunchOp::getGridSize() {
190   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
191   auto args = getBody().getBlocks().front().getArguments();
192   return KernelDim3{args[6], args[7], args[8]};
193 }
194 
195 KernelDim3 LaunchOp::getBlockSize() {
196   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
197   auto args = getBody().getBlocks().front().getArguments();
198   return KernelDim3{args[9], args[10], args[11]};
199 }
200 
201 LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
202   return llvm::drop_begin(getOperands(), kNumConfigOperands);
203 }
204 
205 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
206   return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
207 }
208 
209 KernelDim3 LaunchOp::getGridSizeOperandValues() {
210   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
211 }
212 
213 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
214   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
215 }
216 
217 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
218   auto args = getBody().getBlocks().front().getArguments();
219   return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
220 }
221 
222 LogicalResult LaunchOp::verify() {
223   // Kernel launch takes kNumConfigOperands leading operands for grid/block
224   // sizes and transforms them into kNumConfigRegionAttributes region arguments
225   // for block/thread identifiers and grid/block sizes.
226   if (!getBody().empty()) {
227     Block &entryBlock = getBody().front();
228     if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands())
229       return emitOpError("unexpected number of region arguments");
230   }
231 
232   // Block terminators without successors are expected to exit the kernel region
233   // and must be `gpu.launch`.
234   for (Block &block : getBody()) {
235     if (block.empty())
236       continue;
237     if (block.back().getNumSuccessors() != 0)
238       continue;
239     if (!isa<gpu::Return>(&block.back())) {
240       return block.back()
241                  .emitError("expected 'gpu.terminator' or a terminator with "
242                             "successors")
243                  .attachNote(getLoc())
244              << "in '" << getOperationName() << "' body region";
245     }
246   }
247 
248   return success();
249 }
250 
251 // Pretty-print the kernel grid/block size assignment as
252 //   (%iter-x, %iter-y, %iter-z) in
253 //   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
254 // where %size-* and %iter-* will correspond to the body region arguments.
255 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
256                                 ArrayRef<Value *> operands, KernelDim3 ids) {
257   p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
258   p << *size.x << " = " << *operands[0] << ", ";
259   p << *size.y << " = " << *operands[1] << ", ";
260   p << *size.z << " = " << *operands[2] << ')';
261 }
262 
263 void LaunchOp::print(OpAsmPrinter &p) {
264   SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end());
265   ArrayRef<Value *> operands(operandContainer);
266 
267   // Print the launch configuration.
268   p << getOperationName() << ' ' << getBlocksKeyword();
269   printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds());
270   p << ' ' << getThreadsKeyword();
271   printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds());
272 
273   // From now on, the first kNumConfigOperands operands corresponding to grid
274   // and block sizes are irrelevant, so we can drop them.
275   operands = operands.drop_front(kNumConfigOperands);
276 
277   // Print the data argument remapping.
278   if (!getBody().empty() && !operands.empty()) {
279     p << ' ' << getArgsKeyword() << '(';
280     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
281       if (i != 0)
282         p << ", ";
283       p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
284         << " = " << *operands[i];
285     }
286     p << ") ";
287   }
288 
289   // Print the types of data arguments.
290   if (!operands.empty()) {
291     p << ": ";
292     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
293       if (i != 0)
294         p << ", ";
295       p << operands[i]->getType();
296     }
297   }
298 
299   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
300   p.printOptionalAttrDict(getAttrs());
301 }
302 
303 // Parse the size assignment blocks for blocks and threads.  These have the form
304 //   (%region_arg, %region_arg, %region_arg) in
305 //   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
306 // where %region_arg are percent-identifiers for the region arguments to be
307 // introduced further (SSA defs), and %operand are percent-identifiers for the
308 // SSA value uses.
309 static ParseResult
310 parseSizeAssignment(OpAsmParser &parser,
311                     MutableArrayRef<OpAsmParser::OperandType> sizes,
312                     MutableArrayRef<OpAsmParser::OperandType> regionSizes,
313                     MutableArrayRef<OpAsmParser::OperandType> indices) {
314   assert(indices.size() == 3 && "space for three indices expected");
315   SmallVector<OpAsmParser::OperandType, 3> args;
316   if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3,
317                                      OpAsmParser::Delimiter::Paren) ||
318       parser.parseKeyword("in") || parser.parseLParen())
319     return failure();
320   std::move(args.begin(), args.end(), indices.begin());
321 
322   for (int i = 0; i < 3; ++i) {
323     if (i != 0 && parser.parseComma())
324       return failure();
325     if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() ||
326         parser.parseOperand(sizes[i]))
327       return failure();
328   }
329 
330   return parser.parseRParen();
331 }
332 
333 // Parses a Launch operation.
334 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
335 //                           `threads` `(` ssa-id-list `)` `in` ssa-reassignment
336 //                             (`args` ssa-reassignment `:` type-list)?
337 //                             region attr-dict?
338 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
339 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
340   // Sizes of the grid and block.
341   SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
342       kNumConfigOperands);
343   MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes);
344 
345   // Actual (data) operands passed to the kernel.
346   SmallVector<OpAsmParser::OperandType, 4> dataOperands;
347 
348   // Region arguments to be created.
349   SmallVector<OpAsmParser::OperandType, 16> regionArgs(
350       kNumConfigRegionAttributes);
351   MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs);
352 
353   // Parse the size assignment segments: the first segment assigns grid sizes
354   // and defines values for block identifiers; the second segment assigns block
355   // sizes and defines values for thread identifiers.  In the region argument
356   // list, identifiers precede sizes, and block-related values precede
357   // thread-related values.
358   if (parser.parseKeyword(getBlocksKeyword().data()) ||
359       parseSizeAssignment(parser, sizesRef.take_front(3),
360                           regionArgsRef.slice(6, 3),
361                           regionArgsRef.slice(0, 3)) ||
362       parser.parseKeyword(getThreadsKeyword().data()) ||
363       parseSizeAssignment(parser, sizesRef.drop_front(3),
364                           regionArgsRef.slice(9, 3),
365                           regionArgsRef.slice(3, 3)) ||
366       parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
367                              result.operands))
368     return failure();
369 
370   // If kernel argument renaming segment is present, parse it.  When present,
371   // the segment should have at least one element.  If this segment is present,
372   // so is the trailing type list.  Parse it as well and use the parsed types
373   // to resolve the operands passed to the kernel arguments.
374   SmallVector<Type, 4> dataTypes;
375   if (!parser.parseOptionalKeyword(getArgsKeyword())) {
376     llvm::SMLoc argsLoc = parser.getCurrentLocation();
377 
378     regionArgs.push_back({});
379     dataOperands.push_back({});
380     if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) ||
381         parser.parseEqual() || parser.parseOperand(dataOperands.back()))
382       return failure();
383 
384     while (!parser.parseOptionalComma()) {
385       regionArgs.push_back({});
386       dataOperands.push_back({});
387       if (parser.parseRegionArgument(regionArgs.back()) ||
388           parser.parseEqual() || parser.parseOperand(dataOperands.back()))
389         return failure();
390     }
391 
392     if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) ||
393         parser.resolveOperands(dataOperands, dataTypes, argsLoc,
394                                result.operands))
395       return failure();
396   }
397 
398   // Introduce the body region and parse it.  The region has
399   // kNumConfigRegionAttributes leading arguments that correspond to
400   // block/thread identifiers and grid/block sizes, all of the `index` type.
401   // Follow the actual kernel arguments.
402   Type index = parser.getBuilder().getIndexType();
403   dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
404   Region *body = result.addRegion();
405   return failure(parser.parseRegion(*body, regionArgs, dataTypes) ||
406                  parser.parseOptionalAttributeDict(result.attributes));
407 }
408 
409 void LaunchOp::eraseKernelArgument(unsigned index) {
410   Block &entryBlock = getBody().front();
411   assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes &&
412          "kernel argument index overflow");
413   entryBlock.eraseArgument(kNumConfigRegionAttributes + index);
414   getOperation()->eraseOperand(kNumConfigOperands + index);
415 }
416 
417 namespace {
418 // Clone any known constants passed as operands to the kernel into its body.
419 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
420   using OpRewritePattern<LaunchOp>::OpRewritePattern;
421 
422   PatternMatchResult matchAndRewrite(LaunchOp launchOp,
423                                      PatternRewriter &rewriter) const override {
424     auto origInsertionPoint = rewriter.saveInsertionPoint();
425     rewriter.setInsertionPointToStart(&launchOp.getBody().front());
426 
427     // Traverse operands passed to kernel and check if some of them are known
428     // constants.  If so, clone the constant operation inside the kernel region
429     // and use it instead of passing the value from the parent region.  Perform
430     // the traversal in the inverse order to simplify index arithmetics when
431     // dropping arguments.
432     SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(),
433                                      launchOp.getKernelOperandValues().end());
434     SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(),
435                                        launchOp.getKernelArguments().end());
436     bool found = false;
437     for (unsigned i = operands.size(); i > 0; --i) {
438       unsigned index = i - 1;
439       Value *operand = operands[index];
440       if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
441         continue;
442       }
443 
444       found = true;
445       Value *internalConstant =
446           rewriter.clone(*operand->getDefiningOp())->getResult(0);
447       Value *kernelArg = kernelArgs[index];
448       kernelArg->replaceAllUsesWith(internalConstant);
449       launchOp.eraseKernelArgument(index);
450     }
451     rewriter.restoreInsertionPoint(origInsertionPoint);
452 
453     if (!found)
454       return matchFailure();
455 
456     rewriter.updatedRootInPlace(launchOp);
457     return matchSuccess();
458   }
459 };
460 } // end namespace
461 
462 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
463                                            MLIRContext *context) {
464   results.insert<PropagateConstantBounds>(context);
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // LaunchFuncOp
469 //===----------------------------------------------------------------------===//
470 
471 void LaunchFuncOp::build(Builder *builder, OperationState &result,
472                          FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
473                          Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
474                          Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
475   // Add grid and block sizes as op operands, followed by the data operands.
476   result.addOperands(
477       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
478   result.addOperands(kernelOperands);
479   result.addAttribute(getKernelAttrName(),
480                       builder->getStringAttr(kernelFunc.getName()));
481   auto kernelModule = kernelFunc.getParentOfType<ModuleOp>();
482   if (Optional<StringRef> kernelModuleName = kernelModule.getName())
483     result.addAttribute(getKernelModuleAttrName(),
484                         builder->getSymbolRefAttr(*kernelModuleName));
485 }
486 
487 void LaunchFuncOp::build(Builder *builder, OperationState &result,
488                          FuncOp kernelFunc, KernelDim3 gridSize,
489                          KernelDim3 blockSize,
490                          ArrayRef<Value *> kernelOperands) {
491   build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
492         blockSize.x, blockSize.y, blockSize.z, kernelOperands);
493 }
494 
495 StringRef LaunchFuncOp::kernel() {
496   return getAttrOfType<StringAttr>(getKernelAttrName()).getValue();
497 }
498 
499 unsigned LaunchFuncOp::getNumKernelOperands() {
500   return getNumOperands() - kNumConfigOperands;
501 }
502 
503 StringRef LaunchFuncOp::getKernelModuleName() {
504   return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue();
505 }
506 
507 Value *LaunchFuncOp::getKernelOperand(unsigned i) {
508   return getOperation()->getOperand(i + kNumConfigOperands);
509 }
510 
511 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
512   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
513 }
514 
515 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
516   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
517 }
518 
519 LogicalResult LaunchFuncOp::verify() {
520   auto module = getParentOfType<ModuleOp>();
521   if (!module)
522     return emitOpError("expected to belong to a module");
523 
524   if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName()))
525     return emitOpError("expected the closest surrounding module to have the '" +
526                        GPUDialect::getContainerModuleAttrName() +
527                        "' attribute");
528 
529   auto kernelAttr = getAttrOfType<StringAttr>(getKernelAttrName());
530   if (!kernelAttr)
531     return emitOpError("string attribute '" + getKernelAttrName() +
532                        "' must be specified");
533 
534   auto kernelModuleAttr =
535       getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName());
536   if (!kernelModuleAttr)
537     return emitOpError("symbol reference attribute '" +
538                        getKernelModuleAttrName() + "' must be specified");
539 
540   return success();
541 }
542