1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert GPU dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/StringSwitch.h" 22 23 using namespace mlir; 24 25 static constexpr const char kSPIRVModule[] = "__spv__"; 26 27 namespace { 28 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation 29 /// builtin variables. 30 template <typename SourceOp, spirv::BuiltIn builtin> 31 class LaunchConfigConversion : public OpConversionPattern<SourceOp> { 32 public: 33 using OpConversionPattern<SourceOp>::OpConversionPattern; 34 35 LogicalResult 36 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 37 ConversionPatternRewriter &rewriter) const override; 38 }; 39 40 /// Pattern lowering subgroup size/id to loading SPIR-V invocation 41 /// builtin variables. 42 template <typename SourceOp, spirv::BuiltIn builtin> 43 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> { 44 public: 45 using OpConversionPattern<SourceOp>::OpConversionPattern; 46 47 LogicalResult 48 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 49 ConversionPatternRewriter &rewriter) const override; 50 }; 51 52 /// This is separate because in Vulkan workgroup size is exposed to shaders via 53 /// a constant with WorkgroupSize decoration. So here we cannot generate a 54 /// builtin variable; instead the information in the `spv.entry_point_abi` 55 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp. 56 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> { 57 public: 58 using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern; 59 60 LogicalResult 61 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor, 62 ConversionPatternRewriter &rewriter) const override; 63 }; 64 65 /// Pattern to convert a kernel function in GPU dialect within a spv.module. 66 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> { 67 public: 68 using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern; 69 70 LogicalResult 71 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor, 72 ConversionPatternRewriter &rewriter) const override; 73 74 private: 75 SmallVector<int32_t, 3> workGroupSizeAsInt32; 76 }; 77 78 /// Pattern to convert a gpu.module to a spv.module. 79 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> { 80 public: 81 using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern; 82 83 LogicalResult 84 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, 85 ConversionPatternRewriter &rewriter) const override; 86 }; 87 88 class GPUModuleEndConversion final 89 : public OpConversionPattern<gpu::ModuleEndOp> { 90 public: 91 using OpConversionPattern::OpConversionPattern; 92 93 LogicalResult 94 matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor, 95 ConversionPatternRewriter &rewriter) const override { 96 rewriter.eraseOp(endOp); 97 return success(); 98 } 99 }; 100 101 /// Pattern to convert a gpu.return into a SPIR-V return. 102 // TODO: This can go to DRR when GPU return has operands. 103 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> { 104 public: 105 using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override; 110 }; 111 112 } // namespace 113 114 //===----------------------------------------------------------------------===// 115 // Builtins. 116 //===----------------------------------------------------------------------===// 117 118 template <typename SourceOp, spirv::BuiltIn builtin> 119 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( 120 SourceOp op, typename SourceOp::Adaptor adaptor, 121 ConversionPatternRewriter &rewriter) const { 122 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 123 auto indexType = typeConverter->getIndexType(); 124 125 // SPIR-V invocation builtin variables are a vector of type <3xi32> 126 auto spirvBuiltin = 127 spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); 128 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 129 op, indexType, spirvBuiltin, 130 rewriter.getI32ArrayAttr({static_cast<int32_t>(op.dimension())})); 131 return success(); 132 } 133 134 template <typename SourceOp, spirv::BuiltIn builtin> 135 LogicalResult 136 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( 137 SourceOp op, typename SourceOp::Adaptor adaptor, 138 ConversionPatternRewriter &rewriter) const { 139 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 140 auto indexType = typeConverter->getIndexType(); 141 142 auto spirvBuiltin = 143 spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); 144 rewriter.replaceOp(op, spirvBuiltin); 145 return success(); 146 } 147 148 LogicalResult WorkGroupSizeConversion::matchAndRewrite( 149 gpu::BlockDimOp op, OpAdaptor adaptor, 150 ConversionPatternRewriter &rewriter) const { 151 auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); 152 auto val = workGroupSizeAttr 153 .getValues<int32_t>()[static_cast<int32_t>(op.dimension())]; 154 auto convertedType = 155 getTypeConverter()->convertType(op.getResult().getType()); 156 if (!convertedType) 157 return failure(); 158 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 159 op, convertedType, IntegerAttr::get(convertedType, val)); 160 return success(); 161 } 162 163 //===----------------------------------------------------------------------===// 164 // GPUFuncOp 165 //===----------------------------------------------------------------------===// 166 167 // Legalizes a GPU function as an entry SPIR-V function. 168 static spirv::FuncOp 169 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, 170 ConversionPatternRewriter &rewriter, 171 spirv::EntryPointABIAttr entryPointInfo, 172 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { 173 auto fnType = funcOp.getType(); 174 if (fnType.getNumResults()) { 175 funcOp.emitError("SPIR-V lowering only supports entry functions" 176 "with no return values right now"); 177 return nullptr; 178 } 179 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) { 180 funcOp.emitError( 181 "lowering as entry functions requires ABI info for all arguments " 182 "or none of them"); 183 return nullptr; 184 } 185 // Update the signature to valid SPIR-V types and add the ABI 186 // attributes. These will be "materialized" by using the 187 // LowerABIAttributesPass. 188 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 189 { 190 for (const auto &argType : enumerate(funcOp.getType().getInputs())) { 191 auto convertedType = typeConverter.convertType(argType.value()); 192 signatureConverter.addInputs(argType.index(), convertedType); 193 } 194 } 195 auto newFuncOp = rewriter.create<spirv::FuncOp>( 196 funcOp.getLoc(), funcOp.getName(), 197 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 198 llvm::None)); 199 for (const auto &namedAttr : funcOp->getAttrs()) { 200 if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() || 201 namedAttr.getName() == SymbolTable::getSymbolAttrName()) 202 continue; 203 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 204 } 205 206 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 207 newFuncOp.end()); 208 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 209 &signatureConverter))) 210 return nullptr; 211 rewriter.eraseOp(funcOp); 212 213 // Set the attributes for argument and the function. 214 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); 215 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) { 216 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); 217 } 218 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); 219 220 return newFuncOp; 221 } 222 223 /// Populates `argABI` with spv.interface_var_abi attributes for lowering 224 /// gpu.func to spv.func if no arguments have the attributes set 225 /// already. Returns failure if any argument has the ABI attribute set already. 226 static LogicalResult 227 getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp, 228 SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) { 229 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(funcOp); 230 if (!spirv::needsInterfaceVarABIAttrs(targetEnv)) 231 return success(); 232 233 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 234 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 235 argIndex, spirv::getInterfaceVarABIAttrName())) 236 return failure(); 237 // Vulkan's interface variable requirements needs scalars to be wrapped in a 238 // struct. The struct held in storage buffer. 239 Optional<spirv::StorageClass> sc; 240 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) 241 sc = spirv::StorageClass::StorageBuffer; 242 argABI.push_back(spirv::getInterfaceVarABIAttr(0, argIndex, sc, context)); 243 } 244 return success(); 245 } 246 247 LogicalResult GPUFuncOpConversion::matchAndRewrite( 248 gpu::GPUFuncOp funcOp, OpAdaptor adaptor, 249 ConversionPatternRewriter &rewriter) const { 250 if (!gpu::GPUDialect::isKernel(funcOp)) 251 return failure(); 252 253 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI; 254 if (failed(getDefaultABIAttrs(rewriter.getContext(), funcOp, argABI))) { 255 argABI.clear(); 256 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 257 // If the ABI is already specified, use it. 258 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 259 argIndex, spirv::getInterfaceVarABIAttrName()); 260 if (!abiAttr) { 261 funcOp.emitRemark( 262 "match failure: missing 'spv.interface_var_abi' attribute at " 263 "argument ") 264 << argIndex; 265 return failure(); 266 } 267 argABI.push_back(abiAttr); 268 } 269 } 270 271 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp); 272 if (!entryPointAttr) { 273 funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute"); 274 return failure(); 275 } 276 spirv::FuncOp newFuncOp = lowerAsEntryFunction( 277 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI); 278 if (!newFuncOp) 279 return failure(); 280 newFuncOp->removeAttr( 281 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName())); 282 return success(); 283 } 284 285 //===----------------------------------------------------------------------===// 286 // ModuleOp with gpu.module. 287 //===----------------------------------------------------------------------===// 288 289 LogicalResult GPUModuleConversion::matchAndRewrite( 290 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, 291 ConversionPatternRewriter &rewriter) const { 292 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp); 293 spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv); 294 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv); 295 if (failed(memoryModel)) 296 return moduleOp.emitRemark("match failure: could not selected memory model " 297 "based on 'spv.target_env'"); 298 299 // Add a keyword to the module name to avoid symbolic conflict. 300 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); 301 auto spvModule = rewriter.create<spirv::ModuleOp>( 302 moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None, 303 StringRef(spvModuleName)); 304 305 // Move the region from the module op into the SPIR-V module. 306 Region &spvModuleRegion = spvModule.getRegion(); 307 rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion, 308 spvModuleRegion.begin()); 309 // The spv.module build method adds a block. Remove that. 310 rewriter.eraseBlock(&spvModuleRegion.back()); 311 rewriter.eraseOp(moduleOp); 312 return success(); 313 } 314 315 //===----------------------------------------------------------------------===// 316 // GPU return inside kernel functions to SPIR-V return. 317 //===----------------------------------------------------------------------===// 318 319 LogicalResult GPUReturnOpConversion::matchAndRewrite( 320 gpu::ReturnOp returnOp, OpAdaptor adaptor, 321 ConversionPatternRewriter &rewriter) const { 322 if (!adaptor.getOperands().empty()) 323 return failure(); 324 325 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); 326 return success(); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // GPU To SPIRV Patterns. 331 //===----------------------------------------------------------------------===// 332 333 void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 334 RewritePatternSet &patterns) { 335 patterns.add< 336 GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion, 337 GPUReturnOpConversion, 338 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>, 339 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>, 340 LaunchConfigConversion<gpu::ThreadIdOp, 341 spirv::BuiltIn::LocalInvocationId>, 342 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp, 343 spirv::BuiltIn::SubgroupId>, 344 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp, 345 spirv::BuiltIn::NumSubgroups>, 346 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp, 347 spirv::BuiltIn::SubgroupSize>, 348 WorkGroupSizeConversion>(typeConverter, patterns.getContext()); 349 } 350