1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
16 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
17 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
20 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
21 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
22 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
23 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
24 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
25 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "mlir/IR/SymbolTable.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/Support/FormatVariadic.h"
33 
34 using namespace mlir;
35 
36 static constexpr const char kSPIRVModule[] = "__spv__";
37 
38 //===----------------------------------------------------------------------===//
39 // Utility functions
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns the string name of the `DescriptorSet` decoration.
43 static std::string descriptorSetName() {
44   return llvm::convertToSnakeFromCamelCase(
45       stringifyDecoration(spirv::Decoration::DescriptorSet));
46 }
47 
48 /// Returns the string name of the `Binding` decoration.
49 static std::string bindingName() {
50   return llvm::convertToSnakeFromCamelCase(
51       stringifyDecoration(spirv::Decoration::Binding));
52 }
53 
54 /// Calculates the index of the kernel's operand that is represented by the
55 /// given global variable with the `bind` attribute. We assume that the index of
56 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
57 ///   i -> (0, i)
58 /// which is implemented under `LowerABIAttributesPass`.
59 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
60   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
61   return binding.getInt();
62 }
63 
64 /// Copies the given number of bytes from src to dst pointers.
65 static void copy(Location loc, Value dst, Value src, Value size,
66                  OpBuilder &builder) {
67   MLIRContext *context = builder.getContext();
68   auto llvmI1Type = IntegerType::get(context, 1);
69   Value isVolatile = builder.create<LLVM::ConstantOp>(
70       loc, llvmI1Type, builder.getBoolAttr(false));
71   builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
72 }
73 
74 /// Encodes the binding and descriptor set numbers into a new symbolic name.
75 /// The name is specified by
76 ///   {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
77 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
78 /// binding numbers.
79 static std::string
80 createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
81                                  StringRef kernelModuleName) {
82   IntegerAttr descriptorSet =
83       op->getAttrOfType<IntegerAttr>(descriptorSetName());
84   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
85   return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
86                        kernelModuleName.str(), op.sym_name().str(),
87                        std::to_string(descriptorSet.getInt()),
88                        std::to_string(binding.getInt()));
89 }
90 
91 /// Returns true if the given global variable has both a descriptor set number
92 /// and a binding number.
93 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
94   IntegerAttr descriptorSet =
95       op->getAttrOfType<IntegerAttr>(descriptorSetName());
96   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
97   return descriptorSet && binding;
98 }
99 
100 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
101 /// arguments from the given SPIR-V module. We assume that the module contains a
102 /// single entry point function. Hence, all `spv.GlobalVariable`s with a bind
103 /// attribute are kernel arguments.
104 static LogicalResult getKernelGlobalVariables(
105     spirv::ModuleOp module,
106     DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
107   auto entryPoints = module.getOps<spirv::EntryPointOp>();
108   if (!llvm::hasSingleElement(entryPoints)) {
109     return module.emitError(
110         "The module must contain exactly one entry point function");
111   }
112   auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
113   for (auto globalOp : globalVariables) {
114     if (hasDescriptorSetAndBinding(globalOp))
115       globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
116   }
117   return success();
118 }
119 
120 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
121 /// function.
122 static LogicalResult encodeKernelName(spirv::ModuleOp module) {
123   StringRef spvModuleName = module.sym_name().getValue();
124   // We already know that the module contains exactly one entry point function
125   // based on `getKernelGlobalVariables()` call. Update this function's name
126   // to:
127   //   {spv_module_name}_{function_name}
128   auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
129   StringRef funcName = entryPoint.fn();
130   auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
131   StringAttr newFuncName =
132       StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
133   if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
134     return failure();
135   SymbolTable::setSymbolName(funcOp, newFuncName);
136   return success();
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Conversion patterns
141 //===----------------------------------------------------------------------===//
142 
143 namespace {
144 
145 /// Structure to group information about the variables being copied.
146 struct CopyInfo {
147   Value dst;
148   Value src;
149   Value size;
150 };
151 
152 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
153 /// copy the data to the global variable (emulating device side), call the
154 /// kernel as a normal void LLVM function, and copy the data back (emulating the
155 /// host side).
156 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
157   using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
158 
159   LogicalResult
160   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
161                   ConversionPatternRewriter &rewriter) const override {
162     auto *op = launchOp.getOperation();
163     MLIRContext *context = rewriter.getContext();
164     auto module = launchOp->getParentOfType<ModuleOp>();
165 
166     // Get the SPIR-V module that represents the gpu kernel module. The module
167     // is named:
168     //   __spv__{kernel_module_name}
169     // based on GPU to SPIR-V conversion.
170     StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
171     std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
172     auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
173         StringAttr::get(context, spvModuleName));
174     if (!spvModule) {
175       return launchOp.emitOpError("SPIR-V kernel module '")
176              << spvModuleName << "' is not found";
177     }
178 
179     // Declare kernel function in the main module so that it later can be linked
180     // with its definition from the kernel module. We know that the kernel
181     // function would have no arguments and the data is passed via global
182     // variables. The name of the kernel will be
183     //   {spv_module_name}_{kernel_function_name}
184     // to avoid symbolic name conflicts.
185     StringRef kernelFuncName = launchOp.getKernelName().getValue();
186     std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
187     auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
188         StringAttr::get(context, newKernelFuncName));
189     if (!kernelFunc) {
190       OpBuilder::InsertionGuard guard(rewriter);
191       rewriter.setInsertionPointToStart(module.getBody());
192       kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
193           rewriter.getUnknownLoc(), newKernelFuncName,
194           LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
195                                       ArrayRef<Type>()));
196       rewriter.setInsertionPoint(launchOp);
197     }
198 
199     // Get all global variables associated with the kernel operands.
200     DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
201     if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
202       return failure();
203 
204     // Traverse kernel operands that were converted to MemRefDescriptors. For
205     // each operand, create a global variable and copy data from operand to it.
206     Location loc = launchOp.getLoc();
207     SmallVector<CopyInfo, 4> copyInfo;
208     auto numKernelOperands = launchOp.getNumKernelOperands();
209     auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
210     for (const auto &operand : llvm::enumerate(kernelOperands)) {
211       // Check if the kernel's operand is a ranked memref.
212       auto memRefType = launchOp.getKernelOperand(operand.index())
213                             .getType()
214                             .dyn_cast<MemRefType>();
215       if (!memRefType)
216         return failure();
217 
218       // Calculate the size of the memref and get the pointer to the allocated
219       // buffer.
220       SmallVector<Value, 4> sizes;
221       SmallVector<Value, 4> strides;
222       Value sizeBytes;
223       getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
224                                sizeBytes);
225       MemRefDescriptor descriptor(operand.value());
226       Value src = descriptor.allocatedPtr(rewriter, loc);
227 
228       // Get the global variable in the SPIR-V module that is associated with
229       // the kernel operand. Construct its new name and create a corresponding
230       // LLVM dialect global variable.
231       spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
232       auto pointeeType =
233           spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
234       auto dstGlobalType = typeConverter->convertType(pointeeType);
235       if (!dstGlobalType)
236         return failure();
237       std::string name =
238           createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
239       // Check if this variable has already been created.
240       auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
241       if (!dstGlobal) {
242         OpBuilder::InsertionGuard guard(rewriter);
243         rewriter.setInsertionPointToStart(module.getBody());
244         dstGlobal = rewriter.create<LLVM::GlobalOp>(
245             loc, dstGlobalType,
246             /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
247             /*alignment=*/0);
248         rewriter.setInsertionPoint(launchOp);
249       }
250 
251       // Copy the data from src operand pointer to dst global variable. Save
252       // src, dst and size so that we can copy data back after emulating the
253       // kernel call.
254       Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
255       copy(loc, dst, src, sizeBytes, rewriter);
256 
257       CopyInfo info;
258       info.dst = dst;
259       info.src = src;
260       info.size = sizeBytes;
261       copyInfo.push_back(info);
262     }
263     // Create a call to the kernel and copy the data back.
264     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
265                                               ArrayRef<Value>());
266     for (CopyInfo info : copyInfo)
267       copy(loc, info.src, info.dst, info.size, rewriter);
268     return success();
269   }
270 };
271 
272 class LowerHostCodeToLLVM
273     : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
274 public:
275   void runOnOperation() override {
276     ModuleOp module = getOperation();
277 
278     // Erase the GPU module.
279     for (auto gpuModule :
280          llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
281       gpuModule.erase();
282 
283     // Specify options to lower to LLVM and pull in the conversion patterns.
284     LowerToLLVMOptions options(module.getContext());
285     options.emitCWrappers = true;
286     auto *context = module.getContext();
287     RewritePatternSet patterns(context);
288     LLVMTypeConverter typeConverter(context, options);
289     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
290                                                             patterns);
291     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
292     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
293     patterns.add<GPULaunchLowering>(typeConverter);
294 
295     // Pull in SPIR-V type conversion patterns to convert SPIR-V global
296     // variable's type to LLVM dialect type.
297     populateSPIRVToLLVMTypeConversion(typeConverter);
298 
299     ConversionTarget target(*context);
300     target.addLegalDialect<LLVM::LLVMDialect>();
301     if (failed(applyPartialConversion(module, target, std::move(patterns))))
302       signalPassFailure();
303 
304     // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
305     // conflicts.
306     for (auto spvModule : module.getOps<spirv::ModuleOp>())
307       (void)encodeKernelName(spvModule);
308   }
309 };
310 } // namespace
311 
312 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
313 mlir::createLowerHostCodeToLLVMPass() {
314   return std::make_unique<LowerHostCodeToLLVM>();
315 }
316