1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert GPU dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/StringSwitch.h" 22 23 using namespace mlir; 24 25 static constexpr const char kSPIRVModule[] = "__spv__"; 26 27 namespace { 28 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation 29 /// builtin variables. 30 template <typename SourceOp, spirv::BuiltIn builtin> 31 class LaunchConfigConversion : public OpConversionPattern<SourceOp> { 32 public: 33 using OpConversionPattern<SourceOp>::OpConversionPattern; 34 35 LogicalResult 36 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 37 ConversionPatternRewriter &rewriter) const override; 38 }; 39 40 /// Pattern lowering subgroup size/id to loading SPIR-V invocation 41 /// builtin variables. 42 template <typename SourceOp, spirv::BuiltIn builtin> 43 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> { 44 public: 45 using OpConversionPattern<SourceOp>::OpConversionPattern; 46 47 LogicalResult 48 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 49 ConversionPatternRewriter &rewriter) const override; 50 }; 51 52 /// This is separate because in Vulkan workgroup size is exposed to shaders via 53 /// a constant with WorkgroupSize decoration. So here we cannot generate a 54 /// builtin variable; instead the information in the `spv.entry_point_abi` 55 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp. 56 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> { 57 public: 58 using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern; 59 60 LogicalResult 61 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor, 62 ConversionPatternRewriter &rewriter) const override; 63 }; 64 65 /// Pattern to convert a kernel function in GPU dialect within a spv.module. 66 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> { 67 public: 68 using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern; 69 70 LogicalResult 71 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor, 72 ConversionPatternRewriter &rewriter) const override; 73 74 private: 75 SmallVector<int32_t, 3> workGroupSizeAsInt32; 76 }; 77 78 /// Pattern to convert a gpu.module to a spv.module. 79 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> { 80 public: 81 using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern; 82 83 LogicalResult 84 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, 85 ConversionPatternRewriter &rewriter) const override; 86 }; 87 88 class GPUModuleEndConversion final 89 : public OpConversionPattern<gpu::ModuleEndOp> { 90 public: 91 using OpConversionPattern::OpConversionPattern; 92 93 LogicalResult 94 matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor, 95 ConversionPatternRewriter &rewriter) const override { 96 rewriter.eraseOp(endOp); 97 return success(); 98 } 99 }; 100 101 /// Pattern to convert a gpu.return into a SPIR-V return. 102 // TODO: This can go to DRR when GPU return has operands. 103 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> { 104 public: 105 using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override; 110 }; 111 112 } // namespace 113 114 //===----------------------------------------------------------------------===// 115 // Builtins. 116 //===----------------------------------------------------------------------===// 117 118 static Optional<int32_t> getLaunchConfigIndex(Operation *op) { 119 auto dimAttr = op->getAttrOfType<StringAttr>("dimension"); 120 if (!dimAttr) 121 return llvm::None; 122 123 return llvm::StringSwitch<Optional<int32_t>>(dimAttr.getValue()) 124 .Case("x", 0) 125 .Case("y", 1) 126 .Case("z", 2) 127 .Default(llvm::None); 128 } 129 130 template <typename SourceOp, spirv::BuiltIn builtin> 131 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( 132 SourceOp op, typename SourceOp::Adaptor adaptor, 133 ConversionPatternRewriter &rewriter) const { 134 auto index = getLaunchConfigIndex(op); 135 if (!index) 136 return failure(); 137 138 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 139 auto indexType = typeConverter->getIndexType(); 140 141 // SPIR-V invocation builtin variables are a vector of type <3xi32> 142 auto spirvBuiltin = 143 spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); 144 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 145 op, indexType, spirvBuiltin, 146 rewriter.getI32ArrayAttr({index.getValue()})); 147 return success(); 148 } 149 150 template <typename SourceOp, spirv::BuiltIn builtin> 151 LogicalResult 152 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( 153 SourceOp op, typename SourceOp::Adaptor adaptor, 154 ConversionPatternRewriter &rewriter) const { 155 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 156 auto indexType = typeConverter->getIndexType(); 157 158 auto spirvBuiltin = 159 spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); 160 rewriter.replaceOp(op, spirvBuiltin); 161 return success(); 162 } 163 164 LogicalResult WorkGroupSizeConversion::matchAndRewrite( 165 gpu::BlockDimOp op, OpAdaptor adaptor, 166 ConversionPatternRewriter &rewriter) const { 167 auto index = getLaunchConfigIndex(op); 168 if (!index) 169 return failure(); 170 171 auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); 172 auto val = workGroupSizeAttr.getValues<int32_t>()[index.getValue()]; 173 auto convertedType = 174 getTypeConverter()->convertType(op.getResult().getType()); 175 if (!convertedType) 176 return failure(); 177 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 178 op, convertedType, IntegerAttr::get(convertedType, val)); 179 return success(); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // GPUFuncOp 184 //===----------------------------------------------------------------------===// 185 186 // Legalizes a GPU function as an entry SPIR-V function. 187 static spirv::FuncOp 188 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, 189 ConversionPatternRewriter &rewriter, 190 spirv::EntryPointABIAttr entryPointInfo, 191 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { 192 auto fnType = funcOp.getType(); 193 if (fnType.getNumResults()) { 194 funcOp.emitError("SPIR-V lowering only supports entry functions" 195 "with no return values right now"); 196 return nullptr; 197 } 198 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) { 199 funcOp.emitError( 200 "lowering as entry functions requires ABI info for all arguments " 201 "or none of them"); 202 return nullptr; 203 } 204 // Update the signature to valid SPIR-V types and add the ABI 205 // attributes. These will be "materialized" by using the 206 // LowerABIAttributesPass. 207 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 208 { 209 for (auto argType : enumerate(funcOp.getType().getInputs())) { 210 auto convertedType = typeConverter.convertType(argType.value()); 211 signatureConverter.addInputs(argType.index(), convertedType); 212 } 213 } 214 auto newFuncOp = rewriter.create<spirv::FuncOp>( 215 funcOp.getLoc(), funcOp.getName(), 216 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 217 llvm::None)); 218 for (const auto &namedAttr : funcOp->getAttrs()) { 219 if (namedAttr.getName() == function_like_impl::getTypeAttrName() || 220 namedAttr.getName() == SymbolTable::getSymbolAttrName()) 221 continue; 222 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 223 } 224 225 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 226 newFuncOp.end()); 227 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 228 &signatureConverter))) 229 return nullptr; 230 rewriter.eraseOp(funcOp); 231 232 // Set the attributes for argument and the function. 233 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); 234 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) { 235 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); 236 } 237 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); 238 239 return newFuncOp; 240 } 241 242 /// Populates `argABI` with spv.interface_var_abi attributes for lowering 243 /// gpu.func to spv.func if no arguments have the attributes set 244 /// already. Returns failure if any argument has the ABI attribute set already. 245 static LogicalResult 246 getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp, 247 SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) { 248 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(funcOp); 249 if (!spirv::needsInterfaceVarABIAttrs(targetEnv)) 250 return success(); 251 252 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 253 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 254 argIndex, spirv::getInterfaceVarABIAttrName())) 255 return failure(); 256 // Vulkan's interface variable requirements needs scalars to be wrapped in a 257 // struct. The struct held in storage buffer. 258 Optional<spirv::StorageClass> sc; 259 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) 260 sc = spirv::StorageClass::StorageBuffer; 261 argABI.push_back(spirv::getInterfaceVarABIAttr(0, argIndex, sc, context)); 262 } 263 return success(); 264 } 265 266 LogicalResult GPUFuncOpConversion::matchAndRewrite( 267 gpu::GPUFuncOp funcOp, OpAdaptor adaptor, 268 ConversionPatternRewriter &rewriter) const { 269 if (!gpu::GPUDialect::isKernel(funcOp)) 270 return failure(); 271 272 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI; 273 if (failed(getDefaultABIAttrs(rewriter.getContext(), funcOp, argABI))) { 274 argABI.clear(); 275 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 276 // If the ABI is already specified, use it. 277 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 278 argIndex, spirv::getInterfaceVarABIAttrName()); 279 if (!abiAttr) { 280 funcOp.emitRemark( 281 "match failure: missing 'spv.interface_var_abi' attribute at " 282 "argument ") 283 << argIndex; 284 return failure(); 285 } 286 argABI.push_back(abiAttr); 287 } 288 } 289 290 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp); 291 if (!entryPointAttr) { 292 funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute"); 293 return failure(); 294 } 295 spirv::FuncOp newFuncOp = lowerAsEntryFunction( 296 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI); 297 if (!newFuncOp) 298 return failure(); 299 newFuncOp->removeAttr( 300 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName())); 301 return success(); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // ModuleOp with gpu.module. 306 //===----------------------------------------------------------------------===// 307 308 LogicalResult GPUModuleConversion::matchAndRewrite( 309 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, 310 ConversionPatternRewriter &rewriter) const { 311 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp); 312 spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv); 313 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv); 314 if (failed(memoryModel)) 315 return moduleOp.emitRemark("match failure: could not selected memory model " 316 "based on 'spv.target_env'"); 317 318 // Add a keyword to the module name to avoid symbolic conflict. 319 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); 320 auto spvModule = rewriter.create<spirv::ModuleOp>( 321 moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None, 322 StringRef(spvModuleName)); 323 324 // Move the region from the module op into the SPIR-V module. 325 Region &spvModuleRegion = spvModule.getRegion(); 326 rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion, 327 spvModuleRegion.begin()); 328 // The spv.module build method adds a block. Remove that. 329 rewriter.eraseBlock(&spvModuleRegion.back()); 330 rewriter.eraseOp(moduleOp); 331 return success(); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // GPU return inside kernel functions to SPIR-V return. 336 //===----------------------------------------------------------------------===// 337 338 LogicalResult GPUReturnOpConversion::matchAndRewrite( 339 gpu::ReturnOp returnOp, OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const { 341 if (!adaptor.getOperands().empty()) 342 return failure(); 343 344 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); 345 return success(); 346 } 347 348 //===----------------------------------------------------------------------===// 349 // GPU To SPIRV Patterns. 350 //===----------------------------------------------------------------------===// 351 352 void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 353 RewritePatternSet &patterns) { 354 patterns.add< 355 GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion, 356 GPUReturnOpConversion, 357 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>, 358 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>, 359 LaunchConfigConversion<gpu::ThreadIdOp, 360 spirv::BuiltIn::LocalInvocationId>, 361 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp, 362 spirv::BuiltIn::SubgroupId>, 363 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp, 364 spirv::BuiltIn::NumSubgroups>, 365 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp, 366 spirv::BuiltIn::SubgroupSize>, 367 WorkGroupSizeConversion>(typeConverter, patterns.getContext()); 368 } 369