1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the GPU kernel-related dialect and its operations.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Dialect/GPU/GPUDialect.h"
23 #include "mlir/Dialect/StandardOps/Ops.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Function.h"
26 #include "mlir/IR/Module.h"
27 #include "mlir/IR/OpImplementation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/StandardTypes.h"
30 
31 using namespace mlir;
32 using namespace mlir::gpu;
33 
34 StringRef GPUDialect::getDialectName() { return "gpu"; }
35 
36 bool GPUDialect::isKernel(FuncOp function) {
37   UnitAttr isKernelAttr =
38       function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
39   return static_cast<bool>(isKernelAttr);
40 }
41 
42 GPUDialect::GPUDialect(MLIRContext *context)
43     : Dialect(getDialectName(), context) {
44   addOperations<LaunchOp, LaunchFuncOp,
45 #define GET_OP_LIST
46 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
47                 >();
48 }
49 
50 template <typename T> static LogicalResult verifyIndexOp(T op) {
51   auto dimension = op.dimension();
52   if (dimension != "x" && dimension != "y" && dimension != "z")
53     return op.emitError("dimension \"") << dimension << "\" is invalid";
54   return success();
55 }
56 
57 #define GET_OP_CLASSES
58 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
59 
60 //===----------------------------------------------------------------------===//
61 // LaunchOp
62 //===----------------------------------------------------------------------===//
63 
64 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) {
65   SmallVector<Type, 4> types;
66   types.reserve(values.size());
67   for (Value *v : values)
68     types.push_back(v->getType());
69   return types;
70 }
71 
72 void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX,
73                      Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
74                      Value *blockSizeY, Value *blockSizeZ,
75                      ArrayRef<Value *> operands) {
76   // Add grid and block sizes as op operands, followed by the data operands.
77   result.addOperands(
78       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
79   result.addOperands(operands);
80 
81   // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
82   // where the first kNumConfigRegionAttributes arguments have `index` type and
83   // the rest have the same types as the data operands.
84   Region *kernelRegion = result.addRegion();
85   Block *body = new Block();
86   body->addArguments(
87       std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
88   body->addArguments(getValueTypes(operands));
89   kernelRegion->push_back(body);
90 }
91 
92 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); }
93 
94 KernelDim3 LaunchOp::getBlockIds() {
95   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
96   auto args = getBody().getBlocks().front().getArguments();
97   return KernelDim3{args[0], args[1], args[2]};
98 }
99 
100 KernelDim3 LaunchOp::getThreadIds() {
101   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
102   auto args = getBody().getBlocks().front().getArguments();
103   return KernelDim3{args[3], args[4], args[5]};
104 }
105 
106 KernelDim3 LaunchOp::getGridSize() {
107   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
108   auto args = getBody().getBlocks().front().getArguments();
109   return KernelDim3{args[6], args[7], args[8]};
110 }
111 
112 KernelDim3 LaunchOp::getBlockSize() {
113   assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
114   auto args = getBody().getBlocks().front().getArguments();
115   return KernelDim3{args[9], args[10], args[11]};
116 }
117 
118 LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
119   return llvm::drop_begin(getOperands(), kNumConfigOperands);
120 }
121 
122 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
123   return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
124 }
125 
126 KernelDim3 LaunchOp::getGridSizeOperandValues() {
127   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
128 }
129 
130 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
131   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
132 }
133 
134 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
135   auto args = getBody().getBlocks().front().getArguments();
136   return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
137 }
138 
139 LogicalResult LaunchOp::verify() {
140   // Kernel launch takes kNumConfigOperands leading operands for grid/block
141   // sizes and transforms them into kNumConfigRegionAttributes region arguments
142   // for block/thread identifiers and grid/block sizes.
143   if (!getBody().empty()) {
144     Block &entryBlock = getBody().front();
145     if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands())
146       return emitOpError("unexpected number of region arguments");
147   }
148 
149   // Block terminators without successors are expected to exit the kernel region
150   // and must be `gpu.launch`.
151   for (Block &block : getBody()) {
152     if (block.empty())
153       continue;
154     if (block.back().getNumSuccessors() != 0)
155       continue;
156     if (!isa<gpu::Return>(&block.back())) {
157       return block.back()
158                  .emitError("expected 'gpu.terminator' or a terminator with "
159                             "successors")
160                  .attachNote(getLoc())
161              << "in '" << getOperationName() << "' body region";
162     }
163   }
164 
165   return success();
166 }
167 
168 // Pretty-print the kernel grid/block size assignment as
169 //   (%iter-x, %iter-y, %iter-z) in
170 //   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
171 // where %size-* and %iter-* will correspond to the body region arguments.
172 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
173                                 ArrayRef<Value *> operands, KernelDim3 ids) {
174   p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
175   p << *size.x << " = " << *operands[0] << ", ";
176   p << *size.y << " = " << *operands[1] << ", ";
177   p << *size.z << " = " << *operands[2] << ')';
178 }
179 
180 void LaunchOp::print(OpAsmPrinter &p) {
181   SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end());
182   ArrayRef<Value *> operands(operandContainer);
183 
184   // Print the launch configuration.
185   p << getOperationName() << ' ' << getBlocksKeyword();
186   printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds());
187   p << ' ' << getThreadsKeyword();
188   printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds());
189 
190   // From now on, the first kNumConfigOperands operands corresponding to grid
191   // and block sizes are irrelevant, so we can drop them.
192   operands = operands.drop_front(kNumConfigOperands);
193 
194   // Print the data argument remapping.
195   if (!getBody().empty() && !operands.empty()) {
196     p << ' ' << getArgsKeyword() << '(';
197     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
198       if (i != 0)
199         p << ", ";
200       p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
201         << " = " << *operands[i];
202     }
203     p << ") ";
204   }
205 
206   // Print the types of data arguments.
207   if (!operands.empty()) {
208     p << ": ";
209     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
210       if (i != 0)
211         p << ", ";
212       p << operands[i]->getType();
213     }
214   }
215 
216   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
217   p.printOptionalAttrDict(getAttrs());
218 }
219 
220 // Parse the size assignment blocks for blocks and threads.  These have the form
221 //   (%region_arg, %region_arg, %region_arg) in
222 //   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
223 // where %region_arg are percent-identifiers for the region arguments to be
224 // introduced futher (SSA defs), and %operand are percent-identifiers for the
225 // SSA value uses.
226 static ParseResult
227 parseSizeAssignment(OpAsmParser &parser,
228                     MutableArrayRef<OpAsmParser::OperandType> sizes,
229                     MutableArrayRef<OpAsmParser::OperandType> regionSizes,
230                     MutableArrayRef<OpAsmParser::OperandType> indices) {
231   assert(indices.size() == 3 && "space for three indices expected");
232   SmallVector<OpAsmParser::OperandType, 3> args;
233   if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3,
234                                      OpAsmParser::Delimiter::Paren) ||
235       parser.parseKeyword("in") || parser.parseLParen())
236     return failure();
237   std::move(args.begin(), args.end(), indices.begin());
238 
239   for (int i = 0; i < 3; ++i) {
240     if (i != 0 && parser.parseComma())
241       return failure();
242     if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() ||
243         parser.parseOperand(sizes[i]))
244       return failure();
245   }
246 
247   return parser.parseRParen();
248 }
249 
250 // Parses a Launch operation.
251 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
252 //                           `threads` `(` ssa-id-list `)` `in` ssa-reassignment
253 //                             (`args` ssa-reassignment `:` type-list)?
254 //                             region attr-dict?
255 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
256 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
257   // Sizes of the grid and block.
258   SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
259       kNumConfigOperands);
260   MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes);
261 
262   // Actual (data) operands passed to the kernel.
263   SmallVector<OpAsmParser::OperandType, 4> dataOperands;
264 
265   // Region arguments to be created.
266   SmallVector<OpAsmParser::OperandType, 16> regionArgs(
267       kNumConfigRegionAttributes);
268   MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs);
269 
270   // Parse the size assignment segments: the first segment assigns grid siezs
271   // and defines values for block identifiers; the second segment assigns block
272   // sies and defines values for thread identifiers.  In the region argument
273   // list, identifiers preceed sizes, and block-related values preceed
274   // thread-related values.
275   if (parser.parseKeyword(getBlocksKeyword().data()) ||
276       parseSizeAssignment(parser, sizesRef.take_front(3),
277                           regionArgsRef.slice(6, 3),
278                           regionArgsRef.slice(0, 3)) ||
279       parser.parseKeyword(getThreadsKeyword().data()) ||
280       parseSizeAssignment(parser, sizesRef.drop_front(3),
281                           regionArgsRef.slice(9, 3),
282                           regionArgsRef.slice(3, 3)) ||
283       parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
284                              result.operands))
285     return failure();
286 
287   // If kernel argument renaming segment is present, parse it.  When present,
288   // the segment should have at least one element.  If this segment is present,
289   // so is the trailing type list.  Parse it as well and use the parsed types
290   // to resolve the operands passed to the kernel arguments.
291   SmallVector<Type, 4> dataTypes;
292   if (!parser.parseOptionalKeyword(getArgsKeyword())) {
293     llvm::SMLoc argsLoc = parser.getCurrentLocation();
294 
295     regionArgs.push_back({});
296     dataOperands.push_back({});
297     if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) ||
298         parser.parseEqual() || parser.parseOperand(dataOperands.back()))
299       return failure();
300 
301     while (!parser.parseOptionalComma()) {
302       regionArgs.push_back({});
303       dataOperands.push_back({});
304       if (parser.parseRegionArgument(regionArgs.back()) ||
305           parser.parseEqual() || parser.parseOperand(dataOperands.back()))
306         return failure();
307     }
308 
309     if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) ||
310         parser.resolveOperands(dataOperands, dataTypes, argsLoc,
311                                result.operands))
312       return failure();
313   }
314 
315   // Introduce the body region and parse it.  The region has
316   // kNumConfigRegionAttributes leading arguments that correspond to
317   // block/thread identifiers and grid/block sizes, all of the `index` type.
318   // Follow the actual kernel arguments.
319   Type index = parser.getBuilder().getIndexType();
320   dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
321   Region *body = result.addRegion();
322   return failure(parser.parseRegion(*body, regionArgs, dataTypes) ||
323                  parser.parseOptionalAttributeDict(result.attributes));
324 }
325 
326 void LaunchOp::eraseKernelArgument(unsigned index) {
327   Block &entryBlock = getBody().front();
328   assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes &&
329          "kernel argument index overflow");
330   entryBlock.eraseArgument(kNumConfigRegionAttributes + index);
331   getOperation()->eraseOperand(kNumConfigOperands + index);
332 }
333 
334 namespace {
335 // Clone any known constants passed as operands to the kernel into its body.
336 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
337   using OpRewritePattern<LaunchOp>::OpRewritePattern;
338 
339   PatternMatchResult matchAndRewrite(LaunchOp launchOp,
340                                      PatternRewriter &rewriter) const override {
341     auto oringInsertionPoint = rewriter.saveInsertionPoint();
342     rewriter.setInsertionPointToStart(&launchOp.getBody().front());
343 
344     // Traverse operands passed to kernel and check if some of them are known
345     // constants.  If so, clone the constant operation inside the kernel region
346     // and use it instead of passing the value from the parent region.  Perform
347     // the traversal in the inverse order to simplify index arithmetics when
348     // dropping arguments.
349     SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(),
350                                      launchOp.getKernelOperandValues().end());
351     SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(),
352                                        launchOp.getKernelArguments().end());
353     bool found = false;
354     for (unsigned i = operands.size(); i > 0; --i) {
355       unsigned index = i - 1;
356       Value *operand = operands[index];
357       if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
358         continue;
359       }
360 
361       found = true;
362       Value *internalConstant =
363           rewriter.clone(*operand->getDefiningOp())->getResult(0);
364       Value *kernelArg = kernelArgs[index];
365       kernelArg->replaceAllUsesWith(internalConstant);
366       launchOp.eraseKernelArgument(index);
367     }
368     rewriter.restoreInsertionPoint(oringInsertionPoint);
369 
370     if (!found)
371       return matchFailure();
372 
373     rewriter.updatedRootInPlace(launchOp);
374     return matchSuccess();
375   }
376 };
377 } // end namespace
378 
379 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
380                                            MLIRContext *context) {
381   results.insert<PropagateConstantBounds>(context);
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // LaunchFuncOp
386 //===----------------------------------------------------------------------===//
387 
388 void LaunchFuncOp::build(Builder *builder, OperationState &result,
389                          FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
390                          Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
391                          Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
392   // Add grid and block sizes as op operands, followed by the data operands.
393   result.addOperands(
394       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
395   result.addOperands(kernelOperands);
396   result.addAttribute(getKernelAttrName(),
397                       builder->getSymbolRefAttr(kernelFunc));
398 }
399 
400 void LaunchFuncOp::build(Builder *builder, OperationState &result,
401                          FuncOp kernelFunc, KernelDim3 gridSize,
402                          KernelDim3 blockSize,
403                          ArrayRef<Value *> kernelOperands) {
404   build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
405         blockSize.x, blockSize.y, blockSize.z, kernelOperands);
406 }
407 
408 StringRef LaunchFuncOp::kernel() {
409   return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue();
410 }
411 
412 unsigned LaunchFuncOp::getNumKernelOperands() {
413   return getNumOperands() - kNumConfigOperands;
414 }
415 
416 Value *LaunchFuncOp::getKernelOperand(unsigned i) {
417   return getOperation()->getOperand(i + kNumConfigOperands);
418 }
419 
420 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
421   return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
422 }
423 
424 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
425   return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
426 }
427 
428 LogicalResult LaunchFuncOp::verify() {
429   auto kernelAttr = this->getAttr(getKernelAttrName());
430   if (!kernelAttr) {
431     return emitOpError("attribute 'kernel' must be specified");
432   } else if (!kernelAttr.isa<SymbolRefAttr>()) {
433     return emitOpError("attribute 'kernel' must be a function");
434   }
435 
436   auto module = getParentOfType<ModuleOp>();
437   FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel());
438   if (!kernelFunc)
439     return emitOpError("kernel function '") << kernelAttr << "' is undefined";
440 
441   if (!kernelFunc.getAttrOfType<mlir::UnitAttr>(
442           GPUDialect::getKernelFuncAttrName())) {
443     return emitOpError("kernel function is missing the '")
444            << GPUDialect::getKernelFuncAttrName() << "' attribute";
445   }
446   unsigned numKernelFuncArgs = kernelFunc.getNumArguments();
447   if (getNumKernelOperands() != numKernelFuncArgs) {
448     return emitOpError("got ")
449            << getNumKernelOperands() << " kernel operands but expected "
450            << numKernelFuncArgs;
451   }
452   auto functionType = kernelFunc.getType();
453   for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
454     if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
455       return emitOpError("type of function argument ")
456              << i << " does not match";
457     }
458   }
459   return success();
460 }
461