1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements passes to convert `gpu.launch_func` op into a sequence 10 // of LLVM calls that emulate the host and device sides. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 16 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 17 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 20 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 21 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 22 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" 24 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 26 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/SymbolTable.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 31 #include "llvm/ADT/DenseMap.h" 32 #include "llvm/ADT/StringExtras.h" 33 #include "llvm/Support/FormatVariadic.h" 34 35 using namespace mlir; 36 37 static constexpr const char kSPIRVModule[] = "__spv__"; 38 39 //===----------------------------------------------------------------------===// 40 // Utility functions 41 //===----------------------------------------------------------------------===// 42 43 /// Returns the string name of the `DescriptorSet` decoration. 44 static std::string descriptorSetName() { 45 return llvm::convertToSnakeFromCamelCase( 46 stringifyDecoration(spirv::Decoration::DescriptorSet)); 47 } 48 49 /// Returns the string name of the `Binding` decoration. 50 static std::string bindingName() { 51 return llvm::convertToSnakeFromCamelCase( 52 stringifyDecoration(spirv::Decoration::Binding)); 53 } 54 55 /// Calculates the index of the kernel's operand that is represented by the 56 /// given global variable with the `bind` attribute. We assume that the index of 57 /// each kernel's operand is mapped to (descriptorSet, binding) by the map: 58 /// i -> (0, i) 59 /// which is implemented under `LowerABIAttributesPass`. 60 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { 61 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 62 return binding.getInt(); 63 } 64 65 /// Copies the given number of bytes from src to dst pointers. 66 static void copy(Location loc, Value dst, Value src, Value size, 67 OpBuilder &builder) { 68 MLIRContext *context = builder.getContext(); 69 auto llvmI1Type = IntegerType::get(context, 1); 70 Value isVolatile = builder.create<LLVM::ConstantOp>( 71 loc, llvmI1Type, builder.getBoolAttr(false)); 72 builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile); 73 } 74 75 /// Encodes the binding and descriptor set numbers into a new symbolic name. 76 /// The name is specified by 77 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b} 78 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and 79 /// binding numbers. 80 static std::string 81 createGlobalVariableWithBindName(spirv::GlobalVariableOp op, 82 StringRef kernelModuleName) { 83 IntegerAttr descriptorSet = 84 op->getAttrOfType<IntegerAttr>(descriptorSetName()); 85 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 86 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", 87 kernelModuleName.str(), op.sym_name().str(), 88 std::to_string(descriptorSet.getInt()), 89 std::to_string(binding.getInt())); 90 } 91 92 /// Returns true if the given global variable has both a descriptor set number 93 /// and a binding number. 94 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { 95 IntegerAttr descriptorSet = 96 op->getAttrOfType<IntegerAttr>(descriptorSetName()); 97 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); 98 return descriptorSet && binding; 99 } 100 101 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel 102 /// arguments from the given SPIR-V module. We assume that the module contains a 103 /// single entry point function. Hence, all `spv.GlobalVariable`s with a bind 104 /// attribute are kernel arguments. 105 static LogicalResult getKernelGlobalVariables( 106 spirv::ModuleOp module, 107 DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) { 108 auto entryPoints = module.getOps<spirv::EntryPointOp>(); 109 if (!llvm::hasSingleElement(entryPoints)) { 110 return module.emitError( 111 "The module must contain exactly one entry point function"); 112 } 113 auto globalVariables = module.getOps<spirv::GlobalVariableOp>(); 114 for (auto globalOp : globalVariables) { 115 if (hasDescriptorSetAndBinding(globalOp)) 116 globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp; 117 } 118 return success(); 119 } 120 121 /// Encodes the SPIR-V module's symbolic name into the name of the entry point 122 /// function. 123 static LogicalResult encodeKernelName(spirv::ModuleOp module) { 124 StringRef spvModuleName = *module.sym_name(); 125 // We already know that the module contains exactly one entry point function 126 // based on `getKernelGlobalVariables()` call. Update this function's name 127 // to: 128 // {spv_module_name}_{function_name} 129 auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin(); 130 StringRef funcName = entryPoint.fn(); 131 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr()); 132 StringAttr newFuncName = 133 StringAttr::get(module->getContext(), spvModuleName + "_" + funcName); 134 if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) 135 return failure(); 136 SymbolTable::setSymbolName(funcOp, newFuncName); 137 return success(); 138 } 139 140 //===----------------------------------------------------------------------===// 141 // Conversion patterns 142 //===----------------------------------------------------------------------===// 143 144 namespace { 145 146 /// Structure to group information about the variables being copied. 147 struct CopyInfo { 148 Value dst; 149 Value src; 150 Value size; 151 }; 152 153 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we 154 /// copy the data to the global variable (emulating device side), call the 155 /// kernel as a normal void LLVM function, and copy the data back (emulating the 156 /// host side). 157 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { 158 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 162 ConversionPatternRewriter &rewriter) const override { 163 auto *op = launchOp.getOperation(); 164 MLIRContext *context = rewriter.getContext(); 165 auto module = launchOp->getParentOfType<ModuleOp>(); 166 167 // Get the SPIR-V module that represents the gpu kernel module. The module 168 // is named: 169 // __spv__{kernel_module_name} 170 // based on GPU to SPIR-V conversion. 171 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue(); 172 std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); 173 auto spvModule = module.lookupSymbol<spirv::ModuleOp>( 174 StringAttr::get(context, spvModuleName)); 175 if (!spvModule) { 176 return launchOp.emitOpError("SPIR-V kernel module '") 177 << spvModuleName << "' is not found"; 178 } 179 180 // Declare kernel function in the main module so that it later can be linked 181 // with its definition from the kernel module. We know that the kernel 182 // function would have no arguments and the data is passed via global 183 // variables. The name of the kernel will be 184 // {spv_module_name}_{kernel_function_name} 185 // to avoid symbolic name conflicts. 186 StringRef kernelFuncName = launchOp.getKernelName().getValue(); 187 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); 188 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>( 189 StringAttr::get(context, newKernelFuncName)); 190 if (!kernelFunc) { 191 OpBuilder::InsertionGuard guard(rewriter); 192 rewriter.setInsertionPointToStart(module.getBody()); 193 kernelFunc = rewriter.create<LLVM::LLVMFuncOp>( 194 rewriter.getUnknownLoc(), newKernelFuncName, 195 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), 196 ArrayRef<Type>())); 197 rewriter.setInsertionPoint(launchOp); 198 } 199 200 // Get all global variables associated with the kernel operands. 201 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; 202 if (failed(getKernelGlobalVariables(spvModule, globalVariableMap))) 203 return failure(); 204 205 // Traverse kernel operands that were converted to MemRefDescriptors. For 206 // each operand, create a global variable and copy data from operand to it. 207 Location loc = launchOp.getLoc(); 208 SmallVector<CopyInfo, 4> copyInfo; 209 auto numKernelOperands = launchOp.getNumKernelOperands(); 210 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); 211 for (const auto &operand : llvm::enumerate(kernelOperands)) { 212 // Check if the kernel's operand is a ranked memref. 213 auto memRefType = launchOp.getKernelOperand(operand.index()) 214 .getType() 215 .dyn_cast<MemRefType>(); 216 if (!memRefType) 217 return failure(); 218 219 // Calculate the size of the memref and get the pointer to the allocated 220 // buffer. 221 SmallVector<Value, 4> sizes; 222 SmallVector<Value, 4> strides; 223 Value sizeBytes; 224 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides, 225 sizeBytes); 226 MemRefDescriptor descriptor(operand.value()); 227 Value src = descriptor.allocatedPtr(rewriter, loc); 228 229 // Get the global variable in the SPIR-V module that is associated with 230 // the kernel operand. Construct its new name and create a corresponding 231 // LLVM dialect global variable. 232 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; 233 auto pointeeType = 234 spirvGlobal.type().cast<spirv::PointerType>().getPointeeType(); 235 auto dstGlobalType = typeConverter->convertType(pointeeType); 236 if (!dstGlobalType) 237 return failure(); 238 std::string name = 239 createGlobalVariableWithBindName(spirvGlobal, spvModuleName); 240 // Check if this variable has already been created. 241 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name); 242 if (!dstGlobal) { 243 OpBuilder::InsertionGuard guard(rewriter); 244 rewriter.setInsertionPointToStart(module.getBody()); 245 dstGlobal = rewriter.create<LLVM::GlobalOp>( 246 loc, dstGlobalType, 247 /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), 248 /*alignment=*/0); 249 rewriter.setInsertionPoint(launchOp); 250 } 251 252 // Copy the data from src operand pointer to dst global variable. Save 253 // src, dst and size so that we can copy data back after emulating the 254 // kernel call. 255 Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal); 256 copy(loc, dst, src, sizeBytes, rewriter); 257 258 CopyInfo info; 259 info.dst = dst; 260 info.src = src; 261 info.size = sizeBytes; 262 copyInfo.push_back(info); 263 } 264 // Create a call to the kernel and copy the data back. 265 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc, 266 ArrayRef<Value>()); 267 for (CopyInfo info : copyInfo) 268 copy(loc, info.src, info.dst, info.size, rewriter); 269 return success(); 270 } 271 }; 272 273 class LowerHostCodeToLLVM 274 : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> { 275 public: 276 void runOnOperation() override { 277 ModuleOp module = getOperation(); 278 279 // Erase the GPU module. 280 for (auto gpuModule : 281 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>())) 282 gpuModule.erase(); 283 284 // Request C wrapper emission. 285 for (auto func : module.getOps<func::FuncOp>()) { 286 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 287 UnitAttr::get(&getContext())); 288 } 289 290 // Specify options to lower to LLVM and pull in the conversion patterns. 291 LowerToLLVMOptions options(module.getContext()); 292 auto *context = module.getContext(); 293 RewritePatternSet patterns(context); 294 LLVMTypeConverter typeConverter(context, options); 295 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 296 patterns); 297 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 298 populateFuncToLLVMConversionPatterns(typeConverter, patterns); 299 patterns.add<GPULaunchLowering>(typeConverter); 300 301 // Pull in SPIR-V type conversion patterns to convert SPIR-V global 302 // variable's type to LLVM dialect type. 303 populateSPIRVToLLVMTypeConversion(typeConverter); 304 305 ConversionTarget target(*context); 306 target.addLegalDialect<LLVM::LLVMDialect>(); 307 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 308 signalPassFailure(); 309 310 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic 311 // conflicts. 312 for (auto spvModule : module.getOps<spirv::ModuleOp>()) 313 (void)encodeKernelName(spvModule); 314 } 315 }; 316 } // namespace 317 318 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 319 mlir::createLowerHostCodeToLLVMPass() { 320 return std::make_unique<LowerHostCodeToLLVM>(); 321 } 322