11090a830SDenis Khalikov //===- ConvertGPULaunchFuncToVulkanLaunchFunc.cpp - MLIR conversion pass --===//
21090a830SDenis Khalikov //
31090a830SDenis Khalikov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41090a830SDenis Khalikov // See https://llvm.org/LICENSE.txt for license information.
51090a830SDenis Khalikov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61090a830SDenis Khalikov //
71090a830SDenis Khalikov //===----------------------------------------------------------------------===//
81090a830SDenis Khalikov //
91090a830SDenis Khalikov // This file implements a pass to convert gpu launch function into a vulkan
101090a830SDenis Khalikov // launch function. Creates a SPIR-V binary shader from the `spirv::ModuleOp`
111090a830SDenis Khalikov // using `spirv::serialize` function, attaches binary data and entry point name
121090a830SDenis Khalikov // as an attributes to vulkan launch call op.
131090a830SDenis Khalikov //
141090a830SDenis Khalikov //===----------------------------------------------------------------------===//
151090a830SDenis Khalikov 
161090a830SDenis Khalikov #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
171090a830SDenis Khalikov #include "mlir/Dialect/GPU/GPUDialect.h"
181090a830SDenis Khalikov #include "mlir/Dialect/SPIRV/SPIRVOps.h"
191090a830SDenis Khalikov #include "mlir/Dialect/SPIRV/Serialization.h"
201090a830SDenis Khalikov #include "mlir/Dialect/StandardOps/IR/Ops.h"
211090a830SDenis Khalikov #include "mlir/IR/Attributes.h"
221090a830SDenis Khalikov #include "mlir/IR/Builders.h"
231090a830SDenis Khalikov #include "mlir/IR/Function.h"
241090a830SDenis Khalikov #include "mlir/IR/Module.h"
251090a830SDenis Khalikov #include "mlir/IR/StandardTypes.h"
261090a830SDenis Khalikov #include "mlir/Pass/Pass.h"
271090a830SDenis Khalikov 
281090a830SDenis Khalikov using namespace mlir;
291090a830SDenis Khalikov 
301090a830SDenis Khalikov static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
311090a830SDenis Khalikov static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
321090a830SDenis Khalikov static constexpr const char *kVulkanLaunch = "vulkanLaunch";
331090a830SDenis Khalikov 
341090a830SDenis Khalikov namespace {
351090a830SDenis Khalikov 
36*bfb2ce02SDenis Khalikov /// A pass to convert gpu launch op to vulkan launch call op, by creating a
37*bfb2ce02SDenis Khalikov /// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize`
38*bfb2ce02SDenis Khalikov /// function and attaching binary data and entry point name as an attributes to
39*bfb2ce02SDenis Khalikov /// created vulkan launch call op.
401090a830SDenis Khalikov class ConvertGpuLaunchFuncToVulkanLaunchFunc
411090a830SDenis Khalikov     : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
421090a830SDenis Khalikov public:
431090a830SDenis Khalikov   void runOnModule() override;
441090a830SDenis Khalikov 
451090a830SDenis Khalikov private:
461090a830SDenis Khalikov   /// Creates a SPIR-V binary shader from the given `module` using
471090a830SDenis Khalikov   /// `spirv::serialize` function.
481090a830SDenis Khalikov   LogicalResult createBinaryShader(ModuleOp module,
491090a830SDenis Khalikov                                    std::vector<char> &binaryShader);
501090a830SDenis Khalikov 
511090a830SDenis Khalikov   /// Converts the given `luanchOp` to vulkan launch call.
521090a830SDenis Khalikov   void convertGpuLaunchFunc(gpu::LaunchFuncOp launchOp);
531090a830SDenis Khalikov 
541090a830SDenis Khalikov   /// Checks where the given type is supported by Vulkan runtime.
551090a830SDenis Khalikov   bool isSupportedType(Type type) {
561090a830SDenis Khalikov     // TODO(denis0x0D): Handle other types.
571090a830SDenis Khalikov     if (auto memRefType = type.dyn_cast_or_null<MemRefType>())
581090a830SDenis Khalikov       return memRefType.hasRank() && memRefType.getRank() == 1;
591090a830SDenis Khalikov     return false;
601090a830SDenis Khalikov   }
611090a830SDenis Khalikov 
621090a830SDenis Khalikov   /// Declares the vulkan launch function. Returns an error if the any type of
631090a830SDenis Khalikov   /// operand is unsupported by Vulkan runtime.
641090a830SDenis Khalikov   LogicalResult declareVulkanLaunchFunc(Location loc,
651090a830SDenis Khalikov                                         gpu::LaunchFuncOp launchOp);
661090a830SDenis Khalikov 
671090a830SDenis Khalikov };
681090a830SDenis Khalikov 
691090a830SDenis Khalikov } // anonymous namespace
701090a830SDenis Khalikov 
711090a830SDenis Khalikov void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
721090a830SDenis Khalikov   bool done = false;
731090a830SDenis Khalikov   getModule().walk([this, &done](gpu::LaunchFuncOp op) {
741090a830SDenis Khalikov     if (done) {
751090a830SDenis Khalikov       op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
761090a830SDenis Khalikov       return signalPassFailure();
771090a830SDenis Khalikov     }
781090a830SDenis Khalikov     done = true;
791090a830SDenis Khalikov     convertGpuLaunchFunc(op);
801090a830SDenis Khalikov   });
811090a830SDenis Khalikov 
821090a830SDenis Khalikov   // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
831090a830SDenis Khalikov   for (auto gpuModule :
841090a830SDenis Khalikov        llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
851090a830SDenis Khalikov     gpuModule.erase();
861090a830SDenis Khalikov 
871090a830SDenis Khalikov   for (auto spirvModule :
881090a830SDenis Khalikov        llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
891090a830SDenis Khalikov     spirvModule.erase();
901090a830SDenis Khalikov }
911090a830SDenis Khalikov 
921090a830SDenis Khalikov LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
931090a830SDenis Khalikov     Location loc, gpu::LaunchFuncOp launchOp) {
941090a830SDenis Khalikov   OpBuilder builder(getModule().getBody()->getTerminator());
951090a830SDenis Khalikov   // TODO: Workgroup size is written into the kernel. So to properly modelling
961090a830SDenis Khalikov   // vulkan launch, we cannot have the local workgroup size configuration here.
971090a830SDenis Khalikov   SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
981090a830SDenis Khalikov 
991090a830SDenis Khalikov   // Check that all operands have supported types except those for the launch
1001090a830SDenis Khalikov   // configuration.
1011090a830SDenis Khalikov   for (auto type : llvm::drop_begin(vulkanLaunchTypes, 6)) {
1021090a830SDenis Khalikov     if (!isSupportedType(type))
1031090a830SDenis Khalikov       return launchOp.emitError() << type << " is unsupported to run on Vulkan";
1041090a830SDenis Khalikov   }
1051090a830SDenis Khalikov 
1061090a830SDenis Khalikov   // Declare vulkan launch function.
1071090a830SDenis Khalikov   builder.create<FuncOp>(
1081090a830SDenis Khalikov       loc, kVulkanLaunch,
1091090a830SDenis Khalikov       FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{}, loc->getContext()),
1101090a830SDenis Khalikov       ArrayRef<NamedAttribute>{});
1111090a830SDenis Khalikov 
1121090a830SDenis Khalikov   return success();
1131090a830SDenis Khalikov }
1141090a830SDenis Khalikov 
1151090a830SDenis Khalikov LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
1161090a830SDenis Khalikov     ModuleOp module, std::vector<char> &binaryShader) {
1171090a830SDenis Khalikov   bool done = false;
1181090a830SDenis Khalikov   SmallVector<uint32_t, 0> binary;
1191090a830SDenis Khalikov   for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
1201090a830SDenis Khalikov     if (done)
1211090a830SDenis Khalikov       return spirvModule.emitError("should only contain one 'spv.module' op");
1221090a830SDenis Khalikov     done = true;
1231090a830SDenis Khalikov 
1241090a830SDenis Khalikov     if (failed(spirv::serialize(spirvModule, binary)))
1251090a830SDenis Khalikov       return failure();
1261090a830SDenis Khalikov   }
1271090a830SDenis Khalikov   binaryShader.resize(binary.size() * sizeof(uint32_t));
1281090a830SDenis Khalikov   std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
1291090a830SDenis Khalikov               binaryShader.size());
1301090a830SDenis Khalikov   return success();
1311090a830SDenis Khalikov }
1321090a830SDenis Khalikov 
1331090a830SDenis Khalikov void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
1341090a830SDenis Khalikov     gpu::LaunchFuncOp launchOp) {
1351090a830SDenis Khalikov   ModuleOp module = getModule();
1361090a830SDenis Khalikov   OpBuilder builder(launchOp);
1371090a830SDenis Khalikov   Location loc = launchOp.getLoc();
1381090a830SDenis Khalikov 
1391090a830SDenis Khalikov   // Serialize `spirv::Module` into binary form.
1401090a830SDenis Khalikov   std::vector<char> binary;
1411090a830SDenis Khalikov   if (failed(createBinaryShader(module, binary)))
1421090a830SDenis Khalikov     return signalPassFailure();
1431090a830SDenis Khalikov 
1441090a830SDenis Khalikov   // Declare vulkan launch function.
1451090a830SDenis Khalikov   if (failed(declareVulkanLaunchFunc(loc, launchOp)))
1461090a830SDenis Khalikov     return signalPassFailure();
1471090a830SDenis Khalikov 
1481090a830SDenis Khalikov   // Create vulkan launch call op.
1491090a830SDenis Khalikov   auto vulkanLaunchCallOp = builder.create<CallOp>(
1501090a830SDenis Khalikov       loc, ArrayRef<Type>{}, builder.getSymbolRefAttr(kVulkanLaunch),
1511090a830SDenis Khalikov       launchOp.getOperands());
1521090a830SDenis Khalikov 
1531090a830SDenis Khalikov   // Set SPIR-V binary shader data as an attribute.
1541090a830SDenis Khalikov   vulkanLaunchCallOp.setAttr(
1551090a830SDenis Khalikov       kSPIRVBlobAttrName,
1561090a830SDenis Khalikov       StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
1571090a830SDenis Khalikov 
1581090a830SDenis Khalikov   // Set entry point name as an attribute.
1591090a830SDenis Khalikov   vulkanLaunchCallOp.setAttr(
1601090a830SDenis Khalikov       kSPIRVEntryPointAttrName,
1611090a830SDenis Khalikov       StringAttr::get(launchOp.kernel(), loc->getContext()));
1621090a830SDenis Khalikov 
1631090a830SDenis Khalikov   launchOp.erase();
1641090a830SDenis Khalikov }
1651090a830SDenis Khalikov 
1661090a830SDenis Khalikov std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
1671090a830SDenis Khalikov mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass() {
1681090a830SDenis Khalikov   return std::make_unique<ConvertGpuLaunchFuncToVulkanLaunchFunc>();
1691090a830SDenis Khalikov }
1701090a830SDenis Khalikov 
1711090a830SDenis Khalikov static PassRegistration<ConvertGpuLaunchFuncToVulkanLaunchFunc>
1721090a830SDenis Khalikov     pass("convert-gpu-launch-to-vulkan-launch",
1731090a830SDenis Khalikov          "Convert gpu.launch_func to vulkanLaunch external call");
174