1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements passes to convert `gpu.launch_func` op into a sequence 10 // of LLVM calls that emulate the host and device sides. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h" 16 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" 17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/SymbolTable.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 using namespace mlir; 30 31 static constexpr const char kSPIRVModule[] = "__spv__"; 32 33 //===----------------------------------------------------------------------===// 34 // Utility functions 35 //===----------------------------------------------------------------------===// 36 37 /// Returns the string name of the `DescriptorSet` decoration. 38 static std::string descriptorSetName() { 39 return llvm::convertToSnakeFromCamelCase( 40 stringifyDecoration(spirv::Decoration::DescriptorSet)); 41 } 42 43 /// Returns the string name of the `Binding` decoration. 44 static std::string bindingName() { 45 return llvm::convertToSnakeFromCamelCase( 46 stringifyDecoration(spirv::Decoration::Binding)); 47 } 48 49 /// Calculates the index of the kernel's operand that is represented by the 50 /// given global variable with the `bind` attribute. We assume that the index of 51 /// each kernel's operand is mapped to (descriptorSet, binding) by the map: 52 /// i -> (0, i) 53 /// which is implemented under `LowerABIAttributesPass`. 54 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { 55 IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName()); 56 return binding.getInt(); 57 } 58 59 /// Copies the given number of bytes from src to dst pointers. 60 static void copy(Location loc, Value dst, Value src, Value size, 61 OpBuilder &builder) { 62 MLIRContext *context = builder.getContext(); 63 auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context); 64 Value isVolatile = builder.create<LLVM::ConstantOp>( 65 loc, llvmI1Type, builder.getBoolAttr(false)); 66 builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile); 67 } 68 69 /// Encodes the binding and descriptor set numbers into a new symbolic name. 70 /// The name is specified by 71 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b} 72 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and 73 /// binding numbers. 74 static std::string 75 createGlobalVariableWithBindName(spirv::GlobalVariableOp op, 76 StringRef kernelModuleName) { 77 IntegerAttr descriptorSet = 78 op.getAttrOfType<IntegerAttr>(descriptorSetName()); 79 IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName()); 80 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", 81 kernelModuleName.str(), op.sym_name().str(), 82 std::to_string(descriptorSet.getInt()), 83 std::to_string(binding.getInt())); 84 } 85 86 /// Returns true if the given global variable has both a descriptor set number 87 /// and a binding number. 88 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { 89 IntegerAttr descriptorSet = 90 op.getAttrOfType<IntegerAttr>(descriptorSetName()); 91 IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName()); 92 return descriptorSet && binding; 93 } 94 95 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel 96 /// arguments from the given SPIR-V module. We assume that the module contains a 97 /// single entry point function. Hence, all `spv.globalVariable`s with a bind 98 /// attribute are kernel arguments. 99 static LogicalResult getKernelGlobalVariables( 100 spirv::ModuleOp module, 101 DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) { 102 auto entryPoints = module.getOps<spirv::EntryPointOp>(); 103 if (!llvm::hasSingleElement(entryPoints)) { 104 return module.emitError( 105 "The module must contain exactly one entry point function"); 106 } 107 auto globalVariables = module.getOps<spirv::GlobalVariableOp>(); 108 for (auto globalOp : globalVariables) { 109 if (hasDescriptorSetAndBinding(globalOp)) 110 globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp; 111 } 112 return success(); 113 } 114 115 /// Encodes the SPIR-V module's symbolic name into the name of the entry point 116 /// function. 117 static LogicalResult encodeKernelName(spirv::ModuleOp module) { 118 StringRef spvModuleName = module.sym_name().getValue(); 119 // We already know that the module contains exactly one entry point function 120 // based on `getKernelGlobalVariables()` call. Update this function's name 121 // to: 122 // {spv_module_name}_{function_name} 123 auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin(); 124 StringRef funcName = entryPoint.fn(); 125 auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName); 126 std::string newFuncName = spvModuleName.str() + "_" + funcName.str(); 127 if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) 128 return failure(); 129 SymbolTable::setSymbolName(funcOp, newFuncName); 130 return success(); 131 } 132 133 //===----------------------------------------------------------------------===// 134 // Conversion patterns 135 //===----------------------------------------------------------------------===// 136 137 namespace { 138 139 /// Structure to group information about the variables being copied. 140 struct CopyInfo { 141 Value dst; 142 Value src; 143 Value size; 144 }; 145 146 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we 147 /// copy the data to the global variable (emulating device side), call the 148 /// kernel as a normal void LLVM function, and copy the data back (emulating the 149 /// host side). 150 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { 151 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern; 152 153 LogicalResult 154 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 155 ConversionPatternRewriter &rewriter) const override { 156 gpu::LaunchFuncOp launchOp = cast<gpu::LaunchFuncOp>(op); 157 MLIRContext *context = rewriter.getContext(); 158 auto module = launchOp.getParentOfType<ModuleOp>(); 159 160 // Get the SPIR-V module that represents the gpu kernel module. The module 161 // is named: 162 // __spv__{kernel_module_name} 163 // based on GPU to SPIR-V conversion. 164 StringRef kernelModuleName = launchOp.getKernelModuleName(); 165 std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); 166 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName); 167 if (!spvModule) { 168 return launchOp.emitOpError("SPIR-V kernel module '") 169 << spvModuleName << "' is not found"; 170 } 171 172 // Declare kernel function in the main module so that it later can be linked 173 // with its definition from the kernel module. We know that the kernel 174 // function would have no arguments and the data is passed via global 175 // variables. The name of the kernel will be 176 // {spv_module_name}_{kernel_function_name} 177 // to avoid symbolic name conflicts. 178 StringRef kernelFuncName = launchOp.getKernelName(); 179 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); 180 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName); 181 if (!kernelFunc) { 182 OpBuilder::InsertionGuard guard(rewriter); 183 rewriter.setInsertionPointToStart(module.getBody()); 184 kernelFunc = rewriter.create<LLVM::LLVMFuncOp>( 185 rewriter.getUnknownLoc(), newKernelFuncName, 186 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context), 187 ArrayRef<LLVM::LLVMType>(), 188 /*isVarArg=*/false)); 189 rewriter.setInsertionPoint(launchOp); 190 } 191 192 // Get all global variables associated with the kernel operands. 193 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; 194 if (failed(getKernelGlobalVariables(spvModule, globalVariableMap))) 195 return failure(); 196 197 // Traverse kernel operands that were converted to MemRefDescriptors. For 198 // each operand, create a global variable and copy data from operand to it. 199 Location loc = launchOp.getLoc(); 200 SmallVector<CopyInfo, 4> copyInfo; 201 auto numKernelOperands = launchOp.getNumKernelOperands(); 202 auto kernelOperands = operands.take_back(numKernelOperands); 203 for (auto operand : llvm::enumerate(kernelOperands)) { 204 // Check if the kernel's opernad is a ranked memref. 205 auto memRefType = launchOp.getKernelOperand(operand.index()) 206 .getType() 207 .dyn_cast<MemRefType>(); 208 if (!memRefType) 209 return failure(); 210 211 // Calculate the size of the memref and get the pointer to the allocated 212 // buffer. 213 SmallVector<Value, 4> sizes; 214 SmallVector<Value, 4> strides; 215 Value sizeBytes; 216 getMemRefDescriptorSizes(loc, memRefType, operand.value(), rewriter, 217 sizes, strides, sizeBytes); 218 MemRefDescriptor descriptor(operand.value()); 219 Value src = descriptor.allocatedPtr(rewriter, loc); 220 221 // Get the global variable in the SPIR-V module that is associated with 222 // the kernel operand. Construct its new name and create a corresponding 223 // LLVM dialect global variable. 224 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; 225 auto pointeeType = 226 spirvGlobal.type().cast<spirv::PointerType>().getPointeeType(); 227 auto dstGlobalType = typeConverter.convertType(pointeeType); 228 if (!dstGlobalType) 229 return failure(); 230 std::string name = 231 createGlobalVariableWithBindName(spirvGlobal, spvModuleName); 232 // Check if this variable has already been created. 233 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name); 234 if (!dstGlobal) { 235 OpBuilder::InsertionGuard guard(rewriter); 236 rewriter.setInsertionPointToStart(module.getBody()); 237 dstGlobal = rewriter.create<LLVM::GlobalOp>( 238 loc, dstGlobalType.cast<LLVM::LLVMType>(), 239 /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute()); 240 rewriter.setInsertionPoint(launchOp); 241 } 242 243 // Copy the data from src operand pointer to dst global variable. Save 244 // src, dst and size so that we can copy data back after emulating the 245 // kernel call. 246 Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal); 247 copy(loc, dst, src, sizeBytes, rewriter); 248 249 CopyInfo info; 250 info.dst = dst; 251 info.src = src; 252 info.size = sizeBytes; 253 copyInfo.push_back(info); 254 } 255 // Create a call to the kernel and copy the data back. 256 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc, 257 ArrayRef<Value>()); 258 for (CopyInfo info : copyInfo) 259 copy(loc, info.src, info.dst, info.size, rewriter); 260 return success(); 261 } 262 }; 263 264 class LowerHostCodeToLLVM 265 : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> { 266 public: 267 void runOnOperation() override { 268 ModuleOp module = getOperation(); 269 270 // Erase the GPU module. 271 for (auto gpuModule : 272 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>())) 273 gpuModule.erase(); 274 275 // Specify options to lower Standard to LLVM and pull in the conversion 276 // patterns. 277 LowerToLLVMOptions options = { 278 /*useBarePtrCallConv=*/false, 279 /*emitCWrappers=*/true, 280 /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}; 281 auto *context = module.getContext(); 282 OwningRewritePatternList patterns; 283 LLVMTypeConverter typeConverter(context, options); 284 populateStdToLLVMConversionPatterns(typeConverter, patterns); 285 patterns.insert<GPULaunchLowering>(typeConverter); 286 287 // Pull in SPIR-V type conversion patterns to convert SPIR-V global 288 // variable's type to LLVM dialect type. 289 populateSPIRVToLLVMTypeConversion(typeConverter); 290 291 ConversionTarget target(*context); 292 target.addLegalDialect<LLVM::LLVMDialect>(); 293 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 294 signalPassFailure(); 295 296 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic 297 // conflicts. 298 for (auto spvModule : module.getOps<spirv::ModuleOp>()) 299 encodeKernelName(spvModule); 300 } 301 }; 302 } // namespace 303 304 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 305 mlir::createLowerHostCodeToLLVMPass() { 306 return std::make_unique<LowerHostCodeToLLVM>(); 307 } 308