1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements passes to convert `gpu.launch_func` op into a sequence 10 // of LLVM calls that emulate the host and device sides. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 16 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 17 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 20 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 21 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 22 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" 23 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 24 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 25 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 26 #include "mlir/IR/BuiltinOps.h" 27 #include "mlir/IR/SymbolTable.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/StringExtras.h" 32 #include "llvm/Support/FormatVariadic.h" 33 34 using namespace mlir; 35 36 static constexpr const char kSPIRVModule[] = "__spv__"; 37 38 //===----------------------------------------------------------------------===// 39 // Utility functions 40 //===----------------------------------------------------------------------===// 41 42 /// Returns the string name of the `DescriptorSet` decoration. 43 static std::string descriptorSetName() { 44 return llvm::convertToSnakeFromCamelCase( 45 stringifyDecoration(spirv::Decoration::DescriptorSet)); 46 } 47 48 /// Returns the string name of the `Binding` decoration. 49 static std::string bindingName() { 50 return llvm::convertToSnakeFromCamelCase( 51 stringifyDecoration(spirv::Decoration::Binding)); 52 } 53 54 /// Calculates the index of the kernel's operand that is represented by the 55 /// given global variable with the `bind` attribute. We assume that the index of 56 /// each kernel's operand is mapped to (descriptorSet, binding) by the map: 57 /// i -> (0, i) 58 /// which is implemented under `LowerABIAttributesPass`. 59 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { 60 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 61 return binding.getInt(); 62 } 63 64 /// Copies the given number of bytes from src to dst pointers. 65 static void copy(Location loc, Value dst, Value src, Value size, 66 OpBuilder &builder) { 67 MLIRContext *context = builder.getContext(); 68 auto llvmI1Type = IntegerType::get(context, 1); 69 Value isVolatile = builder.create<LLVM::ConstantOp>( 70 loc, llvmI1Type, builder.getBoolAttr(false)); 71 builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile); 72 } 73 74 /// Encodes the binding and descriptor set numbers into a new symbolic name. 75 /// The name is specified by 76 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b} 77 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and 78 /// binding numbers. 79 static std::string 80 createGlobalVariableWithBindName(spirv::GlobalVariableOp op, 81 StringRef kernelModuleName) { 82 IntegerAttr descriptorSet = 83 op->getAttrOfType<IntegerAttr>(descriptorSetName()); 84 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 85 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", 86 kernelModuleName.str(), op.sym_name().str(), 87 std::to_string(descriptorSet.getInt()), 88 std::to_string(binding.getInt())); 89 } 90 91 /// Returns true if the given global variable has both a descriptor set number 92 /// and a binding number. 93 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { 94 IntegerAttr descriptorSet = 95 op->getAttrOfType<IntegerAttr>(descriptorSetName()); 96 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 97 return descriptorSet && binding; 98 } 99 100 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel 101 /// arguments from the given SPIR-V module. We assume that the module contains a 102 /// single entry point function. Hence, all `spv.GlobalVariable`s with a bind 103 /// attribute are kernel arguments. 104 static LogicalResult getKernelGlobalVariables( 105 spirv::ModuleOp module, 106 DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) { 107 auto entryPoints = module.getOps<spirv::EntryPointOp>(); 108 if (!llvm::hasSingleElement(entryPoints)) { 109 return module.emitError( 110 "The module must contain exactly one entry point function"); 111 } 112 auto globalVariables = module.getOps<spirv::GlobalVariableOp>(); 113 for (auto globalOp : globalVariables) { 114 if (hasDescriptorSetAndBinding(globalOp)) 115 globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp; 116 } 117 return success(); 118 } 119 120 /// Encodes the SPIR-V module's symbolic name into the name of the entry point 121 /// function. 122 static LogicalResult encodeKernelName(spirv::ModuleOp module) { 123 StringRef spvModuleName = module.sym_name().getValue(); 124 // We already know that the module contains exactly one entry point function 125 // based on `getKernelGlobalVariables()` call. Update this function's name 126 // to: 127 // {spv_module_name}_{function_name} 128 auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin(); 129 StringRef funcName = entryPoint.fn(); 130 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr()); 131 StringAttr newFuncName = 132 StringAttr::get(module->getContext(), spvModuleName + "_" + funcName); 133 if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) 134 return failure(); 135 SymbolTable::setSymbolName(funcOp, newFuncName); 136 return success(); 137 } 138 139 //===----------------------------------------------------------------------===// 140 // Conversion patterns 141 //===----------------------------------------------------------------------===// 142 143 namespace { 144 145 /// Structure to group information about the variables being copied. 146 struct CopyInfo { 147 Value dst; 148 Value src; 149 Value size; 150 }; 151 152 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we 153 /// copy the data to the global variable (emulating device side), call the 154 /// kernel as a normal void LLVM function, and copy the data back (emulating the 155 /// host side). 156 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { 157 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern; 158 159 LogicalResult 160 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 161 ConversionPatternRewriter &rewriter) const override { 162 auto *op = launchOp.getOperation(); 163 MLIRContext *context = rewriter.getContext(); 164 auto module = launchOp->getParentOfType<ModuleOp>(); 165 166 // Get the SPIR-V module that represents the gpu kernel module. The module 167 // is named: 168 // __spv__{kernel_module_name} 169 // based on GPU to SPIR-V conversion. 170 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue(); 171 std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); 172 auto spvModule = module.lookupSymbol<spirv::ModuleOp>( 173 StringAttr::get(context, spvModuleName)); 174 if (!spvModule) { 175 return launchOp.emitOpError("SPIR-V kernel module '") 176 << spvModuleName << "' is not found"; 177 } 178 179 // Declare kernel function in the main module so that it later can be linked 180 // with its definition from the kernel module. We know that the kernel 181 // function would have no arguments and the data is passed via global 182 // variables. The name of the kernel will be 183 // {spv_module_name}_{kernel_function_name} 184 // to avoid symbolic name conflicts. 185 StringRef kernelFuncName = launchOp.getKernelName().getValue(); 186 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); 187 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>( 188 StringAttr::get(context, newKernelFuncName)); 189 if (!kernelFunc) { 190 OpBuilder::InsertionGuard guard(rewriter); 191 rewriter.setInsertionPointToStart(module.getBody()); 192 kernelFunc = rewriter.create<LLVM::LLVMFuncOp>( 193 rewriter.getUnknownLoc(), newKernelFuncName, 194 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), 195 ArrayRef<Type>())); 196 rewriter.setInsertionPoint(launchOp); 197 } 198 199 // Get all global variables associated with the kernel operands. 200 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; 201 if (failed(getKernelGlobalVariables(spvModule, globalVariableMap))) 202 return failure(); 203 204 // Traverse kernel operands that were converted to MemRefDescriptors. For 205 // each operand, create a global variable and copy data from operand to it. 206 Location loc = launchOp.getLoc(); 207 SmallVector<CopyInfo, 4> copyInfo; 208 auto numKernelOperands = launchOp.getNumKernelOperands(); 209 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); 210 for (const auto &operand : llvm::enumerate(kernelOperands)) { 211 // Check if the kernel's operand is a ranked memref. 212 auto memRefType = launchOp.getKernelOperand(operand.index()) 213 .getType() 214 .dyn_cast<MemRefType>(); 215 if (!memRefType) 216 return failure(); 217 218 // Calculate the size of the memref and get the pointer to the allocated 219 // buffer. 220 SmallVector<Value, 4> sizes; 221 SmallVector<Value, 4> strides; 222 Value sizeBytes; 223 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides, 224 sizeBytes); 225 MemRefDescriptor descriptor(operand.value()); 226 Value src = descriptor.allocatedPtr(rewriter, loc); 227 228 // Get the global variable in the SPIR-V module that is associated with 229 // the kernel operand. Construct its new name and create a corresponding 230 // LLVM dialect global variable. 231 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; 232 auto pointeeType = 233 spirvGlobal.type().cast<spirv::PointerType>().getPointeeType(); 234 auto dstGlobalType = typeConverter->convertType(pointeeType); 235 if (!dstGlobalType) 236 return failure(); 237 std::string name = 238 createGlobalVariableWithBindName(spirvGlobal, spvModuleName); 239 // Check if this variable has already been created. 240 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name); 241 if (!dstGlobal) { 242 OpBuilder::InsertionGuard guard(rewriter); 243 rewriter.setInsertionPointToStart(module.getBody()); 244 dstGlobal = rewriter.create<LLVM::GlobalOp>( 245 loc, dstGlobalType, 246 /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), 247 /*alignment=*/0); 248 rewriter.setInsertionPoint(launchOp); 249 } 250 251 // Copy the data from src operand pointer to dst global variable. Save 252 // src, dst and size so that we can copy data back after emulating the 253 // kernel call. 254 Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal); 255 copy(loc, dst, src, sizeBytes, rewriter); 256 257 CopyInfo info; 258 info.dst = dst; 259 info.src = src; 260 info.size = sizeBytes; 261 copyInfo.push_back(info); 262 } 263 // Create a call to the kernel and copy the data back. 264 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc, 265 ArrayRef<Value>()); 266 for (CopyInfo info : copyInfo) 267 copy(loc, info.src, info.dst, info.size, rewriter); 268 return success(); 269 } 270 }; 271 272 class LowerHostCodeToLLVM 273 : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> { 274 public: 275 void runOnOperation() override { 276 ModuleOp module = getOperation(); 277 278 // Erase the GPU module. 279 for (auto gpuModule : 280 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>())) 281 gpuModule.erase(); 282 283 // Specify options to lower to LLVM and pull in the conversion patterns. 284 LowerToLLVMOptions options(module.getContext()); 285 options.emitCWrappers = true; 286 auto *context = module.getContext(); 287 RewritePatternSet patterns(context); 288 LLVMTypeConverter typeConverter(context, options); 289 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 290 patterns); 291 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 292 populateFuncToLLVMConversionPatterns(typeConverter, patterns); 293 patterns.add<GPULaunchLowering>(typeConverter); 294 295 // Pull in SPIR-V type conversion patterns to convert SPIR-V global 296 // variable's type to LLVM dialect type. 297 populateSPIRVToLLVMTypeConversion(typeConverter); 298 299 ConversionTarget target(*context); 300 target.addLegalDialect<LLVM::LLVMDialect>(); 301 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 302 signalPassFailure(); 303 304 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic 305 // conflicts. 306 for (auto spvModule : module.getOps<spirv::ModuleOp>()) 307 (void)encodeKernelName(spvModule); 308 } 309 }; 310 } // namespace 311 312 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 313 mlir::createLowerHostCodeToLLVMPass() { 314 return std::make_unique<LowerHostCodeToLLVM>(); 315 } 316