1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to lower attributes that specify the shader ABI 10 // for the functions in the generated SPIR-V module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/SetVector.h" 22 23 using namespace mlir; 24 25 /// Creates a global variable for an argument based on the ABI info. 26 static spirv::GlobalVariableOp 27 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, 28 unsigned argIndex, 29 spirv::InterfaceVarABIAttr abiInfo) { 30 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>(); 31 if (!spirvModule) 32 return nullptr; 33 34 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 35 builder.setInsertionPoint(funcOp.getOperation()); 36 std::string varName = 37 funcOp.getName().str() + "_arg_" + std::to_string(argIndex); 38 39 // Get the type of variable. If this is a scalar/vector type and has an ABI 40 // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not 41 // it must already be a !spv.ptr<!spv.struct<...>>. 42 auto varType = funcOp.getFunctionType().getInput(argIndex); 43 if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) { 44 auto storageClass = abiInfo.getStorageClass(); 45 if (!storageClass) 46 return nullptr; 47 varType = 48 spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); 49 } 50 auto varPtrType = varType.cast<spirv::PointerType>(); 51 auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>(); 52 53 // Set the offset information. 54 varPointeeType = 55 VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>(); 56 57 if (!varPointeeType) 58 return nullptr; 59 60 varType = 61 spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); 62 63 return builder.create<spirv::GlobalVariableOp>( 64 funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(), 65 abiInfo.getBinding()); 66 } 67 68 /// Gets the global variables that need to be specified as interface variable 69 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so. 70 static LogicalResult 71 getInterfaceVariables(spirv::FuncOp funcOp, 72 SmallVectorImpl<Attribute> &interfaceVars) { 73 auto module = funcOp->getParentOfType<spirv::ModuleOp>(); 74 if (!module) { 75 return failure(); 76 } 77 SetVector<Operation *> interfaceVarSet; 78 79 // TODO: This should in reality traverse the entry function 80 // call graph and collect all the interfaces. For now, just traverse the 81 // instructions in this function. 82 funcOp.walk([&](spirv::AddressOfOp addressOfOp) { 83 auto var = 84 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable()); 85 // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s 86 // storage classes are limited to the Input and Output storage classes. 87 // Starting with version 1.4, the interface’s storage classes are all 88 // storage classes used in declaring all global variables referenced by the 89 // entry point’s call tree." We should consider the target environment here. 90 switch (var.type().cast<spirv::PointerType>().getStorageClass()) { 91 case spirv::StorageClass::Input: 92 case spirv::StorageClass::Output: 93 interfaceVarSet.insert(var.getOperation()); 94 break; 95 default: 96 break; 97 } 98 }); 99 for (auto &var : interfaceVarSet) { 100 interfaceVars.push_back(SymbolRefAttr::get( 101 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name())); 102 } 103 return success(); 104 } 105 106 /// Lowers the entry point attribute. 107 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, 108 OpBuilder &builder) { 109 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 110 auto entryPointAttr = 111 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName); 112 if (!entryPointAttr) { 113 return failure(); 114 } 115 116 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 117 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>(); 118 builder.setInsertionPointToEnd(spirvModule.getBody()); 119 120 // Adds the spv.EntryPointOp after collecting all the interface variables 121 // needed. 122 SmallVector<Attribute, 1> interfaceVars; 123 if (failed(getInterfaceVariables(funcOp, interfaceVars))) { 124 return failure(); 125 } 126 127 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(funcOp); 128 FailureOr<spirv::ExecutionModel> executionModel = 129 spirv::getExecutionModel(targetEnv); 130 if (failed(executionModel)) 131 return funcOp.emitRemark("lower entry point failure: could not select " 132 "execution model based on 'spv.target_env'"); 133 134 builder.create<spirv::EntryPointOp>( 135 funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars); 136 137 // Specifies the spv.ExecutionModeOp. 138 auto localSizeAttr = entryPointAttr.getLocal_size(); 139 if (localSizeAttr) { 140 auto values = localSizeAttr.getValues<int32_t>(); 141 SmallVector<int32_t, 3> localSize(values); 142 builder.create<spirv::ExecutionModeOp>( 143 funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize); 144 funcOp->removeAttr(entryPointAttrName); 145 } 146 return success(); 147 } 148 149 namespace { 150 /// A pattern to convert function signature according to interface variable ABI 151 /// attributes. 152 /// 153 /// Specifically, this pattern creates global variables according to interface 154 /// variable ABI attributes attached to function arguments and converts all 155 /// function argument uses to those global variables. This is necessary because 156 /// Vulkan requires all shader entry points to be of void(void) type. 157 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> { 158 public: 159 using OpConversionPattern<spirv::FuncOp>::OpConversionPattern; 160 161 LogicalResult 162 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, 163 ConversionPatternRewriter &rewriter) const override; 164 }; 165 166 /// Pass to implement the ABI information specified as attributes. 167 class LowerABIAttributesPass final 168 : public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> { 169 void runOnOperation() override; 170 }; 171 } // namespace 172 173 LogicalResult ProcessInterfaceVarABI::matchAndRewrite( 174 spirv::FuncOp funcOp, OpAdaptor adaptor, 175 ConversionPatternRewriter &rewriter) const { 176 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>( 177 spirv::getEntryPointABIAttrName())) { 178 // TODO: Non-entry point functions are not handled. 179 return failure(); 180 } 181 TypeConverter::SignatureConversion signatureConverter( 182 funcOp.getFunctionType().getNumInputs()); 183 184 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 185 auto indexType = typeConverter.getIndexType(); 186 187 auto attrName = spirv::getInterfaceVarABIAttrName(); 188 for (const auto &argType : 189 llvm::enumerate(funcOp.getFunctionType().getInputs())) { 190 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 191 argType.index(), attrName); 192 if (!abiInfo) { 193 // TODO: For non-entry point functions, it should be legal 194 // to pass around scalar/vector values and return a scalar/vector. For now 195 // non-entry point functions are not handled in this ABI lowering and will 196 // produce an error. 197 return failure(); 198 } 199 spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument( 200 rewriter, funcOp, argType.index(), abiInfo); 201 if (!var) 202 return failure(); 203 204 OpBuilder::InsertionGuard funcInsertionGuard(rewriter); 205 rewriter.setInsertionPointToStart(&funcOp.front()); 206 // Insert spirv::AddressOf and spirv::AccessChain operations. 207 Value replacement = 208 rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var); 209 // Check if the arg is a scalar or vector type. In that case, the value 210 // needs to be loaded into registers. 211 // TODO: This is loading value of the scalar into registers 212 // at the start of the function. It is probably better to do the load just 213 // before the use. There might be multiple loads and currently there is no 214 // easy way to replace all uses with a sequence of operations. 215 if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) { 216 auto zero = 217 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); 218 auto loadPtr = rewriter.create<spirv::AccessChainOp>( 219 funcOp.getLoc(), replacement, zero.constant()); 220 replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr); 221 } 222 signatureConverter.remapInput(argType.index(), replacement); 223 } 224 if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(), 225 &signatureConverter))) 226 return failure(); 227 228 // Creates a new function with the update signature. 229 rewriter.updateRootInPlace(funcOp, [&] { 230 funcOp.setType(rewriter.getFunctionType( 231 signatureConverter.getConvertedTypes(), llvm::None)); 232 }); 233 return success(); 234 } 235 236 void LowerABIAttributesPass::runOnOperation() { 237 // Uses the signature conversion methodology of the dialect conversion 238 // framework to implement the conversion. 239 spirv::ModuleOp module = getOperation(); 240 MLIRContext *context = &getContext(); 241 242 spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module)); 243 244 SPIRVTypeConverter typeConverter(targetEnv); 245 246 // Insert a bitcast in the case of a pointer type change. 247 typeConverter.addSourceMaterialization([](OpBuilder &builder, 248 spirv::PointerType type, 249 ValueRange inputs, Location loc) { 250 if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>()) 251 return Value(); 252 return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult(); 253 }); 254 255 RewritePatternSet patterns(context); 256 patterns.add<ProcessInterfaceVarABI>(typeConverter, context); 257 258 ConversionTarget target(*context); 259 // "Legal" function ops should have no interface variable ABI attributes. 260 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) { 261 StringRef attrName = spirv::getInterfaceVarABIAttrName(); 262 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) 263 if (op.getArgAttr(i, attrName)) 264 return false; 265 return true; 266 }); 267 // All other SPIR-V ops are legal. 268 target.markUnknownOpDynamicallyLegal([](Operation *op) { 269 return op->getDialect()->getNamespace() == 270 spirv::SPIRVDialect::getDialectNamespace(); 271 }); 272 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 273 return signalPassFailure(); 274 275 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point 276 // attributes. 277 OpBuilder builder(context); 278 SmallVector<spirv::FuncOp, 1> entryPointFns; 279 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 280 module.walk([&](spirv::FuncOp funcOp) { 281 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) { 282 entryPointFns.push_back(funcOp); 283 } 284 }); 285 for (auto fn : entryPointFns) { 286 if (failed(lowerEntryPointABIAttr(fn, builder))) { 287 return signalPassFailure(); 288 } 289 } 290 } 291 292 std::unique_ptr<OperationPass<spirv::ModuleOp>> 293 mlir::spirv::createLowerABIAttributesPass() { 294 return std::make_unique<LowerABIAttributesPass>(); 295 } 296