11090a830SDenis Khalikov //===- ConvertGPULaunchFuncToVulkanLaunchFunc.cpp - MLIR conversion pass --===//
21090a830SDenis Khalikov //
31090a830SDenis Khalikov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41090a830SDenis Khalikov // See https://llvm.org/LICENSE.txt for license information.
51090a830SDenis Khalikov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61090a830SDenis Khalikov //
71090a830SDenis Khalikov //===----------------------------------------------------------------------===//
81090a830SDenis Khalikov //
91090a830SDenis Khalikov // This file implements a pass to convert gpu launch function into a vulkan
101090a830SDenis Khalikov // launch function. Creates a SPIR-V binary shader from the `spirv::ModuleOp`
111090a830SDenis Khalikov // using `spirv::serialize` function, attaches binary data and entry point name
121090a830SDenis Khalikov // as an attributes to vulkan launch call op.
131090a830SDenis Khalikov //
141090a830SDenis Khalikov //===----------------------------------------------------------------------===//
151090a830SDenis Khalikov
161834ad4aSRiver Riddle #include "../PassDetail.h"
171090a830SDenis Khalikov #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
19*d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
2001178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
2101178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
221090a830SDenis Khalikov #include "mlir/IR/Attributes.h"
231090a830SDenis Khalikov #include "mlir/IR/Builders.h"
2465fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
2509f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
26ecab6389Sergawy #include "mlir/Target/SPIRV/Serialization.h"
271090a830SDenis Khalikov
281090a830SDenis Khalikov using namespace mlir;
291090a830SDenis Khalikov
301090a830SDenis Khalikov static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
311090a830SDenis Khalikov static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
321090a830SDenis Khalikov static constexpr const char *kVulkanLaunch = "vulkanLaunch";
331090a830SDenis Khalikov
341090a830SDenis Khalikov namespace {
351090a830SDenis Khalikov
36bfb2ce02SDenis Khalikov /// A pass to convert gpu launch op to vulkan launch call op, by creating a
37bfb2ce02SDenis Khalikov /// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize`
38bfb2ce02SDenis Khalikov /// function and attaching binary data and entry point name as an attributes to
39bfb2ce02SDenis Khalikov /// created vulkan launch call op.
401090a830SDenis Khalikov class ConvertGpuLaunchFuncToVulkanLaunchFunc
411834ad4aSRiver Riddle : public ConvertGpuLaunchFuncToVulkanLaunchFuncBase<
421834ad4aSRiver Riddle ConvertGpuLaunchFuncToVulkanLaunchFunc> {
431090a830SDenis Khalikov public:
44722f909fSRiver Riddle void runOnOperation() override;
451090a830SDenis Khalikov
461090a830SDenis Khalikov private:
471090a830SDenis Khalikov /// Creates a SPIR-V binary shader from the given `module` using
481090a830SDenis Khalikov /// `spirv::serialize` function.
491090a830SDenis Khalikov LogicalResult createBinaryShader(ModuleOp module,
501090a830SDenis Khalikov std::vector<char> &binaryShader);
511090a830SDenis Khalikov
52e5a85126SKazuaki Ishizaki /// Converts the given `launchOp` to vulkan launch call.
531090a830SDenis Khalikov void convertGpuLaunchFunc(gpu::LaunchFuncOp launchOp);
541090a830SDenis Khalikov
551090a830SDenis Khalikov /// Checks where the given type is supported by Vulkan runtime.
isSupportedType(Type type)561090a830SDenis Khalikov bool isSupportedType(Type type) {
571009177dSDenis Khalikov if (auto memRefType = type.dyn_cast_or_null<MemRefType>()) {
581009177dSDenis Khalikov auto elementType = memRefType.getElementType();
598f4ab8c7SDenis Khalikov return memRefType.hasRank() &&
601009177dSDenis Khalikov (memRefType.getRank() >= 1 && memRefType.getRank() <= 3) &&
611009177dSDenis Khalikov (elementType.isIntOrFloat());
621009177dSDenis Khalikov }
631090a830SDenis Khalikov return false;
641090a830SDenis Khalikov }
651090a830SDenis Khalikov
661090a830SDenis Khalikov /// Declares the vulkan launch function. Returns an error if the any type of
671090a830SDenis Khalikov /// operand is unsupported by Vulkan runtime.
681090a830SDenis Khalikov LogicalResult declareVulkanLaunchFunc(Location loc,
691090a830SDenis Khalikov gpu::LaunchFuncOp launchOp);
70a48f0a3cSDenis Khalikov
71a48f0a3cSDenis Khalikov private:
72a48f0a3cSDenis Khalikov /// The number of vulkan launch configuration operands, placed at the leading
73a48f0a3cSDenis Khalikov /// positions of the operand list.
74a48f0a3cSDenis Khalikov static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
751090a830SDenis Khalikov };
761090a830SDenis Khalikov
77be0a7e9fSMehdi Amini } // namespace
781090a830SDenis Khalikov
runOnOperation()79722f909fSRiver Riddle void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
801090a830SDenis Khalikov bool done = false;
81722f909fSRiver Riddle getOperation().walk([this, &done](gpu::LaunchFuncOp op) {
821090a830SDenis Khalikov if (done) {
831090a830SDenis Khalikov op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
841090a830SDenis Khalikov return signalPassFailure();
851090a830SDenis Khalikov }
861090a830SDenis Khalikov done = true;
871090a830SDenis Khalikov convertGpuLaunchFunc(op);
881090a830SDenis Khalikov });
891090a830SDenis Khalikov
901090a830SDenis Khalikov // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
911090a830SDenis Khalikov for (auto gpuModule :
92722f909fSRiver Riddle llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
931090a830SDenis Khalikov gpuModule.erase();
941090a830SDenis Khalikov
951090a830SDenis Khalikov for (auto spirvModule :
96722f909fSRiver Riddle llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>()))
971090a830SDenis Khalikov spirvModule.erase();
981090a830SDenis Khalikov }
991090a830SDenis Khalikov
declareVulkanLaunchFunc(Location loc,gpu::LaunchFuncOp launchOp)1001090a830SDenis Khalikov LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
1011090a830SDenis Khalikov Location loc, gpu::LaunchFuncOp launchOp) {
102973ddb7dSMehdi Amini auto builder = OpBuilder::atBlockEnd(getOperation().getBody());
1031090a830SDenis Khalikov
104a48f0a3cSDenis Khalikov // Workgroup size is written into the kernel. So to properly modelling
105a48f0a3cSDenis Khalikov // vulkan launch, we have to skip local workgroup size configuration here.
106a48f0a3cSDenis Khalikov SmallVector<Type, 8> gpuLaunchTypes(launchOp.getOperandTypes());
107a48f0a3cSDenis Khalikov // The first kVulkanLaunchNumConfigOperands of the gpu.launch_func op are the
108a48f0a3cSDenis Khalikov // same as the config operands for the vulkan launch call op.
109a48f0a3cSDenis Khalikov SmallVector<Type, 8> vulkanLaunchTypes(gpuLaunchTypes.begin(),
110a48f0a3cSDenis Khalikov gpuLaunchTypes.begin() +
111a48f0a3cSDenis Khalikov kVulkanLaunchNumConfigOperands);
112a48f0a3cSDenis Khalikov vulkanLaunchTypes.append(gpuLaunchTypes.begin() +
113a48f0a3cSDenis Khalikov gpu::LaunchOp::kNumConfigOperands,
114a48f0a3cSDenis Khalikov gpuLaunchTypes.end());
115a48f0a3cSDenis Khalikov
116a48f0a3cSDenis Khalikov // Check that all operands have supported types except those for the
117a48f0a3cSDenis Khalikov // launch configuration.
1188f4ab8c7SDenis Khalikov for (auto type :
119a48f0a3cSDenis Khalikov llvm::drop_begin(vulkanLaunchTypes, kVulkanLaunchNumConfigOperands)) {
1201090a830SDenis Khalikov if (!isSupportedType(type))
1211090a830SDenis Khalikov return launchOp.emitError() << type << " is unsupported to run on Vulkan";
1221090a830SDenis Khalikov }
1231090a830SDenis Khalikov
1241090a830SDenis Khalikov // Declare vulkan launch function.
1251b97cdf8SRiver Riddle auto funcType = builder.getFunctionType(vulkanLaunchTypes, {});
12658ceae95SRiver Riddle builder.create<func::FuncOp>(loc, kVulkanLaunch, funcType).setPrivate();
1271090a830SDenis Khalikov
1281090a830SDenis Khalikov return success();
1291090a830SDenis Khalikov }
1301090a830SDenis Khalikov
createBinaryShader(ModuleOp module,std::vector<char> & binaryShader)1311090a830SDenis Khalikov LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
1321090a830SDenis Khalikov ModuleOp module, std::vector<char> &binaryShader) {
1331090a830SDenis Khalikov bool done = false;
1341090a830SDenis Khalikov SmallVector<uint32_t, 0> binary;
1351090a830SDenis Khalikov for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
1361090a830SDenis Khalikov if (done)
1371090a830SDenis Khalikov return spirvModule.emitError("should only contain one 'spv.module' op");
1381090a830SDenis Khalikov done = true;
1391090a830SDenis Khalikov
1401090a830SDenis Khalikov if (failed(spirv::serialize(spirvModule, binary)))
1411090a830SDenis Khalikov return failure();
1421090a830SDenis Khalikov }
1431090a830SDenis Khalikov binaryShader.resize(binary.size() * sizeof(uint32_t));
1441090a830SDenis Khalikov std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
1451090a830SDenis Khalikov binaryShader.size());
1461090a830SDenis Khalikov return success();
1471090a830SDenis Khalikov }
1481090a830SDenis Khalikov
convertGpuLaunchFunc(gpu::LaunchFuncOp launchOp)1491090a830SDenis Khalikov void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
1501090a830SDenis Khalikov gpu::LaunchFuncOp launchOp) {
151722f909fSRiver Riddle ModuleOp module = getOperation();
1521090a830SDenis Khalikov OpBuilder builder(launchOp);
1531090a830SDenis Khalikov Location loc = launchOp.getLoc();
1541090a830SDenis Khalikov
1551090a830SDenis Khalikov // Serialize `spirv::Module` into binary form.
1561090a830SDenis Khalikov std::vector<char> binary;
1571090a830SDenis Khalikov if (failed(createBinaryShader(module, binary)))
1581090a830SDenis Khalikov return signalPassFailure();
1591090a830SDenis Khalikov
1601090a830SDenis Khalikov // Declare vulkan launch function.
1611090a830SDenis Khalikov if (failed(declareVulkanLaunchFunc(loc, launchOp)))
1621090a830SDenis Khalikov return signalPassFailure();
1631090a830SDenis Khalikov
164a48f0a3cSDenis Khalikov SmallVector<Value, 8> gpuLaunchOperands(launchOp.getOperands());
165a48f0a3cSDenis Khalikov SmallVector<Value, 8> vulkanLaunchOperands(
166a48f0a3cSDenis Khalikov gpuLaunchOperands.begin(),
167a48f0a3cSDenis Khalikov gpuLaunchOperands.begin() + kVulkanLaunchNumConfigOperands);
168a48f0a3cSDenis Khalikov vulkanLaunchOperands.append(gpuLaunchOperands.begin() +
169a48f0a3cSDenis Khalikov gpu::LaunchOp::kNumConfigOperands,
170a48f0a3cSDenis Khalikov gpuLaunchOperands.end());
171a48f0a3cSDenis Khalikov
1721090a830SDenis Khalikov // Create vulkan launch call op.
17323aa5a74SRiver Riddle auto vulkanLaunchCallOp = builder.create<func::CallOp>(
174faf1c224SChris Lattner loc, TypeRange{}, SymbolRefAttr::get(builder.getContext(), kVulkanLaunch),
175a48f0a3cSDenis Khalikov vulkanLaunchOperands);
1761090a830SDenis Khalikov
1771090a830SDenis Khalikov // Set SPIR-V binary shader data as an attribute.
1781ffc1aaaSChristian Sigg vulkanLaunchCallOp->setAttr(
1791090a830SDenis Khalikov kSPIRVBlobAttrName,
180faf1c224SChris Lattner builder.getStringAttr(StringRef(binary.data(), binary.size())));
1811090a830SDenis Khalikov
1821090a830SDenis Khalikov // Set entry point name as an attribute.
18341d4aa7dSChris Lattner vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
18441d4aa7dSChris Lattner launchOp.getKernelName());
1851090a830SDenis Khalikov
1861090a830SDenis Khalikov launchOp.erase();
1871090a830SDenis Khalikov }
1881090a830SDenis Khalikov
18980aca1eaSRiver Riddle std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertGpuLaunchFuncToVulkanLaunchFuncPass()1901090a830SDenis Khalikov mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass() {
1911090a830SDenis Khalikov return std::make_unique<ConvertGpuLaunchFuncToVulkanLaunchFunc>();
1921090a830SDenis Khalikov }
193