1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass that unifies access of multiple aliased resources 10 // into access of one single resource. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinAttributes.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/SymbolTable.h" 23 #include "mlir/Pass/AnalysisManager.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/DenseMap.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/Support/Debug.h" 28 #include <algorithm> 29 #include <iterator> 30 31 #define DEBUG_TYPE "spirv-unify-aliased-resource" 32 33 using namespace mlir; 34 35 //===----------------------------------------------------------------------===// 36 // Utility functions 37 //===----------------------------------------------------------------------===// 38 39 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #) 40 using AliasedResourceMap = 41 DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>; 42 43 /// Collects all aliased resources in the given SPIR-V `moduleOp`. 44 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) { 45 AliasedResourceMap aliasedResources; 46 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) { 47 if (varOp->getAttrOfType<UnitAttr>("aliased")) { 48 Optional<uint32_t> set = varOp.descriptor_set(); 49 Optional<uint32_t> binding = varOp.binding(); 50 if (set && binding) 51 aliasedResources[{*set, *binding}].push_back(varOp); 52 } 53 }); 54 return aliasedResources; 55 } 56 57 /// Returns the element type if the given `type` is a runtime array resource: 58 /// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise. 59 static Type getRuntimeArrayElementType(Type type) { 60 auto ptrType = type.dyn_cast<spirv::PointerType>(); 61 if (!ptrType) 62 return {}; 63 64 auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 65 if (!structType || structType.getNumElements() != 1) 66 return {}; 67 68 auto rtArrayType = 69 structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>(); 70 if (!rtArrayType) 71 return {}; 72 73 return rtArrayType.getElementType(); 74 } 75 76 /// Given a list of resource element `types`, returns the index of the canonical 77 /// resource that all resources should be unified into. Returns llvm::None if 78 /// unable to unify. 79 static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) { 80 SmallVector<int> scalarNumBits, totalNumBits; 81 scalarNumBits.reserve(types.size()); 82 totalNumBits.reserve(types.size()); 83 bool hasVector = false; 84 85 for (spirv::SPIRVType type : types) { 86 assert(type.isScalarOrVector()); 87 if (auto vectorType = type.dyn_cast<VectorType>()) { 88 if (vectorType.getNumElements() % 2 != 0) 89 return llvm::None; // Odd-sized vector has special layout requirements. 90 91 Optional<int64_t> numBytes = type.getSizeInBytes(); 92 if (!numBytes) 93 return llvm::None; 94 95 scalarNumBits.push_back( 96 vectorType.getElementType().getIntOrFloatBitWidth()); 97 totalNumBits.push_back(*numBytes * 8); 98 hasVector = true; 99 } else { 100 scalarNumBits.push_back(type.getIntOrFloatBitWidth()); 101 totalNumBits.push_back(scalarNumBits.back()); 102 } 103 } 104 105 if (hasVector) { 106 // If there are vector types, require all element types to be the same for 107 // now to simplify the transformation. 108 if (!llvm::is_splat(scalarNumBits)) 109 return llvm::None; 110 111 // Choose the one with the largest bitwidth as the canonical resource, so 112 // that we can still keep vectorized load/store. 113 auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end()); 114 // Make sure that the canonical resource's bitwidth is divisible by others. 115 // With out this, we cannot properly adjust the index later. 116 if (llvm::any_of(totalNumBits, 117 [maxVal](int64_t bits) { return *maxVal % bits != 0; })) 118 return llvm::None; 119 120 return std::distance(totalNumBits.begin(), maxVal); 121 } 122 123 // All element types are scalars. Then choose the smallest bitwidth as the 124 // cannonical resource to avoid subcomponent load/store. 125 auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end()); 126 if (llvm::any_of(scalarNumBits, 127 [minVal](int64_t bit) { return bit % *minVal != 0; })) 128 return llvm::None; 129 return std::distance(scalarNumBits.begin(), minVal); 130 } 131 132 static bool areSameBitwidthScalarType(Type a, Type b) { 133 return a.isIntOrFloat() && b.isIntOrFloat() && 134 a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth(); 135 } 136 137 //===----------------------------------------------------------------------===// 138 // Analysis 139 //===----------------------------------------------------------------------===// 140 141 namespace { 142 /// A class for analyzing aliased resources. 143 /// 144 /// Resources are expected to be spv.GlobalVarible that has a descriptor set and 145 /// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>` 146 /// per Vulkan requirements. 147 /// 148 /// Right now, we only support the case that there is a single runtime array 149 /// inside the struct. 150 class ResourceAliasAnalysis { 151 public: 152 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis) 153 154 explicit ResourceAliasAnalysis(Operation *); 155 156 /// Returns true if the given `op` can be rewritten to use a canonical 157 /// resource. 158 bool shouldUnify(Operation *op) const; 159 160 /// Returns all descriptors and their corresponding aliased resources. 161 const AliasedResourceMap &getResourceMap() const { return resourceMap; } 162 163 /// Returns the canonical resource for the given descriptor/variable. 164 spirv::GlobalVariableOp 165 getCanonicalResource(const Descriptor &descriptor) const; 166 spirv::GlobalVariableOp 167 getCanonicalResource(spirv::GlobalVariableOp varOp) const; 168 169 /// Returns the element type for the given variable. 170 spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const; 171 172 private: 173 /// Given the descriptor and aliased resources bound to it, analyze whether we 174 /// can unify them and record if so. 175 void recordIfUnifiable(const Descriptor &descriptor, 176 ArrayRef<spirv::GlobalVariableOp> resources); 177 178 /// Mapping from a descriptor to all aliased resources bound to it. 179 AliasedResourceMap resourceMap; 180 181 /// Mapping from a descriptor to the chosen canonical resource. 182 DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap; 183 184 /// Mapping from an aliased resource to its descriptor. 185 DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap; 186 187 /// Mapping from an aliased resource to its element (scalar/vector) type. 188 DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap; 189 }; 190 } // namespace 191 192 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) { 193 // Collect all aliased resources first and put them into different sets 194 // according to the descriptor. 195 AliasedResourceMap aliasedResources = 196 collectAliasedResources(cast<spirv::ModuleOp>(root)); 197 198 // For each resource set, analyze whether we can unify; if so, try to identify 199 // a canonical resource, whose element type has the largest bitwidth. 200 for (const auto &descriptorResource : aliasedResources) { 201 recordIfUnifiable(descriptorResource.first, descriptorResource.second); 202 } 203 } 204 205 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const { 206 if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { 207 auto canonicalOp = getCanonicalResource(varOp); 208 return canonicalOp && varOp != canonicalOp; 209 } 210 if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) { 211 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); 212 auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()); 213 return shouldUnify(varOp); 214 } 215 216 if (auto acOp = dyn_cast<spirv::AccessChainOp>(op)) 217 return shouldUnify(acOp.base_ptr().getDefiningOp()); 218 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) 219 return shouldUnify(loadOp.ptr().getDefiningOp()); 220 if (auto storeOp = dyn_cast<spirv::StoreOp>(op)) 221 return shouldUnify(storeOp.ptr().getDefiningOp()); 222 223 return false; 224 } 225 226 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( 227 const Descriptor &descriptor) const { 228 auto varIt = canonicalResourceMap.find(descriptor); 229 if (varIt == canonicalResourceMap.end()) 230 return {}; 231 return varIt->second; 232 } 233 234 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( 235 spirv::GlobalVariableOp varOp) const { 236 auto descriptorIt = descriptorMap.find(varOp); 237 if (descriptorIt == descriptorMap.end()) 238 return {}; 239 return getCanonicalResource(descriptorIt->second); 240 } 241 242 spirv::SPIRVType 243 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const { 244 auto it = elementTypeMap.find(varOp); 245 if (it == elementTypeMap.end()) 246 return {}; 247 return it->second; 248 } 249 250 void ResourceAliasAnalysis::recordIfUnifiable( 251 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) { 252 // Collect the element types for all resources in the current set. 253 SmallVector<spirv::SPIRVType> elementTypes; 254 for (spirv::GlobalVariableOp resource : resources) { 255 Type elementType = getRuntimeArrayElementType(resource.type()); 256 if (!elementType) 257 return; // Unexpected resource variable type. 258 259 auto type = elementType.cast<spirv::SPIRVType>(); 260 if (!type.isScalarOrVector()) 261 return; // Unexpected resource element type. 262 263 elementTypes.push_back(type); 264 } 265 266 Optional<int> index = deduceCanonicalResource(elementTypes); 267 if (!index) 268 return; 269 270 // Update internal data structures for later use. 271 resourceMap[descriptor].assign(resources.begin(), resources.end()); 272 canonicalResourceMap[descriptor] = resources[*index]; 273 for (const auto &resource : llvm::enumerate(resources)) { 274 descriptorMap[resource.value()] = descriptor; 275 elementTypeMap[resource.value()] = elementTypes[resource.index()]; 276 } 277 } 278 279 //===----------------------------------------------------------------------===// 280 // Patterns 281 //===----------------------------------------------------------------------===// 282 283 template <typename OpTy> 284 class ConvertAliasResource : public OpConversionPattern<OpTy> { 285 public: 286 ConvertAliasResource(const ResourceAliasAnalysis &analysis, 287 MLIRContext *context, PatternBenefit benefit = 1) 288 : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {} 289 290 protected: 291 const ResourceAliasAnalysis &analysis; 292 }; 293 294 struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> { 295 using ConvertAliasResource::ConvertAliasResource; 296 297 LogicalResult 298 matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, 299 ConversionPatternRewriter &rewriter) const override { 300 // Just remove the aliased resource. Users will be rewritten to use the 301 // canonical one. 302 rewriter.eraseOp(varOp); 303 return success(); 304 } 305 }; 306 307 struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> { 308 using ConvertAliasResource::ConvertAliasResource; 309 310 LogicalResult 311 matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, 312 ConversionPatternRewriter &rewriter) const override { 313 // Rewrite the AddressOf op to get the address of the canoncical resource. 314 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); 315 auto srcVarOp = cast<spirv::GlobalVariableOp>( 316 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); 317 auto dstVarOp = analysis.getCanonicalResource(srcVarOp); 318 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp); 319 return success(); 320 } 321 }; 322 323 struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> { 324 using ConvertAliasResource::ConvertAliasResource; 325 326 LogicalResult 327 matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, 328 ConversionPatternRewriter &rewriter) const override { 329 auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>(); 330 if (!addressOp) 331 return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op"); 332 333 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>(); 334 auto srcVarOp = cast<spirv::GlobalVariableOp>( 335 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); 336 auto dstVarOp = analysis.getCanonicalResource(srcVarOp); 337 338 spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); 339 spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp); 340 341 if (srcElemType == dstElemType || 342 areSameBitwidthScalarType(srcElemType, dstElemType)) { 343 // We have the same bitwidth for source and destination element types. 344 // Thie indices keep the same. 345 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( 346 acOp, adaptor.base_ptr(), adaptor.indices()); 347 return success(); 348 } 349 350 Location loc = acOp.getLoc(); 351 auto i32Type = rewriter.getI32Type(); 352 353 if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) { 354 // The source indices are for a buffer with scalar element types. Rewrite 355 // them into a buffer with vector element types. We need to scale the last 356 // index for the vector as a whole, then add one level of index for inside 357 // the vector. 358 int srcNumBits = *srcElemType.getSizeInBytes(); 359 int dstNumBits = *dstElemType.getSizeInBytes(); 360 assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0); 361 int ratio = dstNumBits / srcNumBits; 362 auto ratioValue = rewriter.create<spirv::ConstantOp>( 363 loc, i32Type, rewriter.getI32IntegerAttr(ratio)); 364 365 auto indices = llvm::to_vector<4>(acOp.indices()); 366 Value oldIndex = indices.back(); 367 indices.back() = 368 rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue); 369 indices.push_back( 370 rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue)); 371 372 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( 373 acOp, adaptor.base_ptr(), indices); 374 return success(); 375 } 376 377 if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) { 378 // The source indices are for a buffer with larger bitwidth scalar element 379 // types. Rewrite them into a buffer with smaller bitwidth element types. 380 // We only need to scale the last index. 381 int srcNumBits = *srcElemType.getSizeInBytes(); 382 int dstNumBits = *dstElemType.getSizeInBytes(); 383 assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); 384 int ratio = srcNumBits / dstNumBits; 385 auto ratioValue = rewriter.create<spirv::ConstantOp>( 386 loc, i32Type, rewriter.getI32IntegerAttr(ratio)); 387 388 auto indices = llvm::to_vector<4>(acOp.indices()); 389 Value oldIndex = indices.back(); 390 indices.back() = 391 rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue); 392 393 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( 394 acOp, adaptor.base_ptr(), indices); 395 return success(); 396 } 397 398 return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types"); 399 } 400 }; 401 402 struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { 403 using ConvertAliasResource::ConvertAliasResource; 404 405 LogicalResult 406 matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, 407 ConversionPatternRewriter &rewriter) const override { 408 auto srcElemType = 409 loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 410 auto dstElemType = 411 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 412 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) 413 return rewriter.notifyMatchFailure(loadOp, "not scalar type"); 414 415 Location loc = loadOp.getLoc(); 416 auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr()); 417 if (srcElemType == dstElemType) { 418 rewriter.replaceOp(loadOp, newLoadOp->getResults()); 419 return success(); 420 } 421 422 if (areSameBitwidthScalarType(srcElemType, dstElemType)) { 423 auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType, 424 newLoadOp.value()); 425 rewriter.replaceOp(loadOp, castOp->getResults()); 426 427 return success(); 428 } 429 430 // The source and destination have scalar types of different bitwidths. 431 // For such cases, we need to load multiple smaller bitwidth values and 432 // construct a larger bitwidth one. 433 434 int srcNumBits = srcElemType.getIntOrFloatBitWidth(); 435 int dstNumBits = dstElemType.getIntOrFloatBitWidth(); 436 assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); 437 int ratio = srcNumBits / dstNumBits; 438 if (ratio > 4) 439 return rewriter.notifyMatchFailure(loadOp, "more than 4 components"); 440 441 SmallVector<Value> components; 442 components.reserve(ratio); 443 components.push_back(newLoadOp); 444 445 auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>(); 446 if (!acOp) 447 return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain"); 448 449 auto i32Type = rewriter.getI32Type(); 450 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); 451 auto indices = llvm::to_vector<4>(acOp.indices()); 452 for (int i = 1; i < ratio; ++i) { 453 // Load all subsequent components belonging to this element. 454 indices.back() = rewriter.create<spirv::IAddOp>(loc, i32Type, 455 indices.back(), oneValue); 456 auto componentAcOp = 457 rewriter.create<spirv::AccessChainOp>(loc, acOp.base_ptr(), indices); 458 components.push_back(rewriter.create<spirv::LoadOp>(loc, componentAcOp)); 459 } 460 std::reverse(components.begin(), components.end()); // For little endian.. 461 462 // Create a vector of the components and then cast back to the larger 463 // bitwidth element type. 464 auto vectorType = VectorType::get({ratio}, dstElemType); 465 Value vectorValue = rewriter.create<spirv::CompositeConstructOp>( 466 loc, vectorType, components); 467 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(loadOp, srcElemType, 468 vectorValue); 469 return success(); 470 } 471 }; 472 473 struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> { 474 using ConvertAliasResource::ConvertAliasResource; 475 476 LogicalResult 477 matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, 478 ConversionPatternRewriter &rewriter) const override { 479 auto srcElemType = 480 storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 481 auto dstElemType = 482 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 483 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) 484 return rewriter.notifyMatchFailure(storeOp, "not scalar type"); 485 if (!areSameBitwidthScalarType(srcElemType, dstElemType)) 486 return rewriter.notifyMatchFailure(storeOp, "different bitwidth"); 487 488 Location loc = storeOp.getLoc(); 489 Value value = adaptor.value(); 490 if (srcElemType != dstElemType) 491 value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value); 492 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value, 493 storeOp->getAttrs()); 494 return success(); 495 } 496 }; 497 498 //===----------------------------------------------------------------------===// 499 // Pass 500 //===----------------------------------------------------------------------===// 501 502 namespace { 503 class UnifyAliasedResourcePass final 504 : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> { 505 public: 506 void runOnOperation() override; 507 }; 508 } // namespace 509 510 void UnifyAliasedResourcePass::runOnOperation() { 511 spirv::ModuleOp moduleOp = getOperation(); 512 MLIRContext *context = &getContext(); 513 514 // Analyze aliased resources first. 515 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>(); 516 517 ConversionTarget target(*context); 518 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp, 519 spirv::AccessChainOp, spirv::LoadOp, 520 spirv::StoreOp>( 521 [&analysis](Operation *op) { return !analysis.shouldUnify(op); }); 522 target.addLegalDialect<spirv::SPIRVDialect>(); 523 524 // Run patterns to rewrite usages of non-canonical resources. 525 RewritePatternSet patterns(context); 526 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain, 527 ConvertLoad, ConvertStore>(analysis, context); 528 if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) 529 return signalPassFailure(); 530 531 // Drop aliased attribute if we only have one single bound resource for a 532 // descriptor. We need to re-collect the map here given in the above the 533 // conversion is best effort; certain sets may not be converted. 534 AliasedResourceMap resourceMap = 535 collectAliasedResources(cast<spirv::ModuleOp>(moduleOp)); 536 for (const auto &dr : resourceMap) { 537 const auto &resources = dr.second; 538 if (resources.size() == 1) 539 resources.front()->removeAttr("aliased"); 540 } 541 } 542 543 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>> 544 spirv::createUnifyAliasedResourcePass() { 545 return std::make_unique<UnifyAliasedResourcePass>(); 546 } 547