1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
16 #include "mlir/Dialect/SPIRV/Passes.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/SetVector.h"
22 
23 using namespace mlir;
24 
25 /// Creates a global variable for an argument based on the ABI info.
26 static spirv::GlobalVariableOp
27 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
28                                      unsigned argIndex,
29                                      spirv::InterfaceVarABIAttr abiInfo) {
30   auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
31   if (!spirvModule)
32     return nullptr;
33 
34   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
35   builder.setInsertionPoint(funcOp.getOperation());
36   std::string varName =
37       funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
38 
39   // Get the type of variable. If this is a scalar/vector type and has an ABI
40   // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
41   // it must already be a !spv.ptr<!spv.struct<...>>.
42   auto varType = funcOp.getType().getInput(argIndex);
43   if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
44     auto storageClass = abiInfo.getStorageClass();
45     if (!storageClass)
46       return nullptr;
47     varType =
48         spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
49   }
50   auto varPtrType = varType.cast<spirv::PointerType>();
51   auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
52 
53   // Set the offset information.
54   varPointeeType =
55       VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
56   varType =
57       spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
58 
59   return builder.create<spirv::GlobalVariableOp>(
60       funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
61       abiInfo.getBinding());
62 }
63 
64 /// Gets the global variables that need to be specified as interface variable
65 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
66 static LogicalResult
67 getInterfaceVariables(spirv::FuncOp funcOp,
68                       SmallVectorImpl<Attribute> &interfaceVars) {
69   auto module = funcOp.getParentOfType<spirv::ModuleOp>();
70   if (!module) {
71     return failure();
72   }
73   llvm::SetVector<Operation *> interfaceVarSet;
74 
75   // TODO(ravishankarm) : This should in reality traverse the entry function
76   // call graph and collect all the interfaces. For now, just traverse the
77   // instructions in this function.
78   funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
79     auto var =
80         module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
81     // TODO(antiagainst): Per SPIR-V spec: "Before version 1.4, the interface’s
82     // storage classes are limited to the Input and Output storage classes.
83     // Starting with version 1.4, the interface’s storage classes are all
84     // storage classes used in declaring all global variables referenced by the
85     // entry point’s call tree." We should consider the target environment here.
86     switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
87     case spirv::StorageClass::Input:
88     case spirv::StorageClass::Output:
89       interfaceVarSet.insert(var.getOperation());
90       break;
91     default:
92       break;
93     }
94   });
95   for (auto &var : interfaceVarSet) {
96     interfaceVars.push_back(SymbolRefAttr::get(
97         cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
98   }
99   return success();
100 }
101 
102 /// Lowers the entry point attribute.
103 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
104                                             OpBuilder &builder) {
105   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
106   auto entryPointAttr =
107       funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
108   if (!entryPointAttr) {
109     return failure();
110   }
111 
112   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
113   auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
114   builder.setInsertionPoint(spirvModule.body().front().getTerminator());
115 
116   // Adds the spv.EntryPointOp after collecting all the interface variables
117   // needed.
118   SmallVector<Attribute, 1> interfaceVars;
119   if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
120     return failure();
121   }
122   builder.create<spirv::EntryPointOp>(
123       funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars);
124   // Specifies the spv.ExecutionModeOp.
125   auto localSizeAttr = entryPointAttr.local_size();
126   SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>());
127   builder.create<spirv::ExecutionModeOp>(
128       funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
129   funcOp.removeAttr(entryPointAttrName);
130   return success();
131 }
132 
133 namespace {
134 /// A pattern to convert function signature according to interface variable ABI
135 /// attributes.
136 ///
137 /// Specifically, this pattern creates global variables according to interface
138 /// variable ABI attributes attached to function arguments and converts all
139 /// function argument uses to those global variables. This is necessary because
140 /// Vulkan requires all shader entry points to be of void(void) type.
141 class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
142 public:
143   using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
144   LogicalResult
145   matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
146                   ConversionPatternRewriter &rewriter) const override;
147 };
148 
149 /// Pass to implement the ABI information specified as attributes.
150 class LowerABIAttributesPass final
151     : public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> {
152   void runOnOperation() override;
153 };
154 } // namespace
155 
156 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
157     spirv::FuncOp funcOp, ArrayRef<Value> operands,
158     ConversionPatternRewriter &rewriter) const {
159   if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
160           spirv::getEntryPointABIAttrName())) {
161     // TODO(ravishankarm) : Non-entry point functions are not handled.
162     return failure();
163   }
164   TypeConverter::SignatureConversion signatureConverter(
165       funcOp.getType().getNumInputs());
166 
167   auto attrName = spirv::getInterfaceVarABIAttrName();
168   for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
169     auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
170         argType.index(), attrName);
171     if (!abiInfo) {
172       // TODO(ravishankarm) : For non-entry point functions, it should be legal
173       // to pass around scalar/vector values and return a scalar/vector. For now
174       // non-entry point functions are not handled in this ABI lowering and will
175       // produce an error.
176       return failure();
177     }
178     spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
179         rewriter, funcOp, argType.index(), abiInfo);
180     if (!var)
181       return failure();
182 
183     OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
184     rewriter.setInsertionPointToStart(&funcOp.front());
185     // Insert spirv::AddressOf and spirv::AccessChain operations.
186     Value replacement =
187         rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
188     // Check if the arg is a scalar or vector type. In that case, the value
189     // needs to be loaded into registers.
190     // TODO(ravishankarm) : This is loading value of the scalar into registers
191     // at the start of the function. It is probably better to do the load just
192     // before the use. There might be multiple loads and currently there is no
193     // easy way to replace all uses with a sequence of operations.
194     if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
195       auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext());
196       auto zero =
197           spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
198       auto loadPtr = rewriter.create<spirv::AccessChainOp>(
199           funcOp.getLoc(), replacement, zero.constant());
200       replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
201     }
202     signatureConverter.remapInput(argType.index(), replacement);
203   }
204 
205   // Creates a new function with the update signature.
206   rewriter.updateRootInPlace(funcOp, [&] {
207     funcOp.setType(rewriter.getFunctionType(
208         signatureConverter.getConvertedTypes(), llvm::None));
209     rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
210   });
211   return success();
212 }
213 
214 void LowerABIAttributesPass::runOnOperation() {
215   // Uses the signature conversion methodology of the dialect conversion
216   // framework to implement the conversion.
217   spirv::ModuleOp module = getOperation();
218   MLIRContext *context = &getContext();
219 
220   spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
221 
222   SPIRVTypeConverter typeConverter(targetEnv);
223   OwningRewritePatternList patterns;
224   patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
225 
226   ConversionTarget target(*context);
227   // "Legal" function ops should have no interface variable ABI attributes.
228   target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
229     StringRef attrName = spirv::getInterfaceVarABIAttrName();
230     for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
231       if (op.getArgAttr(i, attrName))
232         return false;
233     return true;
234   });
235   // All other SPIR-V ops are legal.
236   target.markUnknownOpDynamicallyLegal([](Operation *op) {
237     return op->getDialect()->getNamespace() ==
238            spirv::SPIRVDialect::getDialectNamespace();
239   });
240   if (failed(
241           applyPartialConversion(module, target, patterns, &typeConverter))) {
242     return signalPassFailure();
243   }
244 
245   // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
246   // attributes.
247   OpBuilder builder(context);
248   SmallVector<spirv::FuncOp, 1> entryPointFns;
249   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
250   module.walk([&](spirv::FuncOp funcOp) {
251     if (funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
252       entryPointFns.push_back(funcOp);
253     }
254   });
255   for (auto fn : entryPointFns) {
256     if (failed(lowerEntryPointABIAttr(fn, builder))) {
257       return signalPassFailure();
258     }
259   }
260 }
261 
262 std::unique_ptr<OperationPass<spirv::ModuleOp>>
263 mlir::spirv::createLowerABIAttributesPass() {
264   return std::make_unique<LowerABIAttributesPass>();
265 }
266