1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass that unifies access of multiple aliased resources 10 // into access of one single resource. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinAttributes.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/SymbolTable.h" 23 #include "mlir/Pass/AnalysisManager.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/DenseMap.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/Support/Debug.h" 28 #include <algorithm> 29 30 #define DEBUG_TYPE "spirv-unify-aliased-resource" 31 32 using namespace mlir; 33 34 //===----------------------------------------------------------------------===// 35 // Utility functions 36 //===----------------------------------------------------------------------===// 37 38 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #) 39 using AliasedResourceMap = 40 DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>; 41 42 /// Collects all aliased resources in the given SPIR-V `moduleOp`. 43 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) { 44 AliasedResourceMap aliasedResoruces; 45 moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) { 46 if (varOp->getAttrOfType<UnitAttr>("aliased")) { 47 Optional<uint32_t> set = varOp.descriptor_set(); 48 Optional<uint32_t> binding = varOp.binding(); 49 if (set && binding) 50 aliasedResoruces[{*set, *binding}].push_back(varOp); 51 } 52 }); 53 return aliasedResoruces; 54 } 55 56 /// Returns the element type if the given `type` is a runtime array resource: 57 /// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise. 58 static Type getRuntimeArrayElementType(Type type) { 59 auto ptrType = type.dyn_cast<spirv::PointerType>(); 60 if (!ptrType) 61 return {}; 62 63 auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 64 if (!structType || structType.getNumElements() != 1) 65 return {}; 66 67 auto rtArrayType = 68 structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>(); 69 if (!rtArrayType) 70 return {}; 71 72 return rtArrayType.getElementType(); 73 } 74 75 /// Returns true if all `types`, which can either be scalar or vector types, 76 /// have the same bitwidth base scalar type. 77 static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) { 78 SmallVector<int64_t> scalarTypes; 79 scalarTypes.reserve(types.size()); 80 for (spirv::SPIRVType type : types) { 81 assert(type.isScalarOrVector()); 82 if (auto vectorType = type.dyn_cast<VectorType>()) 83 scalarTypes.push_back( 84 vectorType.getElementType().getIntOrFloatBitWidth()); 85 else 86 scalarTypes.push_back(type.getIntOrFloatBitWidth()); 87 } 88 return llvm::is_splat(scalarTypes); 89 } 90 91 //===----------------------------------------------------------------------===// 92 // Analysis 93 //===----------------------------------------------------------------------===// 94 95 namespace { 96 /// A class for analyzing aliased resources. 97 /// 98 /// Resources are expected to be spv.GlobalVarible that has a descriptor set and 99 /// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>` 100 /// per Vulkan requirements. 101 /// 102 /// Right now, we only support the case that there is a single runtime array 103 /// inside the struct. 104 class ResourceAliasAnalysis { 105 public: 106 explicit ResourceAliasAnalysis(Operation *); 107 108 /// Returns true if the given `op` can be rewritten to use a canonical 109 /// resource. 110 bool shouldUnify(Operation *op) const; 111 112 /// Returns all descriptors and their corresponding aliased resources. 113 const AliasedResourceMap &getResourceMap() const { return resourceMap; } 114 115 /// Returns the canonical resource for the given descriptor/variable. 116 spirv::GlobalVariableOp 117 getCanonicalResource(const Descriptor &descriptor) const; 118 spirv::GlobalVariableOp 119 getCanonicalResource(spirv::GlobalVariableOp varOp) const; 120 121 /// Returns the element type for the given variable. 122 spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const; 123 124 private: 125 /// Given the descriptor and aliased resources bound to it, analyze whether we 126 /// can unify them and record if so. 127 void recordIfUnifiable(const Descriptor &descriptor, 128 ArrayRef<spirv::GlobalVariableOp> resources); 129 130 /// Mapping from a descriptor to all aliased resources bound to it. 131 AliasedResourceMap resourceMap; 132 133 /// Mapping from a descriptor to the chosen canonical resource. 134 DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap; 135 136 /// Mapping from an aliased resource to its descriptor. 137 DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap; 138 139 /// Mapping from an aliased resource to its element (scalar/vector) type. 140 DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap; 141 }; 142 } // namespace 143 144 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) { 145 // Collect all aliased resources first and put them into different sets 146 // according to the descriptor. 147 AliasedResourceMap aliasedResoruces = 148 collectAliasedResources(cast<spirv::ModuleOp>(root)); 149 150 // For each resource set, analyze whether we can unify; if so, try to identify 151 // a canonical resource, whose element type has the largest bitwidth. 152 for (const auto &descriptorResoruce : aliasedResoruces) { 153 recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second); 154 } 155 } 156 157 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const { 158 if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { 159 auto canonicalOp = getCanonicalResource(varOp); 160 return canonicalOp && varOp != canonicalOp; 161 } 162 if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) { 163 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); 164 auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()); 165 return shouldUnify(varOp); 166 } 167 168 if (auto acOp = dyn_cast<spirv::AccessChainOp>(op)) 169 return shouldUnify(acOp.base_ptr().getDefiningOp()); 170 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) 171 return shouldUnify(loadOp.ptr().getDefiningOp()); 172 if (auto storeOp = dyn_cast<spirv::StoreOp>(op)) 173 return shouldUnify(storeOp.ptr().getDefiningOp()); 174 175 return false; 176 } 177 178 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( 179 const Descriptor &descriptor) const { 180 auto varIt = canonicalResourceMap.find(descriptor); 181 if (varIt == canonicalResourceMap.end()) 182 return {}; 183 return varIt->second; 184 } 185 186 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( 187 spirv::GlobalVariableOp varOp) const { 188 auto descriptorIt = descriptorMap.find(varOp); 189 if (descriptorIt == descriptorMap.end()) 190 return {}; 191 return getCanonicalResource(descriptorIt->second); 192 } 193 194 spirv::SPIRVType 195 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const { 196 auto it = elementTypeMap.find(varOp); 197 if (it == elementTypeMap.end()) 198 return {}; 199 return it->second; 200 } 201 202 void ResourceAliasAnalysis::recordIfUnifiable( 203 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) { 204 // Collect the element types and byte counts for all resources in the 205 // current set. 206 SmallVector<spirv::SPIRVType> elementTypes; 207 SmallVector<int64_t> numBytes; 208 209 for (spirv::GlobalVariableOp resource : resources) { 210 Type elementType = getRuntimeArrayElementType(resource.type()); 211 if (!elementType) 212 return; // Unexpected resource variable type. 213 214 auto type = elementType.cast<spirv::SPIRVType>(); 215 if (!type.isScalarOrVector()) 216 return; // Unexpected resource element type. 217 218 if (auto vectorType = type.dyn_cast<VectorType>()) 219 if (vectorType.getNumElements() % 2 != 0) 220 return; // Odd-sized vector has special layout requirements. 221 222 Optional<int64_t> count = type.getSizeInBytes(); 223 if (!count) 224 return; 225 226 elementTypes.push_back(type); 227 numBytes.push_back(*count); 228 } 229 230 // Make sure base scalar types have the same bitwdith, so that we don't need 231 // to handle extracting components for now. 232 if (!hasSameBitwidthScalarType(elementTypes)) 233 return; 234 235 // Make sure that the canonical resource's bitwidth is divisible by others. 236 // With out this, we cannot properly adjust the index later. 237 auto *maxCount = std::max_element(numBytes.begin(), numBytes.end()); 238 if (llvm::any_of(numBytes, [maxCount](int64_t count) { 239 return *maxCount % count != 0; 240 })) 241 return; 242 243 spirv::GlobalVariableOp canonicalResource = 244 resources[std::distance(numBytes.begin(), maxCount)]; 245 246 // Update internal data structures for later use. 247 resourceMap[descriptor].assign(resources.begin(), resources.end()); 248 canonicalResourceMap[descriptor] = canonicalResource; 249 for (const auto &resource : llvm::enumerate(resources)) { 250 descriptorMap[resource.value()] = descriptor; 251 elementTypeMap[resource.value()] = elementTypes[resource.index()]; 252 } 253 } 254 255 //===----------------------------------------------------------------------===// 256 // Patterns 257 //===----------------------------------------------------------------------===// 258 259 template <typename OpTy> 260 class ConvertAliasResoruce : public OpConversionPattern<OpTy> { 261 public: 262 ConvertAliasResoruce(const ResourceAliasAnalysis &analysis, 263 MLIRContext *context, PatternBenefit benefit = 1) 264 : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {} 265 266 protected: 267 const ResourceAliasAnalysis &analysis; 268 }; 269 270 struct ConvertVariable : public ConvertAliasResoruce<spirv::GlobalVariableOp> { 271 using ConvertAliasResoruce::ConvertAliasResoruce; 272 273 LogicalResult 274 matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, 275 ConversionPatternRewriter &rewriter) const override { 276 // Just remove the aliased resource. Users will be rewritten to use the 277 // canonical one. 278 rewriter.eraseOp(varOp); 279 return success(); 280 } 281 }; 282 283 struct ConvertAddressOf : public ConvertAliasResoruce<spirv::AddressOfOp> { 284 using ConvertAliasResoruce::ConvertAliasResoruce; 285 286 LogicalResult 287 matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, 288 ConversionPatternRewriter &rewriter) const override { 289 // Rewrite the AddressOf op to get the address of the canoncical resource. 290 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); 291 auto srcVarOp = cast<spirv::GlobalVariableOp>( 292 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); 293 auto dstVarOp = analysis.getCanonicalResource(srcVarOp); 294 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp); 295 return success(); 296 } 297 }; 298 299 struct ConvertAccessChain : public ConvertAliasResoruce<spirv::AccessChainOp> { 300 using ConvertAliasResoruce::ConvertAliasResoruce; 301 302 LogicalResult 303 matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, 304 ConversionPatternRewriter &rewriter) const override { 305 auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>(); 306 if (!addressOp) 307 return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op"); 308 309 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>(); 310 auto srcVarOp = cast<spirv::GlobalVariableOp>( 311 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); 312 auto dstVarOp = analysis.getCanonicalResource(srcVarOp); 313 314 spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); 315 spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp); 316 317 if ((srcElemType == dstElemType) || 318 (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) { 319 // We have the same bitwidth for source and destination element types. 320 // Thie indices keep the same. 321 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( 322 acOp, adaptor.base_ptr(), adaptor.indices()); 323 return success(); 324 } 325 326 Location loc = acOp.getLoc(); 327 auto i32Type = rewriter.getI32Type(); 328 329 if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) { 330 // The source indices are for a buffer with scalar element types. Rewrite 331 // them into a buffer with vector element types. We need to scale the last 332 // index for the vector as a whole, then add one level of index for inside 333 // the vector. 334 int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes(); 335 auto ratioValue = rewriter.create<spirv::ConstantOp>( 336 loc, i32Type, rewriter.getI32IntegerAttr(ratio)); 337 338 auto indices = llvm::to_vector<4>(acOp.indices()); 339 Value oldIndex = indices.back(); 340 indices.back() = 341 rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue); 342 indices.push_back( 343 rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue)); 344 345 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( 346 acOp, adaptor.base_ptr(), indices); 347 return success(); 348 } 349 350 return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types"); 351 } 352 }; 353 354 struct ConvertLoad : public ConvertAliasResoruce<spirv::LoadOp> { 355 using ConvertAliasResoruce::ConvertAliasResoruce; 356 357 LogicalResult 358 matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, 359 ConversionPatternRewriter &rewriter) const override { 360 auto srcElemType = 361 loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 362 auto dstElemType = 363 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 364 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) 365 return rewriter.notifyMatchFailure(loadOp, "not scalar type"); 366 367 Location loc = loadOp.getLoc(); 368 auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr()); 369 if (srcElemType == dstElemType) { 370 rewriter.replaceOp(loadOp, newLoadOp->getResults()); 371 } else { 372 auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType, 373 newLoadOp.value()); 374 rewriter.replaceOp(loadOp, castOp->getResults()); 375 } 376 377 return success(); 378 } 379 }; 380 381 struct ConvertStore : public ConvertAliasResoruce<spirv::StoreOp> { 382 using ConvertAliasResoruce::ConvertAliasResoruce; 383 384 LogicalResult 385 matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, 386 ConversionPatternRewriter &rewriter) const override { 387 auto srcElemType = 388 storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 389 auto dstElemType = 390 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 391 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) 392 return rewriter.notifyMatchFailure(storeOp, "not scalar type"); 393 394 Location loc = storeOp.getLoc(); 395 Value value = adaptor.value(); 396 if (srcElemType != dstElemType) 397 value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value); 398 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value, 399 storeOp->getAttrs()); 400 return success(); 401 } 402 }; 403 404 //===----------------------------------------------------------------------===// 405 // Pass 406 //===----------------------------------------------------------------------===// 407 408 namespace { 409 class UnifyAliasedResourcePass final 410 : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> { 411 public: 412 void runOnOperation() override; 413 }; 414 } // namespace 415 416 void UnifyAliasedResourcePass::runOnOperation() { 417 spirv::ModuleOp moduleOp = getOperation(); 418 MLIRContext *context = &getContext(); 419 420 // Analyze aliased resources first. 421 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>(); 422 423 ConversionTarget target(*context); 424 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp, 425 spirv::AccessChainOp, spirv::LoadOp, 426 spirv::StoreOp>( 427 [&analysis](Operation *op) { return !analysis.shouldUnify(op); }); 428 target.addLegalDialect<spirv::SPIRVDialect>(); 429 430 // Run patterns to rewrite usages of non-canonical resources. 431 RewritePatternSet patterns(context); 432 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain, 433 ConvertLoad, ConvertStore>(analysis, context); 434 if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) 435 return signalPassFailure(); 436 437 // Drop aliased attribute if we only have one single bound resource for a 438 // descriptor. We need to re-collect the map here given in the above the 439 // conversion is best effort; certain sets may not be converted. 440 AliasedResourceMap resourceMap = 441 collectAliasedResources(cast<spirv::ModuleOp>(moduleOp)); 442 for (const auto &dr : resourceMap) { 443 const auto &resources = dr.second; 444 if (resources.size() == 1) 445 resources.front()->removeAttr("aliased"); 446 } 447 } 448 449 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>> 450 spirv::createUnifyAliasedResourcePass() { 451 return std::make_unique<UnifyAliasedResourcePass>(); 452 } 453