1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to lower attributes that specify the shader ABI 10 // for the functions in the generated SPIR-V module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/LayoutUtils.h" 15 #include "mlir/Dialect/SPIRV/Passes.h" 16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 18 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 19 #include "mlir/Dialect/StandardOps/Ops.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/SetVector.h" 22 23 using namespace mlir; 24 25 /// Checks if the `type` is a scalar or vector type. It is assumed that they are 26 /// valid for SPIR-V dialect already. 27 static bool isScalarOrVectorType(Type type) { 28 return spirv::SPIRVDialect::isValidScalarType(type) || type.isa<VectorType>(); 29 } 30 31 /// Creates a global variable for an argument based on the ABI info. 32 static spirv::GlobalVariableOp 33 createGlobalVariableForArg(FuncOp funcOp, OpBuilder &builder, unsigned argNum, 34 spirv::InterfaceVarABIAttr abiInfo) { 35 auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>(); 36 if (!spirvModule) { 37 return nullptr; 38 } 39 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 40 builder.setInsertionPoint(funcOp.getOperation()); 41 std::string varName = 42 funcOp.getName().str() + "_arg_" + std::to_string(argNum); 43 44 // Get the type of variable. If this is a scalar/vector type and has an ABI 45 // info create a variable of type !spv.ptr<!spv.struct<elementTYpe>>. If not 46 // it must already be a !spv.ptr<!spv.struct<...>>. 47 auto varType = funcOp.getType().getInput(argNum); 48 auto storageClass = 49 static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt()); 50 if (isScalarOrVectorType(varType)) { 51 varType = 52 spirv::PointerType::get(spirv::StructType::get(varType), storageClass); 53 } 54 auto varPtrType = varType.cast<spirv::PointerType>(); 55 auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>(); 56 57 // Set the offset information. 58 VulkanLayoutUtils::Size size = 0, alignment = 0; 59 varPointeeType = 60 VulkanLayoutUtils::decorateType(varPointeeType, size, alignment) 61 .cast<spirv::StructType>(); 62 varType = 63 spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); 64 65 return builder.create<spirv::GlobalVariableOp>( 66 funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(), 67 abiInfo.binding().getInt()); 68 } 69 70 /// Gets the global variables that need to be specified as interface variable 71 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so. 72 static LogicalResult 73 getInterfaceVariables(FuncOp funcOp, 74 SmallVectorImpl<Attribute> &interfaceVars) { 75 auto module = funcOp.getParentOfType<spirv::ModuleOp>(); 76 if (!module) { 77 return failure(); 78 } 79 llvm::SetVector<Operation *> interfaceVarSet; 80 81 // TODO(ravishankarm) : This should in reality traverse the entry function 82 // call graph and collect all the interfaces. For now, just traverse the 83 // instructions in this function. 84 funcOp.walk([&](spirv::AddressOfOp addressOfOp) { 85 auto var = 86 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable()); 87 if (var.type().cast<spirv::PointerType>().getStorageClass() != 88 spirv::StorageClass::StorageBuffer) { 89 interfaceVarSet.insert(var.getOperation()); 90 } 91 }); 92 for (auto &var : interfaceVarSet) { 93 interfaceVars.push_back(SymbolRefAttr::get( 94 cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext())); 95 } 96 return success(); 97 } 98 99 /// Lowers the entry point attribute. 100 static LogicalResult lowerEntryPointABIAttr(FuncOp funcOp, OpBuilder &builder) { 101 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 102 auto entryPointAttr = 103 funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName); 104 if (!entryPointAttr) { 105 return failure(); 106 } 107 108 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 109 auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>(); 110 builder.setInsertionPoint(spirvModule.body().front().getTerminator()); 111 112 // Adds the spv.EntryPointOp after collecting all the interface variables 113 // needed. 114 SmallVector<Attribute, 1> interfaceVars; 115 if (failed(getInterfaceVariables(funcOp, interfaceVars))) { 116 return failure(); 117 } 118 builder.create<spirv::EntryPointOp>( 119 funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars); 120 // Specifies the spv.ExecutionModeOp. 121 auto localSizeAttr = entryPointAttr.local_size(); 122 SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>()); 123 builder.create<spirv::ExecutionModeOp>( 124 funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize); 125 funcOp.removeAttr(entryPointAttrName); 126 return success(); 127 } 128 129 namespace { 130 /// Pattern rewriter for changing function signature to match the ABI specified 131 /// in attributes. 132 class FuncOpLowering final : public SPIRVOpLowering<FuncOp> { 133 public: 134 using SPIRVOpLowering<FuncOp>::SPIRVOpLowering; 135 PatternMatchResult 136 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 137 ConversionPatternRewriter &rewriter) const override; 138 }; 139 140 /// Pass to implement the ABI information specified as attributes. 141 class LowerABIAttributesPass final 142 : public OperationPass<LowerABIAttributesPass, spirv::ModuleOp> { 143 private: 144 void runOnOperation() override; 145 }; 146 } // namespace 147 148 PatternMatchResult 149 FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 150 ConversionPatternRewriter &rewriter) const { 151 if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>( 152 spirv::getEntryPointABIAttrName())) { 153 // TODO(ravishankarm) : Non-entry point functions are not handled. 154 return matchFailure(); 155 } 156 TypeConverter::SignatureConversion signatureConverter( 157 funcOp.getType().getNumInputs()); 158 159 auto attrName = spirv::getInterfaceVarABIAttrName(); 160 for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) { 161 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 162 argType.index(), attrName); 163 if (!abiInfo) { 164 // TODO(ravishankarm) : For non-entry point functions, it should be legal 165 // to pass around scalar/vector values and return a scalar/vector. For now 166 // non-entry point functions are not handled in this ABI lowering and will 167 // produce an error. 168 return matchFailure(); 169 } 170 auto var = 171 createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo); 172 if (!var) { 173 return matchFailure(); 174 } 175 176 OpBuilder::InsertionGuard funcInsertionGuard(rewriter); 177 rewriter.setInsertionPointToStart(&funcOp.front()); 178 // Insert spirv::AddressOf and spirv::AccessChain operations. 179 Value replacement = 180 rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var); 181 // Check if the arg is a scalar or vector type. In that case, the value 182 // needs to be loaded into registers. 183 // TODO(ravishankarm) : This is loading value of the scalar into registers 184 // at the start of the function. It is probably better to do the load just 185 // before the use. There might be multiple loads and currently there is no 186 // easy way to replace all uses with a sequence of operations. 187 if (isScalarOrVectorType(argType.value())) { 188 auto indexType = 189 typeConverter.convertType(IndexType::get(funcOp.getContext())); 190 auto zero = 191 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter); 192 auto loadPtr = rewriter.create<spirv::AccessChainOp>( 193 funcOp.getLoc(), replacement, zero.constant()); 194 replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr, 195 /*memory_access=*/nullptr, 196 /*alignment=*/nullptr); 197 } 198 signatureConverter.remapInput(argType.index(), replacement); 199 } 200 201 // Creates a new function with the update signature. 202 rewriter.updateRootInPlace(funcOp, [&] { 203 funcOp.setType(rewriter.getFunctionType( 204 signatureConverter.getConvertedTypes(), llvm::None)); 205 rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); 206 }); 207 return matchSuccess(); 208 } 209 210 void LowerABIAttributesPass::runOnOperation() { 211 // Uses the signature conversion methodology of the dialect conversion 212 // framework to implement the conversion. 213 spirv::ModuleOp module = getOperation(); 214 MLIRContext *context = &getContext(); 215 216 SPIRVTypeConverter typeConverter; 217 OwningRewritePatternList patterns; 218 patterns.insert<FuncOpLowering>(context, typeConverter); 219 220 ConversionTarget target(*context); 221 target.addLegalDialect<spirv::SPIRVDialect>(); 222 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 223 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 224 return op.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName) && 225 op.getNumResults() == 0 && op.getNumArguments() == 0; 226 }); 227 target.addLegalOp<ReturnOp>(); 228 if (failed( 229 applyPartialConversion(module, target, patterns, &typeConverter))) { 230 return signalPassFailure(); 231 } 232 233 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point 234 // attributes. 235 OpBuilder builder(context); 236 SmallVector<FuncOp, 1> entryPointFns; 237 module.walk([&](FuncOp funcOp) { 238 if (funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) { 239 entryPointFns.push_back(funcOp); 240 } 241 }); 242 for (auto fn : entryPointFns) { 243 if (failed(lowerEntryPointABIAttr(fn, builder))) { 244 return signalPassFailure(); 245 } 246 } 247 } 248 249 std::unique_ptr<OpPassBase<spirv::ModuleOp>> 250 mlir::spirv::createLowerABIAttributesPass() { 251 return std::make_unique<LowerABIAttributesPass>(); 252 } 253 254 static PassRegistration<LowerABIAttributesPass> 255 pass("spirv-lower-abi-attrs", "Lower SPIR-V ABI Attributes"); 256