160965b46SAlex Zinenko //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
260965b46SAlex Zinenko //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
660965b46SAlex Zinenko //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
860965b46SAlex Zinenko //
960965b46SAlex Zinenko // This file implements the GPU kernel-related dialect and its operations.
1060965b46SAlex Zinenko //
1160965b46SAlex Zinenko //===----------------------------------------------------------------------===//
1260965b46SAlex Zinenko
13d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
140372db05SFrederik Gossen
15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1600b6463bSWilliam S. Moses #include "mlir/Dialect/MemRef/IR/MemRef.h"
170372db05SFrederik Gossen #include "mlir/IR/Attributes.h"
1860965b46SAlex Zinenko #include "mlir/IR/Builders.h"
1965fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
2009f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
21473b364aSChristian Sigg #include "mlir/IR/DialectImplementation.h"
222f16bf7aSAlex Zinenko #include "mlir/IR/FunctionImplementation.h"
2357eda9beSUday Bondhugula #include "mlir/IR/Matchers.h"
2460965b46SAlex Zinenko #include "mlir/IR/OpImplementation.h"
2560965b46SAlex Zinenko #include "mlir/IR/PatternMatch.h"
260955d8dfSChristian Sigg #include "mlir/IR/TypeUtilities.h"
2716219f8cSArnab Dutta #include "mlir/Interfaces/SideEffectInterfaces.h"
28fc61d07dSUday Bondhugula #include "mlir/Transforms/InliningUtils.h"
29473b364aSChristian Sigg #include "llvm/ADT/TypeSwitch.h"
3060965b46SAlex Zinenko
3160965b46SAlex Zinenko using namespace mlir;
3260965b46SAlex Zinenko using namespace mlir::gpu;
3360965b46SAlex Zinenko
34d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
35485cc55eSStella Laurenzo
3690d65d32SAlex Zinenko //===----------------------------------------------------------------------===//
37875eb523SNavdeep Kumar // MMAMatrixType
38875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
39875eb523SNavdeep Kumar
get(ArrayRef<int64_t> shape,Type elementType,StringRef operand)40875eb523SNavdeep Kumar MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
41875eb523SNavdeep Kumar StringRef operand) {
42875eb523SNavdeep Kumar return Base::get(elementType.getContext(), shape, elementType, operand);
43875eb523SNavdeep Kumar }
44875eb523SNavdeep Kumar
45875eb523SNavdeep Kumar MMAMatrixType
getChecked(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,StringRef operand)46875eb523SNavdeep Kumar MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
47875eb523SNavdeep Kumar ArrayRef<int64_t> shape, Type elementType,
48875eb523SNavdeep Kumar StringRef operand) {
49875eb523SNavdeep Kumar return Base::getChecked(emitError, elementType.getContext(), shape,
50875eb523SNavdeep Kumar elementType, operand);
51875eb523SNavdeep Kumar }
52875eb523SNavdeep Kumar
getNumDims() const53875eb523SNavdeep Kumar unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
54875eb523SNavdeep Kumar
getShape() const55875eb523SNavdeep Kumar ArrayRef<int64_t> MMAMatrixType::getShape() const {
56875eb523SNavdeep Kumar return getImpl()->getShape();
57875eb523SNavdeep Kumar }
58875eb523SNavdeep Kumar
getElementType() const59875eb523SNavdeep Kumar Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
60875eb523SNavdeep Kumar
getOperand() const61875eb523SNavdeep Kumar StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
62875eb523SNavdeep Kumar
isValidElementType(Type elementType)63875eb523SNavdeep Kumar bool MMAMatrixType::isValidElementType(Type elementType) {
64875eb523SNavdeep Kumar return elementType.isF16() || elementType.isF32();
65875eb523SNavdeep Kumar }
66875eb523SNavdeep Kumar
67875eb523SNavdeep Kumar LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,StringRef operand)68875eb523SNavdeep Kumar MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
69875eb523SNavdeep Kumar ArrayRef<int64_t> shape, Type elementType,
70875eb523SNavdeep Kumar StringRef operand) {
71875eb523SNavdeep Kumar if (!operand.equals("AOp") && !operand.equals("BOp") &&
72b44007beSthomasraoux !operand.equals("COp"))
73b44007beSthomasraoux return emitError() << "operand expected to be one of AOp, BOp or COp";
74875eb523SNavdeep Kumar
75875eb523SNavdeep Kumar if (shape.size() != 2)
76875eb523SNavdeep Kumar return emitError() << "MMAMatrixType must have exactly two dimensions";
77875eb523SNavdeep Kumar
78875eb523SNavdeep Kumar if (!MMAMatrixType::isValidElementType(elementType))
79875eb523SNavdeep Kumar return emitError() << "MMAMatrixType elements must be F16 or F32";
80875eb523SNavdeep Kumar
81875eb523SNavdeep Kumar return success();
82875eb523SNavdeep Kumar }
83875eb523SNavdeep Kumar
84875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
8590d65d32SAlex Zinenko // GPUDialect
8690d65d32SAlex Zinenko //===----------------------------------------------------------------------===//
8790d65d32SAlex Zinenko
88875eb523SNavdeep Kumar /// GPU memory space identifiers.
89875eb523SNavdeep Kumar enum GPUMemorySpace {
90875eb523SNavdeep Kumar /// Generic memory space identifier.
91875eb523SNavdeep Kumar kGenericMemorySpace = 0,
92875eb523SNavdeep Kumar
93875eb523SNavdeep Kumar /// Global memory space identifier.
94875eb523SNavdeep Kumar kGlobalMemorySpace = 1,
95875eb523SNavdeep Kumar
96875eb523SNavdeep Kumar /// Shared memory space identifier.
97875eb523SNavdeep Kumar kSharedMemorySpace = 3
98875eb523SNavdeep Kumar };
99875eb523SNavdeep Kumar
isKernel(Operation * op)1005e7959a3SAlex Zinenko bool GPUDialect::isKernel(Operation *op) {
1015e7959a3SAlex Zinenko UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
10260965b46SAlex Zinenko return static_cast<bool>(isKernelAttr);
10360965b46SAlex Zinenko }
10460965b46SAlex Zinenko
105fc61d07dSUday Bondhugula namespace {
106fc61d07dSUday Bondhugula /// This class defines the interface for handling inlining with gpu
107fc61d07dSUday Bondhugula /// operations.
108fc61d07dSUday Bondhugula struct GPUInlinerInterface : public DialectInlinerInterface {
109fc61d07dSUday Bondhugula using DialectInlinerInterface::DialectInlinerInterface;
110fc61d07dSUday Bondhugula
111fc61d07dSUday Bondhugula /// All gpu dialect ops can be inlined.
isLegalToInline__anonf220e5a50111::GPUInlinerInterface112fc61d07dSUday Bondhugula bool isLegalToInline(Operation *, Region *, bool,
113fc61d07dSUday Bondhugula BlockAndValueMapping &) const final {
114fc61d07dSUday Bondhugula return true;
115fc61d07dSUday Bondhugula }
116fc61d07dSUday Bondhugula };
117fc61d07dSUday Bondhugula } // namespace
118fc61d07dSUday Bondhugula
initialize()119575b22b5SMehdi Amini void GPUDialect::initialize() {
120473b364aSChristian Sigg addTypes<AsyncTokenType>();
121875eb523SNavdeep Kumar addTypes<MMAMatrixType>();
122d1213ae5SAlex Zinenko addOperations<
12360965b46SAlex Zinenko #define GET_OP_LIST
124d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
12560965b46SAlex Zinenko >();
126aae51255SMogball addAttributes<
127aae51255SMogball #define GET_ATTRDEF_LIST
128d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
129aae51255SMogball >();
130fc61d07dSUday Bondhugula addInterfaces<GPUInlinerInterface>();
13160965b46SAlex Zinenko }
13260965b46SAlex Zinenko
parseType(DialectAsmParser & parser) const133473b364aSChristian Sigg Type GPUDialect::parseType(DialectAsmParser &parser) const {
134473b364aSChristian Sigg // Parse the main keyword for the type.
135473b364aSChristian Sigg StringRef keyword;
136473b364aSChristian Sigg if (parser.parseKeyword(&keyword))
137473b364aSChristian Sigg return Type();
138473b364aSChristian Sigg MLIRContext *context = getContext();
139473b364aSChristian Sigg
140473b364aSChristian Sigg // Handle 'async token' types.
141473b364aSChristian Sigg if (keyword == "async.token")
142473b364aSChristian Sigg return AsyncTokenType::get(context);
143473b364aSChristian Sigg
144875eb523SNavdeep Kumar if (keyword == "mma_matrix") {
1456842ec42SRiver Riddle SMLoc beginLoc = parser.getNameLoc();
146875eb523SNavdeep Kumar
147875eb523SNavdeep Kumar // Parse '<'.
148875eb523SNavdeep Kumar if (parser.parseLess())
149875eb523SNavdeep Kumar return nullptr;
150875eb523SNavdeep Kumar
151875eb523SNavdeep Kumar // Parse the size and elementType.
152875eb523SNavdeep Kumar SmallVector<int64_t> shape;
153875eb523SNavdeep Kumar Type elementType;
154875eb523SNavdeep Kumar if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
155875eb523SNavdeep Kumar parser.parseType(elementType))
156875eb523SNavdeep Kumar return nullptr;
157875eb523SNavdeep Kumar
158875eb523SNavdeep Kumar // Parse ','
159875eb523SNavdeep Kumar if (parser.parseComma())
160875eb523SNavdeep Kumar return nullptr;
161875eb523SNavdeep Kumar
162875eb523SNavdeep Kumar // Parse operand.
1639658b061SRiver Riddle std::string operand;
164875eb523SNavdeep Kumar if (failed(parser.parseOptionalString(&operand)))
165875eb523SNavdeep Kumar return nullptr;
166875eb523SNavdeep Kumar
167875eb523SNavdeep Kumar // Parse '>'.
168875eb523SNavdeep Kumar if (parser.parseGreater())
169875eb523SNavdeep Kumar return nullptr;
170875eb523SNavdeep Kumar
171875eb523SNavdeep Kumar return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn(
172875eb523SNavdeep Kumar parser.getEncodedSourceLoc(beginLoc)),
173875eb523SNavdeep Kumar shape, elementType, operand);
174875eb523SNavdeep Kumar }
175875eb523SNavdeep Kumar
176473b364aSChristian Sigg parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
177473b364aSChristian Sigg return Type();
178473b364aSChristian Sigg }
179473b364aSChristian Sigg
printType(Type type,DialectAsmPrinter & os) const180473b364aSChristian Sigg void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
181473b364aSChristian Sigg TypeSwitch<Type>(type)
182473b364aSChristian Sigg .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
183875eb523SNavdeep Kumar .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
184875eb523SNavdeep Kumar os << "mma_matrix<";
185875eb523SNavdeep Kumar auto shape = fragTy.getShape();
186875eb523SNavdeep Kumar for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
187875eb523SNavdeep Kumar os << *dim << 'x';
188875eb523SNavdeep Kumar os << shape.back() << 'x' << fragTy.getElementType();
189875eb523SNavdeep Kumar os << ", \"" << fragTy.getOperand() << "\"" << '>';
190875eb523SNavdeep Kumar })
191473b364aSChristian Sigg .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
192473b364aSChristian Sigg }
193473b364aSChristian Sigg
verifyOperationAttribute(Operation * op,NamedAttribute attr)19490d65d32SAlex Zinenko LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
19590d65d32SAlex Zinenko NamedAttribute attr) {
1960c7890c8SRiver Riddle if (!attr.getValue().isa<UnitAttr>() ||
1970c7890c8SRiver Riddle attr.getName() != getContainerModuleAttrName())
19890d65d32SAlex Zinenko return success();
19990d65d32SAlex Zinenko
20090d65d32SAlex Zinenko auto module = dyn_cast<ModuleOp>(op);
20190d65d32SAlex Zinenko if (!module)
20290d65d32SAlex Zinenko return op->emitError("expected '")
20390d65d32SAlex Zinenko << getContainerModuleAttrName() << "' attribute to be attached to '"
20490d65d32SAlex Zinenko << ModuleOp::getOperationName() << '\'';
20590d65d32SAlex Zinenko
20690d65d32SAlex Zinenko auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
20790d65d32SAlex Zinenko // Ignore launches that are nested more or less deep than functions in the
20890d65d32SAlex Zinenko // module we are currently checking.
2090bf4a82aSChristian Sigg if (!launchOp->getParentOp() ||
2100bf4a82aSChristian Sigg launchOp->getParentOp()->getParentOp() != module)
21190d65d32SAlex Zinenko return success();
21290d65d32SAlex Zinenko
21390d65d32SAlex Zinenko // Ignore launch ops with missing attributes here. The errors will be
21490d65d32SAlex Zinenko // reported by the verifiers of those ops.
2150bf4a82aSChristian Sigg if (!launchOp->getAttrOfType<SymbolRefAttr>(
2160372db05SFrederik Gossen LaunchFuncOp::getKernelAttrName()))
21790d65d32SAlex Zinenko return success();
21890d65d32SAlex Zinenko
21990d65d32SAlex Zinenko // Check that `launch_func` refers to a well-formed GPU kernel module.
22041d4aa7dSChris Lattner StringAttr kernelModuleName = launchOp.getKernelModuleName();
2219a52ea5cSTres Popp auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName);
22290d65d32SAlex Zinenko if (!kernelModule)
22390d65d32SAlex Zinenko return launchOp.emitOpError()
22441d4aa7dSChris Lattner << "kernel module '" << kernelModuleName.getValue()
22541d4aa7dSChris Lattner << "' is undefined";
22690d65d32SAlex Zinenko
22790d65d32SAlex Zinenko // Check that `launch_func` refers to a well-formed kernel function.
22841d4aa7dSChris Lattner Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr());
229b88a4d72SRiver Riddle if (!kernelFunc)
23090d65d32SAlex Zinenko return launchOp.emitOpError("kernel function '")
2310372db05SFrederik Gossen << launchOp.kernel() << "' is undefined";
232b88a4d72SRiver Riddle auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
233b88a4d72SRiver Riddle if (!kernelConvertedFunction) {
234b88a4d72SRiver Riddle InFlightDiagnostic diag = launchOp.emitOpError()
235b88a4d72SRiver Riddle << "referenced kernel '" << launchOp.kernel()
236b88a4d72SRiver Riddle << "' is not a function";
237b88a4d72SRiver Riddle diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
238b88a4d72SRiver Riddle return diag;
239b88a4d72SRiver Riddle }
240b88a4d72SRiver Riddle
2415e7959a3SAlex Zinenko if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
24290d65d32SAlex Zinenko GPUDialect::getKernelFuncAttrName()))
24390d65d32SAlex Zinenko return launchOp.emitOpError("kernel function is missing the '")
24490d65d32SAlex Zinenko << GPUDialect::getKernelFuncAttrName() << "' attribute";
2455e7959a3SAlex Zinenko
246b88a4d72SRiver Riddle // TODO: If the kernel isn't a GPU function (which happens during separate
247b88a4d72SRiver Riddle // compilation), do not check type correspondence as it would require the
248b88a4d72SRiver Riddle // verifier to be aware of the type conversion.
249b88a4d72SRiver Riddle auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
250b88a4d72SRiver Riddle if (!kernelGPUFunction)
2515a177805SAlex Zinenko return success();
2525a177805SAlex Zinenko
2535e7959a3SAlex Zinenko unsigned actualNumArguments = launchOp.getNumKernelOperands();
2545a177805SAlex Zinenko unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
2555e7959a3SAlex Zinenko if (expectedNumArguments != actualNumArguments)
2565e7959a3SAlex Zinenko return launchOp.emitOpError("got ")
2575e7959a3SAlex Zinenko << actualNumArguments << " kernel operands but expected "
2585e7959a3SAlex Zinenko << expectedNumArguments;
25990d65d32SAlex Zinenko
2604a3460a7SRiver Riddle auto functionType = kernelGPUFunction.getFunctionType();
2615a177805SAlex Zinenko for (unsigned i = 0; i < expectedNumArguments; ++i) {
2625a177805SAlex Zinenko if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
2635a177805SAlex Zinenko return launchOp.emitOpError("type of function argument ")
2645a177805SAlex Zinenko << i << " does not match";
2655a177805SAlex Zinenko }
2665a177805SAlex Zinenko }
26790d65d32SAlex Zinenko
26890d65d32SAlex Zinenko return success();
26990d65d32SAlex Zinenko });
27090d65d32SAlex Zinenko
27190d65d32SAlex Zinenko return walkResult.wasInterrupted() ? failure() : success();
27290d65d32SAlex Zinenko }
27390d65d32SAlex Zinenko
274f47a38f5SUday Bondhugula /// Parses an optional list of async operands with an optional leading keyword.
275f47a38f5SUday Bondhugula /// (`async`)? (`[` ssa-id-list `]`)?
276f47a38f5SUday Bondhugula ///
277f47a38f5SUday Bondhugula /// This method is used by the tablegen assembly format for async ops as well.
parseAsyncDependencies(OpAsmParser & parser,Type & asyncTokenType,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & asyncDependencies)278f47a38f5SUday Bondhugula static ParseResult parseAsyncDependencies(
279f47a38f5SUday Bondhugula OpAsmParser &parser, Type &asyncTokenType,
280f47a38f5SUday Bondhugula SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) {
281f47a38f5SUday Bondhugula auto loc = parser.getCurrentLocation();
282f47a38f5SUday Bondhugula if (succeeded(parser.parseOptionalKeyword("async"))) {
283f47a38f5SUday Bondhugula if (parser.getNumResults() == 0)
284f47a38f5SUday Bondhugula return parser.emitError(loc, "needs to be named when marked 'async'");
285f47a38f5SUday Bondhugula asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
286f47a38f5SUday Bondhugula }
287f47a38f5SUday Bondhugula return parser.parseOperandList(asyncDependencies,
288f47a38f5SUday Bondhugula OpAsmParser::Delimiter::OptionalSquare);
289f47a38f5SUday Bondhugula }
290f47a38f5SUday Bondhugula
291f47a38f5SUday Bondhugula /// Prints optional async dependencies with its leading keyword.
292f47a38f5SUday Bondhugula /// (`async`)? (`[` ssa-id-list `]`)?
293f47a38f5SUday Bondhugula // Used by the tablegen assembly format for several async ops.
printAsyncDependencies(OpAsmPrinter & printer,Operation * op,Type asyncTokenType,OperandRange asyncDependencies)294f47a38f5SUday Bondhugula static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
295f47a38f5SUday Bondhugula Type asyncTokenType,
296f47a38f5SUday Bondhugula OperandRange asyncDependencies) {
297f47a38f5SUday Bondhugula if (asyncTokenType)
298f47a38f5SUday Bondhugula printer << "async";
299f47a38f5SUday Bondhugula if (asyncDependencies.empty())
300f47a38f5SUday Bondhugula return;
301f47a38f5SUday Bondhugula if (asyncTokenType)
302f47a38f5SUday Bondhugula printer << ' ';
303f47a38f5SUday Bondhugula printer << '[';
304f47a38f5SUday Bondhugula llvm::interleaveComma(asyncDependencies, printer);
305f47a38f5SUday Bondhugula printer << ']';
306f47a38f5SUday Bondhugula }
307f47a38f5SUday Bondhugula
308f47a38f5SUday Bondhugula //===----------------------------------------------------------------------===//
309f47a38f5SUday Bondhugula // AllReduceOp
310f47a38f5SUday Bondhugula //===----------------------------------------------------------------------===//
311f47a38f5SUday Bondhugula
verifyRegions()312ed645f63SChia-hung Duan LogicalResult gpu::AllReduceOp::verifyRegions() {
313037f0995SKazu Hirata if (body().empty() != op().has_value())
314094ede6dSRiver Riddle return emitError("expected either an op attribute or a non-empty body");
315094ede6dSRiver Riddle if (!body().empty()) {
316094ede6dSRiver Riddle if (body().getNumArguments() != 2)
317094ede6dSRiver Riddle return emitError("expected two region arguments");
318094ede6dSRiver Riddle for (auto argument : body().getArguments()) {
319094ede6dSRiver Riddle if (argument.getType() != getType())
320094ede6dSRiver Riddle return emitError("incorrect region argument type");
321d2f0f847SChristian Sigg }
322d2f0f847SChristian Sigg unsigned yieldCount = 0;
323094ede6dSRiver Riddle for (Block &block : body()) {
324b74af4aaSChristian Sigg if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
325d2f0f847SChristian Sigg if (yield.getNumOperands() != 1)
326094ede6dSRiver Riddle return emitError("expected one gpu.yield operand");
327094ede6dSRiver Riddle if (yield.getOperand(0).getType() != getType())
328094ede6dSRiver Riddle return emitError("incorrect gpu.yield type");
329d2f0f847SChristian Sigg ++yieldCount;
330d2f0f847SChristian Sigg }
331d2f0f847SChristian Sigg }
332d2f0f847SChristian Sigg if (yieldCount == 0)
333094ede6dSRiver Riddle return emitError("expected gpu.yield op in region");
334c7380995SValentin Clement } else {
335094ede6dSRiver Riddle gpu::AllReduceOperation opName = *op();
336aae51255SMogball if ((opName == gpu::AllReduceOperation::AND ||
337aae51255SMogball opName == gpu::AllReduceOperation::OR ||
338aae51255SMogball opName == gpu::AllReduceOperation::XOR) &&
339094ede6dSRiver Riddle !getType().isa<IntegerType>()) {
340094ede6dSRiver Riddle return emitError()
341094ede6dSRiver Riddle << '`' << gpu::stringifyAllReduceOperation(opName)
342094ede6dSRiver Riddle << "` accumulator is only compatible with Integer type";
343c7380995SValentin Clement }
344d2f0f847SChristian Sigg }
345d2f0f847SChristian Sigg return success();
346d2f0f847SChristian Sigg }
347d2f0f847SChristian Sigg
348aae51255SMogball // TODO: Support optional custom attributes (without dialect prefix).
parseAllReduceOperation(AsmParser & parser,AllReduceOperationAttr & attr)349aae51255SMogball static ParseResult parseAllReduceOperation(AsmParser &parser,
350aae51255SMogball AllReduceOperationAttr &attr) {
351aae51255SMogball StringRef enumStr;
352aae51255SMogball if (!parser.parseOptionalKeyword(&enumStr)) {
353aae51255SMogball Optional<AllReduceOperation> op = gpu::symbolizeAllReduceOperation(enumStr);
354aae51255SMogball if (!op)
355aae51255SMogball return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
356aae51255SMogball attr = AllReduceOperationAttr::get(parser.getContext(), *op);
35742d46b4eSChristian Sigg }
35842d46b4eSChristian Sigg return success();
35942d46b4eSChristian Sigg }
36042d46b4eSChristian Sigg
printAllReduceOperation(AsmPrinter & printer,Operation * op,AllReduceOperationAttr attr)361aae51255SMogball static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
362aae51255SMogball AllReduceOperationAttr attr) {
363aae51255SMogball if (attr)
364aae51255SMogball attr.print(printer);
36542d46b4eSChristian Sigg }
36642d46b4eSChristian Sigg
36760965b46SAlex Zinenko //===----------------------------------------------------------------------===//
368473b364aSChristian Sigg // AsyncOpInterface
369473b364aSChristian Sigg //===----------------------------------------------------------------------===//
370473b364aSChristian Sigg
addAsyncDependency(Operation * op,Value token)371473b364aSChristian Sigg void gpu::addAsyncDependency(Operation *op, Value token) {
372473b364aSChristian Sigg op->insertOperands(0, {token});
373473b364aSChristian Sigg if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
374473b364aSChristian Sigg return;
375473b364aSChristian Sigg auto attrName =
376473b364aSChristian Sigg OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
377473b364aSChristian Sigg auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName);
3780cb5d7fcSRiver Riddle
3790cb5d7fcSRiver Riddle // Async dependencies is the only variadic operand.
380473b364aSChristian Sigg if (!sizeAttr)
3810cb5d7fcSRiver Riddle return;
3820cb5d7fcSRiver Riddle
3830cb5d7fcSRiver Riddle SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>());
384473b364aSChristian Sigg ++sizes.front();
385473b364aSChristian Sigg op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes));
386473b364aSChristian Sigg }
387473b364aSChristian Sigg
388473b364aSChristian Sigg //===----------------------------------------------------------------------===//
38960965b46SAlex Zinenko // LaunchOp
39060965b46SAlex Zinenko //===----------------------------------------------------------------------===//
39160965b46SAlex Zinenko
build(OpBuilder & builder,OperationState & result,Value gridSizeX,Value gridSizeY,Value gridSizeZ,Value blockSizeX,Value blockSizeY,Value blockSizeZ,Value dynamicSharedMemorySize,Type asyncTokenType,ValueRange asyncDependencies)392bb1d976fSAlex Zinenko void LaunchOp::build(OpBuilder &builder, OperationState &result,
393bb1d976fSAlex Zinenko Value gridSizeX, Value gridSizeY, Value gridSizeZ,
39408b63db8SUday Bondhugula Value blockSizeX, Value blockSizeY, Value blockSizeZ,
395f47a38f5SUday Bondhugula Value dynamicSharedMemorySize, Type asyncTokenType,
396f47a38f5SUday Bondhugula ValueRange asyncDependencies) {
397f47a38f5SUday Bondhugula result.addOperands(asyncDependencies);
398f47a38f5SUday Bondhugula if (asyncTokenType)
399f47a38f5SUday Bondhugula result.types.push_back(builder.getType<AsyncTokenType>());
400f47a38f5SUday Bondhugula
40160965b46SAlex Zinenko // Add grid and block sizes as op operands, followed by the data operands.
402729727ebSRiver Riddle result.addOperands(
40360965b46SAlex Zinenko {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
40408b63db8SUday Bondhugula if (dynamicSharedMemorySize)
40508b63db8SUday Bondhugula result.addOperands(dynamicSharedMemorySize);
40660965b46SAlex Zinenko
40760965b46SAlex Zinenko // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
40860965b46SAlex Zinenko // where the first kNumConfigRegionAttributes arguments have `index` type and
40960965b46SAlex Zinenko // the rest have the same types as the data operands.
410729727ebSRiver Riddle Region *kernelRegion = result.addRegion();
41160965b46SAlex Zinenko Block *body = new Block();
412e084679fSRiver Riddle for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
413e084679fSRiver Riddle body->addArgument(builder.getIndexType(), result.location);
41460965b46SAlex Zinenko kernelRegion->push_back(body);
415f47a38f5SUday Bondhugula SmallVector<int32_t, 8> segmentSizes(8, 1);
416f47a38f5SUday Bondhugula segmentSizes.front() = asyncDependencies.size();
417f47a38f5SUday Bondhugula segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
418f47a38f5SUday Bondhugula result.addAttribute(getOperandSegmentSizeAttr(),
419f47a38f5SUday Bondhugula builder.getI32VectorAttr(segmentSizes));
42060965b46SAlex Zinenko }
42160965b46SAlex Zinenko
getBlockIds()42260965b46SAlex Zinenko KernelDim3 LaunchOp::getBlockIds() {
4232eaadfc4SRahul Joshi assert(!body().empty() && "LaunchOp body must not be empty.");
424e2b71610SRahul Joshi auto args = body().getArguments();
42560965b46SAlex Zinenko return KernelDim3{args[0], args[1], args[2]};
42660965b46SAlex Zinenko }
42760965b46SAlex Zinenko
getThreadIds()42860965b46SAlex Zinenko KernelDim3 LaunchOp::getThreadIds() {
4292eaadfc4SRahul Joshi assert(!body().empty() && "LaunchOp body must not be empty.");
430e2b71610SRahul Joshi auto args = body().getArguments();
43160965b46SAlex Zinenko return KernelDim3{args[3], args[4], args[5]};
43260965b46SAlex Zinenko }
43360965b46SAlex Zinenko
getGridSize()43460965b46SAlex Zinenko KernelDim3 LaunchOp::getGridSize() {
4352eaadfc4SRahul Joshi assert(!body().empty() && "LaunchOp body must not be empty.");
436e2b71610SRahul Joshi auto args = body().getArguments();
43760965b46SAlex Zinenko return KernelDim3{args[6], args[7], args[8]};
43860965b46SAlex Zinenko }
43960965b46SAlex Zinenko
getBlockSize()44060965b46SAlex Zinenko KernelDim3 LaunchOp::getBlockSize() {
4412eaadfc4SRahul Joshi assert(!body().empty() && "LaunchOp body must not be empty.");
442e2b71610SRahul Joshi auto args = body().getArguments();
44360965b46SAlex Zinenko return KernelDim3{args[9], args[10], args[11]};
44460965b46SAlex Zinenko }
44560965b46SAlex Zinenko
getGridSizeOperandValues()44660965b46SAlex Zinenko KernelDim3 LaunchOp::getGridSizeOperandValues() {
447f47a38f5SUday Bondhugula auto operands = getOperands().drop_front(asyncDependencies().size());
448f47a38f5SUday Bondhugula return KernelDim3{operands[0], operands[1], operands[2]};
44960965b46SAlex Zinenko }
45060965b46SAlex Zinenko
getBlockSizeOperandValues()45160965b46SAlex Zinenko KernelDim3 LaunchOp::getBlockSizeOperandValues() {
452f47a38f5SUday Bondhugula auto operands = getOperands().drop_front(asyncDependencies().size());
453f47a38f5SUday Bondhugula return KernelDim3{operands[3], operands[4], operands[5]};
45460965b46SAlex Zinenko }
45560965b46SAlex Zinenko
verifyRegions()456ed645f63SChia-hung Duan LogicalResult LaunchOp::verifyRegions() {
45760965b46SAlex Zinenko // Kernel launch takes kNumConfigOperands leading operands for grid/block
45860965b46SAlex Zinenko // sizes and transforms them into kNumConfigRegionAttributes region arguments
45960965b46SAlex Zinenko // for block/thread identifiers and grid/block sizes.
460094ede6dSRiver Riddle if (!body().empty()) {
461f47a38f5SUday Bondhugula if (body().getNumArguments() !=
462f47a38f5SUday Bondhugula LaunchOp::kNumConfigOperands + getNumOperands() -
463f47a38f5SUday Bondhugula (dynamicSharedMemorySize() ? 1 : 0) - asyncDependencies().size())
464094ede6dSRiver Riddle return emitOpError("unexpected number of region arguments");
46560965b46SAlex Zinenko }
46660965b46SAlex Zinenko
46760965b46SAlex Zinenko // Block terminators without successors are expected to exit the kernel region
46826927518SStephan Herhut // and must be `gpu.terminator`.
469094ede6dSRiver Riddle for (Block &block : body()) {
47060965b46SAlex Zinenko if (block.empty())
47160965b46SAlex Zinenko continue;
47260965b46SAlex Zinenko if (block.back().getNumSuccessors() != 0)
47360965b46SAlex Zinenko continue;
47426927518SStephan Herhut if (!isa<gpu::TerminatorOp>(&block.back())) {
47560965b46SAlex Zinenko return block.back()
47626927518SStephan Herhut .emitError()
47726927518SStephan Herhut .append("expected '", gpu::TerminatorOp::getOperationName(),
47826927518SStephan Herhut "' or a terminator with successors")
479094ede6dSRiver Riddle .attachNote(getLoc())
48026927518SStephan Herhut .append("in '", LaunchOp::getOperationName(), "' body region");
48160965b46SAlex Zinenko }
48260965b46SAlex Zinenko }
48360965b46SAlex Zinenko
484f47a38f5SUday Bondhugula if (getNumResults() == 0 && asyncToken())
485f47a38f5SUday Bondhugula return emitOpError("needs to be named when async keyword is specified");
486f47a38f5SUday Bondhugula
48760965b46SAlex Zinenko return success();
48860965b46SAlex Zinenko }
48960965b46SAlex Zinenko
49060965b46SAlex Zinenko // Pretty-print the kernel grid/block size assignment as
49160965b46SAlex Zinenko // (%iter-x, %iter-y, %iter-z) in
49260965b46SAlex Zinenko // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
49360965b46SAlex Zinenko // where %size-* and %iter-* will correspond to the body region arguments.
printSizeAssignment(OpAsmPrinter & p,KernelDim3 size,KernelDim3 operands,KernelDim3 ids)4943a643de9SRiver Riddle static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
495701fbe87SChristian Sigg KernelDim3 operands, KernelDim3 ids) {
4962bdf33ccSRiver Riddle p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
497701fbe87SChristian Sigg p << size.x << " = " << operands.x << ", ";
498701fbe87SChristian Sigg p << size.y << " = " << operands.y << ", ";
499701fbe87SChristian Sigg p << size.z << " = " << operands.z << ')';
50060965b46SAlex Zinenko }
50160965b46SAlex Zinenko
print(OpAsmPrinter & p)5022418cd92SRiver Riddle void LaunchOp::print(OpAsmPrinter &p) {
503f47a38f5SUday Bondhugula if (asyncToken()) {
504f47a38f5SUday Bondhugula p << " async";
505f47a38f5SUday Bondhugula if (!asyncDependencies().empty())
506f47a38f5SUday Bondhugula p << " [" << asyncDependencies() << ']';
507f47a38f5SUday Bondhugula }
50860965b46SAlex Zinenko // Print the launch configuration.
5092418cd92SRiver Riddle p << ' ' << getBlocksKeyword();
5102418cd92SRiver Riddle printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
5112418cd92SRiver Riddle getBlockIds());
5122418cd92SRiver Riddle p << ' ' << getThreadsKeyword();
5132418cd92SRiver Riddle printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
5142418cd92SRiver Riddle getThreadIds());
5152418cd92SRiver Riddle if (dynamicSharedMemorySize())
5162418cd92SRiver Riddle p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
5172418cd92SRiver Riddle << dynamicSharedMemorySize();
51860965b46SAlex Zinenko
5195c36ee8dSMogball p << ' ';
5202418cd92SRiver Riddle p.printRegion(body(), /*printEntryBlockArgs=*/false);
521f47a38f5SUday Bondhugula p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
522f47a38f5SUday Bondhugula LaunchOp::getOperandSegmentSizeAttr()});
52360965b46SAlex Zinenko }
52460965b46SAlex Zinenko
52560965b46SAlex Zinenko // Parse the size assignment blocks for blocks and threads. These have the form
52660965b46SAlex Zinenko // (%region_arg, %region_arg, %region_arg) in
52760965b46SAlex Zinenko // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
52860965b46SAlex Zinenko // where %region_arg are percent-identifiers for the region arguments to be
52985dcaf19SChristian Sigg // introduced further (SSA defs), and %operand are percent-identifiers for the
53060965b46SAlex Zinenko // SSA value uses.
53160965b46SAlex Zinenko static ParseResult
parseSizeAssignment(OpAsmParser & parser,MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,MutableArrayRef<OpAsmParser::UnresolvedOperand> indices)5322797517eSRiver Riddle parseSizeAssignment(OpAsmParser &parser,
533e13d23bcSMarkus Böck MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,
534e13d23bcSMarkus Böck MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,
535e13d23bcSMarkus Böck MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
53660965b46SAlex Zinenko assert(indices.size() == 3 && "space for three indices expected");
537e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
5385dedf911SChris Lattner if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren,
5395dedf911SChris Lattner /*allowResultNumber=*/false) ||
5402797517eSRiver Riddle parser.parseKeyword("in") || parser.parseLParen())
54160965b46SAlex Zinenko return failure();
54260965b46SAlex Zinenko std::move(args.begin(), args.end(), indices.begin());
54360965b46SAlex Zinenko
54460965b46SAlex Zinenko for (int i = 0; i < 3; ++i) {
5452797517eSRiver Riddle if (i != 0 && parser.parseComma())
54660965b46SAlex Zinenko return failure();
5475dedf911SChris Lattner if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
5485dedf911SChris Lattner parser.parseEqual() || parser.parseOperand(sizes[i]))
54960965b46SAlex Zinenko return failure();
55060965b46SAlex Zinenko }
55160965b46SAlex Zinenko
5522797517eSRiver Riddle return parser.parseRParen();
55360965b46SAlex Zinenko }
55460965b46SAlex Zinenko
5552418cd92SRiver Riddle /// Parses a Launch operation.
556f47a38f5SUday Bondhugula /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
557f47a38f5SUday Bondhugula // `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
558f47a38f5SUday Bondhugula /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
5592418cd92SRiver Riddle /// region attr-dict?
5602418cd92SRiver Riddle /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
parse(OpAsmParser & parser,OperationState & result)5612418cd92SRiver Riddle ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
56260965b46SAlex Zinenko // Sizes of the grid and block.
563e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
564e13d23bcSMarkus Böck sizes(LaunchOp::kNumConfigOperands);
565e13d23bcSMarkus Böck MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
56660965b46SAlex Zinenko
56760965b46SAlex Zinenko // Actual (data) operands passed to the kernel.
568e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 4> dataOperands;
56960965b46SAlex Zinenko
57060965b46SAlex Zinenko // Region arguments to be created.
571e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
5723230267dSAlex Zinenko LaunchOp::kNumConfigRegionAttributes);
573e13d23bcSMarkus Böck MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
57460965b46SAlex Zinenko
575f47a38f5SUday Bondhugula // Parse optional async dependencies.
576f47a38f5SUday Bondhugula SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
577f47a38f5SUday Bondhugula Type asyncTokenType;
578f47a38f5SUday Bondhugula if (failed(
579f47a38f5SUday Bondhugula parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
580f47a38f5SUday Bondhugula parser.resolveOperands(asyncDependencies, asyncTokenType,
581f47a38f5SUday Bondhugula result.operands))
582f47a38f5SUday Bondhugula return failure();
583f47a38f5SUday Bondhugula if (parser.getNumResults() > 0)
584f47a38f5SUday Bondhugula result.types.push_back(asyncTokenType);
585f47a38f5SUday Bondhugula
58685dcaf19SChristian Sigg // Parse the size assignment segments: the first segment assigns grid sizes
58760965b46SAlex Zinenko // and defines values for block identifiers; the second segment assigns block
58885dcaf19SChristian Sigg // sizes and defines values for thread identifiers. In the region argument
58985dcaf19SChristian Sigg // list, identifiers precede sizes, and block-related values precede
59060965b46SAlex Zinenko // thread-related values.
5913230267dSAlex Zinenko if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
59260965b46SAlex Zinenko parseSizeAssignment(parser, sizesRef.take_front(3),
59360965b46SAlex Zinenko regionArgsRef.slice(6, 3),
59460965b46SAlex Zinenko regionArgsRef.slice(0, 3)) ||
5953230267dSAlex Zinenko parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
59660965b46SAlex Zinenko parseSizeAssignment(parser, sizesRef.drop_front(3),
59760965b46SAlex Zinenko regionArgsRef.slice(9, 3),
59860965b46SAlex Zinenko regionArgsRef.slice(3, 3)) ||
5992797517eSRiver Riddle parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
600729727ebSRiver Riddle result.operands))
60160965b46SAlex Zinenko return failure();
60260965b46SAlex Zinenko
603e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
604f47a38f5SUday Bondhugula bool hasDynamicSharedMemorySize = false;
60508b63db8SUday Bondhugula if (!parser.parseOptionalKeyword(
606f47a38f5SUday Bondhugula LaunchOp::getDynamicSharedMemorySizeKeyword())) {
607f47a38f5SUday Bondhugula hasDynamicSharedMemorySize = true;
60808b63db8SUday Bondhugula if (parser.parseOperand(dynamicSharedMemorySize) ||
60908b63db8SUday Bondhugula parser.resolveOperand(dynamicSharedMemorySize,
61008b63db8SUday Bondhugula parser.getBuilder().getI32Type(),
61108b63db8SUday Bondhugula result.operands))
61208b63db8SUday Bondhugula return failure();
613f47a38f5SUday Bondhugula }
61408b63db8SUday Bondhugula
61560965b46SAlex Zinenko // Introduce the body region and parse it. The region has
616283b5e73SStephan Herhut // kNumConfigRegionAttributes arguments that correspond to
61760965b46SAlex Zinenko // block/thread identifiers and grid/block sizes, all of the `index` type.
6182797517eSRiver Riddle Type index = parser.getBuilder().getIndexType();
619283b5e73SStephan Herhut SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
620283b5e73SStephan Herhut LaunchOp::kNumConfigRegionAttributes, index);
621d85eb4e2SChris Lattner
622d85eb4e2SChris Lattner SmallVector<OpAsmParser::Argument> regionArguments;
623d85eb4e2SChris Lattner for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
624d85eb4e2SChris Lattner OpAsmParser::Argument arg;
625d85eb4e2SChris Lattner arg.ssaName = std::get<0>(ssaValueAndType);
626d85eb4e2SChris Lattner arg.type = std::get<1>(ssaValueAndType);
627d85eb4e2SChris Lattner regionArguments.push_back(arg);
628d85eb4e2SChris Lattner }
629d85eb4e2SChris Lattner
630729727ebSRiver Riddle Region *body = result.addRegion();
631d85eb4e2SChris Lattner if (parser.parseRegion(*body, regionArguments) ||
632f47a38f5SUday Bondhugula parser.parseOptionalAttrDict(result.attributes))
633f47a38f5SUday Bondhugula return failure();
634f47a38f5SUday Bondhugula
635f47a38f5SUday Bondhugula SmallVector<int32_t, 8> segmentSizes(8, 1);
636f47a38f5SUday Bondhugula segmentSizes.front() = asyncDependencies.size();
637f47a38f5SUday Bondhugula segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
638f47a38f5SUday Bondhugula result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
639f47a38f5SUday Bondhugula parser.getBuilder().getI32VectorAttr(segmentSizes));
640f47a38f5SUday Bondhugula return success();
64160965b46SAlex Zinenko }
64260965b46SAlex Zinenko
6435c77ed03SUday Bondhugula /// Simplify the gpu.launch when the range of a thread or block ID is
64457eda9beSUday Bondhugula /// trivially known to be one.
64557eda9beSUday Bondhugula struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
64657eda9beSUday Bondhugula using OpRewritePattern<LaunchOp>::OpRewritePattern;
matchAndRewriteFoldLaunchArguments64757eda9beSUday Bondhugula LogicalResult matchAndRewrite(LaunchOp op,
64857eda9beSUday Bondhugula PatternRewriter &rewriter) const override {
64957eda9beSUday Bondhugula // If the range implies a single value for `id`, replace `id`'s uses by
65057eda9beSUday Bondhugula // zero.
65157eda9beSUday Bondhugula Value zero;
65257eda9beSUday Bondhugula bool simplified = false;
65357eda9beSUday Bondhugula auto constPropIdUses = [&](Value id, Value size) {
6545c77ed03SUday Bondhugula // Check if size is trivially one.
6555c77ed03SUday Bondhugula if (!matchPattern(size, m_One()))
65657eda9beSUday Bondhugula return;
65757eda9beSUday Bondhugula if (!simplified) {
65857eda9beSUday Bondhugula // Create a zero value the first time.
65957eda9beSUday Bondhugula OpBuilder::InsertionGuard guard(rewriter);
66057eda9beSUday Bondhugula rewriter.setInsertionPointToStart(&op.body().front());
661a54f4eaeSMogball zero =
662a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
66357eda9beSUday Bondhugula }
66457eda9beSUday Bondhugula id.replaceAllUsesWith(zero);
66557eda9beSUday Bondhugula simplified = true;
66657eda9beSUday Bondhugula };
66757eda9beSUday Bondhugula constPropIdUses(op.getBlockIds().x, op.gridSizeX());
66857eda9beSUday Bondhugula constPropIdUses(op.getBlockIds().y, op.gridSizeY());
66957eda9beSUday Bondhugula constPropIdUses(op.getBlockIds().z, op.gridSizeZ());
67057eda9beSUday Bondhugula constPropIdUses(op.getThreadIds().x, op.blockSizeX());
67157eda9beSUday Bondhugula constPropIdUses(op.getThreadIds().y, op.blockSizeY());
67257eda9beSUday Bondhugula constPropIdUses(op.getThreadIds().z, op.blockSizeZ());
67357eda9beSUday Bondhugula
67457eda9beSUday Bondhugula return success(simplified);
67557eda9beSUday Bondhugula }
67657eda9beSUday Bondhugula };
67757eda9beSUday Bondhugula
getCanonicalizationPatterns(RewritePatternSet & rewrites,MLIRContext * context)67857eda9beSUday Bondhugula void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
67957eda9beSUday Bondhugula MLIRContext *context) {
68057eda9beSUday Bondhugula rewrites.add<FoldLaunchArguments>(context);
68157eda9beSUday Bondhugula }
68257eda9beSUday Bondhugula
68360965b46SAlex Zinenko //===----------------------------------------------------------------------===//
68460965b46SAlex Zinenko // LaunchFuncOp
68560965b46SAlex Zinenko //===----------------------------------------------------------------------===//
68660965b46SAlex Zinenko
build(OpBuilder & builder,OperationState & result,GPUFuncOp kernelFunc,KernelDim3 gridSize,KernelDim3 blockSize,Value dynamicSharedMemorySize,ValueRange kernelOperands,Type asyncTokenType,ValueRange asyncDependencies)687bb1d976fSAlex Zinenko void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
6881c1803dbSChristian Sigg GPUFuncOp kernelFunc, KernelDim3 gridSize,
68908b63db8SUday Bondhugula KernelDim3 blockSize, Value dynamicSharedMemorySize,
690f47a38f5SUday Bondhugula ValueRange kernelOperands, Type asyncTokenType,
691f47a38f5SUday Bondhugula ValueRange asyncDependencies) {
692f47a38f5SUday Bondhugula result.addOperands(asyncDependencies);
693f47a38f5SUday Bondhugula if (asyncTokenType)
694f47a38f5SUday Bondhugula result.types.push_back(builder.getType<AsyncTokenType>());
695f47a38f5SUday Bondhugula
69660965b46SAlex Zinenko // Add grid and block sizes as op operands, followed by the data operands.
6971c1803dbSChristian Sigg result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x,
6981c1803dbSChristian Sigg blockSize.y, blockSize.z});
69908b63db8SUday Bondhugula if (dynamicSharedMemorySize)
70008b63db8SUday Bondhugula result.addOperands(dynamicSharedMemorySize);
701729727ebSRiver Riddle result.addOperands(kernelOperands);
7020bf4a82aSChristian Sigg auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
703faf1c224SChris Lattner auto kernelSymbol =
704faf1c224SChris Lattner SymbolRefAttr::get(kernelModule.getNameAttr(),
705faf1c224SChris Lattner {SymbolRefAttr::get(kernelFunc.getNameAttr())});
7060372db05SFrederik Gossen result.addAttribute(getKernelAttrName(), kernelSymbol);
70708b63db8SUday Bondhugula SmallVector<int32_t, 9> segmentSizes(9, 1);
708f47a38f5SUday Bondhugula segmentSizes.front() = asyncDependencies.size();
70908b63db8SUday Bondhugula segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0;
71035561140SChristian Sigg segmentSizes.back() = static_cast<int32_t>(kernelOperands.size());
71135561140SChristian Sigg result.addAttribute(getOperandSegmentSizeAttr(),
71235561140SChristian Sigg builder.getI32VectorAttr(segmentSizes));
71360965b46SAlex Zinenko }
71460965b46SAlex Zinenko
getKernelModuleName()71541d4aa7dSChris Lattner StringAttr LaunchFuncOp::getKernelModuleName() {
7160372db05SFrederik Gossen return kernel().getRootReference();
71790d65d32SAlex Zinenko }
71890d65d32SAlex Zinenko
getKernelName()71941d4aa7dSChris Lattner StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }
7200372db05SFrederik Gossen
getNumKernelOperands()7210e3d1ca5SChristian Sigg unsigned LaunchFuncOp::getNumKernelOperands() { return operands().size(); }
7220e3d1ca5SChristian Sigg
getKernelOperand(unsigned i)7230e3d1ca5SChristian Sigg Value LaunchFuncOp::getKernelOperand(unsigned i) { return operands()[i]; }
72460965b46SAlex Zinenko
getGridSizeOperandValues()72560965b46SAlex Zinenko KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
72635561140SChristian Sigg auto operands = getOperands().drop_front(asyncDependencies().size());
72735561140SChristian Sigg return KernelDim3{operands[0], operands[1], operands[2]};
72860965b46SAlex Zinenko }
72960965b46SAlex Zinenko
getBlockSizeOperandValues()73060965b46SAlex Zinenko KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
73135561140SChristian Sigg auto operands = getOperands().drop_front(asyncDependencies().size());
73235561140SChristian Sigg return KernelDim3{operands[3], operands[4], operands[5]};
73360965b46SAlex Zinenko }
73460965b46SAlex Zinenko
verify()735094ede6dSRiver Riddle LogicalResult LaunchFuncOp::verify() {
736094ede6dSRiver Riddle auto module = (*this)->getParentOfType<ModuleOp>();
73790d65d32SAlex Zinenko if (!module)
738094ede6dSRiver Riddle return emitOpError("expected to belong to a module");
73960965b46SAlex Zinenko
7400bf4a82aSChristian Sigg if (!module->getAttrOfType<UnitAttr>(
7410bf4a82aSChristian Sigg GPUDialect::getContainerModuleAttrName()))
742094ede6dSRiver Riddle return emitOpError("expected the closest surrounding module to have the '" +
743094ede6dSRiver Riddle GPUDialect::getContainerModuleAttrName() +
744094ede6dSRiver Riddle "' attribute");
74590d65d32SAlex Zinenko
746094ede6dSRiver Riddle auto kernelAttr = (*this)->getAttrOfType<SymbolRefAttr>(getKernelAttrName());
74790d65d32SAlex Zinenko if (!kernelAttr)
748094ede6dSRiver Riddle return emitOpError("symbol reference attribute '" + getKernelAttrName() +
749094ede6dSRiver Riddle "' must be specified");
75090d65d32SAlex Zinenko
75160965b46SAlex Zinenko return success();
75260965b46SAlex Zinenko }
753bf4692dcSAlex Zinenko
parseLaunchFuncOperands(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & argNames,SmallVectorImpl<Type> & argTypes)754e13d23bcSMarkus Böck static ParseResult parseLaunchFuncOperands(
755e13d23bcSMarkus Böck OpAsmParser &parser,
756e13d23bcSMarkus Böck SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
7571c1803dbSChristian Sigg SmallVectorImpl<Type> &argTypes) {
7581c1803dbSChristian Sigg if (parser.parseOptionalKeyword("args"))
7591c1803dbSChristian Sigg return success();
760d85eb4e2SChris Lattner
761d85eb4e2SChris Lattner SmallVector<OpAsmParser::Argument> args;
762d85eb4e2SChris Lattner if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
763d85eb4e2SChris Lattner /*allowType=*/true))
764d85eb4e2SChris Lattner return failure();
765d85eb4e2SChris Lattner for (auto &arg : args) {
766d85eb4e2SChris Lattner argNames.push_back(arg.ssaName);
767d85eb4e2SChris Lattner argTypes.push_back(arg.type);
768d85eb4e2SChris Lattner }
769d85eb4e2SChris Lattner return success();
7701c1803dbSChristian Sigg }
7711c1803dbSChristian Sigg
printLaunchFuncOperands(OpAsmPrinter & printer,Operation *,OperandRange operands,TypeRange types)772035e12e6SJohn Demme static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
7731c1803dbSChristian Sigg OperandRange operands, TypeRange types) {
7741c1803dbSChristian Sigg if (operands.empty())
7751c1803dbSChristian Sigg return;
7761c1803dbSChristian Sigg printer << "args(";
7771c1803dbSChristian Sigg llvm::interleaveComma(llvm::zip(operands, types), printer,
7781c1803dbSChristian Sigg [&](const auto &pair) {
7791c1803dbSChristian Sigg printer.printOperand(std::get<0>(pair));
7801c1803dbSChristian Sigg printer << " : ";
7811c1803dbSChristian Sigg printer.printType(std::get<1>(pair));
7821c1803dbSChristian Sigg });
7831c1803dbSChristian Sigg printer << ")";
7841c1803dbSChristian Sigg }
7851c1803dbSChristian Sigg
78669f76471SMatthias Springer //===----------------------------------------------------------------------===//
78769f76471SMatthias Springer // ShuffleOp
78869f76471SMatthias Springer //===----------------------------------------------------------------------===//
78969f76471SMatthias Springer
build(OpBuilder & builder,OperationState & result,Value value,int32_t offset,int32_t width,ShuffleMode mode)79069f76471SMatthias Springer void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
79169f76471SMatthias Springer int32_t offset, int32_t width, ShuffleMode mode) {
79269f76471SMatthias Springer build(builder, result, value,
79369f76471SMatthias Springer builder.create<arith::ConstantOp>(result.location,
79469f76471SMatthias Springer builder.getI32IntegerAttr(offset)),
79569f76471SMatthias Springer builder.create<arith::ConstantOp>(result.location,
79669f76471SMatthias Springer builder.getI32IntegerAttr(width)),
79769f76471SMatthias Springer mode);
79869f76471SMatthias Springer }
79969f76471SMatthias Springer
800bf4692dcSAlex Zinenko //===----------------------------------------------------------------------===//
801bf4692dcSAlex Zinenko // GPUFuncOp
802bf4692dcSAlex Zinenko //===----------------------------------------------------------------------===//
803bf4692dcSAlex Zinenko
804ad398164SWen-Heng (Jack) Chung /// Adds a new block argument that corresponds to buffers located in
805ad398164SWen-Heng (Jack) Chung /// workgroup memory.
addWorkgroupAttribution(Type type,Location loc)806e084679fSRiver Riddle BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
807ad398164SWen-Heng (Jack) Chung auto attrName = getNumWorkgroupAttributionsAttrName();
8080bf4a82aSChristian Sigg auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
8091ffc1aaaSChristian Sigg (*this)->setAttr(attrName,
8101ffc1aaaSChristian Sigg IntegerAttr::get(attr.getType(), attr.getValue() + 1));
8114a3460a7SRiver Riddle return getBody().insertArgument(
8124a3460a7SRiver Riddle getFunctionType().getNumInputs() + attr.getInt(), type, loc);
813ad398164SWen-Heng (Jack) Chung }
814ad398164SWen-Heng (Jack) Chung
815ad398164SWen-Heng (Jack) Chung /// Adds a new block argument that corresponds to buffers located in
816ad398164SWen-Heng (Jack) Chung /// private memory.
addPrivateAttribution(Type type,Location loc)817e084679fSRiver Riddle BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
818ad398164SWen-Heng (Jack) Chung // Buffers on the private memory always come after buffers on the workgroup
819ad398164SWen-Heng (Jack) Chung // memory.
820e084679fSRiver Riddle return getBody().addArgument(type, loc);
82108778d8cSAlex Zinenko }
82208778d8cSAlex Zinenko
build(OpBuilder & builder,OperationState & result,StringRef name,FunctionType type,TypeRange workgroupAttributions,TypeRange privateAttributions,ArrayRef<NamedAttribute> attrs)823bb1d976fSAlex Zinenko void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
824bb1d976fSAlex Zinenko StringRef name, FunctionType type,
82508e4f078SRahul Joshi TypeRange workgroupAttributions,
82608e4f078SRahul Joshi TypeRange privateAttributions,
827bf4692dcSAlex Zinenko ArrayRef<NamedAttribute> attrs) {
828bf4692dcSAlex Zinenko result.addAttribute(SymbolTable::getSymbolAttrName(),
829bb1d976fSAlex Zinenko builder.getStringAttr(name));
830bf4692dcSAlex Zinenko result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
831bf4692dcSAlex Zinenko result.addAttribute(getNumWorkgroupAttributionsAttrName(),
832bb1d976fSAlex Zinenko builder.getI64IntegerAttr(workgroupAttributions.size()));
833bf4692dcSAlex Zinenko result.addAttributes(attrs);
834bf4692dcSAlex Zinenko Region *body = result.addRegion();
835bf4692dcSAlex Zinenko Block *entryBlock = new Block;
836e084679fSRiver Riddle
837e084679fSRiver Riddle // TODO: Allow passing in proper locations here.
838e39dae85SRiver Riddle for (Type argTy : type.getInputs())
839e39dae85SRiver Riddle entryBlock->addArgument(argTy, result.location);
840e39dae85SRiver Riddle for (Type argTy : workgroupAttributions)
841e39dae85SRiver Riddle entryBlock->addArgument(argTy, result.location);
842e39dae85SRiver Riddle for (Type argTy : privateAttributions)
843e39dae85SRiver Riddle entryBlock->addArgument(argTy, result.location);
844bf4692dcSAlex Zinenko
845bf4692dcSAlex Zinenko body->getBlocks().push_back(entryBlock);
846bf4692dcSAlex Zinenko }
847bf4692dcSAlex Zinenko
848bf4692dcSAlex Zinenko /// Parses a GPU function memory attribution.
849bf4692dcSAlex Zinenko ///
850bf4692dcSAlex Zinenko /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
851bf4692dcSAlex Zinenko /// (`private` `(` ssa-id-and-type-list `)`)?
852bf4692dcSAlex Zinenko ///
853bf4692dcSAlex Zinenko /// Note that this function parses only one of the two similar parts, with the
854bf4692dcSAlex Zinenko /// keyword provided as argument.
855bf4692dcSAlex Zinenko static ParseResult
parseAttributions(OpAsmParser & parser,StringRef keyword,SmallVectorImpl<OpAsmParser::Argument> & args)856bf4692dcSAlex Zinenko parseAttributions(OpAsmParser &parser, StringRef keyword,
857d85eb4e2SChris Lattner SmallVectorImpl<OpAsmParser::Argument> &args) {
858bf4692dcSAlex Zinenko // If we could not parse the keyword, just assume empty list and succeed.
859bf4692dcSAlex Zinenko if (failed(parser.parseOptionalKeyword(keyword)))
860bf4692dcSAlex Zinenko return success();
861bf4692dcSAlex Zinenko
862d85eb4e2SChris Lattner return parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
863d85eb4e2SChris Lattner /*allowType=*/true);
864bf4692dcSAlex Zinenko }
865bf4692dcSAlex Zinenko
866bf4692dcSAlex Zinenko /// Parses a GPU function.
867bf4692dcSAlex Zinenko ///
868bf4692dcSAlex Zinenko /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
869bf4692dcSAlex Zinenko /// (`->` function-result-list)? memory-attribution `kernel`?
870bf4692dcSAlex Zinenko /// function-attributes? region
parse(OpAsmParser & parser,OperationState & result)8712418cd92SRiver Riddle ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
872d85eb4e2SChris Lattner SmallVector<OpAsmParser::Argument> entryArgs;
873d85eb4e2SChris Lattner SmallVector<DictionaryAttr> resultAttrs;
8741e09f0a9SDominik Grewe SmallVector<Type> resultTypes;
875bf4692dcSAlex Zinenko bool isVariadic;
876bf4692dcSAlex Zinenko
877bf4692dcSAlex Zinenko // Parse the function name.
878bf4692dcSAlex Zinenko StringAttr nameAttr;
879bf4692dcSAlex Zinenko if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
880bf4692dcSAlex Zinenko result.attributes))
881bf4692dcSAlex Zinenko return failure();
882bf4692dcSAlex Zinenko
883bf4692dcSAlex Zinenko auto signatureLocation = parser.getCurrentLocation();
8847ceffae1SRiver Riddle if (failed(function_interface_impl::parseFunctionSignature(
885d85eb4e2SChris Lattner parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
886d85eb4e2SChris Lattner resultAttrs)))
887bf4692dcSAlex Zinenko return failure();
888bf4692dcSAlex Zinenko
889d85eb4e2SChris Lattner if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
890bf4692dcSAlex Zinenko return parser.emitError(signatureLocation)
891bf4692dcSAlex Zinenko << "gpu.func requires named arguments";
892bf4692dcSAlex Zinenko
893bf4692dcSAlex Zinenko // Construct the function type. More types will be added to the region, but
89426927518SStephan Herhut // not to the function type.
895bf4692dcSAlex Zinenko Builder &builder = parser.getBuilder();
896d85eb4e2SChris Lattner
897d85eb4e2SChris Lattner SmallVector<Type> argTypes;
898d85eb4e2SChris Lattner for (auto &arg : entryArgs)
899d85eb4e2SChris Lattner argTypes.push_back(arg.type);
900bf4692dcSAlex Zinenko auto type = builder.getFunctionType(argTypes, resultTypes);
901ccc767d6SAlex Zinenko result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
902bf4692dcSAlex Zinenko
903d85eb4e2SChris Lattner function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
904d85eb4e2SChris Lattner resultAttrs);
905d85eb4e2SChris Lattner
906bf4692dcSAlex Zinenko // Parse workgroup memory attributions.
907ccc767d6SAlex Zinenko if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
908d85eb4e2SChris Lattner entryArgs)))
909bf4692dcSAlex Zinenko return failure();
910bf4692dcSAlex Zinenko
911bf4692dcSAlex Zinenko // Store the number of operands we just parsed as the number of workgroup
912bf4692dcSAlex Zinenko // memory attributions.
913d85eb4e2SChris Lattner unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
914ccc767d6SAlex Zinenko result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
915bf4692dcSAlex Zinenko builder.getI64IntegerAttr(numWorkgroupAttrs));
916bf4692dcSAlex Zinenko
917bf4692dcSAlex Zinenko // Parse private memory attributions.
918d85eb4e2SChris Lattner if (failed(
919d85eb4e2SChris Lattner parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), entryArgs)))
920bf4692dcSAlex Zinenko return failure();
921bf4692dcSAlex Zinenko
922bf4692dcSAlex Zinenko // Parse the kernel attribute if present.
923ccc767d6SAlex Zinenko if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
924bf4692dcSAlex Zinenko result.addAttribute(GPUDialect::getKernelFuncAttrName(),
925bf4692dcSAlex Zinenko builder.getUnitAttr());
926bf4692dcSAlex Zinenko
927bf4692dcSAlex Zinenko // Parse attributes.
928bf4692dcSAlex Zinenko if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
929bf4692dcSAlex Zinenko return failure();
930bf4692dcSAlex Zinenko
931bf4692dcSAlex Zinenko // Parse the region. If no argument names were provided, take all names
932bf4692dcSAlex Zinenko // (including those of attributions) from the entry block.
933bf4692dcSAlex Zinenko auto *body = result.addRegion();
934d85eb4e2SChris Lattner return parser.parseRegion(*body, entryArgs);
935bf4692dcSAlex Zinenko }
936bf4692dcSAlex Zinenko
printAttributions(OpAsmPrinter & p,StringRef keyword,ArrayRef<BlockArgument> values)937bf4692dcSAlex Zinenko static void printAttributions(OpAsmPrinter &p, StringRef keyword,
938e62a6956SRiver Riddle ArrayRef<BlockArgument> values) {
939bf4692dcSAlex Zinenko if (values.empty())
940bf4692dcSAlex Zinenko return;
941bf4692dcSAlex Zinenko
942bf4692dcSAlex Zinenko p << ' ' << keyword << '(';
9432f21a579SRiver Riddle llvm::interleaveComma(
9442f21a579SRiver Riddle values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
945bf4692dcSAlex Zinenko p << ')';
946bf4692dcSAlex Zinenko }
947bf4692dcSAlex Zinenko
print(OpAsmPrinter & p)9482418cd92SRiver Riddle void GPUFuncOp::print(OpAsmPrinter &p) {
949c41b16c2SMehdi Amini p << ' ';
9502418cd92SRiver Riddle p.printSymbolName(getName());
951bf4692dcSAlex Zinenko
9524a3460a7SRiver Riddle FunctionType type = getFunctionType();
9532418cd92SRiver Riddle function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
9542418cd92SRiver Riddle /*isVariadic=*/false,
9552418cd92SRiver Riddle type.getResults());
956bf4692dcSAlex Zinenko
9572418cd92SRiver Riddle printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
9582418cd92SRiver Riddle printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
9592418cd92SRiver Riddle if (isKernel())
9602418cd92SRiver Riddle p << ' ' << getKernelKeyword();
961bf4692dcSAlex Zinenko
9627ceffae1SRiver Riddle function_interface_impl::printFunctionAttributes(
9632418cd92SRiver Riddle p, *this, type.getNumInputs(), type.getNumResults(),
9642418cd92SRiver Riddle {getNumWorkgroupAttributionsAttrName(),
965bf4692dcSAlex Zinenko GPUDialect::getKernelFuncAttrName()});
9665c36ee8dSMogball p << ' ';
9672418cd92SRiver Riddle p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
968bf4692dcSAlex Zinenko }
969bf4692dcSAlex Zinenko
verifyType()970bf4692dcSAlex Zinenko LogicalResult GPUFuncOp::verifyType() {
9714a3460a7SRiver Riddle Type type = getFunctionTypeAttr().getValue();
972bf4692dcSAlex Zinenko if (!type.isa<FunctionType>())
973bf4692dcSAlex Zinenko return emitOpError("requires '" + getTypeAttrName() +
974bf4692dcSAlex Zinenko "' attribute of function type");
97526927518SStephan Herhut
9764a3460a7SRiver Riddle if (isKernel() && getFunctionType().getNumResults() != 0)
97726927518SStephan Herhut return emitOpError() << "expected void return type for kernel function";
97826927518SStephan Herhut
979bf4692dcSAlex Zinenko return success();
980bf4692dcSAlex Zinenko }
981bf4692dcSAlex Zinenko
verifyAttributions(Operation * op,ArrayRef<BlockArgument> attributions,unsigned memorySpace)98240ef46fbSAlex Zinenko static LogicalResult verifyAttributions(Operation *op,
983e62a6956SRiver Riddle ArrayRef<BlockArgument> attributions,
98440ef46fbSAlex Zinenko unsigned memorySpace) {
985e62a6956SRiver Riddle for (Value v : attributions) {
9862bdf33ccSRiver Riddle auto type = v.getType().dyn_cast<MemRefType>();
98740ef46fbSAlex Zinenko if (!type)
98840ef46fbSAlex Zinenko return op->emitOpError() << "expected memref type in attribution";
98940ef46fbSAlex Zinenko
99037eca08eSVladislav Vinogradov if (type.getMemorySpaceAsInt() != memorySpace) {
99140ef46fbSAlex Zinenko return op->emitOpError()
99240ef46fbSAlex Zinenko << "expected memory space " << memorySpace << " in attribution";
99340ef46fbSAlex Zinenko }
99440ef46fbSAlex Zinenko }
99540ef46fbSAlex Zinenko return success();
99640ef46fbSAlex Zinenko }
99740ef46fbSAlex Zinenko
998bf4692dcSAlex Zinenko /// Verifies the body of the function.
verifyBody()999bf4692dcSAlex Zinenko LogicalResult GPUFuncOp::verifyBody() {
1000bf4692dcSAlex Zinenko unsigned numFuncArguments = getNumArguments();
1001bf4692dcSAlex Zinenko unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1002bf4692dcSAlex Zinenko unsigned numBlockArguments = front().getNumArguments();
1003603b974cSWen-Heng (Jack) Chung if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1004bf4692dcSAlex Zinenko return emitOpError() << "expected at least "
1005603b974cSWen-Heng (Jack) Chung << numFuncArguments + numWorkgroupAttributions
1006bf4692dcSAlex Zinenko << " arguments to body region";
1007bf4692dcSAlex Zinenko
10084a3460a7SRiver Riddle ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1009bf4692dcSAlex Zinenko for (unsigned i = 0; i < numFuncArguments; ++i) {
10102bdf33ccSRiver Riddle Type blockArgType = front().getArgument(i).getType();
1011bf4692dcSAlex Zinenko if (funcArgTypes[i] != blockArgType)
1012bf4692dcSAlex Zinenko return emitOpError() << "expected body region argument #" << i
1013bf4692dcSAlex Zinenko << " to be of type " << funcArgTypes[i] << ", got "
1014bf4692dcSAlex Zinenko << blockArgType;
1015bf4692dcSAlex Zinenko }
1016bf4692dcSAlex Zinenko
101740ef46fbSAlex Zinenko if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
101840ef46fbSAlex Zinenko GPUDialect::getWorkgroupAddressSpace())) ||
101940ef46fbSAlex Zinenko failed(verifyAttributions(getOperation(), getPrivateAttributions(),
102040ef46fbSAlex Zinenko GPUDialect::getPrivateAddressSpace())))
102140ef46fbSAlex Zinenko return failure();
102240ef46fbSAlex Zinenko
1023bf4692dcSAlex Zinenko return success();
1024bf4692dcSAlex Zinenko }
1025ccc767d6SAlex Zinenko
10269a52ea5cSTres Popp //===----------------------------------------------------------------------===//
102726927518SStephan Herhut // ReturnOp
102826927518SStephan Herhut //===----------------------------------------------------------------------===//
102926927518SStephan Herhut
verify()1030094ede6dSRiver Riddle LogicalResult gpu::ReturnOp::verify() {
1031094ede6dSRiver Riddle GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
103226927518SStephan Herhut
10334a3460a7SRiver Riddle FunctionType funType = function.getFunctionType();
103426927518SStephan Herhut
1035094ede6dSRiver Riddle if (funType.getNumResults() != operands().size())
1036094ede6dSRiver Riddle return emitOpError()
103726927518SStephan Herhut .append("expected ", funType.getNumResults(), " result operands")
103826927518SStephan Herhut .attachNote(function.getLoc())
103926927518SStephan Herhut .append("return type declared here");
104026927518SStephan Herhut
1041e4853be2SMehdi Amini for (const auto &pair : llvm::enumerate(
10424a3460a7SRiver Riddle llvm::zip(function.getFunctionType().getResults(), operands()))) {
104326927518SStephan Herhut Type type;
104426927518SStephan Herhut Value operand;
104526927518SStephan Herhut std::tie(type, operand) = pair.value();
104626927518SStephan Herhut if (type != operand.getType())
1047094ede6dSRiver Riddle return emitOpError() << "unexpected type `" << operand.getType()
104826927518SStephan Herhut << "' for operand #" << pair.index();
104926927518SStephan Herhut }
105026927518SStephan Herhut return success();
105126927518SStephan Herhut }
105226927518SStephan Herhut
105326927518SStephan Herhut //===----------------------------------------------------------------------===//
10549a52ea5cSTres Popp // GPUModuleOp
10559a52ea5cSTres Popp //===----------------------------------------------------------------------===//
10569a52ea5cSTres Popp
build(OpBuilder & builder,OperationState & result,StringRef name)1057bb1d976fSAlex Zinenko void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
10589a52ea5cSTres Popp StringRef name) {
1059bb1d976fSAlex Zinenko ensureTerminator(*result.addRegion(), builder, result.location);
1060bb1d976fSAlex Zinenko result.attributes.push_back(builder.getNamedAttr(
1061bb1d976fSAlex Zinenko ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
10629a52ea5cSTres Popp }
10639a52ea5cSTres Popp
parse(OpAsmParser & parser,OperationState & result)10642418cd92SRiver Riddle ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
10659a52ea5cSTres Popp StringAttr nameAttr;
10662418cd92SRiver Riddle if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
1067d85eb4e2SChris Lattner result.attributes) ||
10689a52ea5cSTres Popp // If module attributes are present, parse them.
1069d85eb4e2SChris Lattner parser.parseOptionalAttrDictWithKeyword(result.attributes))
10709a52ea5cSTres Popp return failure();
10719a52ea5cSTres Popp
10729a52ea5cSTres Popp // Parse the module body.
10739a52ea5cSTres Popp auto *body = result.addRegion();
1074d85eb4e2SChris Lattner if (parser.parseRegion(*body, {}))
10759a52ea5cSTres Popp return failure();
10769a52ea5cSTres Popp
10779a52ea5cSTres Popp // Ensure that this module has a valid terminator.
10789a52ea5cSTres Popp GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
10799a52ea5cSTres Popp return success();
10809a52ea5cSTres Popp }
10819a52ea5cSTres Popp
print(OpAsmPrinter & p)10822418cd92SRiver Riddle void GPUModuleOp::print(OpAsmPrinter &p) {
1083c41b16c2SMehdi Amini p << ' ';
10842418cd92SRiver Riddle p.printSymbolName(getName());
10852418cd92SRiver Riddle p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
10862418cd92SRiver Riddle {mlir::SymbolTable::getSymbolAttrName()});
10875c36ee8dSMogball p << ' ';
10882418cd92SRiver Riddle p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
10899a52ea5cSTres Popp /*printBlockTerminators=*/false);
10909a52ea5cSTres Popp }
10919a52ea5cSTres Popp
10920955d8dfSChristian Sigg //===----------------------------------------------------------------------===//
10930955d8dfSChristian Sigg // GPUMemcpyOp
10940955d8dfSChristian Sigg //===----------------------------------------------------------------------===//
10950955d8dfSChristian Sigg
verify()1096094ede6dSRiver Riddle LogicalResult MemcpyOp::verify() {
1097094ede6dSRiver Riddle auto srcType = src().getType();
1098094ede6dSRiver Riddle auto dstType = dst().getType();
10990955d8dfSChristian Sigg
11000955d8dfSChristian Sigg if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1101094ede6dSRiver Riddle return emitOpError("arguments have incompatible element type");
11020955d8dfSChristian Sigg
11030955d8dfSChristian Sigg if (failed(verifyCompatibleShape(srcType, dstType)))
1104094ede6dSRiver Riddle return emitOpError("arguments have incompatible shape");
11050955d8dfSChristian Sigg
11060955d8dfSChristian Sigg return success();
11070955d8dfSChristian Sigg }
11080955d8dfSChristian Sigg
110916219f8cSArnab Dutta namespace {
111016219f8cSArnab Dutta
111116219f8cSArnab Dutta /// Erases a common case of copy ops where a destination value is used only by
111216219f8cSArnab Dutta /// the copy op, alloc and dealloc ops.
111316219f8cSArnab Dutta struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
111416219f8cSArnab Dutta using OpRewritePattern<MemcpyOp>::OpRewritePattern;
111516219f8cSArnab Dutta
matchAndRewrite__anonf220e5a50911::EraseTrivialCopyOp111616219f8cSArnab Dutta LogicalResult matchAndRewrite(MemcpyOp op,
111716219f8cSArnab Dutta PatternRewriter &rewriter) const override {
111816219f8cSArnab Dutta Value dest = op.dst();
111916219f8cSArnab Dutta Operation *destDefOp = dest.getDefiningOp();
112016219f8cSArnab Dutta // `dest` must be defined by an op having Allocate memory effect in order to
112116219f8cSArnab Dutta // perform the folding.
112216219f8cSArnab Dutta if (!destDefOp ||
112316219f8cSArnab Dutta !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
112416219f8cSArnab Dutta return failure();
112516219f8cSArnab Dutta // We can erase `op` iff `dest` has no other use apart from its
112616219f8cSArnab Dutta // use by `op` and dealloc ops.
112716219f8cSArnab Dutta if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
112816219f8cSArnab Dutta return user != op &&
112916219f8cSArnab Dutta !hasSingleEffect<MemoryEffects::Free>(user, dest);
113016219f8cSArnab Dutta }))
113116219f8cSArnab Dutta return failure();
113216219f8cSArnab Dutta // We can perform the folding if and only if op has a single async
113316219f8cSArnab Dutta // dependency and produces an async token as result, or if it does not have
113416219f8cSArnab Dutta // any async dependency and does not produce any async token result.
113516219f8cSArnab Dutta if (op.asyncDependencies().size() > 1 ||
113616219f8cSArnab Dutta ((op.asyncDependencies().empty() && op.asyncToken()) ||
113716219f8cSArnab Dutta (!op.asyncDependencies().empty() && !op.asyncToken())))
113816219f8cSArnab Dutta return failure();
113916219f8cSArnab Dutta rewriter.replaceOp(op, op.asyncDependencies());
114016219f8cSArnab Dutta return success();
114116219f8cSArnab Dutta }
114216219f8cSArnab Dutta };
114316219f8cSArnab Dutta
114416219f8cSArnab Dutta } // end anonymous namespace
114516219f8cSArnab Dutta
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)114616219f8cSArnab Dutta void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
114716219f8cSArnab Dutta MLIRContext *context) {
114816219f8cSArnab Dutta results.add<EraseTrivialCopyOp>(context);
114916219f8cSArnab Dutta }
115016219f8cSArnab Dutta
1151875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
1152875eb523SNavdeep Kumar // GPU_SubgroupMmaLoadMatrixOp
1153875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
1154875eb523SNavdeep Kumar
1155d77f4836SThomas Raoux /// Return true if the last dimension of the MemRefType has unit stride. Also
1156d77f4836SThomas Raoux /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)1157d77f4836SThomas Raoux static bool isLastMemrefDimUnitStride(MemRefType type) {
1158d77f4836SThomas Raoux int64_t offset;
1159d77f4836SThomas Raoux SmallVector<int64_t> strides;
1160d77f4836SThomas Raoux if (failed(getStridesAndOffset(type, strides, offset))) {
1161d77f4836SThomas Raoux return false;
1162d77f4836SThomas Raoux }
1163d77f4836SThomas Raoux return strides.back() == 1;
1164d77f4836SThomas Raoux }
1165d77f4836SThomas Raoux
verify()1166094ede6dSRiver Riddle LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1167094ede6dSRiver Riddle auto srcType = srcMemref().getType();
1168094ede6dSRiver Riddle auto resType = res().getType();
1169875eb523SNavdeep Kumar auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
1170875eb523SNavdeep Kumar auto operand = resMatrixType.getOperand();
1171875eb523SNavdeep Kumar auto srcMemrefType = srcType.cast<MemRefType>();
1172875eb523SNavdeep Kumar auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
1173875eb523SNavdeep Kumar
1174d77f4836SThomas Raoux if (!isLastMemrefDimUnitStride(srcMemrefType))
1175d77f4836SThomas Raoux return emitError(
1176d77f4836SThomas Raoux "expected source memref most minor dim must have unit stride");
1177875eb523SNavdeep Kumar
1178875eb523SNavdeep Kumar if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
1179875eb523SNavdeep Kumar srcMemSpace != kGlobalMemorySpace)
1180094ede6dSRiver Riddle return emitError(
1181875eb523SNavdeep Kumar "source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
1182875eb523SNavdeep Kumar "kGlobalMemorySpace only allowed");
1183875eb523SNavdeep Kumar
1184875eb523SNavdeep Kumar if (!operand.equals("AOp") && !operand.equals("BOp") &&
1185875eb523SNavdeep Kumar !operand.equals("COp"))
1186094ede6dSRiver Riddle return emitError("only AOp, BOp and COp can be loaded");
1187875eb523SNavdeep Kumar
1188875eb523SNavdeep Kumar return success();
1189875eb523SNavdeep Kumar }
1190875eb523SNavdeep Kumar
1191875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
1192875eb523SNavdeep Kumar // GPU_SubgroupMmaStoreMatrixOp
1193875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
1194875eb523SNavdeep Kumar
verify()1195094ede6dSRiver Riddle LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1196094ede6dSRiver Riddle auto srcType = src().getType();
1197094ede6dSRiver Riddle auto dstType = dstMemref().getType();
1198875eb523SNavdeep Kumar auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
1199875eb523SNavdeep Kumar auto dstMemrefType = dstType.cast<MemRefType>();
1200875eb523SNavdeep Kumar auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
1201d77f4836SThomas Raoux
1202d77f4836SThomas Raoux if (!isLastMemrefDimUnitStride(dstMemrefType))
1203d77f4836SThomas Raoux return emitError(
1204d77f4836SThomas Raoux "expected destination memref most minor dim must have unit stride");
1205875eb523SNavdeep Kumar
1206875eb523SNavdeep Kumar if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
1207875eb523SNavdeep Kumar dstMemSpace != kGlobalMemorySpace)
1208094ede6dSRiver Riddle return emitError("destination memorySpace of kGenericMemorySpace, "
1209875eb523SNavdeep Kumar "kGlobalMemorySpace or kSharedMemorySpace only allowed");
1210875eb523SNavdeep Kumar
1211b44007beSthomasraoux if (!srcMatrixType.getOperand().equals("COp"))
1212094ede6dSRiver Riddle return emitError(
1213b44007beSthomasraoux "expected the operand matrix being stored to have 'COp' operand type");
1214875eb523SNavdeep Kumar
1215875eb523SNavdeep Kumar return success();
1216875eb523SNavdeep Kumar }
1217875eb523SNavdeep Kumar
1218875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
1219875eb523SNavdeep Kumar // GPU_SubgroupMmaComputeOp
1220875eb523SNavdeep Kumar //===----------------------------------------------------------------------===//
1221875eb523SNavdeep Kumar
verify()1222094ede6dSRiver Riddle LogicalResult SubgroupMmaComputeOp::verify() {
1223875eb523SNavdeep Kumar enum OperandMap { A, B, C };
1224875eb523SNavdeep Kumar SmallVector<MMAMatrixType, 3> opTypes;
1225094ede6dSRiver Riddle opTypes.push_back(opA().getType().cast<MMAMatrixType>());
1226094ede6dSRiver Riddle opTypes.push_back(opB().getType().cast<MMAMatrixType>());
1227094ede6dSRiver Riddle opTypes.push_back(opC().getType().cast<MMAMatrixType>());
1228875eb523SNavdeep Kumar
1229875eb523SNavdeep Kumar if (!opTypes[A].getOperand().equals("AOp") ||
1230875eb523SNavdeep Kumar !opTypes[B].getOperand().equals("BOp") ||
1231875eb523SNavdeep Kumar !opTypes[C].getOperand().equals("COp"))
1232094ede6dSRiver Riddle return emitError("operands must be in the order AOp, BOp, COp");
1233875eb523SNavdeep Kumar
1234875eb523SNavdeep Kumar ArrayRef<int64_t> aShape, bShape, cShape;
1235875eb523SNavdeep Kumar aShape = opTypes[A].getShape();
1236875eb523SNavdeep Kumar bShape = opTypes[B].getShape();
1237875eb523SNavdeep Kumar cShape = opTypes[C].getShape();
1238875eb523SNavdeep Kumar
1239875eb523SNavdeep Kumar if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1240875eb523SNavdeep Kumar bShape[1] != cShape[1])
1241094ede6dSRiver Riddle return emitError("operand shapes do not satisfy matmul constraints");
1242875eb523SNavdeep Kumar
1243875eb523SNavdeep Kumar return success();
1244875eb523SNavdeep Kumar }
1245875eb523SNavdeep Kumar
124600b6463bSWilliam S. Moses /// This is a common class used for patterns of the form
124700b6463bSWilliam S. Moses /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
124800b6463bSWilliam S. Moses /// into the root operation directly.
foldMemRefCast(Operation * op)124900b6463bSWilliam S. Moses static LogicalResult foldMemRefCast(Operation *op) {
125000b6463bSWilliam S. Moses bool folded = false;
125100b6463bSWilliam S. Moses for (OpOperand &operand : op->getOpOperands()) {
125200b6463bSWilliam S. Moses auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>();
125300b6463bSWilliam S. Moses if (cast) {
125400b6463bSWilliam S. Moses operand.set(cast.getOperand());
125500b6463bSWilliam S. Moses folded = true;
125600b6463bSWilliam S. Moses }
125700b6463bSWilliam S. Moses }
125800b6463bSWilliam S. Moses return success(folded);
125900b6463bSWilliam S. Moses }
126000b6463bSWilliam S. Moses
fold(ArrayRef<Attribute> operands,SmallVectorImpl<::mlir::OpFoldResult> & results)126100b6463bSWilliam S. Moses LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
126200b6463bSWilliam S. Moses SmallVectorImpl<::mlir::OpFoldResult> &results) {
126300b6463bSWilliam S. Moses return foldMemRefCast(*this);
126400b6463bSWilliam S. Moses }
126500b6463bSWilliam S. Moses
fold(ArrayRef<Attribute> operands,SmallVectorImpl<::mlir::OpFoldResult> & results)1266361458b1SLoren Maggiore LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
1267361458b1SLoren Maggiore SmallVectorImpl<::mlir::OpFoldResult> &results) {
1268361458b1SLoren Maggiore return foldMemRefCast(*this);
1269361458b1SLoren Maggiore }
1270361458b1SLoren Maggiore
12710080d2aaSmarina kolpakova a.k.a. geexie //===----------------------------------------------------------------------===//
1272392d55c1SArnab Dutta // GPU_WaitOp
1273392d55c1SArnab Dutta //===----------------------------------------------------------------------===//
1274392d55c1SArnab Dutta
1275392d55c1SArnab Dutta namespace {
1276392d55c1SArnab Dutta
1277392d55c1SArnab Dutta /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1278392d55c1SArnab Dutta /// %t = gpu.wait async [] // No async dependencies.
1279392d55c1SArnab Dutta /// ... gpu.wait ... [%t, ...] // %t can be removed.
1280392d55c1SArnab Dutta struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1281392d55c1SArnab Dutta public:
1282392d55c1SArnab Dutta using OpRewritePattern::OpRewritePattern;
1283392d55c1SArnab Dutta
matchAndRewrite__anonf220e5a50b11::EraseRedundantGpuWaitOpPairs1284392d55c1SArnab Dutta LogicalResult matchAndRewrite(WaitOp op,
1285392d55c1SArnab Dutta PatternRewriter &rewriter) const final {
1286392d55c1SArnab Dutta auto predicate = [](Value value) {
128721b25162SMehdi Amini auto waitOp = value.getDefiningOp<WaitOp>();
128821b25162SMehdi Amini return waitOp && waitOp->getNumOperands() == 0;
1289392d55c1SArnab Dutta };
1290392d55c1SArnab Dutta if (llvm::none_of(op.asyncDependencies(), predicate))
1291392d55c1SArnab Dutta return failure();
1292392d55c1SArnab Dutta SmallVector<Value> validOperands;
1293392d55c1SArnab Dutta for (Value operand : op->getOperands()) {
1294392d55c1SArnab Dutta if (predicate(operand))
1295392d55c1SArnab Dutta continue;
1296392d55c1SArnab Dutta validOperands.push_back(operand);
1297392d55c1SArnab Dutta }
1298392d55c1SArnab Dutta op->setOperands(validOperands);
1299392d55c1SArnab Dutta return success();
1300392d55c1SArnab Dutta }
1301392d55c1SArnab Dutta };
1302392d55c1SArnab Dutta
1303392d55c1SArnab Dutta /// Simplify trivial gpu.wait ops for the following patterns.
1304392d55c1SArnab Dutta /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
1305392d55c1SArnab Dutta /// dependencies).
1306392d55c1SArnab Dutta /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
1307392d55c1SArnab Dutta /// %t0.
1308392d55c1SArnab Dutta /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
1309392d55c1SArnab Dutta /// dependencies nor return any token.
1310392d55c1SArnab Dutta struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
1311392d55c1SArnab Dutta public:
1312392d55c1SArnab Dutta using OpRewritePattern::OpRewritePattern;
1313392d55c1SArnab Dutta
matchAndRewrite__anonf220e5a50b11::SimplifyGpuWaitOp1314392d55c1SArnab Dutta LogicalResult matchAndRewrite(WaitOp op,
1315392d55c1SArnab Dutta PatternRewriter &rewriter) const final {
1316392d55c1SArnab Dutta // Erase gpu.wait ops that neither have any async dependencies nor return
1317392d55c1SArnab Dutta // any async token.
1318392d55c1SArnab Dutta if (op.asyncDependencies().empty() && !op.asyncToken()) {
1319392d55c1SArnab Dutta rewriter.eraseOp(op);
1320392d55c1SArnab Dutta return success();
1321392d55c1SArnab Dutta }
1322392d55c1SArnab Dutta // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
1323392d55c1SArnab Dutta if (llvm::hasSingleElement(op.asyncDependencies()) && op.asyncToken()) {
1324392d55c1SArnab Dutta rewriter.replaceOp(op, op.asyncDependencies());
1325392d55c1SArnab Dutta return success();
1326392d55c1SArnab Dutta }
1327392d55c1SArnab Dutta // Erase %t = gpu.wait async ... ops, where %t has no uses.
1328392d55c1SArnab Dutta if (op.asyncToken() && op.asyncToken().use_empty()) {
1329392d55c1SArnab Dutta rewriter.eraseOp(op);
1330392d55c1SArnab Dutta return success();
1331392d55c1SArnab Dutta }
1332392d55c1SArnab Dutta return failure();
1333392d55c1SArnab Dutta }
1334392d55c1SArnab Dutta };
1335392d55c1SArnab Dutta
1336392d55c1SArnab Dutta } // end anonymous namespace
1337392d55c1SArnab Dutta
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1338392d55c1SArnab Dutta void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
1339392d55c1SArnab Dutta MLIRContext *context) {
1340392d55c1SArnab Dutta results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
1341392d55c1SArnab Dutta }
1342392d55c1SArnab Dutta
1343392d55c1SArnab Dutta //===----------------------------------------------------------------------===//
13440080d2aaSmarina kolpakova a.k.a. geexie // GPU_AllocOp
13450080d2aaSmarina kolpakova a.k.a. geexie //===----------------------------------------------------------------------===//
1346f1efac7fSAkshay Baviskar
verify()1347bbfec2a1SRiver Riddle LogicalResult AllocOp::verify() {
1348bbfec2a1SRiver Riddle auto memRefType = memref().getType().cast<MemRefType>();
1349f1efac7fSAkshay Baviskar
1350bbfec2a1SRiver Riddle if (static_cast<int64_t>(dynamicSizes().size()) !=
1351f1efac7fSAkshay Baviskar memRefType.getNumDynamicDims())
1352bbfec2a1SRiver Riddle return emitOpError("dimension operand count does not equal memref "
1353f1efac7fSAkshay Baviskar "dynamic dimension count");
1354f1efac7fSAkshay Baviskar
1355f1efac7fSAkshay Baviskar unsigned numSymbols = 0;
1356f1efac7fSAkshay Baviskar if (!memRefType.getLayout().isIdentity())
1357f1efac7fSAkshay Baviskar numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
1358bbfec2a1SRiver Riddle if (symbolOperands().size() != numSymbols) {
1359bbfec2a1SRiver Riddle return emitOpError(
1360bbfec2a1SRiver Riddle "symbol operand count does not equal memref symbol count");
1361bbfec2a1SRiver Riddle }
1362f1efac7fSAkshay Baviskar
1363f1efac7fSAkshay Baviskar return success();
1364f1efac7fSAkshay Baviskar }
1365f1efac7fSAkshay Baviskar
13660080d2aaSmarina kolpakova a.k.a. geexie namespace {
13670080d2aaSmarina kolpakova a.k.a. geexie
13680080d2aaSmarina kolpakova a.k.a. geexie /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
13690080d2aaSmarina kolpakova a.k.a. geexie /// `memref::AllocOp`.
13700080d2aaSmarina kolpakova a.k.a. geexie struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
13710080d2aaSmarina kolpakova a.k.a. geexie using OpRewritePattern<memref::DimOp>::OpRewritePattern;
13720080d2aaSmarina kolpakova a.k.a. geexie
matchAndRewrite__anonf220e5a50d11::SimplifyDimOfAllocOp13730080d2aaSmarina kolpakova a.k.a. geexie LogicalResult matchAndRewrite(memref::DimOp dimOp,
13740080d2aaSmarina kolpakova a.k.a. geexie PatternRewriter &rewriter) const override {
1375*136d746eSJacques Pienaar auto index = dimOp.getIndex().getDefiningOp<arith::ConstantIndexOp>();
13760080d2aaSmarina kolpakova a.k.a. geexie if (!index)
13770080d2aaSmarina kolpakova a.k.a. geexie return failure();
13780080d2aaSmarina kolpakova a.k.a. geexie
1379*136d746eSJacques Pienaar auto memrefType = dimOp.getSource().getType().dyn_cast<MemRefType>();
1380a54f4eaeSMogball if (!memrefType || !memrefType.isDynamicDim(index.value()))
13810080d2aaSmarina kolpakova a.k.a. geexie return failure();
13820080d2aaSmarina kolpakova a.k.a. geexie
1383*136d746eSJacques Pienaar auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
13840080d2aaSmarina kolpakova a.k.a. geexie if (!alloc)
13850080d2aaSmarina kolpakova a.k.a. geexie return failure();
13860080d2aaSmarina kolpakova a.k.a. geexie
13870080d2aaSmarina kolpakova a.k.a. geexie Value substituteOp = *(alloc.dynamicSizes().begin() +
1388a54f4eaeSMogball memrefType.getDynamicDimIndex(index.value()));
13890080d2aaSmarina kolpakova a.k.a. geexie rewriter.replaceOp(dimOp, substituteOp);
13900080d2aaSmarina kolpakova a.k.a. geexie return success();
13910080d2aaSmarina kolpakova a.k.a. geexie }
13920080d2aaSmarina kolpakova a.k.a. geexie };
13930080d2aaSmarina kolpakova a.k.a. geexie
1394be0a7e9fSMehdi Amini } // namespace
13950080d2aaSmarina kolpakova a.k.a. geexie
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)13960080d2aaSmarina kolpakova a.k.a. geexie void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
13970080d2aaSmarina kolpakova a.k.a. geexie MLIRContext *context) {
13980080d2aaSmarina kolpakova a.k.a. geexie results.add<SimplifyDimOfAllocOp>(context);
13990080d2aaSmarina kolpakova a.k.a. geexie }
14000080d2aaSmarina kolpakova a.k.a. geexie
1401d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
1402d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
1403473b364aSChristian Sigg
1404aae51255SMogball #define GET_ATTRDEF_CLASSES
1405d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
1406aae51255SMogball
1407ccc767d6SAlex Zinenko #define GET_OP_CLASSES
1408d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
1409