1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements passes to convert `gpu.launch_func` op into a sequence 10 // of LLVM calls that emulate the host and device sides. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 16 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 17 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 18 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" 19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 20 #include "mlir/Dialect/GPU/GPUDialect.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/IR/BuiltinOps.h" 25 #include "mlir/IR/SymbolTable.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 28 #include "llvm/ADT/DenseMap.h" 29 #include "llvm/ADT/StringExtras.h" 30 #include "llvm/Support/FormatVariadic.h" 31 32 using namespace mlir; 33 34 static constexpr const char kSPIRVModule[] = "__spv__"; 35 36 //===----------------------------------------------------------------------===// 37 // Utility functions 38 //===----------------------------------------------------------------------===// 39 40 /// Returns the string name of the `DescriptorSet` decoration. 41 static std::string descriptorSetName() { 42 return llvm::convertToSnakeFromCamelCase( 43 stringifyDecoration(spirv::Decoration::DescriptorSet)); 44 } 45 46 /// Returns the string name of the `Binding` decoration. 47 static std::string bindingName() { 48 return llvm::convertToSnakeFromCamelCase( 49 stringifyDecoration(spirv::Decoration::Binding)); 50 } 51 52 /// Calculates the index of the kernel's operand that is represented by the 53 /// given global variable with the `bind` attribute. We assume that the index of 54 /// each kernel's operand is mapped to (descriptorSet, binding) by the map: 55 /// i -> (0, i) 56 /// which is implemented under `LowerABIAttributesPass`. 57 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { 58 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 59 return binding.getInt(); 60 } 61 62 /// Copies the given number of bytes from src to dst pointers. 63 static void copy(Location loc, Value dst, Value src, Value size, 64 OpBuilder &builder) { 65 MLIRContext *context = builder.getContext(); 66 auto llvmI1Type = IntegerType::get(context, 1); 67 Value isVolatile = builder.create<LLVM::ConstantOp>( 68 loc, llvmI1Type, builder.getBoolAttr(false)); 69 builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile); 70 } 71 72 /// Encodes the binding and descriptor set numbers into a new symbolic name. 73 /// The name is specified by 74 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b} 75 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and 76 /// binding numbers. 77 static std::string 78 createGlobalVariableWithBindName(spirv::GlobalVariableOp op, 79 StringRef kernelModuleName) { 80 IntegerAttr descriptorSet = 81 op->getAttrOfType<IntegerAttr>(descriptorSetName()); 82 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 83 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", 84 kernelModuleName.str(), op.sym_name().str(), 85 std::to_string(descriptorSet.getInt()), 86 std::to_string(binding.getInt())); 87 } 88 89 /// Returns true if the given global variable has both a descriptor set number 90 /// and a binding number. 91 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { 92 IntegerAttr descriptorSet = 93 op->getAttrOfType<IntegerAttr>(descriptorSetName()); 94 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 95 return descriptorSet && binding; 96 } 97 98 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel 99 /// arguments from the given SPIR-V module. We assume that the module contains a 100 /// single entry point function. Hence, all `spv.GlobalVariable`s with a bind 101 /// attribute are kernel arguments. 102 static LogicalResult getKernelGlobalVariables( 103 spirv::ModuleOp module, 104 DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) { 105 auto entryPoints = module.getOps<spirv::EntryPointOp>(); 106 if (!llvm::hasSingleElement(entryPoints)) { 107 return module.emitError( 108 "The module must contain exactly one entry point function"); 109 } 110 auto globalVariables = module.getOps<spirv::GlobalVariableOp>(); 111 for (auto globalOp : globalVariables) { 112 if (hasDescriptorSetAndBinding(globalOp)) 113 globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp; 114 } 115 return success(); 116 } 117 118 /// Encodes the SPIR-V module's symbolic name into the name of the entry point 119 /// function. 120 static LogicalResult encodeKernelName(spirv::ModuleOp module) { 121 StringRef spvModuleName = module.sym_name().getValue(); 122 // We already know that the module contains exactly one entry point function 123 // based on `getKernelGlobalVariables()` call. Update this function's name 124 // to: 125 // {spv_module_name}_{function_name} 126 auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin(); 127 StringRef funcName = entryPoint.fn(); 128 auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName); 129 std::string newFuncName = spvModuleName.str() + "_" + funcName.str(); 130 if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) 131 return failure(); 132 SymbolTable::setSymbolName(funcOp, newFuncName); 133 return success(); 134 } 135 136 //===----------------------------------------------------------------------===// 137 // Conversion patterns 138 //===----------------------------------------------------------------------===// 139 140 namespace { 141 142 /// Structure to group information about the variables being copied. 143 struct CopyInfo { 144 Value dst; 145 Value src; 146 Value size; 147 }; 148 149 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we 150 /// copy the data to the global variable (emulating device side), call the 151 /// kernel as a normal void LLVM function, and copy the data back (emulating the 152 /// host side). 153 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { 154 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern; 155 156 LogicalResult 157 matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands, 158 ConversionPatternRewriter &rewriter) const override { 159 auto *op = launchOp.getOperation(); 160 MLIRContext *context = rewriter.getContext(); 161 auto module = launchOp->getParentOfType<ModuleOp>(); 162 163 // Get the SPIR-V module that represents the gpu kernel module. The module 164 // is named: 165 // __spv__{kernel_module_name} 166 // based on GPU to SPIR-V conversion. 167 StringRef kernelModuleName = launchOp.getKernelModuleName(); 168 std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); 169 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName); 170 if (!spvModule) { 171 return launchOp.emitOpError("SPIR-V kernel module '") 172 << spvModuleName << "' is not found"; 173 } 174 175 // Declare kernel function in the main module so that it later can be linked 176 // with its definition from the kernel module. We know that the kernel 177 // function would have no arguments and the data is passed via global 178 // variables. The name of the kernel will be 179 // {spv_module_name}_{kernel_function_name} 180 // to avoid symbolic name conflicts. 181 StringRef kernelFuncName = launchOp.getKernelName(); 182 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); 183 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName); 184 if (!kernelFunc) { 185 OpBuilder::InsertionGuard guard(rewriter); 186 rewriter.setInsertionPointToStart(module.getBody()); 187 kernelFunc = rewriter.create<LLVM::LLVMFuncOp>( 188 rewriter.getUnknownLoc(), newKernelFuncName, 189 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), 190 ArrayRef<Type>())); 191 rewriter.setInsertionPoint(launchOp); 192 } 193 194 // Get all global variables associated with the kernel operands. 195 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; 196 if (failed(getKernelGlobalVariables(spvModule, globalVariableMap))) 197 return failure(); 198 199 // Traverse kernel operands that were converted to MemRefDescriptors. For 200 // each operand, create a global variable and copy data from operand to it. 201 Location loc = launchOp.getLoc(); 202 SmallVector<CopyInfo, 4> copyInfo; 203 auto numKernelOperands = launchOp.getNumKernelOperands(); 204 auto kernelOperands = operands.take_back(numKernelOperands); 205 for (auto operand : llvm::enumerate(kernelOperands)) { 206 // Check if the kernel's operand is a ranked memref. 207 auto memRefType = launchOp.getKernelOperand(operand.index()) 208 .getType() 209 .dyn_cast<MemRefType>(); 210 if (!memRefType) 211 return failure(); 212 213 // Calculate the size of the memref and get the pointer to the allocated 214 // buffer. 215 SmallVector<Value, 4> sizes; 216 SmallVector<Value, 4> strides; 217 Value sizeBytes; 218 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides, 219 sizeBytes); 220 MemRefDescriptor descriptor(operand.value()); 221 Value src = descriptor.allocatedPtr(rewriter, loc); 222 223 // Get the global variable in the SPIR-V module that is associated with 224 // the kernel operand. Construct its new name and create a corresponding 225 // LLVM dialect global variable. 226 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; 227 auto pointeeType = 228 spirvGlobal.type().cast<spirv::PointerType>().getPointeeType(); 229 auto dstGlobalType = typeConverter->convertType(pointeeType); 230 if (!dstGlobalType) 231 return failure(); 232 std::string name = 233 createGlobalVariableWithBindName(spirvGlobal, spvModuleName); 234 // Check if this variable has already been created. 235 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name); 236 if (!dstGlobal) { 237 OpBuilder::InsertionGuard guard(rewriter); 238 rewriter.setInsertionPointToStart(module.getBody()); 239 dstGlobal = rewriter.create<LLVM::GlobalOp>( 240 loc, dstGlobalType, 241 /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), 242 /*alignment=*/0); 243 rewriter.setInsertionPoint(launchOp); 244 } 245 246 // Copy the data from src operand pointer to dst global variable. Save 247 // src, dst and size so that we can copy data back after emulating the 248 // kernel call. 249 Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal); 250 copy(loc, dst, src, sizeBytes, rewriter); 251 252 CopyInfo info; 253 info.dst = dst; 254 info.src = src; 255 info.size = sizeBytes; 256 copyInfo.push_back(info); 257 } 258 // Create a call to the kernel and copy the data back. 259 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc, 260 ArrayRef<Value>()); 261 for (CopyInfo info : copyInfo) 262 copy(loc, info.src, info.dst, info.size, rewriter); 263 return success(); 264 } 265 }; 266 267 class LowerHostCodeToLLVM 268 : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> { 269 public: 270 void runOnOperation() override { 271 ModuleOp module = getOperation(); 272 273 // Erase the GPU module. 274 for (auto gpuModule : 275 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>())) 276 gpuModule.erase(); 277 278 // Specify options to lower Standard to LLVM and pull in the conversion 279 // patterns. 280 LowerToLLVMOptions options(module.getContext()); 281 options.emitCWrappers = true; 282 auto *context = module.getContext(); 283 RewritePatternSet patterns(context); 284 LLVMTypeConverter typeConverter(context, options); 285 populateStdToLLVMConversionPatterns(typeConverter, patterns); 286 patterns.add<GPULaunchLowering>(typeConverter); 287 288 // Pull in SPIR-V type conversion patterns to convert SPIR-V global 289 // variable's type to LLVM dialect type. 290 populateSPIRVToLLVMTypeConversion(typeConverter); 291 292 ConversionTarget target(*context); 293 target.addLegalDialect<LLVM::LLVMDialect>(); 294 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 295 signalPassFailure(); 296 297 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic 298 // conflicts. 299 for (auto spvModule : module.getOps<spirv::ModuleOp>()) 300 (void)encodeKernelName(spvModule); 301 } 302 }; 303 } // namespace 304 305 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 306 mlir::createLowerHostCodeToLLVMPass() { 307 return std::make_unique<LowerHostCodeToLLVM>(); 308 } 309