1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/SetVector.h"
22 
23 using namespace mlir;
24 
25 /// Creates a global variable for an argument based on the ABI info.
26 static spirv::GlobalVariableOp
createGlobalVarForEntryPointArgument(OpBuilder & builder,spirv::FuncOp funcOp,unsigned argIndex,spirv::InterfaceVarABIAttr abiInfo)27 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
28                                      unsigned argIndex,
29                                      spirv::InterfaceVarABIAttr abiInfo) {
30   auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
31   if (!spirvModule)
32     return nullptr;
33 
34   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
35   builder.setInsertionPoint(funcOp.getOperation());
36   std::string varName =
37       funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
38 
39   // Get the type of variable. If this is a scalar/vector type and has an ABI
40   // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
41   // it must already be a !spv.ptr<!spv.struct<...>>.
42   auto varType = funcOp.getFunctionType().getInput(argIndex);
43   if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
44     auto storageClass = abiInfo.getStorageClass();
45     if (!storageClass)
46       return nullptr;
47     varType =
48         spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
49   }
50   auto varPtrType = varType.cast<spirv::PointerType>();
51   auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
52 
53   // Set the offset information.
54   varPointeeType =
55       VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
56 
57   if (!varPointeeType)
58     return nullptr;
59 
60   varType =
61       spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
62 
63   return builder.create<spirv::GlobalVariableOp>(
64       funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
65       abiInfo.getBinding());
66 }
67 
68 /// Gets the global variables that need to be specified as interface variable
69 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
70 static LogicalResult
getInterfaceVariables(spirv::FuncOp funcOp,SmallVectorImpl<Attribute> & interfaceVars)71 getInterfaceVariables(spirv::FuncOp funcOp,
72                       SmallVectorImpl<Attribute> &interfaceVars) {
73   auto module = funcOp->getParentOfType<spirv::ModuleOp>();
74   if (!module) {
75     return failure();
76   }
77   SetVector<Operation *> interfaceVarSet;
78 
79   // TODO: This should in reality traverse the entry function
80   // call graph and collect all the interfaces. For now, just traverse the
81   // instructions in this function.
82   funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
83     auto var =
84         module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
85     // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
86     // storage classes are limited to the Input and Output storage classes.
87     // Starting with version 1.4, the interface’s storage classes are all
88     // storage classes used in declaring all global variables referenced by the
89     // entry point’s call tree." We should consider the target environment here.
90     switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
91     case spirv::StorageClass::Input:
92     case spirv::StorageClass::Output:
93       interfaceVarSet.insert(var.getOperation());
94       break;
95     default:
96       break;
97     }
98   });
99   for (auto &var : interfaceVarSet) {
100     interfaceVars.push_back(SymbolRefAttr::get(
101         funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
102   }
103   return success();
104 }
105 
106 /// Lowers the entry point attribute.
lowerEntryPointABIAttr(spirv::FuncOp funcOp,OpBuilder & builder)107 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
108                                             OpBuilder &builder) {
109   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
110   auto entryPointAttr =
111       funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
112   if (!entryPointAttr) {
113     return failure();
114   }
115 
116   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
117   auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
118   builder.setInsertionPointToEnd(spirvModule.getBody());
119 
120   // Adds the spv.EntryPointOp after collecting all the interface variables
121   // needed.
122   SmallVector<Attribute, 1> interfaceVars;
123   if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
124     return failure();
125   }
126 
127   spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(funcOp);
128   FailureOr<spirv::ExecutionModel> executionModel =
129       spirv::getExecutionModel(targetEnv);
130   if (failed(executionModel))
131     return funcOp.emitRemark("lower entry point failure: could not select "
132                              "execution model based on 'spv.target_env'");
133 
134   builder.create<spirv::EntryPointOp>(funcOp.getLoc(), executionModel.value(),
135                                       funcOp, interfaceVars);
136 
137   // Specifies the spv.ExecutionModeOp.
138   auto localSizeAttr = entryPointAttr.getLocalSize();
139   if (localSizeAttr) {
140     auto values = localSizeAttr.getValues<int32_t>();
141     SmallVector<int32_t, 3> localSize(values);
142     builder.create<spirv::ExecutionModeOp>(
143         funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
144     funcOp->removeAttr(entryPointAttrName);
145   }
146   return success();
147 }
148 
149 namespace {
150 /// A pattern to convert function signature according to interface variable ABI
151 /// attributes.
152 ///
153 /// Specifically, this pattern creates global variables according to interface
154 /// variable ABI attributes attached to function arguments and converts all
155 /// function argument uses to those global variables. This is necessary because
156 /// Vulkan requires all shader entry points to be of void(void) type.
157 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
158 public:
159   using OpConversionPattern<spirv::FuncOp>::OpConversionPattern;
160 
161   LogicalResult
162   matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
163                   ConversionPatternRewriter &rewriter) const override;
164 };
165 
166 /// Pass to implement the ABI information specified as attributes.
167 class LowerABIAttributesPass final
168     : public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> {
169   void runOnOperation() override;
170 };
171 } // namespace
172 
matchAndRewrite(spirv::FuncOp funcOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const173 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
174     spirv::FuncOp funcOp, OpAdaptor adaptor,
175     ConversionPatternRewriter &rewriter) const {
176   if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
177           spirv::getEntryPointABIAttrName())) {
178     // TODO: Non-entry point functions are not handled.
179     return failure();
180   }
181   TypeConverter::SignatureConversion signatureConverter(
182       funcOp.getFunctionType().getNumInputs());
183 
184   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
185   auto indexType = typeConverter.getIndexType();
186 
187   auto attrName = spirv::getInterfaceVarABIAttrName();
188   for (const auto &argType :
189        llvm::enumerate(funcOp.getFunctionType().getInputs())) {
190     auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
191         argType.index(), attrName);
192     if (!abiInfo) {
193       // TODO: For non-entry point functions, it should be legal
194       // to pass around scalar/vector values and return a scalar/vector. For now
195       // non-entry point functions are not handled in this ABI lowering and will
196       // produce an error.
197       return failure();
198     }
199     spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
200         rewriter, funcOp, argType.index(), abiInfo);
201     if (!var)
202       return failure();
203 
204     OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
205     rewriter.setInsertionPointToStart(&funcOp.front());
206     // Insert spirv::AddressOf and spirv::AccessChain operations.
207     Value replacement =
208         rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
209     // Check if the arg is a scalar or vector type. In that case, the value
210     // needs to be loaded into registers.
211     // TODO: This is loading value of the scalar into registers
212     // at the start of the function. It is probably better to do the load just
213     // before the use. There might be multiple loads and currently there is no
214     // easy way to replace all uses with a sequence of operations.
215     if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
216       auto zero =
217           spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
218       auto loadPtr = rewriter.create<spirv::AccessChainOp>(
219           funcOp.getLoc(), replacement, zero.constant());
220       replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
221     }
222     signatureConverter.remapInput(argType.index(), replacement);
223   }
224   if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
225                                          &signatureConverter)))
226     return failure();
227 
228   // Creates a new function with the update signature.
229   rewriter.updateRootInPlace(funcOp, [&] {
230     funcOp.setType(rewriter.getFunctionType(
231         signatureConverter.getConvertedTypes(), llvm::None));
232   });
233   return success();
234 }
235 
runOnOperation()236 void LowerABIAttributesPass::runOnOperation() {
237   // Uses the signature conversion methodology of the dialect conversion
238   // framework to implement the conversion.
239   spirv::ModuleOp module = getOperation();
240   MLIRContext *context = &getContext();
241 
242   spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
243 
244   SPIRVTypeConverter typeConverter(targetEnv);
245 
246   // Insert a bitcast in the case of a pointer type change.
247   typeConverter.addSourceMaterialization([](OpBuilder &builder,
248                                             spirv::PointerType type,
249                                             ValueRange inputs, Location loc) {
250     if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
251       return Value();
252     return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
253   });
254 
255   RewritePatternSet patterns(context);
256   patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
257 
258   ConversionTarget target(*context);
259   // "Legal" function ops should have no interface variable ABI attributes.
260   target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
261     StringRef attrName = spirv::getInterfaceVarABIAttrName();
262     for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
263       if (op.getArgAttr(i, attrName))
264         return false;
265     return true;
266   });
267   // All other SPIR-V ops are legal.
268   target.markUnknownOpDynamicallyLegal([](Operation *op) {
269     return op->getDialect()->getNamespace() ==
270            spirv::SPIRVDialect::getDialectNamespace();
271   });
272   if (failed(applyPartialConversion(module, target, std::move(patterns))))
273     return signalPassFailure();
274 
275   // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
276   // attributes.
277   OpBuilder builder(context);
278   SmallVector<spirv::FuncOp, 1> entryPointFns;
279   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
280   module.walk([&](spirv::FuncOp funcOp) {
281     if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
282       entryPointFns.push_back(funcOp);
283     }
284   });
285   for (auto fn : entryPointFns) {
286     if (failed(lowerEntryPointABIAttr(fn, builder))) {
287       return signalPassFailure();
288     }
289   }
290 }
291 
292 std::unique_ptr<OperationPass<spirv::ModuleOp>>
createLowerABIAttributesPass()293 mlir::spirv::createLowerABIAttributesPass() {
294   return std::make_unique<LowerABIAttributesPass>();
295 }
296