1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that unifies access of multiple aliased resources
10 // into access of one single resource.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/SymbolTable.h"
23 #include "mlir/Pass/AnalysisManager.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
28 #include <algorithm>
29 
30 #define DEBUG_TYPE "spirv-unify-aliased-resource"
31 
32 using namespace mlir;
33 
34 //===----------------------------------------------------------------------===//
35 // Utility functions
36 //===----------------------------------------------------------------------===//
37 
38 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
39 using AliasedResourceMap =
40     DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
41 
42 /// Collects all aliased resources in the given SPIR-V `moduleOp`.
43 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
44   AliasedResourceMap aliasedResoruces;
45   moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) {
46     if (varOp->getAttrOfType<UnitAttr>("aliased")) {
47       Optional<uint32_t> set = varOp.descriptor_set();
48       Optional<uint32_t> binding = varOp.binding();
49       if (set && binding)
50         aliasedResoruces[{*set, *binding}].push_back(varOp);
51     }
52   });
53   return aliasedResoruces;
54 }
55 
56 /// Returns the element type if the given `type` is a runtime array resource:
57 /// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise.
58 static Type getRuntimeArrayElementType(Type type) {
59   auto ptrType = type.dyn_cast<spirv::PointerType>();
60   if (!ptrType)
61     return {};
62 
63   auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
64   if (!structType || structType.getNumElements() != 1)
65     return {};
66 
67   auto rtArrayType =
68       structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
69   if (!rtArrayType)
70     return {};
71 
72   return rtArrayType.getElementType();
73 }
74 
75 /// Returns true if all `types`, which can either be scalar or vector types,
76 /// have the same bitwidth base scalar type.
77 static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) {
78   SmallVector<int64_t> scalarTypes;
79   scalarTypes.reserve(types.size());
80   for (spirv::SPIRVType type : types) {
81     assert(type.isScalarOrVector());
82     if (auto vectorType = type.dyn_cast<VectorType>())
83       scalarTypes.push_back(
84           vectorType.getElementType().getIntOrFloatBitWidth());
85     else
86       scalarTypes.push_back(type.getIntOrFloatBitWidth());
87   }
88   return llvm::is_splat(scalarTypes);
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Analysis
93 //===----------------------------------------------------------------------===//
94 
95 namespace {
96 /// A class for analyzing aliased resources.
97 ///
98 /// Resources are expected to be spv.GlobalVarible that has a descriptor set and
99 /// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>`
100 /// per Vulkan requirements.
101 ///
102 /// Right now, we only support the case that there is a single runtime array
103 /// inside the struct.
104 class ResourceAliasAnalysis {
105 public:
106   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
107 
108   explicit ResourceAliasAnalysis(Operation *);
109 
110   /// Returns true if the given `op` can be rewritten to use a canonical
111   /// resource.
112   bool shouldUnify(Operation *op) const;
113 
114   /// Returns all descriptors and their corresponding aliased resources.
115   const AliasedResourceMap &getResourceMap() const { return resourceMap; }
116 
117   /// Returns the canonical resource for the given descriptor/variable.
118   spirv::GlobalVariableOp
119   getCanonicalResource(const Descriptor &descriptor) const;
120   spirv::GlobalVariableOp
121   getCanonicalResource(spirv::GlobalVariableOp varOp) const;
122 
123   /// Returns the element type for the given variable.
124   spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
125 
126 private:
127   /// Given the descriptor and aliased resources bound to it, analyze whether we
128   /// can unify them and record if so.
129   void recordIfUnifiable(const Descriptor &descriptor,
130                          ArrayRef<spirv::GlobalVariableOp> resources);
131 
132   /// Mapping from a descriptor to all aliased resources bound to it.
133   AliasedResourceMap resourceMap;
134 
135   /// Mapping from a descriptor to the chosen canonical resource.
136   DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
137 
138   /// Mapping from an aliased resource to its descriptor.
139   DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
140 
141   /// Mapping from an aliased resource to its element (scalar/vector) type.
142   DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
143 };
144 } // namespace
145 
146 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
147   // Collect all aliased resources first and put them into different sets
148   // according to the descriptor.
149   AliasedResourceMap aliasedResoruces =
150       collectAliasedResources(cast<spirv::ModuleOp>(root));
151 
152   // For each resource set, analyze whether we can unify; if so, try to identify
153   // a canonical resource, whose element type has the largest bitwidth.
154   for (const auto &descriptorResoruce : aliasedResoruces) {
155     recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second);
156   }
157 }
158 
159 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
160   if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
161     auto canonicalOp = getCanonicalResource(varOp);
162     return canonicalOp && varOp != canonicalOp;
163   }
164   if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
165     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
166     auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
167     return shouldUnify(varOp);
168   }
169 
170   if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
171     return shouldUnify(acOp.base_ptr().getDefiningOp());
172   if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
173     return shouldUnify(loadOp.ptr().getDefiningOp());
174   if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
175     return shouldUnify(storeOp.ptr().getDefiningOp());
176 
177   return false;
178 }
179 
180 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
181     const Descriptor &descriptor) const {
182   auto varIt = canonicalResourceMap.find(descriptor);
183   if (varIt == canonicalResourceMap.end())
184     return {};
185   return varIt->second;
186 }
187 
188 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
189     spirv::GlobalVariableOp varOp) const {
190   auto descriptorIt = descriptorMap.find(varOp);
191   if (descriptorIt == descriptorMap.end())
192     return {};
193   return getCanonicalResource(descriptorIt->second);
194 }
195 
196 spirv::SPIRVType
197 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
198   auto it = elementTypeMap.find(varOp);
199   if (it == elementTypeMap.end())
200     return {};
201   return it->second;
202 }
203 
204 void ResourceAliasAnalysis::recordIfUnifiable(
205     const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
206   // Collect the element types and byte counts for all resources in the
207   // current set.
208   SmallVector<spirv::SPIRVType> elementTypes;
209   SmallVector<int64_t> numBytes;
210 
211   for (spirv::GlobalVariableOp resource : resources) {
212     Type elementType = getRuntimeArrayElementType(resource.type());
213     if (!elementType)
214       return; // Unexpected resource variable type.
215 
216     auto type = elementType.cast<spirv::SPIRVType>();
217     if (!type.isScalarOrVector())
218       return; // Unexpected resource element type.
219 
220     if (auto vectorType = type.dyn_cast<VectorType>())
221       if (vectorType.getNumElements() % 2 != 0)
222         return; // Odd-sized vector has special layout requirements.
223 
224     Optional<int64_t> count = type.getSizeInBytes();
225     if (!count)
226       return;
227 
228     elementTypes.push_back(type);
229     numBytes.push_back(*count);
230   }
231 
232   // Make sure base scalar types have the same bitwdith, so that we don't need
233   // to handle extracting components for now.
234   if (!hasSameBitwidthScalarType(elementTypes))
235     return;
236 
237   // Make sure that the canonical resource's bitwidth is divisible by others.
238   // With out this, we cannot properly adjust the index later.
239   auto *maxCount = std::max_element(numBytes.begin(), numBytes.end());
240   if (llvm::any_of(numBytes, [maxCount](int64_t count) {
241         return *maxCount % count != 0;
242       }))
243     return;
244 
245   spirv::GlobalVariableOp canonicalResource =
246       resources[std::distance(numBytes.begin(), maxCount)];
247 
248   // Update internal data structures for later use.
249   resourceMap[descriptor].assign(resources.begin(), resources.end());
250   canonicalResourceMap[descriptor] = canonicalResource;
251   for (const auto &resource : llvm::enumerate(resources)) {
252     descriptorMap[resource.value()] = descriptor;
253     elementTypeMap[resource.value()] = elementTypes[resource.index()];
254   }
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // Patterns
259 //===----------------------------------------------------------------------===//
260 
261 template <typename OpTy>
262 class ConvertAliasResoruce : public OpConversionPattern<OpTy> {
263 public:
264   ConvertAliasResoruce(const ResourceAliasAnalysis &analysis,
265                        MLIRContext *context, PatternBenefit benefit = 1)
266       : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
267 
268 protected:
269   const ResourceAliasAnalysis &analysis;
270 };
271 
272 struct ConvertVariable : public ConvertAliasResoruce<spirv::GlobalVariableOp> {
273   using ConvertAliasResoruce::ConvertAliasResoruce;
274 
275   LogicalResult
276   matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
277                   ConversionPatternRewriter &rewriter) const override {
278     // Just remove the aliased resource. Users will be rewritten to use the
279     // canonical one.
280     rewriter.eraseOp(varOp);
281     return success();
282   }
283 };
284 
285 struct ConvertAddressOf : public ConvertAliasResoruce<spirv::AddressOfOp> {
286   using ConvertAliasResoruce::ConvertAliasResoruce;
287 
288   LogicalResult
289   matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
290                   ConversionPatternRewriter &rewriter) const override {
291     // Rewrite the AddressOf op to get the address of the canoncical resource.
292     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
293     auto srcVarOp = cast<spirv::GlobalVariableOp>(
294         SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
295     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
296     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
297     return success();
298   }
299 };
300 
301 struct ConvertAccessChain : public ConvertAliasResoruce<spirv::AccessChainOp> {
302   using ConvertAliasResoruce::ConvertAliasResoruce;
303 
304   LogicalResult
305   matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
306                   ConversionPatternRewriter &rewriter) const override {
307     auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
308     if (!addressOp)
309       return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
310 
311     auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
312     auto srcVarOp = cast<spirv::GlobalVariableOp>(
313         SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
314     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
315 
316     spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
317     spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
318 
319     if ((srcElemType == dstElemType) ||
320         (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) {
321       // We have the same bitwidth for source and destination element types.
322       // Thie indices keep the same.
323       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
324           acOp, adaptor.base_ptr(), adaptor.indices());
325       return success();
326     }
327 
328     Location loc = acOp.getLoc();
329     auto i32Type = rewriter.getI32Type();
330 
331     if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
332       // The source indices are for a buffer with scalar element types. Rewrite
333       // them into a buffer with vector element types. We need to scale the last
334       // index for the vector as a whole, then add one level of index for inside
335       // the vector.
336       int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes();
337       auto ratioValue = rewriter.create<spirv::ConstantOp>(
338           loc, i32Type, rewriter.getI32IntegerAttr(ratio));
339 
340       auto indices = llvm::to_vector<4>(acOp.indices());
341       Value oldIndex = indices.back();
342       indices.back() =
343           rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
344       indices.push_back(
345           rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
346 
347       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
348           acOp, adaptor.base_ptr(), indices);
349       return success();
350     }
351 
352     return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
353   }
354 };
355 
356 struct ConvertLoad : public ConvertAliasResoruce<spirv::LoadOp> {
357   using ConvertAliasResoruce::ConvertAliasResoruce;
358 
359   LogicalResult
360   matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
361                   ConversionPatternRewriter &rewriter) const override {
362     auto srcElemType =
363         loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
364     auto dstElemType =
365         adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
366     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
367       return rewriter.notifyMatchFailure(loadOp, "not scalar type");
368 
369     Location loc = loadOp.getLoc();
370     auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
371     if (srcElemType == dstElemType) {
372       rewriter.replaceOp(loadOp, newLoadOp->getResults());
373     } else {
374       auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
375                                                       newLoadOp.value());
376       rewriter.replaceOp(loadOp, castOp->getResults());
377     }
378 
379     return success();
380   }
381 };
382 
383 struct ConvertStore : public ConvertAliasResoruce<spirv::StoreOp> {
384   using ConvertAliasResoruce::ConvertAliasResoruce;
385 
386   LogicalResult
387   matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
388                   ConversionPatternRewriter &rewriter) const override {
389     auto srcElemType =
390         storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
391     auto dstElemType =
392         adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
393     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
394       return rewriter.notifyMatchFailure(storeOp, "not scalar type");
395 
396     Location loc = storeOp.getLoc();
397     Value value = adaptor.value();
398     if (srcElemType != dstElemType)
399       value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
400     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
401                                                 storeOp->getAttrs());
402     return success();
403   }
404 };
405 
406 //===----------------------------------------------------------------------===//
407 // Pass
408 //===----------------------------------------------------------------------===//
409 
410 namespace {
411 class UnifyAliasedResourcePass final
412     : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
413 public:
414   void runOnOperation() override;
415 };
416 } // namespace
417 
418 void UnifyAliasedResourcePass::runOnOperation() {
419   spirv::ModuleOp moduleOp = getOperation();
420   MLIRContext *context = &getContext();
421 
422   // Analyze aliased resources first.
423   ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
424 
425   ConversionTarget target(*context);
426   target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
427                                spirv::AccessChainOp, spirv::LoadOp,
428                                spirv::StoreOp>(
429       [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
430   target.addLegalDialect<spirv::SPIRVDialect>();
431 
432   // Run patterns to rewrite usages of non-canonical resources.
433   RewritePatternSet patterns(context);
434   patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
435                ConvertLoad, ConvertStore>(analysis, context);
436   if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
437     return signalPassFailure();
438 
439   // Drop aliased attribute if we only have one single bound resource for a
440   // descriptor. We need to re-collect the map here given in the above the
441   // conversion is best effort; certain sets may not be converted.
442   AliasedResourceMap resourceMap =
443       collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
444   for (const auto &dr : resourceMap) {
445     const auto &resources = dr.second;
446     if (resources.size() == 1)
447       resources.front()->removeAttr("aliased");
448   }
449 }
450 
451 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
452 spirv::createUnifyAliasedResourcePass() {
453   return std::make_unique<UnifyAliasedResourcePass>();
454 }
455