1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that unifies access of multiple aliased resources
10 // into access of one single resource.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/SymbolTable.h"
23 #include "mlir/Pass/AnalysisManager.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
28 #include <algorithm>
29 #include <iterator>
30
31 #define DEBUG_TYPE "spirv-unify-aliased-resource"
32
33 using namespace mlir;
34
35 //===----------------------------------------------------------------------===//
36 // Utility functions
37 //===----------------------------------------------------------------------===//
38
39 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
40 using AliasedResourceMap =
41 DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
42
43 /// Collects all aliased resources in the given SPIR-V `moduleOp`.
collectAliasedResources(spirv::ModuleOp moduleOp)44 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
45 AliasedResourceMap aliasedResources;
46 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
47 if (varOp->getAttrOfType<UnitAttr>("aliased")) {
48 Optional<uint32_t> set = varOp.descriptor_set();
49 Optional<uint32_t> binding = varOp.binding();
50 if (set && binding)
51 aliasedResources[{*set, *binding}].push_back(varOp);
52 }
53 });
54 return aliasedResources;
55 }
56
57 /// Returns the element type if the given `type` is a runtime array resource:
58 /// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise.
getRuntimeArrayElementType(Type type)59 static Type getRuntimeArrayElementType(Type type) {
60 auto ptrType = type.dyn_cast<spirv::PointerType>();
61 if (!ptrType)
62 return {};
63
64 auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
65 if (!structType || structType.getNumElements() != 1)
66 return {};
67
68 auto rtArrayType =
69 structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
70 if (!rtArrayType)
71 return {};
72
73 return rtArrayType.getElementType();
74 }
75
76 /// Given a list of resource element `types`, returns the index of the canonical
77 /// resource that all resources should be unified into. Returns llvm::None if
78 /// unable to unify.
deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types)79 static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
80 SmallVector<int> scalarNumBits, totalNumBits;
81 scalarNumBits.reserve(types.size());
82 totalNumBits.reserve(types.size());
83 bool hasVector = false;
84
85 for (spirv::SPIRVType type : types) {
86 assert(type.isScalarOrVector());
87 if (auto vectorType = type.dyn_cast<VectorType>()) {
88 if (vectorType.getNumElements() % 2 != 0)
89 return llvm::None; // Odd-sized vector has special layout requirements.
90
91 Optional<int64_t> numBytes = type.getSizeInBytes();
92 if (!numBytes)
93 return llvm::None;
94
95 scalarNumBits.push_back(
96 vectorType.getElementType().getIntOrFloatBitWidth());
97 totalNumBits.push_back(*numBytes * 8);
98 hasVector = true;
99 } else {
100 scalarNumBits.push_back(type.getIntOrFloatBitWidth());
101 totalNumBits.push_back(scalarNumBits.back());
102 }
103 }
104
105 if (hasVector) {
106 // If there are vector types, require all element types to be the same for
107 // now to simplify the transformation.
108 if (!llvm::is_splat(scalarNumBits))
109 return llvm::None;
110
111 // Choose the one with the largest bitwidth as the canonical resource, so
112 // that we can still keep vectorized load/store.
113 auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end());
114 // Make sure that the canonical resource's bitwidth is divisible by others.
115 // With out this, we cannot properly adjust the index later.
116 if (llvm::any_of(totalNumBits,
117 [maxVal](int64_t bits) { return *maxVal % bits != 0; }))
118 return llvm::None;
119
120 return std::distance(totalNumBits.begin(), maxVal);
121 }
122
123 // All element types are scalars. Then choose the smallest bitwidth as the
124 // cannonical resource to avoid subcomponent load/store.
125 auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end());
126 if (llvm::any_of(scalarNumBits,
127 [minVal](int64_t bit) { return bit % *minVal != 0; }))
128 return llvm::None;
129 return std::distance(scalarNumBits.begin(), minVal);
130 }
131
areSameBitwidthScalarType(Type a,Type b)132 static bool areSameBitwidthScalarType(Type a, Type b) {
133 return a.isIntOrFloat() && b.isIntOrFloat() &&
134 a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
135 }
136
137 //===----------------------------------------------------------------------===//
138 // Analysis
139 //===----------------------------------------------------------------------===//
140
141 namespace {
142 /// A class for analyzing aliased resources.
143 ///
144 /// Resources are expected to be spv.GlobalVarible that has a descriptor set and
145 /// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>`
146 /// per Vulkan requirements.
147 ///
148 /// Right now, we only support the case that there is a single runtime array
149 /// inside the struct.
150 class ResourceAliasAnalysis {
151 public:
152 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
153
154 explicit ResourceAliasAnalysis(Operation *);
155
156 /// Returns true if the given `op` can be rewritten to use a canonical
157 /// resource.
158 bool shouldUnify(Operation *op) const;
159
160 /// Returns all descriptors and their corresponding aliased resources.
getResourceMap() const161 const AliasedResourceMap &getResourceMap() const { return resourceMap; }
162
163 /// Returns the canonical resource for the given descriptor/variable.
164 spirv::GlobalVariableOp
165 getCanonicalResource(const Descriptor &descriptor) const;
166 spirv::GlobalVariableOp
167 getCanonicalResource(spirv::GlobalVariableOp varOp) const;
168
169 /// Returns the element type for the given variable.
170 spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
171
172 private:
173 /// Given the descriptor and aliased resources bound to it, analyze whether we
174 /// can unify them and record if so.
175 void recordIfUnifiable(const Descriptor &descriptor,
176 ArrayRef<spirv::GlobalVariableOp> resources);
177
178 /// Mapping from a descriptor to all aliased resources bound to it.
179 AliasedResourceMap resourceMap;
180
181 /// Mapping from a descriptor to the chosen canonical resource.
182 DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
183
184 /// Mapping from an aliased resource to its descriptor.
185 DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
186
187 /// Mapping from an aliased resource to its element (scalar/vector) type.
188 DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
189 };
190 } // namespace
191
ResourceAliasAnalysis(Operation * root)192 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
193 // Collect all aliased resources first and put them into different sets
194 // according to the descriptor.
195 AliasedResourceMap aliasedResources =
196 collectAliasedResources(cast<spirv::ModuleOp>(root));
197
198 // For each resource set, analyze whether we can unify; if so, try to identify
199 // a canonical resource, whose element type has the largest bitwidth.
200 for (const auto &descriptorResource : aliasedResources) {
201 recordIfUnifiable(descriptorResource.first, descriptorResource.second);
202 }
203 }
204
shouldUnify(Operation * op) const205 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
206 if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
207 auto canonicalOp = getCanonicalResource(varOp);
208 return canonicalOp && varOp != canonicalOp;
209 }
210 if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
211 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
212 auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
213 return shouldUnify(varOp);
214 }
215
216 if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
217 return shouldUnify(acOp.base_ptr().getDefiningOp());
218 if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
219 return shouldUnify(loadOp.ptr().getDefiningOp());
220 if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
221 return shouldUnify(storeOp.ptr().getDefiningOp());
222
223 return false;
224 }
225
getCanonicalResource(const Descriptor & descriptor) const226 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
227 const Descriptor &descriptor) const {
228 auto varIt = canonicalResourceMap.find(descriptor);
229 if (varIt == canonicalResourceMap.end())
230 return {};
231 return varIt->second;
232 }
233
getCanonicalResource(spirv::GlobalVariableOp varOp) const234 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
235 spirv::GlobalVariableOp varOp) const {
236 auto descriptorIt = descriptorMap.find(varOp);
237 if (descriptorIt == descriptorMap.end())
238 return {};
239 return getCanonicalResource(descriptorIt->second);
240 }
241
242 spirv::SPIRVType
getElementType(spirv::GlobalVariableOp varOp) const243 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
244 auto it = elementTypeMap.find(varOp);
245 if (it == elementTypeMap.end())
246 return {};
247 return it->second;
248 }
249
recordIfUnifiable(const Descriptor & descriptor,ArrayRef<spirv::GlobalVariableOp> resources)250 void ResourceAliasAnalysis::recordIfUnifiable(
251 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
252 // Collect the element types for all resources in the current set.
253 SmallVector<spirv::SPIRVType> elementTypes;
254 for (spirv::GlobalVariableOp resource : resources) {
255 Type elementType = getRuntimeArrayElementType(resource.type());
256 if (!elementType)
257 return; // Unexpected resource variable type.
258
259 auto type = elementType.cast<spirv::SPIRVType>();
260 if (!type.isScalarOrVector())
261 return; // Unexpected resource element type.
262
263 elementTypes.push_back(type);
264 }
265
266 Optional<int> index = deduceCanonicalResource(elementTypes);
267 if (!index)
268 return;
269
270 // Update internal data structures for later use.
271 resourceMap[descriptor].assign(resources.begin(), resources.end());
272 canonicalResourceMap[descriptor] = resources[*index];
273 for (const auto &resource : llvm::enumerate(resources)) {
274 descriptorMap[resource.value()] = descriptor;
275 elementTypeMap[resource.value()] = elementTypes[resource.index()];
276 }
277 }
278
279 //===----------------------------------------------------------------------===//
280 // Patterns
281 //===----------------------------------------------------------------------===//
282
283 template <typename OpTy>
284 class ConvertAliasResource : public OpConversionPattern<OpTy> {
285 public:
ConvertAliasResource(const ResourceAliasAnalysis & analysis,MLIRContext * context,PatternBenefit benefit=1)286 ConvertAliasResource(const ResourceAliasAnalysis &analysis,
287 MLIRContext *context, PatternBenefit benefit = 1)
288 : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
289
290 protected:
291 const ResourceAliasAnalysis &analysis;
292 };
293
294 struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
295 using ConvertAliasResource::ConvertAliasResource;
296
297 LogicalResult
matchAndRewriteConvertVariable298 matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override {
300 // Just remove the aliased resource. Users will be rewritten to use the
301 // canonical one.
302 rewriter.eraseOp(varOp);
303 return success();
304 }
305 };
306
307 struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
308 using ConvertAliasResource::ConvertAliasResource;
309
310 LogicalResult
matchAndRewriteConvertAddressOf311 matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
312 ConversionPatternRewriter &rewriter) const override {
313 // Rewrite the AddressOf op to get the address of the canoncical resource.
314 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
315 auto srcVarOp = cast<spirv::GlobalVariableOp>(
316 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
317 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
318 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
319 return success();
320 }
321 };
322
323 struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
324 using ConvertAliasResource::ConvertAliasResource;
325
326 LogicalResult
matchAndRewriteConvertAccessChain327 matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter) const override {
329 auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
330 if (!addressOp)
331 return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
332
333 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
334 auto srcVarOp = cast<spirv::GlobalVariableOp>(
335 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
336 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
337
338 spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
339 spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
340
341 if (srcElemType == dstElemType ||
342 areSameBitwidthScalarType(srcElemType, dstElemType)) {
343 // We have the same bitwidth for source and destination element types.
344 // Thie indices keep the same.
345 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
346 acOp, adaptor.base_ptr(), adaptor.indices());
347 return success();
348 }
349
350 Location loc = acOp.getLoc();
351 auto i32Type = rewriter.getI32Type();
352
353 if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
354 // The source indices are for a buffer with scalar element types. Rewrite
355 // them into a buffer with vector element types. We need to scale the last
356 // index for the vector as a whole, then add one level of index for inside
357 // the vector.
358 int srcNumBits = *srcElemType.getSizeInBytes();
359 int dstNumBits = *dstElemType.getSizeInBytes();
360 assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0);
361 int ratio = dstNumBits / srcNumBits;
362 auto ratioValue = rewriter.create<spirv::ConstantOp>(
363 loc, i32Type, rewriter.getI32IntegerAttr(ratio));
364
365 auto indices = llvm::to_vector<4>(acOp.indices());
366 Value oldIndex = indices.back();
367 indices.back() =
368 rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
369 indices.push_back(
370 rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
371
372 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
373 acOp, adaptor.base_ptr(), indices);
374 return success();
375 }
376
377 if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) {
378 // The source indices are for a buffer with larger bitwidth scalar element
379 // types. Rewrite them into a buffer with smaller bitwidth element types.
380 // We only need to scale the last index.
381 int srcNumBits = *srcElemType.getSizeInBytes();
382 int dstNumBits = *dstElemType.getSizeInBytes();
383 assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
384 int ratio = srcNumBits / dstNumBits;
385 auto ratioValue = rewriter.create<spirv::ConstantOp>(
386 loc, i32Type, rewriter.getI32IntegerAttr(ratio));
387
388 auto indices = llvm::to_vector<4>(acOp.indices());
389 Value oldIndex = indices.back();
390 indices.back() =
391 rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
392
393 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
394 acOp, adaptor.base_ptr(), indices);
395 return success();
396 }
397
398 return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
399 }
400 };
401
402 struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
403 using ConvertAliasResource::ConvertAliasResource;
404
405 LogicalResult
matchAndRewriteConvertLoad406 matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const override {
408 auto srcElemType =
409 loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
410 auto dstElemType =
411 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
412 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
413 return rewriter.notifyMatchFailure(loadOp, "not scalar type");
414
415 Location loc = loadOp.getLoc();
416 auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
417 if (srcElemType == dstElemType) {
418 rewriter.replaceOp(loadOp, newLoadOp->getResults());
419 return success();
420 }
421
422 if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
423 auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
424 newLoadOp.value());
425 rewriter.replaceOp(loadOp, castOp->getResults());
426
427 return success();
428 }
429
430 // The source and destination have scalar types of different bitwidths.
431 // For such cases, we need to load multiple smaller bitwidth values and
432 // construct a larger bitwidth one.
433
434 int srcNumBits = srcElemType.getIntOrFloatBitWidth();
435 int dstNumBits = dstElemType.getIntOrFloatBitWidth();
436 assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
437 int ratio = srcNumBits / dstNumBits;
438 if (ratio > 4)
439 return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
440
441 SmallVector<Value> components;
442 components.reserve(ratio);
443 components.push_back(newLoadOp);
444
445 auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
446 if (!acOp)
447 return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
448
449 auto i32Type = rewriter.getI32Type();
450 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
451 auto indices = llvm::to_vector<4>(acOp.indices());
452 for (int i = 1; i < ratio; ++i) {
453 // Load all subsequent components belonging to this element.
454 indices.back() = rewriter.create<spirv::IAddOp>(loc, i32Type,
455 indices.back(), oneValue);
456 auto componentAcOp =
457 rewriter.create<spirv::AccessChainOp>(loc, acOp.base_ptr(), indices);
458 // Assuming little endian, this reads lower-ordered bits of the number to
459 // lower-numbered components of the vector.
460 components.push_back(rewriter.create<spirv::LoadOp>(loc, componentAcOp));
461 }
462
463 // Create a vector of the components and then cast back to the larger
464 // bitwidth element type. For spv.bitcast, the lower-numbered components of
465 // the vector map to lower-ordered bits of the larger bitwidth element type.
466 auto vectorType = VectorType::get({ratio}, dstElemType);
467 Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
468 loc, vectorType, components);
469 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(loadOp, srcElemType,
470 vectorValue);
471 return success();
472 }
473 };
474
475 struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
476 using ConvertAliasResource::ConvertAliasResource;
477
478 LogicalResult
matchAndRewriteConvertStore479 matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
480 ConversionPatternRewriter &rewriter) const override {
481 auto srcElemType =
482 storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
483 auto dstElemType =
484 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
485 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
486 return rewriter.notifyMatchFailure(storeOp, "not scalar type");
487 if (!areSameBitwidthScalarType(srcElemType, dstElemType))
488 return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
489
490 Location loc = storeOp.getLoc();
491 Value value = adaptor.value();
492 if (srcElemType != dstElemType)
493 value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
494 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
495 storeOp->getAttrs());
496 return success();
497 }
498 };
499
500 //===----------------------------------------------------------------------===//
501 // Pass
502 //===----------------------------------------------------------------------===//
503
504 namespace {
505 class UnifyAliasedResourcePass final
506 : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
507 public:
508 void runOnOperation() override;
509 };
510 } // namespace
511
runOnOperation()512 void UnifyAliasedResourcePass::runOnOperation() {
513 spirv::ModuleOp moduleOp = getOperation();
514 MLIRContext *context = &getContext();
515
516 // Analyze aliased resources first.
517 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
518
519 ConversionTarget target(*context);
520 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
521 spirv::AccessChainOp, spirv::LoadOp,
522 spirv::StoreOp>(
523 [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
524 target.addLegalDialect<spirv::SPIRVDialect>();
525
526 // Run patterns to rewrite usages of non-canonical resources.
527 RewritePatternSet patterns(context);
528 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
529 ConvertLoad, ConvertStore>(analysis, context);
530 if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
531 return signalPassFailure();
532
533 // Drop aliased attribute if we only have one single bound resource for a
534 // descriptor. We need to re-collect the map here given in the above the
535 // conversion is best effort; certain sets may not be converted.
536 AliasedResourceMap resourceMap =
537 collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
538 for (const auto &dr : resourceMap) {
539 const auto &resources = dr.second;
540 if (resources.size() == 1)
541 resources.front()->removeAttr("aliased");
542 }
543 }
544
545 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
createUnifyAliasedResourcePass()546 spirv::createUnifyAliasedResourcePass() {
547 return std::make_unique<UnifyAliasedResourcePass>();
548 }
549