1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that unifies access of multiple aliased resources
10 // into access of one single resource.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/SymbolTable.h"
23 #include "mlir/Pass/AnalysisManager.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
28 #include <algorithm>
29 
30 #define DEBUG_TYPE "spirv-unify-aliased-resource"
31 
32 using namespace mlir;
33 
34 //===----------------------------------------------------------------------===//
35 // Utility functions
36 //===----------------------------------------------------------------------===//
37 
38 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
39 using AliasedResourceMap =
40     DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
41 
42 /// Collects all aliased resources in the given SPIR-V `moduleOp`.
43 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
44   AliasedResourceMap aliasedResoruces;
45   moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) {
46     if (varOp->getAttrOfType<UnitAttr>("aliased")) {
47       Optional<uint32_t> set = varOp.descriptor_set();
48       Optional<uint32_t> binding = varOp.binding();
49       if (set && binding)
50         aliasedResoruces[{*set, *binding}].push_back(varOp);
51     }
52   });
53   return aliasedResoruces;
54 }
55 
56 /// Returns the element type if the given `type` is a runtime array resource:
57 /// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise.
58 static Type getRuntimeArrayElementType(Type type) {
59   auto ptrType = type.dyn_cast<spirv::PointerType>();
60   if (!ptrType)
61     return {};
62 
63   auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
64   if (!structType || structType.getNumElements() != 1)
65     return {};
66 
67   auto rtArrayType =
68       structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
69   if (!rtArrayType)
70     return {};
71 
72   return rtArrayType.getElementType();
73 }
74 
75 /// Returns true if all `types`, which can either be scalar or vector types,
76 /// have the same bitwidth base scalar type.
77 static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) {
78   SmallVector<int64_t> scalarTypes;
79   scalarTypes.reserve(types.size());
80   for (spirv::SPIRVType type : types) {
81     assert(type.isScalarOrVector());
82     if (auto vectorType = type.dyn_cast<VectorType>())
83       scalarTypes.push_back(
84           vectorType.getElementType().getIntOrFloatBitWidth());
85     else
86       scalarTypes.push_back(type.getIntOrFloatBitWidth());
87   }
88   return llvm::is_splat(scalarTypes);
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Analysis
93 //===----------------------------------------------------------------------===//
94 
95 namespace {
96 /// A class for analyzing aliased resources.
97 ///
98 /// Resources are expected to be spv.GlobalVarible that has a descriptor set and
99 /// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>`
100 /// per Vulkan requirements.
101 ///
102 /// Right now, we only support the case that there is a single runtime array
103 /// inside the struct.
104 class ResourceAliasAnalysis {
105 public:
106   explicit ResourceAliasAnalysis(Operation *);
107 
108   /// Returns true if the given `op` can be rewritten to use a canonical
109   /// resource.
110   bool shouldUnify(Operation *op) const;
111 
112   /// Returns all descriptors and their corresponding aliased resources.
113   const AliasedResourceMap &getResourceMap() const { return resourceMap; }
114 
115   /// Returns the canonical resource for the given descriptor/variable.
116   spirv::GlobalVariableOp
117   getCanonicalResource(const Descriptor &descriptor) const;
118   spirv::GlobalVariableOp
119   getCanonicalResource(spirv::GlobalVariableOp varOp) const;
120 
121   /// Returns the element type for the given variable.
122   spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
123 
124 private:
125   /// Given the descriptor and aliased resources bound to it, analyze whether we
126   /// can unify them and record if so.
127   void recordIfUnifiable(const Descriptor &descriptor,
128                          ArrayRef<spirv::GlobalVariableOp> resources);
129 
130   /// Mapping from a descriptor to all aliased resources bound to it.
131   AliasedResourceMap resourceMap;
132 
133   /// Mapping from a descriptor to the chosen canonical resource.
134   DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
135 
136   /// Mapping from an aliased resource to its descriptor.
137   DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
138 
139   /// Mapping from an aliased resource to its element (scalar/vector) type.
140   DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
141 };
142 } // namespace
143 
144 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
145   // Collect all aliased resources first and put them into different sets
146   // according to the descriptor.
147   AliasedResourceMap aliasedResoruces =
148       collectAliasedResources(cast<spirv::ModuleOp>(root));
149 
150   // For each resource set, analyze whether we can unify; if so, try to identify
151   // a canonical resource, whose element type has the largest bitwidth.
152   for (const auto &descriptorResoruce : aliasedResoruces) {
153     recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second);
154   }
155 }
156 
157 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
158   if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
159     auto canonicalOp = getCanonicalResource(varOp);
160     return canonicalOp && varOp != canonicalOp;
161   }
162   if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
163     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
164     auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
165     return shouldUnify(varOp);
166   }
167 
168   if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
169     return shouldUnify(acOp.base_ptr().getDefiningOp());
170   if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
171     return shouldUnify(loadOp.ptr().getDefiningOp());
172   if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
173     return shouldUnify(storeOp.ptr().getDefiningOp());
174 
175   return false;
176 }
177 
178 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
179     const Descriptor &descriptor) const {
180   auto varIt = canonicalResourceMap.find(descriptor);
181   if (varIt == canonicalResourceMap.end())
182     return {};
183   return varIt->second;
184 }
185 
186 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
187     spirv::GlobalVariableOp varOp) const {
188   auto descriptorIt = descriptorMap.find(varOp);
189   if (descriptorIt == descriptorMap.end())
190     return {};
191   return getCanonicalResource(descriptorIt->second);
192 }
193 
194 spirv::SPIRVType
195 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
196   auto it = elementTypeMap.find(varOp);
197   if (it == elementTypeMap.end())
198     return {};
199   return it->second;
200 }
201 
202 void ResourceAliasAnalysis::recordIfUnifiable(
203     const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
204   // Collect the element types and byte counts for all resources in the
205   // current set.
206   SmallVector<spirv::SPIRVType> elementTypes;
207   SmallVector<int64_t> numBytes;
208 
209   for (spirv::GlobalVariableOp resource : resources) {
210     Type elementType = getRuntimeArrayElementType(resource.type());
211     if (!elementType)
212       return; // Unexpected resource variable type.
213 
214     auto type = elementType.cast<spirv::SPIRVType>();
215     if (!type.isScalarOrVector())
216       return; // Unexpected resource element type.
217 
218     if (auto vectorType = type.dyn_cast<VectorType>())
219       if (vectorType.getNumElements() % 2 != 0)
220         return; // Odd-sized vector has special layout requirements.
221 
222     Optional<int64_t> count = type.getSizeInBytes();
223     if (!count)
224       return;
225 
226     elementTypes.push_back(type);
227     numBytes.push_back(*count);
228   }
229 
230   // Make sure base scalar types have the same bitwdith, so that we don't need
231   // to handle extracting components for now.
232   if (!hasSameBitwidthScalarType(elementTypes))
233     return;
234 
235   // Make sure that the canonical resource's bitwidth is divisible by others.
236   // With out this, we cannot properly adjust the index later.
237   auto *maxCount = std::max_element(numBytes.begin(), numBytes.end());
238   if (llvm::any_of(numBytes, [maxCount](int64_t count) {
239         return *maxCount % count != 0;
240       }))
241     return;
242 
243   spirv::GlobalVariableOp canonicalResource =
244       resources[std::distance(numBytes.begin(), maxCount)];
245 
246   // Update internal data structures for later use.
247   resourceMap[descriptor].assign(resources.begin(), resources.end());
248   canonicalResourceMap[descriptor] = canonicalResource;
249   for (const auto &resource : llvm::enumerate(resources)) {
250     descriptorMap[resource.value()] = descriptor;
251     elementTypeMap[resource.value()] = elementTypes[resource.index()];
252   }
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Patterns
257 //===----------------------------------------------------------------------===//
258 
259 template <typename OpTy>
260 class ConvertAliasResoruce : public OpConversionPattern<OpTy> {
261 public:
262   ConvertAliasResoruce(const ResourceAliasAnalysis &analysis,
263                        MLIRContext *context, PatternBenefit benefit = 1)
264       : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
265 
266 protected:
267   const ResourceAliasAnalysis &analysis;
268 };
269 
270 struct ConvertVariable : public ConvertAliasResoruce<spirv::GlobalVariableOp> {
271   using ConvertAliasResoruce::ConvertAliasResoruce;
272 
273   LogicalResult
274   matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
275                   ConversionPatternRewriter &rewriter) const override {
276     // Just remove the aliased resource. Users will be rewritten to use the
277     // canonical one.
278     rewriter.eraseOp(varOp);
279     return success();
280   }
281 };
282 
283 struct ConvertAddressOf : public ConvertAliasResoruce<spirv::AddressOfOp> {
284   using ConvertAliasResoruce::ConvertAliasResoruce;
285 
286   LogicalResult
287   matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
288                   ConversionPatternRewriter &rewriter) const override {
289     // Rewrite the AddressOf op to get the address of the canoncical resource.
290     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
291     auto srcVarOp = cast<spirv::GlobalVariableOp>(
292         SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
293     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
294     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
295     return success();
296   }
297 };
298 
299 struct ConvertAccessChain : public ConvertAliasResoruce<spirv::AccessChainOp> {
300   using ConvertAliasResoruce::ConvertAliasResoruce;
301 
302   LogicalResult
303   matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
304                   ConversionPatternRewriter &rewriter) const override {
305     auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
306     if (!addressOp)
307       return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
308 
309     auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
310     auto srcVarOp = cast<spirv::GlobalVariableOp>(
311         SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
312     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
313 
314     spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
315     spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
316 
317     if ((srcElemType == dstElemType) ||
318         (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) {
319       // We have the same bitwidth for source and destination element types.
320       // Thie indices keep the same.
321       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
322           acOp, adaptor.base_ptr(), adaptor.indices());
323       return success();
324     }
325 
326     Location loc = acOp.getLoc();
327     auto i32Type = rewriter.getI32Type();
328 
329     if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
330       // The source indices are for a buffer with scalar element types. Rewrite
331       // them into a buffer with vector element types. We need to scale the last
332       // index for the vector as a whole, then add one level of index for inside
333       // the vector.
334       int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes();
335       auto ratioValue = rewriter.create<spirv::ConstantOp>(
336           loc, i32Type, rewriter.getI32IntegerAttr(ratio));
337 
338       auto indices = llvm::to_vector<4>(acOp.indices());
339       Value oldIndex = indices.back();
340       indices.back() =
341           rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
342       indices.push_back(
343           rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
344 
345       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
346           acOp, adaptor.base_ptr(), indices);
347       return success();
348     }
349 
350     return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
351   }
352 };
353 
354 struct ConvertLoad : public ConvertAliasResoruce<spirv::LoadOp> {
355   using ConvertAliasResoruce::ConvertAliasResoruce;
356 
357   LogicalResult
358   matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
359                   ConversionPatternRewriter &rewriter) const override {
360     auto srcElemType =
361         loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
362     auto dstElemType =
363         adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
364     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
365       return rewriter.notifyMatchFailure(loadOp, "not scalar type");
366 
367     Location loc = loadOp.getLoc();
368     auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
369     if (srcElemType == dstElemType) {
370       rewriter.replaceOp(loadOp, newLoadOp->getResults());
371     } else {
372       auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
373                                                       newLoadOp.value());
374       rewriter.replaceOp(loadOp, castOp->getResults());
375     }
376 
377     return success();
378   }
379 };
380 
381 struct ConvertStore : public ConvertAliasResoruce<spirv::StoreOp> {
382   using ConvertAliasResoruce::ConvertAliasResoruce;
383 
384   LogicalResult
385   matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
386                   ConversionPatternRewriter &rewriter) const override {
387     auto srcElemType =
388         storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
389     auto dstElemType =
390         adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
391     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
392       return rewriter.notifyMatchFailure(storeOp, "not scalar type");
393 
394     Location loc = storeOp.getLoc();
395     Value value = adaptor.value();
396     if (srcElemType != dstElemType)
397       value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
398     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
399                                                 storeOp->getAttrs());
400     return success();
401   }
402 };
403 
404 //===----------------------------------------------------------------------===//
405 // Pass
406 //===----------------------------------------------------------------------===//
407 
408 namespace {
409 class UnifyAliasedResourcePass final
410     : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
411 public:
412   void runOnOperation() override;
413 };
414 } // namespace
415 
416 void UnifyAliasedResourcePass::runOnOperation() {
417   spirv::ModuleOp moduleOp = getOperation();
418   MLIRContext *context = &getContext();
419 
420   // Analyze aliased resources first.
421   ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
422 
423   ConversionTarget target(*context);
424   target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
425                                spirv::AccessChainOp, spirv::LoadOp,
426                                spirv::StoreOp>(
427       [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
428   target.addLegalDialect<spirv::SPIRVDialect>();
429 
430   // Run patterns to rewrite usages of non-canonical resources.
431   RewritePatternSet patterns(context);
432   patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
433                ConvertLoad, ConvertStore>(analysis, context);
434   if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
435     return signalPassFailure();
436 
437   // Drop aliased attribute if we only have one single bound resource for a
438   // descriptor. We need to re-collect the map here given in the above the
439   // conversion is best effort; certain sets may not be converted.
440   AliasedResourceMap resourceMap =
441       collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
442   for (const auto &dr : resourceMap) {
443     const auto &resources = dr.second;
444     if (resources.size() == 1)
445       resources.front()->removeAttr("aliased");
446   }
447 }
448 
449 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
450 spirv::createUnifyAliasedResourcePass() {
451   return std::make_unique<UnifyAliasedResourcePass>();
452 }
453