1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
15 #include "mlir/Dialect/SPIRV/Passes.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
18 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
19 #include "mlir/Dialect/StandardOps/Ops.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/SetVector.h"
22 
23 using namespace mlir;
24 
25 /// Checks if the `type` is a scalar or vector type. It is assumed that they are
26 /// valid for SPIR-V dialect already.
27 static bool isScalarOrVectorType(Type type) {
28   return spirv::SPIRVDialect::isValidScalarType(type) || type.isa<VectorType>();
29 }
30 
31 /// Creates a global variable for an argument based on the ABI info.
32 static spirv::GlobalVariableOp
33 createGlobalVariableForArg(FuncOp funcOp, OpBuilder &builder, unsigned argNum,
34                            spirv::InterfaceVarABIAttr abiInfo) {
35   auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
36   if (!spirvModule) {
37     return nullptr;
38   }
39   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
40   builder.setInsertionPoint(funcOp.getOperation());
41   std::string varName =
42       funcOp.getName().str() + "_arg_" + std::to_string(argNum);
43 
44   // Get the type of variable. If this is a scalar/vector type and has an ABI
45   // info create a variable of type !spv.ptr<!spv.struct<elementTYpe>>. If not
46   // it must already be a !spv.ptr<!spv.struct<...>>.
47   auto varType = funcOp.getType().getInput(argNum);
48   auto storageClass =
49       static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
50   if (isScalarOrVectorType(varType)) {
51     varType =
52         spirv::PointerType::get(spirv::StructType::get(varType), storageClass);
53   }
54   auto varPtrType = varType.cast<spirv::PointerType>();
55   auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
56 
57   // Set the offset information.
58   VulkanLayoutUtils::Size size = 0, alignment = 0;
59   varPointeeType =
60       VulkanLayoutUtils::decorateType(varPointeeType, size, alignment)
61           .cast<spirv::StructType>();
62   varType =
63       spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
64 
65   return builder.create<spirv::GlobalVariableOp>(
66       funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(),
67       abiInfo.binding().getInt());
68 }
69 
70 /// Gets the global variables that need to be specified as interface variable
71 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
72 static LogicalResult
73 getInterfaceVariables(FuncOp funcOp,
74                       SmallVectorImpl<Attribute> &interfaceVars) {
75   auto module = funcOp.getParentOfType<spirv::ModuleOp>();
76   if (!module) {
77     return failure();
78   }
79   llvm::SetVector<Operation *> interfaceVarSet;
80 
81   // TODO(ravishankarm) : This should in reality traverse the entry function
82   // call graph and collect all the interfaces. For now, just traverse the
83   // instructions in this function.
84   funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
85     auto var =
86         module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
87     if (var.type().cast<spirv::PointerType>().getStorageClass() !=
88         spirv::StorageClass::StorageBuffer) {
89       interfaceVarSet.insert(var.getOperation());
90     }
91   });
92   for (auto &var : interfaceVarSet) {
93     interfaceVars.push_back(SymbolRefAttr::get(
94         cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
95   }
96   return success();
97 }
98 
99 /// Lowers the entry point attribute.
100 static LogicalResult lowerEntryPointABIAttr(FuncOp funcOp, OpBuilder &builder) {
101   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
102   auto entryPointAttr =
103       funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
104   if (!entryPointAttr) {
105     return failure();
106   }
107 
108   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
109   auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
110   builder.setInsertionPoint(spirvModule.body().front().getTerminator());
111 
112   // Adds the spv.EntryPointOp after collecting all the interface variables
113   // needed.
114   SmallVector<Attribute, 1> interfaceVars;
115   if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
116     return failure();
117   }
118   builder.create<spirv::EntryPointOp>(
119       funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars);
120   // Specifies the spv.ExecutionModeOp.
121   auto localSizeAttr = entryPointAttr.local_size();
122   SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>());
123   builder.create<spirv::ExecutionModeOp>(
124       funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
125   funcOp.removeAttr(entryPointAttrName);
126   return success();
127 }
128 
129 namespace {
130 /// Pattern rewriter for changing function signature to match the ABI specified
131 /// in attributes.
132 class FuncOpLowering final : public SPIRVOpLowering<FuncOp> {
133 public:
134   using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
135   PatternMatchResult
136   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
137                   ConversionPatternRewriter &rewriter) const override;
138 };
139 
140 /// Pass to implement the ABI information specified as attributes.
141 class LowerABIAttributesPass final
142     : public OperationPass<LowerABIAttributesPass, spirv::ModuleOp> {
143 private:
144   void runOnOperation() override;
145 };
146 } // namespace
147 
148 PatternMatchResult
149 FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
150                                 ConversionPatternRewriter &rewriter) const {
151   if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
152           spirv::getEntryPointABIAttrName())) {
153     // TODO(ravishankarm) : Non-entry point functions are not handled.
154     return matchFailure();
155   }
156   TypeConverter::SignatureConversion signatureConverter(
157       funcOp.getType().getNumInputs());
158 
159   auto attrName = spirv::getInterfaceVarABIAttrName();
160   for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
161     auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
162         argType.index(), attrName);
163     if (!abiInfo) {
164       // TODO(ravishankarm) : For non-entry point functions, it should be legal
165       // to pass around scalar/vector values and return a scalar/vector. For now
166       // non-entry point functions are not handled in this ABI lowering and will
167       // produce an error.
168       return matchFailure();
169     }
170     auto var =
171         createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
172     if (!var) {
173       return matchFailure();
174     }
175 
176     OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
177     rewriter.setInsertionPointToStart(&funcOp.front());
178     // Insert spirv::AddressOf and spirv::AccessChain operations.
179     Value replacement =
180         rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
181     // Check if the arg is a scalar or vector type. In that case, the value
182     // needs to be loaded into registers.
183     // TODO(ravishankarm) : This is loading value of the scalar into registers
184     // at the start of the function. It is probably better to do the load just
185     // before the use. There might be multiple loads and currently there is no
186     // easy way to replace all uses with a sequence of operations.
187     if (isScalarOrVectorType(argType.value())) {
188       auto indexType =
189           typeConverter.convertType(IndexType::get(funcOp.getContext()));
190       auto zero =
191           spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);
192       auto loadPtr = rewriter.create<spirv::AccessChainOp>(
193           funcOp.getLoc(), replacement, zero.constant());
194       replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr,
195                                                    /*memory_access=*/nullptr,
196                                                    /*alignment=*/nullptr);
197     }
198     signatureConverter.remapInput(argType.index(), replacement);
199   }
200 
201   // Creates a new function with the update signature.
202   rewriter.updateRootInPlace(funcOp, [&] {
203     funcOp.setType(rewriter.getFunctionType(
204         signatureConverter.getConvertedTypes(), llvm::None));
205     rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
206   });
207   return matchSuccess();
208 }
209 
210 void LowerABIAttributesPass::runOnOperation() {
211   // Uses the signature conversion methodology of the dialect conversion
212   // framework to implement the conversion.
213   spirv::ModuleOp module = getOperation();
214   MLIRContext *context = &getContext();
215 
216   SPIRVTypeConverter typeConverter;
217   OwningRewritePatternList patterns;
218   patterns.insert<FuncOpLowering>(context, typeConverter);
219 
220   ConversionTarget target(*context);
221   target.addLegalDialect<spirv::SPIRVDialect>();
222   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
223   target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
224     return op.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName) &&
225            op.getNumResults() == 0 && op.getNumArguments() == 0;
226   });
227   target.addLegalOp<ReturnOp>();
228   if (failed(
229           applyPartialConversion(module, target, patterns, &typeConverter))) {
230     return signalPassFailure();
231   }
232 
233   // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
234   // attributes.
235   OpBuilder builder(context);
236   SmallVector<FuncOp, 1> entryPointFns;
237   module.walk([&](FuncOp funcOp) {
238     if (funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
239       entryPointFns.push_back(funcOp);
240     }
241   });
242   for (auto fn : entryPointFns) {
243     if (failed(lowerEntryPointABIAttr(fn, builder))) {
244       return signalPassFailure();
245     }
246   }
247 }
248 
249 std::unique_ptr<OpPassBase<spirv::ModuleOp>>
250 mlir::spirv::createLowerABIAttributesPass() {
251   return std::make_unique<LowerABIAttributesPass>();
252 }
253 
254 static PassRegistration<LowerABIAttributesPass>
255     pass("spirv-lower-abi-attrs", "Lower SPIR-V ABI Attributes");
256