1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to lower attributes that specify the shader ABI 10 // for the functions in the generated SPIR-V module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SPIRV/LayoutUtils.h" 16 #include "mlir/Dialect/SPIRV/Passes.h" 17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 19 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/SetVector.h" 22 23 using namespace mlir; 24 25 /// Creates a global variable for an argument based on the ABI info. 26 static spirv::GlobalVariableOp 27 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, 28 unsigned argIndex, 29 spirv::InterfaceVarABIAttr abiInfo) { 30 auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>(); 31 if (!spirvModule) 32 return nullptr; 33 34 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 35 builder.setInsertionPoint(funcOp.getOperation()); 36 std::string varName = 37 funcOp.getName().str() + "_arg_" + std::to_string(argIndex); 38 39 // Get the type of variable. If this is a scalar/vector type and has an ABI 40 // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not 41 // it must already be a !spv.ptr<!spv.struct<...>>. 42 auto varType = funcOp.getType().getInput(argIndex); 43 if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) { 44 auto storageClass = abiInfo.getStorageClass(); 45 if (!storageClass) 46 return nullptr; 47 varType = 48 spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); 49 } 50 auto varPtrType = varType.cast<spirv::PointerType>(); 51 auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>(); 52 53 // Set the offset information. 54 varPointeeType = 55 VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>(); 56 varType = 57 spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); 58 59 return builder.create<spirv::GlobalVariableOp>( 60 funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(), 61 abiInfo.getBinding()); 62 } 63 64 /// Gets the global variables that need to be specified as interface variable 65 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so. 66 static LogicalResult 67 getInterfaceVariables(spirv::FuncOp funcOp, 68 SmallVectorImpl<Attribute> &interfaceVars) { 69 auto module = funcOp.getParentOfType<spirv::ModuleOp>(); 70 if (!module) { 71 return failure(); 72 } 73 llvm::SetVector<Operation *> interfaceVarSet; 74 75 // TODO(ravishankarm) : This should in reality traverse the entry function 76 // call graph and collect all the interfaces. For now, just traverse the 77 // instructions in this function. 78 funcOp.walk([&](spirv::AddressOfOp addressOfOp) { 79 auto var = 80 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable()); 81 // TODO(antiagainst): Per SPIR-V spec: "Before version 1.4, the interface’s 82 // storage classes are limited to the Input and Output storage classes. 83 // Starting with version 1.4, the interface’s storage classes are all 84 // storage classes used in declaring all global variables referenced by the 85 // entry point’s call tree." We should consider the target environment here. 86 switch (var.type().cast<spirv::PointerType>().getStorageClass()) { 87 case spirv::StorageClass::Input: 88 case spirv::StorageClass::Output: 89 interfaceVarSet.insert(var.getOperation()); 90 break; 91 default: 92 break; 93 } 94 }); 95 for (auto &var : interfaceVarSet) { 96 interfaceVars.push_back(SymbolRefAttr::get( 97 cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext())); 98 } 99 return success(); 100 } 101 102 /// Lowers the entry point attribute. 103 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, 104 OpBuilder &builder) { 105 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 106 auto entryPointAttr = 107 funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName); 108 if (!entryPointAttr) { 109 return failure(); 110 } 111 112 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 113 auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>(); 114 builder.setInsertionPoint(spirvModule.body().front().getTerminator()); 115 116 // Adds the spv.EntryPointOp after collecting all the interface variables 117 // needed. 118 SmallVector<Attribute, 1> interfaceVars; 119 if (failed(getInterfaceVariables(funcOp, interfaceVars))) { 120 return failure(); 121 } 122 builder.create<spirv::EntryPointOp>( 123 funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars); 124 // Specifies the spv.ExecutionModeOp. 125 auto localSizeAttr = entryPointAttr.local_size(); 126 SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>()); 127 builder.create<spirv::ExecutionModeOp>( 128 funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize); 129 funcOp.removeAttr(entryPointAttrName); 130 return success(); 131 } 132 133 namespace { 134 /// A pattern to convert function signature according to interface variable ABI 135 /// attributes. 136 /// 137 /// Specifically, this pattern creates global variables according to interface 138 /// variable ABI attributes attached to function arguments and converts all 139 /// function argument uses to those global variables. This is necessary because 140 /// Vulkan requires all shader entry points to be of void(void) type. 141 class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> { 142 public: 143 using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering; 144 LogicalResult 145 matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands, 146 ConversionPatternRewriter &rewriter) const override; 147 }; 148 149 /// Pass to implement the ABI information specified as attributes. 150 class LowerABIAttributesPass final 151 : public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> { 152 void runOnOperation() override; 153 }; 154 } // namespace 155 156 LogicalResult ProcessInterfaceVarABI::matchAndRewrite( 157 spirv::FuncOp funcOp, ArrayRef<Value> operands, 158 ConversionPatternRewriter &rewriter) const { 159 if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>( 160 spirv::getEntryPointABIAttrName())) { 161 // TODO(ravishankarm) : Non-entry point functions are not handled. 162 return failure(); 163 } 164 TypeConverter::SignatureConversion signatureConverter( 165 funcOp.getType().getNumInputs()); 166 167 auto attrName = spirv::getInterfaceVarABIAttrName(); 168 for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) { 169 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 170 argType.index(), attrName); 171 if (!abiInfo) { 172 // TODO(ravishankarm) : For non-entry point functions, it should be legal 173 // to pass around scalar/vector values and return a scalar/vector. For now 174 // non-entry point functions are not handled in this ABI lowering and will 175 // produce an error. 176 return failure(); 177 } 178 spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument( 179 rewriter, funcOp, argType.index(), abiInfo); 180 if (!var) 181 return failure(); 182 183 OpBuilder::InsertionGuard funcInsertionGuard(rewriter); 184 rewriter.setInsertionPointToStart(&funcOp.front()); 185 // Insert spirv::AddressOf and spirv::AccessChain operations. 186 Value replacement = 187 rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var); 188 // Check if the arg is a scalar or vector type. In that case, the value 189 // needs to be loaded into registers. 190 // TODO(ravishankarm) : This is loading value of the scalar into registers 191 // at the start of the function. It is probably better to do the load just 192 // before the use. There might be multiple loads and currently there is no 193 // easy way to replace all uses with a sequence of operations. 194 if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) { 195 auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext()); 196 auto zero = 197 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); 198 auto loadPtr = rewriter.create<spirv::AccessChainOp>( 199 funcOp.getLoc(), replacement, zero.constant()); 200 replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr); 201 } 202 signatureConverter.remapInput(argType.index(), replacement); 203 } 204 205 // Creates a new function with the update signature. 206 rewriter.updateRootInPlace(funcOp, [&] { 207 funcOp.setType(rewriter.getFunctionType( 208 signatureConverter.getConvertedTypes(), llvm::None)); 209 rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); 210 }); 211 return success(); 212 } 213 214 void LowerABIAttributesPass::runOnOperation() { 215 // Uses the signature conversion methodology of the dialect conversion 216 // framework to implement the conversion. 217 spirv::ModuleOp module = getOperation(); 218 MLIRContext *context = &getContext(); 219 220 spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module)); 221 222 SPIRVTypeConverter typeConverter(targetEnv); 223 OwningRewritePatternList patterns; 224 patterns.insert<ProcessInterfaceVarABI>(context, typeConverter); 225 226 ConversionTarget target(*context); 227 // "Legal" function ops should have no interface variable ABI attributes. 228 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) { 229 StringRef attrName = spirv::getInterfaceVarABIAttrName(); 230 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) 231 if (op.getArgAttr(i, attrName)) 232 return false; 233 return true; 234 }); 235 // All other SPIR-V ops are legal. 236 target.markUnknownOpDynamicallyLegal([](Operation *op) { 237 return op->getDialect()->getNamespace() == 238 spirv::SPIRVDialect::getDialectNamespace(); 239 }); 240 if (failed( 241 applyPartialConversion(module, target, patterns, &typeConverter))) { 242 return signalPassFailure(); 243 } 244 245 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point 246 // attributes. 247 OpBuilder builder(context); 248 SmallVector<spirv::FuncOp, 1> entryPointFns; 249 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 250 module.walk([&](spirv::FuncOp funcOp) { 251 if (funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) { 252 entryPointFns.push_back(funcOp); 253 } 254 }); 255 for (auto fn : entryPointFns) { 256 if (failed(lowerEntryPointABIAttr(fn, builder))) { 257 return signalPassFailure(); 258 } 259 } 260 } 261 262 std::unique_ptr<OperationPass<spirv::ModuleOp>> 263 mlir::spirv::createLowerABIAttributesPass() { 264 return std::make_unique<LowerABIAttributesPass>(); 265 } 266