1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
16 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
17 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
20 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
21 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
22 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 
35 using namespace mlir;
36 
37 static constexpr const char kSPIRVModule[] = "__spv__";
38 
39 //===----------------------------------------------------------------------===//
40 // Utility functions
41 //===----------------------------------------------------------------------===//
42 
43 /// Returns the string name of the `DescriptorSet` decoration.
descriptorSetName()44 static std::string descriptorSetName() {
45   return llvm::convertToSnakeFromCamelCase(
46       stringifyDecoration(spirv::Decoration::DescriptorSet));
47 }
48 
49 /// Returns the string name of the `Binding` decoration.
bindingName()50 static std::string bindingName() {
51   return llvm::convertToSnakeFromCamelCase(
52       stringifyDecoration(spirv::Decoration::Binding));
53 }
54 
55 /// Calculates the index of the kernel's operand that is represented by the
56 /// given global variable with the `bind` attribute. We assume that the index of
57 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
58 ///   i -> (0, i)
59 /// which is implemented under `LowerABIAttributesPass`.
calculateGlobalIndex(spirv::GlobalVariableOp op)60 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
61   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
62   return binding.getInt();
63 }
64 
65 /// Copies the given number of bytes from src to dst pointers.
copy(Location loc,Value dst,Value src,Value size,OpBuilder & builder)66 static void copy(Location loc, Value dst, Value src, Value size,
67                  OpBuilder &builder) {
68   MLIRContext *context = builder.getContext();
69   auto llvmI1Type = IntegerType::get(context, 1);
70   Value isVolatile = builder.create<LLVM::ConstantOp>(
71       loc, llvmI1Type, builder.getBoolAttr(false));
72   builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
73 }
74 
75 /// Encodes the binding and descriptor set numbers into a new symbolic name.
76 /// The name is specified by
77 ///   {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
78 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
79 /// binding numbers.
80 static std::string
createGlobalVariableWithBindName(spirv::GlobalVariableOp op,StringRef kernelModuleName)81 createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
82                                  StringRef kernelModuleName) {
83   IntegerAttr descriptorSet =
84       op->getAttrOfType<IntegerAttr>(descriptorSetName());
85   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
86   return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
87                        kernelModuleName.str(), op.sym_name().str(),
88                        std::to_string(descriptorSet.getInt()),
89                        std::to_string(binding.getInt()));
90 }
91 
92 /// Returns true if the given global variable has both a descriptor set number
93 /// and a binding number.
hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)94 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
95   IntegerAttr descriptorSet =
96       op->getAttrOfType<IntegerAttr>(descriptorSetName());
97   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
98   return descriptorSet && binding;
99 }
100 
101 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
102 /// arguments from the given SPIR-V module. We assume that the module contains a
103 /// single entry point function. Hence, all `spv.GlobalVariable`s with a bind
104 /// attribute are kernel arguments.
getKernelGlobalVariables(spirv::ModuleOp module,DenseMap<uint32_t,spirv::GlobalVariableOp> & globalVariableMap)105 static LogicalResult getKernelGlobalVariables(
106     spirv::ModuleOp module,
107     DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
108   auto entryPoints = module.getOps<spirv::EntryPointOp>();
109   if (!llvm::hasSingleElement(entryPoints)) {
110     return module.emitError(
111         "The module must contain exactly one entry point function");
112   }
113   auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
114   for (auto globalOp : globalVariables) {
115     if (hasDescriptorSetAndBinding(globalOp))
116       globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
117   }
118   return success();
119 }
120 
121 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
122 /// function.
encodeKernelName(spirv::ModuleOp module)123 static LogicalResult encodeKernelName(spirv::ModuleOp module) {
124   StringRef spvModuleName = *module.sym_name();
125   // We already know that the module contains exactly one entry point function
126   // based on `getKernelGlobalVariables()` call. Update this function's name
127   // to:
128   //   {spv_module_name}_{function_name}
129   auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
130   StringRef funcName = entryPoint.fn();
131   auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
132   StringAttr newFuncName =
133       StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
134   if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
135     return failure();
136   SymbolTable::setSymbolName(funcOp, newFuncName);
137   return success();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // Conversion patterns
142 //===----------------------------------------------------------------------===//
143 
144 namespace {
145 
146 /// Structure to group information about the variables being copied.
147 struct CopyInfo {
148   Value dst;
149   Value src;
150   Value size;
151 };
152 
153 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
154 /// copy the data to the global variable (emulating device side), call the
155 /// kernel as a normal void LLVM function, and copy the data back (emulating the
156 /// host side).
157 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
158   using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
159 
160   LogicalResult
matchAndRewrite(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const161   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
162                   ConversionPatternRewriter &rewriter) const override {
163     auto *op = launchOp.getOperation();
164     MLIRContext *context = rewriter.getContext();
165     auto module = launchOp->getParentOfType<ModuleOp>();
166 
167     // Get the SPIR-V module that represents the gpu kernel module. The module
168     // is named:
169     //   __spv__{kernel_module_name}
170     // based on GPU to SPIR-V conversion.
171     StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
172     std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
173     auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
174         StringAttr::get(context, spvModuleName));
175     if (!spvModule) {
176       return launchOp.emitOpError("SPIR-V kernel module '")
177              << spvModuleName << "' is not found";
178     }
179 
180     // Declare kernel function in the main module so that it later can be linked
181     // with its definition from the kernel module. We know that the kernel
182     // function would have no arguments and the data is passed via global
183     // variables. The name of the kernel will be
184     //   {spv_module_name}_{kernel_function_name}
185     // to avoid symbolic name conflicts.
186     StringRef kernelFuncName = launchOp.getKernelName().getValue();
187     std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
188     auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
189         StringAttr::get(context, newKernelFuncName));
190     if (!kernelFunc) {
191       OpBuilder::InsertionGuard guard(rewriter);
192       rewriter.setInsertionPointToStart(module.getBody());
193       kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
194           rewriter.getUnknownLoc(), newKernelFuncName,
195           LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
196                                       ArrayRef<Type>()));
197       rewriter.setInsertionPoint(launchOp);
198     }
199 
200     // Get all global variables associated with the kernel operands.
201     DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
202     if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
203       return failure();
204 
205     // Traverse kernel operands that were converted to MemRefDescriptors. For
206     // each operand, create a global variable and copy data from operand to it.
207     Location loc = launchOp.getLoc();
208     SmallVector<CopyInfo, 4> copyInfo;
209     auto numKernelOperands = launchOp.getNumKernelOperands();
210     auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
211     for (const auto &operand : llvm::enumerate(kernelOperands)) {
212       // Check if the kernel's operand is a ranked memref.
213       auto memRefType = launchOp.getKernelOperand(operand.index())
214                             .getType()
215                             .dyn_cast<MemRefType>();
216       if (!memRefType)
217         return failure();
218 
219       // Calculate the size of the memref and get the pointer to the allocated
220       // buffer.
221       SmallVector<Value, 4> sizes;
222       SmallVector<Value, 4> strides;
223       Value sizeBytes;
224       getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
225                                sizeBytes);
226       MemRefDescriptor descriptor(operand.value());
227       Value src = descriptor.allocatedPtr(rewriter, loc);
228 
229       // Get the global variable in the SPIR-V module that is associated with
230       // the kernel operand. Construct its new name and create a corresponding
231       // LLVM dialect global variable.
232       spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
233       auto pointeeType =
234           spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
235       auto dstGlobalType = typeConverter->convertType(pointeeType);
236       if (!dstGlobalType)
237         return failure();
238       std::string name =
239           createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
240       // Check if this variable has already been created.
241       auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
242       if (!dstGlobal) {
243         OpBuilder::InsertionGuard guard(rewriter);
244         rewriter.setInsertionPointToStart(module.getBody());
245         dstGlobal = rewriter.create<LLVM::GlobalOp>(
246             loc, dstGlobalType,
247             /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
248             /*alignment=*/0);
249         rewriter.setInsertionPoint(launchOp);
250       }
251 
252       // Copy the data from src operand pointer to dst global variable. Save
253       // src, dst and size so that we can copy data back after emulating the
254       // kernel call.
255       Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
256       copy(loc, dst, src, sizeBytes, rewriter);
257 
258       CopyInfo info;
259       info.dst = dst;
260       info.src = src;
261       info.size = sizeBytes;
262       copyInfo.push_back(info);
263     }
264     // Create a call to the kernel and copy the data back.
265     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
266                                               ArrayRef<Value>());
267     for (CopyInfo info : copyInfo)
268       copy(loc, info.src, info.dst, info.size, rewriter);
269     return success();
270   }
271 };
272 
273 class LowerHostCodeToLLVM
274     : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
275 public:
runOnOperation()276   void runOnOperation() override {
277     ModuleOp module = getOperation();
278 
279     // Erase the GPU module.
280     for (auto gpuModule :
281          llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
282       gpuModule.erase();
283 
284     // Request C wrapper emission.
285     for (auto func : module.getOps<func::FuncOp>()) {
286       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
287                     UnitAttr::get(&getContext()));
288     }
289 
290     // Specify options to lower to LLVM and pull in the conversion patterns.
291     LowerToLLVMOptions options(module.getContext());
292     auto *context = module.getContext();
293     RewritePatternSet patterns(context);
294     LLVMTypeConverter typeConverter(context, options);
295     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
296                                                             patterns);
297     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
298     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
299     patterns.add<GPULaunchLowering>(typeConverter);
300 
301     // Pull in SPIR-V type conversion patterns to convert SPIR-V global
302     // variable's type to LLVM dialect type.
303     populateSPIRVToLLVMTypeConversion(typeConverter);
304 
305     ConversionTarget target(*context);
306     target.addLegalDialect<LLVM::LLVMDialect>();
307     if (failed(applyPartialConversion(module, target, std::move(patterns))))
308       signalPassFailure();
309 
310     // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
311     // conflicts.
312     for (auto spvModule : module.getOps<spirv::ModuleOp>())
313       (void)encodeKernelName(spvModule);
314   }
315 };
316 } // namespace
317 
318 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLowerHostCodeToLLVMPass()319 mlir::createLowerHostCodeToLLVMPass() {
320   return std::make_unique<LowerHostCodeToLLVM>();
321 }
322