1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the GPU kernel-related dialect and its operations.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Dialect/GPU/GPUDialect.h"
23 #include "mlir/Dialect/StandardOps/Ops.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Function.h"
26 #include "mlir/IR/Module.h"
27 #include "mlir/IR/OpImplementation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/StandardTypes.h"
30 
31 using namespace mlir;
32 using namespace mlir::gpu;
33 
34 //===----------------------------------------------------------------------===//
35 // GPUDialect
36 //===----------------------------------------------------------------------===//
37 
38 StringRef GPUDialect::getDialectName() { return "gpu"; }
39 
40 bool GPUDialect::isKernel(FuncOp function) {
41   UnitAttr isKernelAttr =
42       function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
43   return static_cast<bool>(isKernelAttr);
44 }
45 
46 GPUDialect::GPUDialect(MLIRContext *context)
47     : Dialect(getDialectName(), context) {
48   addOperations<LaunchOp, LaunchFuncOp,
49 #define GET_OP_LIST
50 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
51                 >();
52 }
53 
54 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
55                                                    NamedAttribute attr) {
56   if (!attr.second.isa<UnitAttr>() ||
57       !attr.first.is(getContainerModuleAttrName()))
58     return success();
59 
60   auto module = dyn_cast<ModuleOp>(op);
61   if (!module)
62     return op->emitError("expected '")
63            << getContainerModuleAttrName() << "' attribute to be attached to '"
64            << ModuleOp::getOperationName() << '\'';
65 
66   auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
67     // Ignore launches that are nested more or less deep than functions in the
68     // module we are currently checking.
69     if (!launchOp.getParentOp() ||
70         launchOp.getParentOp()->getParentOp() != module)
71       return success();
72 
73     // Ignore launch ops with missing attributes here. The errors will be
74     // reported by the verifiers of those ops.
75     if (!launchOp.getAttrOfType<StringAttr>(
76             LaunchFuncOp::getKernelAttrName()) ||
77         !launchOp.getAttrOfType<SymbolRefAttr>(
78             LaunchFuncOp::getKernelModuleAttrName()))
79       return success();
80 
81     // Check that `launch_func` refers to a well-formed GPU kernel module.
82     StringRef kernelModuleName = launchOp.getKernelModuleName();
83     auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName);
84     if (!kernelModule)
85       return launchOp.emitOpError()
86              << "kernel module '" << kernelModuleName << "' is undefined";
87     if (!kernelModule.getAttrOfType<UnitAttr>(
88             GPUDialect::getKernelModuleAttrName()))
89       return launchOp.emitOpError("module '")
90              << kernelModuleName << "' is missing the '"
91              << GPUDialect::getKernelModuleAttrName() << "' attribute";
92 
93     // Check that `launch_func` refers to a well-formed kernel function.
94     StringRef kernelName = launchOp.kernel();
95     auto kernelFunction = kernelModule.lookupSymbol<FuncOp>(kernelName);
96     if (!kernelFunction)
97       return launchOp.emitOpError("kernel function '")
98              << kernelName << "' is undefined";
99     if (!kernelFunction.getAttrOfType<mlir::UnitAttr>(
100             GPUDialect::getKernelFuncAttrName()))
101       return launchOp.emitOpError("kernel function is missing the '")
102              << GPUDialect::getKernelFuncAttrName() << "' attribute";
103     if (launchOp.getNumKernelOperands() != kernelFunction.getNumArguments())
104       return launchOp.emitOpError("got ") << launchOp.getNumKernelOperands()
105                                           << " kernel operands but expected "
106                                           << kernelFunction.getNumArguments();
107 
108     // Due to the ordering of the current impl of lowering and LLVMLowering,
109     // type checks need to be temporarily disabled.
110     // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
111     // to encode target module" has landed.
112     // auto functionType = kernelFunc.getType();
113     // for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
114     //   if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
115     //     return emitOpError("type of function argument ")
116     //            << i << " does not match";
117     //   }
118     // }
119 
120     return success();
121   });
122 
123   return walkResult.wasInterrupted() ? failure() : success();
124 }
125 
126 template <typename T> static LogicalResult verifyIndexOp(T op) {
127   auto dimension = op.dimension();
128   if (dimension != "x" && dimension != "y" && dimension != "z")
129     return op.emitError("dimension \"") << dimension << "\" is invalid";
130   return success();
131 }
132 
133 #define GET_OP_CLASSES
134 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
135 
136 //===----------------------------------------------------------------------===//
137 // LaunchOp
138 //===----------------------------------------------------------------------===//
139 
140 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) {
141   SmallVector<Type, 4> types;
142   types.reserve(values.size());
143   for (Value *v : values)
144     types.push_back(v->getType());
145   return types;
146 }
147 
148 void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX,
149                      Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
150                      Value *blockSizeY, Value *blockSizeZ,
151                      ArrayRef<Value *> operands) {
152   // Add grid and block sizes as op operands, followed by the data operands.
153   result.addOperands(
154       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
155   result.addOperands(operands);
156 
157   // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
158   // where the first kNumConfigRegionAttributes arguments have `index` type and
159   // the rest have the same types as the data operands.
160   Region *kernelRegion = result.addRegion();
161   Block *body = new Block();
162   body->addArguments(
163       std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
164   body->addArguments(getValueTypes(operands));
165   kernelRegion->push_back(body);
166 }
167 
168 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); }
169 
170 KernelDim3 LaunchOp::getBlockIds() {
171   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
172   auto args = getBody().getBlocks().front().getArguments();
173   return KernelDim3{args[0], args[1], args[2]};
174 }
175 
176 KernelDim3 LaunchOp::getThreadIds() {
177   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
178   auto args = getBody().getBlocks().front().getArguments();
179   return KernelDim3{args[3], args[4], args[5]};
180 }
181 
182 KernelDim3 LaunchOp::getGridSize() {
183   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
184   auto args = getBody().getBlocks().front().getArguments();
185   return KernelDim3{args[6], args[7], args[8]};
186 }
187 
188 KernelDim3 LaunchOp::getBlockSize() {
189   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
190   auto args = getBody().getBlocks().front().getArguments();
191   return KernelDim3{args[9], args[10], args[11]};
192 }
193 
194 LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
195   return llvm::drop_begin(getOperands(), kNumConfigOperands);
196 }
197 
198 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
199   return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
200 }
201 
202 KernelDim3 LaunchOp::getGridSizeOperandValues() {
203   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
204 }
205 
206 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
207   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
208 }
209 
210 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
211   auto args = getBody().getBlocks().front().getArguments();
212   return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
213 }
214 
215 LogicalResult LaunchOp::verify() {
216   // Kernel launch takes kNumConfigOperands leading operands for grid/block
217   // sizes and transforms them into kNumConfigRegionAttributes region arguments
218   // for block/thread identifiers and grid/block sizes.
219   if (!getBody().empty()) {
220     Block &entryBlock = getBody().front();
221     if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands())
222       return emitOpError("unexpected number of region arguments");
223   }
224 
225   // Block terminators without successors are expected to exit the kernel region
226   // and must be `gpu.launch`.
227   for (Block &block : getBody()) {
228     if (block.empty())
229       continue;
230     if (block.back().getNumSuccessors() != 0)
231       continue;
232     if (!isa<gpu::Return>(&block.back())) {
233       return block.back()
234                  .emitError("expected 'gpu.terminator' or a terminator with "
235                             "successors")
236                  .attachNote(getLoc())
237              << "in '" << getOperationName() << "' body region";
238     }
239   }
240 
241   return success();
242 }
243 
244 // Pretty-print the kernel grid/block size assignment as
245 //   (%iter-x, %iter-y, %iter-z) in
246 //   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
247 // where %size-* and %iter-* will correspond to the body region arguments.
248 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
249                                 ArrayRef<Value *> operands, KernelDim3 ids) {
250   p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
251   p << *size.x << " = " << *operands[0] << ", ";
252   p << *size.y << " = " << *operands[1] << ", ";
253   p << *size.z << " = " << *operands[2] << ')';
254 }
255 
256 void LaunchOp::print(OpAsmPrinter &p) {
257   SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end());
258   ArrayRef<Value *> operands(operandContainer);
259 
260   // Print the launch configuration.
261   p << getOperationName() << ' ' << getBlocksKeyword();
262   printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds());
263   p << ' ' << getThreadsKeyword();
264   printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds());
265 
266   // From now on, the first kNumConfigOperands operands corresponding to grid
267   // and block sizes are irrelevant, so we can drop them.
268   operands = operands.drop_front(kNumConfigOperands);
269 
270   // Print the data argument remapping.
271   if (!getBody().empty() && !operands.empty()) {
272     p << ' ' << getArgsKeyword() << '(';
273     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
274       if (i != 0)
275         p << ", ";
276       p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
277         << " = " << *operands[i];
278     }
279     p << ") ";
280   }
281 
282   // Print the types of data arguments.
283   if (!operands.empty()) {
284     p << ": ";
285     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
286       if (i != 0)
287         p << ", ";
288       p << operands[i]->getType();
289     }
290   }
291 
292   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
293   p.printOptionalAttrDict(getAttrs());
294 }
295 
296 // Parse the size assignment blocks for blocks and threads.  These have the form
297 //   (%region_arg, %region_arg, %region_arg) in
298 //   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
299 // where %region_arg are percent-identifiers for the region arguments to be
300 // introduced further (SSA defs), and %operand are percent-identifiers for the
301 // SSA value uses.
302 static ParseResult
303 parseSizeAssignment(OpAsmParser &parser,
304                     MutableArrayRef<OpAsmParser::OperandType> sizes,
305                     MutableArrayRef<OpAsmParser::OperandType> regionSizes,
306                     MutableArrayRef<OpAsmParser::OperandType> indices) {
307   assert(indices.size() == 3 && "space for three indices expected");
308   SmallVector<OpAsmParser::OperandType, 3> args;
309   if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3,
310                                      OpAsmParser::Delimiter::Paren) ||
311       parser.parseKeyword("in") || parser.parseLParen())
312     return failure();
313   std::move(args.begin(), args.end(), indices.begin());
314 
315   for (int i = 0; i < 3; ++i) {
316     if (i != 0 && parser.parseComma())
317       return failure();
318     if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() ||
319         parser.parseOperand(sizes[i]))
320       return failure();
321   }
322 
323   return parser.parseRParen();
324 }
325 
326 // Parses a Launch operation.
327 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
328 //                           `threads` `(` ssa-id-list `)` `in` ssa-reassignment
329 //                             (`args` ssa-reassignment `:` type-list)?
330 //                             region attr-dict?
331 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
332 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
333   // Sizes of the grid and block.
334   SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
335       kNumConfigOperands);
336   MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes);
337 
338   // Actual (data) operands passed to the kernel.
339   SmallVector<OpAsmParser::OperandType, 4> dataOperands;
340 
341   // Region arguments to be created.
342   SmallVector<OpAsmParser::OperandType, 16> regionArgs(
343       kNumConfigRegionAttributes);
344   MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs);
345 
346   // Parse the size assignment segments: the first segment assigns grid sizes
347   // and defines values for block identifiers; the second segment assigns block
348   // sizes and defines values for thread identifiers.  In the region argument
349   // list, identifiers precede sizes, and block-related values precede
350   // thread-related values.
351   if (parser.parseKeyword(getBlocksKeyword().data()) ||
352       parseSizeAssignment(parser, sizesRef.take_front(3),
353                           regionArgsRef.slice(6, 3),
354                           regionArgsRef.slice(0, 3)) ||
355       parser.parseKeyword(getThreadsKeyword().data()) ||
356       parseSizeAssignment(parser, sizesRef.drop_front(3),
357                           regionArgsRef.slice(9, 3),
358                           regionArgsRef.slice(3, 3)) ||
359       parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
360                              result.operands))
361     return failure();
362 
363   // If kernel argument renaming segment is present, parse it.  When present,
364   // the segment should have at least one element.  If this segment is present,
365   // so is the trailing type list.  Parse it as well and use the parsed types
366   // to resolve the operands passed to the kernel arguments.
367   SmallVector<Type, 4> dataTypes;
368   if (!parser.parseOptionalKeyword(getArgsKeyword())) {
369     llvm::SMLoc argsLoc = parser.getCurrentLocation();
370 
371     regionArgs.push_back({});
372     dataOperands.push_back({});
373     if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) ||
374         parser.parseEqual() || parser.parseOperand(dataOperands.back()))
375       return failure();
376 
377     while (!parser.parseOptionalComma()) {
378       regionArgs.push_back({});
379       dataOperands.push_back({});
380       if (parser.parseRegionArgument(regionArgs.back()) ||
381           parser.parseEqual() || parser.parseOperand(dataOperands.back()))
382         return failure();
383     }
384 
385     if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) ||
386         parser.resolveOperands(dataOperands, dataTypes, argsLoc,
387                                result.operands))
388       return failure();
389   }
390 
391   // Introduce the body region and parse it.  The region has
392   // kNumConfigRegionAttributes leading arguments that correspond to
393   // block/thread identifiers and grid/block sizes, all of the `index` type.
394   // Follow the actual kernel arguments.
395   Type index = parser.getBuilder().getIndexType();
396   dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
397   Region *body = result.addRegion();
398   return failure(parser.parseRegion(*body, regionArgs, dataTypes) ||
399                  parser.parseOptionalAttributeDict(result.attributes));
400 }
401 
402 void LaunchOp::eraseKernelArgument(unsigned index) {
403   Block &entryBlock = getBody().front();
404   assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes &&
405          "kernel argument index overflow");
406   entryBlock.eraseArgument(kNumConfigRegionAttributes + index);
407   getOperation()->eraseOperand(kNumConfigOperands + index);
408 }
409 
410 namespace {
411 // Clone any known constants passed as operands to the kernel into its body.
412 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
413   using OpRewritePattern<LaunchOp>::OpRewritePattern;
414 
415   PatternMatchResult matchAndRewrite(LaunchOp launchOp,
416                                      PatternRewriter &rewriter) const override {
417     auto origInsertionPoint = rewriter.saveInsertionPoint();
418     rewriter.setInsertionPointToStart(&launchOp.getBody().front());
419 
420     // Traverse operands passed to kernel and check if some of them are known
421     // constants.  If so, clone the constant operation inside the kernel region
422     // and use it instead of passing the value from the parent region.  Perform
423     // the traversal in the inverse order to simplify index arithmetics when
424     // dropping arguments.
425     SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(),
426                                      launchOp.getKernelOperandValues().end());
427     SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(),
428                                        launchOp.getKernelArguments().end());
429     bool found = false;
430     for (unsigned i = operands.size(); i > 0; --i) {
431       unsigned index = i - 1;
432       Value *operand = operands[index];
433       if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
434         continue;
435       }
436 
437       found = true;
438       Value *internalConstant =
439           rewriter.clone(*operand->getDefiningOp())->getResult(0);
440       Value *kernelArg = kernelArgs[index];
441       kernelArg->replaceAllUsesWith(internalConstant);
442       launchOp.eraseKernelArgument(index);
443     }
444     rewriter.restoreInsertionPoint(origInsertionPoint);
445 
446     if (!found)
447       return matchFailure();
448 
449     rewriter.updatedRootInPlace(launchOp);
450     return matchSuccess();
451   }
452 };
453 } // end namespace
454 
455 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
456                                            MLIRContext *context) {
457   results.insert<PropagateConstantBounds>(context);
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // LaunchFuncOp
462 //===----------------------------------------------------------------------===//
463 
464 void LaunchFuncOp::build(Builder *builder, OperationState &result,
465                          FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
466                          Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
467                          Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
468   // Add grid and block sizes as op operands, followed by the data operands.
469   result.addOperands(
470       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
471   result.addOperands(kernelOperands);
472   result.addAttribute(getKernelAttrName(),
473                       builder->getStringAttr(kernelFunc.getName()));
474   auto kernelModule = kernelFunc.getParentOfType<ModuleOp>();
475   if (Optional<StringRef> kernelModuleName = kernelModule.getName())
476     result.addAttribute(getKernelModuleAttrName(),
477                         builder->getSymbolRefAttr(*kernelModuleName));
478 }
479 
480 void LaunchFuncOp::build(Builder *builder, OperationState &result,
481                          FuncOp kernelFunc, KernelDim3 gridSize,
482                          KernelDim3 blockSize,
483                          ArrayRef<Value *> kernelOperands) {
484   build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
485         blockSize.x, blockSize.y, blockSize.z, kernelOperands);
486 }
487 
488 StringRef LaunchFuncOp::kernel() {
489   return getAttrOfType<StringAttr>(getKernelAttrName()).getValue();
490 }
491 
492 unsigned LaunchFuncOp::getNumKernelOperands() {
493   return getNumOperands() - kNumConfigOperands;
494 }
495 
496 StringRef LaunchFuncOp::getKernelModuleName() {
497   return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue();
498 }
499 
500 Value *LaunchFuncOp::getKernelOperand(unsigned i) {
501   return getOperation()->getOperand(i + kNumConfigOperands);
502 }
503 
504 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
505   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
506 }
507 
508 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
509   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
510 }
511 
512 LogicalResult LaunchFuncOp::verify() {
513   auto module = getParentOfType<ModuleOp>();
514   if (!module)
515     return emitOpError("expected to belong to a module");
516 
517   if (!module.getAttrOfType<UnitAttr>(GPUDialect::getContainerModuleAttrName()))
518     return emitOpError("expected the closest surrounding module to have the '" +
519                        GPUDialect::getContainerModuleAttrName() +
520                        "' attribute");
521 
522   auto kernelAttr = getAttrOfType<StringAttr>(getKernelAttrName());
523   if (!kernelAttr)
524     return emitOpError("string attribute '" + getKernelAttrName() +
525                        "' must be specified");
526 
527   auto kernelModuleAttr =
528       getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName());
529   if (!kernelModuleAttr)
530     return emitOpError("symbol reference attribute '" +
531                        getKernelModuleAttrName() + "' must be specified");
532 
533   return success();
534 }
535