1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that unifies access of multiple aliased resources
10 // into access of one single resource.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/SymbolTable.h"
23 #include "mlir/Pass/AnalysisManager.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
28 #include <algorithm>
29 #include <iterator>
30 
31 #define DEBUG_TYPE "spirv-unify-aliased-resource"
32 
33 using namespace mlir;
34 
35 //===----------------------------------------------------------------------===//
36 // Utility functions
37 //===----------------------------------------------------------------------===//
38 
39 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
40 using AliasedResourceMap =
41     DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
42 
43 /// Collects all aliased resources in the given SPIR-V `moduleOp`.
collectAliasedResources(spirv::ModuleOp moduleOp)44 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
45   AliasedResourceMap aliasedResources;
46   moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
47     if (varOp->getAttrOfType<UnitAttr>("aliased")) {
48       Optional<uint32_t> set = varOp.descriptor_set();
49       Optional<uint32_t> binding = varOp.binding();
50       if (set && binding)
51         aliasedResources[{*set, *binding}].push_back(varOp);
52     }
53   });
54   return aliasedResources;
55 }
56 
57 /// Returns the element type if the given `type` is a runtime array resource:
58 /// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise.
getRuntimeArrayElementType(Type type)59 static Type getRuntimeArrayElementType(Type type) {
60   auto ptrType = type.dyn_cast<spirv::PointerType>();
61   if (!ptrType)
62     return {};
63 
64   auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
65   if (!structType || structType.getNumElements() != 1)
66     return {};
67 
68   auto rtArrayType =
69       structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
70   if (!rtArrayType)
71     return {};
72 
73   return rtArrayType.getElementType();
74 }
75 
76 /// Given a list of resource element `types`, returns the index of the canonical
77 /// resource that all resources should be unified into. Returns llvm::None if
78 /// unable to unify.
deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types)79 static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
80   SmallVector<int> scalarNumBits, totalNumBits;
81   scalarNumBits.reserve(types.size());
82   totalNumBits.reserve(types.size());
83   bool hasVector = false;
84 
85   for (spirv::SPIRVType type : types) {
86     assert(type.isScalarOrVector());
87     if (auto vectorType = type.dyn_cast<VectorType>()) {
88       if (vectorType.getNumElements() % 2 != 0)
89         return llvm::None; // Odd-sized vector has special layout requirements.
90 
91       Optional<int64_t> numBytes = type.getSizeInBytes();
92       if (!numBytes)
93         return llvm::None;
94 
95       scalarNumBits.push_back(
96           vectorType.getElementType().getIntOrFloatBitWidth());
97       totalNumBits.push_back(*numBytes * 8);
98       hasVector = true;
99     } else {
100       scalarNumBits.push_back(type.getIntOrFloatBitWidth());
101       totalNumBits.push_back(scalarNumBits.back());
102     }
103   }
104 
105   if (hasVector) {
106     // If there are vector types, require all element types to be the same for
107     // now to simplify the transformation.
108     if (!llvm::is_splat(scalarNumBits))
109       return llvm::None;
110 
111     // Choose the one with the largest bitwidth as the canonical resource, so
112     // that we can still keep vectorized load/store.
113     auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end());
114     // Make sure that the canonical resource's bitwidth is divisible by others.
115     // With out this, we cannot properly adjust the index later.
116     if (llvm::any_of(totalNumBits,
117                      [maxVal](int64_t bits) { return *maxVal % bits != 0; }))
118       return llvm::None;
119 
120     return std::distance(totalNumBits.begin(), maxVal);
121   }
122 
123   // All element types are scalars. Then choose the smallest bitwidth as the
124   // cannonical resource to avoid subcomponent load/store.
125   auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end());
126   if (llvm::any_of(scalarNumBits,
127                    [minVal](int64_t bit) { return bit % *minVal != 0; }))
128     return llvm::None;
129   return std::distance(scalarNumBits.begin(), minVal);
130 }
131 
areSameBitwidthScalarType(Type a,Type b)132 static bool areSameBitwidthScalarType(Type a, Type b) {
133   return a.isIntOrFloat() && b.isIntOrFloat() &&
134          a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // Analysis
139 //===----------------------------------------------------------------------===//
140 
141 namespace {
142 /// A class for analyzing aliased resources.
143 ///
144 /// Resources are expected to be spv.GlobalVarible that has a descriptor set and
145 /// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>`
146 /// per Vulkan requirements.
147 ///
148 /// Right now, we only support the case that there is a single runtime array
149 /// inside the struct.
150 class ResourceAliasAnalysis {
151 public:
152   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
153 
154   explicit ResourceAliasAnalysis(Operation *);
155 
156   /// Returns true if the given `op` can be rewritten to use a canonical
157   /// resource.
158   bool shouldUnify(Operation *op) const;
159 
160   /// Returns all descriptors and their corresponding aliased resources.
getResourceMap() const161   const AliasedResourceMap &getResourceMap() const { return resourceMap; }
162 
163   /// Returns the canonical resource for the given descriptor/variable.
164   spirv::GlobalVariableOp
165   getCanonicalResource(const Descriptor &descriptor) const;
166   spirv::GlobalVariableOp
167   getCanonicalResource(spirv::GlobalVariableOp varOp) const;
168 
169   /// Returns the element type for the given variable.
170   spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
171 
172 private:
173   /// Given the descriptor and aliased resources bound to it, analyze whether we
174   /// can unify them and record if so.
175   void recordIfUnifiable(const Descriptor &descriptor,
176                          ArrayRef<spirv::GlobalVariableOp> resources);
177 
178   /// Mapping from a descriptor to all aliased resources bound to it.
179   AliasedResourceMap resourceMap;
180 
181   /// Mapping from a descriptor to the chosen canonical resource.
182   DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
183 
184   /// Mapping from an aliased resource to its descriptor.
185   DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
186 
187   /// Mapping from an aliased resource to its element (scalar/vector) type.
188   DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
189 };
190 } // namespace
191 
ResourceAliasAnalysis(Operation * root)192 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
193   // Collect all aliased resources first and put them into different sets
194   // according to the descriptor.
195   AliasedResourceMap aliasedResources =
196       collectAliasedResources(cast<spirv::ModuleOp>(root));
197 
198   // For each resource set, analyze whether we can unify; if so, try to identify
199   // a canonical resource, whose element type has the largest bitwidth.
200   for (const auto &descriptorResource : aliasedResources) {
201     recordIfUnifiable(descriptorResource.first, descriptorResource.second);
202   }
203 }
204 
shouldUnify(Operation * op) const205 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
206   if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
207     auto canonicalOp = getCanonicalResource(varOp);
208     return canonicalOp && varOp != canonicalOp;
209   }
210   if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
211     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
212     auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
213     return shouldUnify(varOp);
214   }
215 
216   if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
217     return shouldUnify(acOp.base_ptr().getDefiningOp());
218   if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
219     return shouldUnify(loadOp.ptr().getDefiningOp());
220   if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
221     return shouldUnify(storeOp.ptr().getDefiningOp());
222 
223   return false;
224 }
225 
getCanonicalResource(const Descriptor & descriptor) const226 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
227     const Descriptor &descriptor) const {
228   auto varIt = canonicalResourceMap.find(descriptor);
229   if (varIt == canonicalResourceMap.end())
230     return {};
231   return varIt->second;
232 }
233 
getCanonicalResource(spirv::GlobalVariableOp varOp) const234 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
235     spirv::GlobalVariableOp varOp) const {
236   auto descriptorIt = descriptorMap.find(varOp);
237   if (descriptorIt == descriptorMap.end())
238     return {};
239   return getCanonicalResource(descriptorIt->second);
240 }
241 
242 spirv::SPIRVType
getElementType(spirv::GlobalVariableOp varOp) const243 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
244   auto it = elementTypeMap.find(varOp);
245   if (it == elementTypeMap.end())
246     return {};
247   return it->second;
248 }
249 
recordIfUnifiable(const Descriptor & descriptor,ArrayRef<spirv::GlobalVariableOp> resources)250 void ResourceAliasAnalysis::recordIfUnifiable(
251     const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
252   // Collect the element types for all resources in the current set.
253   SmallVector<spirv::SPIRVType> elementTypes;
254   for (spirv::GlobalVariableOp resource : resources) {
255     Type elementType = getRuntimeArrayElementType(resource.type());
256     if (!elementType)
257       return; // Unexpected resource variable type.
258 
259     auto type = elementType.cast<spirv::SPIRVType>();
260     if (!type.isScalarOrVector())
261       return; // Unexpected resource element type.
262 
263     elementTypes.push_back(type);
264   }
265 
266   Optional<int> index = deduceCanonicalResource(elementTypes);
267   if (!index)
268     return;
269 
270   // Update internal data structures for later use.
271   resourceMap[descriptor].assign(resources.begin(), resources.end());
272   canonicalResourceMap[descriptor] = resources[*index];
273   for (const auto &resource : llvm::enumerate(resources)) {
274     descriptorMap[resource.value()] = descriptor;
275     elementTypeMap[resource.value()] = elementTypes[resource.index()];
276   }
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // Patterns
281 //===----------------------------------------------------------------------===//
282 
283 template <typename OpTy>
284 class ConvertAliasResource : public OpConversionPattern<OpTy> {
285 public:
ConvertAliasResource(const ResourceAliasAnalysis & analysis,MLIRContext * context,PatternBenefit benefit=1)286   ConvertAliasResource(const ResourceAliasAnalysis &analysis,
287                        MLIRContext *context, PatternBenefit benefit = 1)
288       : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
289 
290 protected:
291   const ResourceAliasAnalysis &analysis;
292 };
293 
294 struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
295   using ConvertAliasResource::ConvertAliasResource;
296 
297   LogicalResult
matchAndRewriteConvertVariable298   matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
299                   ConversionPatternRewriter &rewriter) const override {
300     // Just remove the aliased resource. Users will be rewritten to use the
301     // canonical one.
302     rewriter.eraseOp(varOp);
303     return success();
304   }
305 };
306 
307 struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
308   using ConvertAliasResource::ConvertAliasResource;
309 
310   LogicalResult
matchAndRewriteConvertAddressOf311   matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
312                   ConversionPatternRewriter &rewriter) const override {
313     // Rewrite the AddressOf op to get the address of the canoncical resource.
314     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
315     auto srcVarOp = cast<spirv::GlobalVariableOp>(
316         SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
317     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
318     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
319     return success();
320   }
321 };
322 
323 struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
324   using ConvertAliasResource::ConvertAliasResource;
325 
326   LogicalResult
matchAndRewriteConvertAccessChain327   matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329     auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
330     if (!addressOp)
331       return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
332 
333     auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
334     auto srcVarOp = cast<spirv::GlobalVariableOp>(
335         SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
336     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
337 
338     spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
339     spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
340 
341     if (srcElemType == dstElemType ||
342         areSameBitwidthScalarType(srcElemType, dstElemType)) {
343       // We have the same bitwidth for source and destination element types.
344       // Thie indices keep the same.
345       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
346           acOp, adaptor.base_ptr(), adaptor.indices());
347       return success();
348     }
349 
350     Location loc = acOp.getLoc();
351     auto i32Type = rewriter.getI32Type();
352 
353     if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
354       // The source indices are for a buffer with scalar element types. Rewrite
355       // them into a buffer with vector element types. We need to scale the last
356       // index for the vector as a whole, then add one level of index for inside
357       // the vector.
358       int srcNumBits = *srcElemType.getSizeInBytes();
359       int dstNumBits = *dstElemType.getSizeInBytes();
360       assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0);
361       int ratio = dstNumBits / srcNumBits;
362       auto ratioValue = rewriter.create<spirv::ConstantOp>(
363           loc, i32Type, rewriter.getI32IntegerAttr(ratio));
364 
365       auto indices = llvm::to_vector<4>(acOp.indices());
366       Value oldIndex = indices.back();
367       indices.back() =
368           rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
369       indices.push_back(
370           rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
371 
372       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
373           acOp, adaptor.base_ptr(), indices);
374       return success();
375     }
376 
377     if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) {
378       // The source indices are for a buffer with larger bitwidth scalar element
379       // types. Rewrite them into a buffer with smaller bitwidth element types.
380       // We only need to scale the last index.
381       int srcNumBits = *srcElemType.getSizeInBytes();
382       int dstNumBits = *dstElemType.getSizeInBytes();
383       assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
384       int ratio = srcNumBits / dstNumBits;
385       auto ratioValue = rewriter.create<spirv::ConstantOp>(
386           loc, i32Type, rewriter.getI32IntegerAttr(ratio));
387 
388       auto indices = llvm::to_vector<4>(acOp.indices());
389       Value oldIndex = indices.back();
390       indices.back() =
391           rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
392 
393       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
394           acOp, adaptor.base_ptr(), indices);
395       return success();
396     }
397 
398     return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
399   }
400 };
401 
402 struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
403   using ConvertAliasResource::ConvertAliasResource;
404 
405   LogicalResult
matchAndRewriteConvertLoad406   matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
407                   ConversionPatternRewriter &rewriter) const override {
408     auto srcElemType =
409         loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
410     auto dstElemType =
411         adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
412     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
413       return rewriter.notifyMatchFailure(loadOp, "not scalar type");
414 
415     Location loc = loadOp.getLoc();
416     auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
417     if (srcElemType == dstElemType) {
418       rewriter.replaceOp(loadOp, newLoadOp->getResults());
419       return success();
420     }
421 
422     if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
423       auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
424                                                       newLoadOp.value());
425       rewriter.replaceOp(loadOp, castOp->getResults());
426 
427       return success();
428     }
429 
430     // The source and destination have scalar types of different bitwidths.
431     // For such cases, we need to load multiple smaller bitwidth values and
432     // construct a larger bitwidth one.
433 
434     int srcNumBits = srcElemType.getIntOrFloatBitWidth();
435     int dstNumBits = dstElemType.getIntOrFloatBitWidth();
436     assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
437     int ratio = srcNumBits / dstNumBits;
438     if (ratio > 4)
439       return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
440 
441     SmallVector<Value> components;
442     components.reserve(ratio);
443     components.push_back(newLoadOp);
444 
445     auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
446     if (!acOp)
447       return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
448 
449     auto i32Type = rewriter.getI32Type();
450     Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
451     auto indices = llvm::to_vector<4>(acOp.indices());
452     for (int i = 1; i < ratio; ++i) {
453       // Load all subsequent components belonging to this element.
454       indices.back() = rewriter.create<spirv::IAddOp>(loc, i32Type,
455                                                       indices.back(), oneValue);
456       auto componentAcOp =
457           rewriter.create<spirv::AccessChainOp>(loc, acOp.base_ptr(), indices);
458       // Assuming little endian, this reads lower-ordered bits of the number to
459       // lower-numbered components of the vector.
460       components.push_back(rewriter.create<spirv::LoadOp>(loc, componentAcOp));
461     }
462 
463     // Create a vector of the components and then cast back to the larger
464     // bitwidth element type. For spv.bitcast, the lower-numbered components of
465     // the vector map to lower-ordered bits of the larger bitwidth element type.
466     auto vectorType = VectorType::get({ratio}, dstElemType);
467     Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
468         loc, vectorType, components);
469     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(loadOp, srcElemType,
470                                                   vectorValue);
471     return success();
472   }
473 };
474 
475 struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
476   using ConvertAliasResource::ConvertAliasResource;
477 
478   LogicalResult
matchAndRewriteConvertStore479   matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
480                   ConversionPatternRewriter &rewriter) const override {
481     auto srcElemType =
482         storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
483     auto dstElemType =
484         adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
485     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
486       return rewriter.notifyMatchFailure(storeOp, "not scalar type");
487     if (!areSameBitwidthScalarType(srcElemType, dstElemType))
488       return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
489 
490     Location loc = storeOp.getLoc();
491     Value value = adaptor.value();
492     if (srcElemType != dstElemType)
493       value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
494     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
495                                                 storeOp->getAttrs());
496     return success();
497   }
498 };
499 
500 //===----------------------------------------------------------------------===//
501 // Pass
502 //===----------------------------------------------------------------------===//
503 
504 namespace {
505 class UnifyAliasedResourcePass final
506     : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
507 public:
508   void runOnOperation() override;
509 };
510 } // namespace
511 
runOnOperation()512 void UnifyAliasedResourcePass::runOnOperation() {
513   spirv::ModuleOp moduleOp = getOperation();
514   MLIRContext *context = &getContext();
515 
516   // Analyze aliased resources first.
517   ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
518 
519   ConversionTarget target(*context);
520   target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
521                                spirv::AccessChainOp, spirv::LoadOp,
522                                spirv::StoreOp>(
523       [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
524   target.addLegalDialect<spirv::SPIRVDialect>();
525 
526   // Run patterns to rewrite usages of non-canonical resources.
527   RewritePatternSet patterns(context);
528   patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
529                ConvertLoad, ConvertStore>(analysis, context);
530   if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
531     return signalPassFailure();
532 
533   // Drop aliased attribute if we only have one single bound resource for a
534   // descriptor. We need to re-collect the map here given in the above the
535   // conversion is best effort; certain sets may not be converted.
536   AliasedResourceMap resourceMap =
537       collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
538   for (const auto &dr : resourceMap) {
539     const auto &resources = dr.second;
540     if (resources.size() == 1)
541       resources.front()->removeAttr("aliased");
542   }
543 }
544 
545 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
createUnifyAliasedResourcePass()546 spirv::createUnifyAliasedResourcePass() {
547   return std::make_unique<UnifyAliasedResourcePass>();
548 }
549