1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
16 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
17 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
18 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
20 #include "mlir/Dialect/GPU/GPUDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/Support/FormatVariadic.h"
31 
32 using namespace mlir;
33 
34 static constexpr const char kSPIRVModule[] = "__spv__";
35 
36 //===----------------------------------------------------------------------===//
37 // Utility functions
38 //===----------------------------------------------------------------------===//
39 
40 /// Returns the string name of the `DescriptorSet` decoration.
41 static std::string descriptorSetName() {
42   return llvm::convertToSnakeFromCamelCase(
43       stringifyDecoration(spirv::Decoration::DescriptorSet));
44 }
45 
46 /// Returns the string name of the `Binding` decoration.
47 static std::string bindingName() {
48   return llvm::convertToSnakeFromCamelCase(
49       stringifyDecoration(spirv::Decoration::Binding));
50 }
51 
52 /// Calculates the index of the kernel's operand that is represented by the
53 /// given global variable with the `bind` attribute. We assume that the index of
54 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
55 ///   i -> (0, i)
56 /// which is implemented under `LowerABIAttributesPass`.
57 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
58   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
59   return binding.getInt();
60 }
61 
62 /// Copies the given number of bytes from src to dst pointers.
63 static void copy(Location loc, Value dst, Value src, Value size,
64                  OpBuilder &builder) {
65   MLIRContext *context = builder.getContext();
66   auto llvmI1Type = IntegerType::get(context, 1);
67   Value isVolatile = builder.create<LLVM::ConstantOp>(
68       loc, llvmI1Type, builder.getBoolAttr(false));
69   builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
70 }
71 
72 /// Encodes the binding and descriptor set numbers into a new symbolic name.
73 /// The name is specified by
74 ///   {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
75 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
76 /// binding numbers.
77 static std::string
78 createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
79                                  StringRef kernelModuleName) {
80   IntegerAttr descriptorSet =
81       op->getAttrOfType<IntegerAttr>(descriptorSetName());
82   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
83   return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
84                        kernelModuleName.str(), op.sym_name().str(),
85                        std::to_string(descriptorSet.getInt()),
86                        std::to_string(binding.getInt()));
87 }
88 
89 /// Returns true if the given global variable has both a descriptor set number
90 /// and a binding number.
91 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
92   IntegerAttr descriptorSet =
93       op->getAttrOfType<IntegerAttr>(descriptorSetName());
94   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
95   return descriptorSet && binding;
96 }
97 
98 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
99 /// arguments from the given SPIR-V module. We assume that the module contains a
100 /// single entry point function. Hence, all `spv.GlobalVariable`s with a bind
101 /// attribute are kernel arguments.
102 static LogicalResult getKernelGlobalVariables(
103     spirv::ModuleOp module,
104     DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
105   auto entryPoints = module.getOps<spirv::EntryPointOp>();
106   if (!llvm::hasSingleElement(entryPoints)) {
107     return module.emitError(
108         "The module must contain exactly one entry point function");
109   }
110   auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
111   for (auto globalOp : globalVariables) {
112     if (hasDescriptorSetAndBinding(globalOp))
113       globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
114   }
115   return success();
116 }
117 
118 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
119 /// function.
120 static LogicalResult encodeKernelName(spirv::ModuleOp module) {
121   StringRef spvModuleName = module.sym_name().getValue();
122   // We already know that the module contains exactly one entry point function
123   // based on `getKernelGlobalVariables()` call. Update this function's name
124   // to:
125   //   {spv_module_name}_{function_name}
126   auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
127   StringRef funcName = entryPoint.fn();
128   auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName);
129   std::string newFuncName = spvModuleName.str() + "_" + funcName.str();
130   if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
131     return failure();
132   SymbolTable::setSymbolName(funcOp, newFuncName);
133   return success();
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Conversion patterns
138 //===----------------------------------------------------------------------===//
139 
140 namespace {
141 
142 /// Structure to group information about the variables being copied.
143 struct CopyInfo {
144   Value dst;
145   Value src;
146   Value size;
147 };
148 
149 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
150 /// copy the data to the global variable (emulating device side), call the
151 /// kernel as a normal void LLVM function, and copy the data back (emulating the
152 /// host side).
153 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
154   using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
155 
156   LogicalResult
157   matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
158                   ConversionPatternRewriter &rewriter) const override {
159     auto *op = launchOp.getOperation();
160     MLIRContext *context = rewriter.getContext();
161     auto module = launchOp->getParentOfType<ModuleOp>();
162 
163     // Get the SPIR-V module that represents the gpu kernel module. The module
164     // is named:
165     //   __spv__{kernel_module_name}
166     // based on GPU to SPIR-V conversion.
167     StringRef kernelModuleName = launchOp.getKernelModuleName();
168     std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
169     auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName);
170     if (!spvModule) {
171       return launchOp.emitOpError("SPIR-V kernel module '")
172              << spvModuleName << "' is not found";
173     }
174 
175     // Declare kernel function in the main module so that it later can be linked
176     // with its definition from the kernel module. We know that the kernel
177     // function would have no arguments and the data is passed via global
178     // variables. The name of the kernel will be
179     //   {spv_module_name}_{kernel_function_name}
180     // to avoid symbolic name conflicts.
181     StringRef kernelFuncName = launchOp.getKernelName();
182     std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
183     auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName);
184     if (!kernelFunc) {
185       OpBuilder::InsertionGuard guard(rewriter);
186       rewriter.setInsertionPointToStart(module.getBody());
187       kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
188           rewriter.getUnknownLoc(), newKernelFuncName,
189           LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
190                                       ArrayRef<Type>()));
191       rewriter.setInsertionPoint(launchOp);
192     }
193 
194     // Get all global variables associated with the kernel operands.
195     DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
196     if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
197       return failure();
198 
199     // Traverse kernel operands that were converted to MemRefDescriptors. For
200     // each operand, create a global variable and copy data from operand to it.
201     Location loc = launchOp.getLoc();
202     SmallVector<CopyInfo, 4> copyInfo;
203     auto numKernelOperands = launchOp.getNumKernelOperands();
204     auto kernelOperands = operands.take_back(numKernelOperands);
205     for (auto operand : llvm::enumerate(kernelOperands)) {
206       // Check if the kernel's operand is a ranked memref.
207       auto memRefType = launchOp.getKernelOperand(operand.index())
208                             .getType()
209                             .dyn_cast<MemRefType>();
210       if (!memRefType)
211         return failure();
212 
213       // Calculate the size of the memref and get the pointer to the allocated
214       // buffer.
215       SmallVector<Value, 4> sizes;
216       SmallVector<Value, 4> strides;
217       Value sizeBytes;
218       getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
219                                sizeBytes);
220       MemRefDescriptor descriptor(operand.value());
221       Value src = descriptor.allocatedPtr(rewriter, loc);
222 
223       // Get the global variable in the SPIR-V module that is associated with
224       // the kernel operand. Construct its new name and create a corresponding
225       // LLVM dialect global variable.
226       spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
227       auto pointeeType =
228           spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
229       auto dstGlobalType = typeConverter->convertType(pointeeType);
230       if (!dstGlobalType)
231         return failure();
232       std::string name =
233           createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
234       // Check if this variable has already been created.
235       auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
236       if (!dstGlobal) {
237         OpBuilder::InsertionGuard guard(rewriter);
238         rewriter.setInsertionPointToStart(module.getBody());
239         dstGlobal = rewriter.create<LLVM::GlobalOp>(
240             loc, dstGlobalType,
241             /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
242             /*alignment=*/0);
243         rewriter.setInsertionPoint(launchOp);
244       }
245 
246       // Copy the data from src operand pointer to dst global variable. Save
247       // src, dst and size so that we can copy data back after emulating the
248       // kernel call.
249       Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
250       copy(loc, dst, src, sizeBytes, rewriter);
251 
252       CopyInfo info;
253       info.dst = dst;
254       info.src = src;
255       info.size = sizeBytes;
256       copyInfo.push_back(info);
257     }
258     // Create a call to the kernel and copy the data back.
259     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
260                                               ArrayRef<Value>());
261     for (CopyInfo info : copyInfo)
262       copy(loc, info.src, info.dst, info.size, rewriter);
263     return success();
264   }
265 };
266 
267 class LowerHostCodeToLLVM
268     : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
269 public:
270   void runOnOperation() override {
271     ModuleOp module = getOperation();
272 
273     // Erase the GPU module.
274     for (auto gpuModule :
275          llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
276       gpuModule.erase();
277 
278     // Specify options to lower Standard to LLVM and pull in the conversion
279     // patterns.
280     LowerToLLVMOptions options(module.getContext());
281     options.emitCWrappers = true;
282     auto *context = module.getContext();
283     RewritePatternSet patterns(context);
284     LLVMTypeConverter typeConverter(context, options);
285     populateStdToLLVMConversionPatterns(typeConverter, patterns);
286     patterns.add<GPULaunchLowering>(typeConverter);
287 
288     // Pull in SPIR-V type conversion patterns to convert SPIR-V global
289     // variable's type to LLVM dialect type.
290     populateSPIRVToLLVMTypeConversion(typeConverter);
291 
292     ConversionTarget target(*context);
293     target.addLegalDialect<LLVM::LLVMDialect>();
294     if (failed(applyPartialConversion(module, target, std::move(patterns))))
295       signalPassFailure();
296 
297     // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
298     // conflicts.
299     for (auto spvModule : module.getOps<spirv::ModuleOp>())
300       (void)encodeKernelName(spvModule);
301   }
302 };
303 } // namespace
304 
305 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
306 mlir::createLowerHostCodeToLLVMPass() {
307   return std::make_unique<LowerHostCodeToLLVM>();
308 }
309