1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass that unifies access of multiple aliased resources 10 // into access of one single resource. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinAttributes.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/SymbolTable.h" 23 #include "mlir/Pass/AnalysisManager.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/DenseMap.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/Support/Debug.h" 28 #include <algorithm> 29 30 #define DEBUG_TYPE "spirv-unify-aliased-resource" 31 32 using namespace mlir; 33 34 //===----------------------------------------------------------------------===// 35 // Utility functions 36 //===----------------------------------------------------------------------===// 37 38 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #) 39 using AliasedResourceMap = 40 DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>; 41 42 /// Collects all aliased resources in the given SPIR-V `moduleOp`. 43 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) { 44 AliasedResourceMap aliasedResoruces; 45 moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) { 46 if (varOp->getAttrOfType<UnitAttr>("aliased")) { 47 Optional<uint32_t> set = varOp.descriptor_set(); 48 Optional<uint32_t> binding = varOp.binding(); 49 if (set && binding) 50 aliasedResoruces[{*set, *binding}].push_back(varOp); 51 } 52 }); 53 return aliasedResoruces; 54 } 55 56 /// Returns the element type if the given `type` is a runtime array resource: 57 /// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise. 58 static Type getRuntimeArrayElementType(Type type) { 59 auto ptrType = type.dyn_cast<spirv::PointerType>(); 60 if (!ptrType) 61 return {}; 62 63 auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 64 if (!structType || structType.getNumElements() != 1) 65 return {}; 66 67 auto rtArrayType = 68 structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>(); 69 if (!rtArrayType) 70 return {}; 71 72 return rtArrayType.getElementType(); 73 } 74 75 /// Returns true if all `types`, which can either be scalar or vector types, 76 /// have the same bitwidth base scalar type. 77 static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) { 78 SmallVector<int64_t> scalarTypes; 79 scalarTypes.reserve(types.size()); 80 for (spirv::SPIRVType type : types) { 81 assert(type.isScalarOrVector()); 82 if (auto vectorType = type.dyn_cast<VectorType>()) 83 scalarTypes.push_back( 84 vectorType.getElementType().getIntOrFloatBitWidth()); 85 else 86 scalarTypes.push_back(type.getIntOrFloatBitWidth()); 87 } 88 return llvm::is_splat(scalarTypes); 89 } 90 91 //===----------------------------------------------------------------------===// 92 // Analysis 93 //===----------------------------------------------------------------------===// 94 95 namespace { 96 /// A class for analyzing aliased resources. 97 /// 98 /// Resources are expected to be spv.GlobalVarible that has a descriptor set and 99 /// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>` 100 /// per Vulkan requirements. 101 /// 102 /// Right now, we only support the case that there is a single runtime array 103 /// inside the struct. 104 class ResourceAliasAnalysis { 105 public: 106 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis) 107 108 explicit ResourceAliasAnalysis(Operation *); 109 110 /// Returns true if the given `op` can be rewritten to use a canonical 111 /// resource. 112 bool shouldUnify(Operation *op) const; 113 114 /// Returns all descriptors and their corresponding aliased resources. 115 const AliasedResourceMap &getResourceMap() const { return resourceMap; } 116 117 /// Returns the canonical resource for the given descriptor/variable. 118 spirv::GlobalVariableOp 119 getCanonicalResource(const Descriptor &descriptor) const; 120 spirv::GlobalVariableOp 121 getCanonicalResource(spirv::GlobalVariableOp varOp) const; 122 123 /// Returns the element type for the given variable. 124 spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const; 125 126 private: 127 /// Given the descriptor and aliased resources bound to it, analyze whether we 128 /// can unify them and record if so. 129 void recordIfUnifiable(const Descriptor &descriptor, 130 ArrayRef<spirv::GlobalVariableOp> resources); 131 132 /// Mapping from a descriptor to all aliased resources bound to it. 133 AliasedResourceMap resourceMap; 134 135 /// Mapping from a descriptor to the chosen canonical resource. 136 DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap; 137 138 /// Mapping from an aliased resource to its descriptor. 139 DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap; 140 141 /// Mapping from an aliased resource to its element (scalar/vector) type. 142 DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap; 143 }; 144 } // namespace 145 146 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) { 147 // Collect all aliased resources first and put them into different sets 148 // according to the descriptor. 149 AliasedResourceMap aliasedResoruces = 150 collectAliasedResources(cast<spirv::ModuleOp>(root)); 151 152 // For each resource set, analyze whether we can unify; if so, try to identify 153 // a canonical resource, whose element type has the largest bitwidth. 154 for (const auto &descriptorResoruce : aliasedResoruces) { 155 recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second); 156 } 157 } 158 159 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const { 160 if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { 161 auto canonicalOp = getCanonicalResource(varOp); 162 return canonicalOp && varOp != canonicalOp; 163 } 164 if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) { 165 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); 166 auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()); 167 return shouldUnify(varOp); 168 } 169 170 if (auto acOp = dyn_cast<spirv::AccessChainOp>(op)) 171 return shouldUnify(acOp.base_ptr().getDefiningOp()); 172 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) 173 return shouldUnify(loadOp.ptr().getDefiningOp()); 174 if (auto storeOp = dyn_cast<spirv::StoreOp>(op)) 175 return shouldUnify(storeOp.ptr().getDefiningOp()); 176 177 return false; 178 } 179 180 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( 181 const Descriptor &descriptor) const { 182 auto varIt = canonicalResourceMap.find(descriptor); 183 if (varIt == canonicalResourceMap.end()) 184 return {}; 185 return varIt->second; 186 } 187 188 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( 189 spirv::GlobalVariableOp varOp) const { 190 auto descriptorIt = descriptorMap.find(varOp); 191 if (descriptorIt == descriptorMap.end()) 192 return {}; 193 return getCanonicalResource(descriptorIt->second); 194 } 195 196 spirv::SPIRVType 197 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const { 198 auto it = elementTypeMap.find(varOp); 199 if (it == elementTypeMap.end()) 200 return {}; 201 return it->second; 202 } 203 204 void ResourceAliasAnalysis::recordIfUnifiable( 205 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) { 206 // Collect the element types and byte counts for all resources in the 207 // current set. 208 SmallVector<spirv::SPIRVType> elementTypes; 209 SmallVector<int64_t> numBytes; 210 211 for (spirv::GlobalVariableOp resource : resources) { 212 Type elementType = getRuntimeArrayElementType(resource.type()); 213 if (!elementType) 214 return; // Unexpected resource variable type. 215 216 auto type = elementType.cast<spirv::SPIRVType>(); 217 if (!type.isScalarOrVector()) 218 return; // Unexpected resource element type. 219 220 if (auto vectorType = type.dyn_cast<VectorType>()) 221 if (vectorType.getNumElements() % 2 != 0) 222 return; // Odd-sized vector has special layout requirements. 223 224 Optional<int64_t> count = type.getSizeInBytes(); 225 if (!count) 226 return; 227 228 elementTypes.push_back(type); 229 numBytes.push_back(*count); 230 } 231 232 // Make sure base scalar types have the same bitwdith, so that we don't need 233 // to handle extracting components for now. 234 if (!hasSameBitwidthScalarType(elementTypes)) 235 return; 236 237 // Make sure that the canonical resource's bitwidth is divisible by others. 238 // With out this, we cannot properly adjust the index later. 239 auto *maxCount = std::max_element(numBytes.begin(), numBytes.end()); 240 if (llvm::any_of(numBytes, [maxCount](int64_t count) { 241 return *maxCount % count != 0; 242 })) 243 return; 244 245 spirv::GlobalVariableOp canonicalResource = 246 resources[std::distance(numBytes.begin(), maxCount)]; 247 248 // Update internal data structures for later use. 249 resourceMap[descriptor].assign(resources.begin(), resources.end()); 250 canonicalResourceMap[descriptor] = canonicalResource; 251 for (const auto &resource : llvm::enumerate(resources)) { 252 descriptorMap[resource.value()] = descriptor; 253 elementTypeMap[resource.value()] = elementTypes[resource.index()]; 254 } 255 } 256 257 //===----------------------------------------------------------------------===// 258 // Patterns 259 //===----------------------------------------------------------------------===// 260 261 template <typename OpTy> 262 class ConvertAliasResoruce : public OpConversionPattern<OpTy> { 263 public: 264 ConvertAliasResoruce(const ResourceAliasAnalysis &analysis, 265 MLIRContext *context, PatternBenefit benefit = 1) 266 : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {} 267 268 protected: 269 const ResourceAliasAnalysis &analysis; 270 }; 271 272 struct ConvertVariable : public ConvertAliasResoruce<spirv::GlobalVariableOp> { 273 using ConvertAliasResoruce::ConvertAliasResoruce; 274 275 LogicalResult 276 matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, 277 ConversionPatternRewriter &rewriter) const override { 278 // Just remove the aliased resource. Users will be rewritten to use the 279 // canonical one. 280 rewriter.eraseOp(varOp); 281 return success(); 282 } 283 }; 284 285 struct ConvertAddressOf : public ConvertAliasResoruce<spirv::AddressOfOp> { 286 using ConvertAliasResoruce::ConvertAliasResoruce; 287 288 LogicalResult 289 matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, 290 ConversionPatternRewriter &rewriter) const override { 291 // Rewrite the AddressOf op to get the address of the canoncical resource. 292 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); 293 auto srcVarOp = cast<spirv::GlobalVariableOp>( 294 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); 295 auto dstVarOp = analysis.getCanonicalResource(srcVarOp); 296 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp); 297 return success(); 298 } 299 }; 300 301 struct ConvertAccessChain : public ConvertAliasResoruce<spirv::AccessChainOp> { 302 using ConvertAliasResoruce::ConvertAliasResoruce; 303 304 LogicalResult 305 matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, 306 ConversionPatternRewriter &rewriter) const override { 307 auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>(); 308 if (!addressOp) 309 return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op"); 310 311 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>(); 312 auto srcVarOp = cast<spirv::GlobalVariableOp>( 313 SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); 314 auto dstVarOp = analysis.getCanonicalResource(srcVarOp); 315 316 spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); 317 spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp); 318 319 if ((srcElemType == dstElemType) || 320 (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) { 321 // We have the same bitwidth for source and destination element types. 322 // Thie indices keep the same. 323 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( 324 acOp, adaptor.base_ptr(), adaptor.indices()); 325 return success(); 326 } 327 328 Location loc = acOp.getLoc(); 329 auto i32Type = rewriter.getI32Type(); 330 331 if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) { 332 // The source indices are for a buffer with scalar element types. Rewrite 333 // them into a buffer with vector element types. We need to scale the last 334 // index for the vector as a whole, then add one level of index for inside 335 // the vector. 336 int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes(); 337 auto ratioValue = rewriter.create<spirv::ConstantOp>( 338 loc, i32Type, rewriter.getI32IntegerAttr(ratio)); 339 340 auto indices = llvm::to_vector<4>(acOp.indices()); 341 Value oldIndex = indices.back(); 342 indices.back() = 343 rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue); 344 indices.push_back( 345 rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue)); 346 347 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( 348 acOp, adaptor.base_ptr(), indices); 349 return success(); 350 } 351 352 return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types"); 353 } 354 }; 355 356 struct ConvertLoad : public ConvertAliasResoruce<spirv::LoadOp> { 357 using ConvertAliasResoruce::ConvertAliasResoruce; 358 359 LogicalResult 360 matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, 361 ConversionPatternRewriter &rewriter) const override { 362 auto srcElemType = 363 loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 364 auto dstElemType = 365 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 366 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) 367 return rewriter.notifyMatchFailure(loadOp, "not scalar type"); 368 369 Location loc = loadOp.getLoc(); 370 auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr()); 371 if (srcElemType == dstElemType) { 372 rewriter.replaceOp(loadOp, newLoadOp->getResults()); 373 } else { 374 auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType, 375 newLoadOp.value()); 376 rewriter.replaceOp(loadOp, castOp->getResults()); 377 } 378 379 return success(); 380 } 381 }; 382 383 struct ConvertStore : public ConvertAliasResoruce<spirv::StoreOp> { 384 using ConvertAliasResoruce::ConvertAliasResoruce; 385 386 LogicalResult 387 matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, 388 ConversionPatternRewriter &rewriter) const override { 389 auto srcElemType = 390 storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 391 auto dstElemType = 392 adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType(); 393 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) 394 return rewriter.notifyMatchFailure(storeOp, "not scalar type"); 395 396 Location loc = storeOp.getLoc(); 397 Value value = adaptor.value(); 398 if (srcElemType != dstElemType) 399 value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value); 400 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value, 401 storeOp->getAttrs()); 402 return success(); 403 } 404 }; 405 406 //===----------------------------------------------------------------------===// 407 // Pass 408 //===----------------------------------------------------------------------===// 409 410 namespace { 411 class UnifyAliasedResourcePass final 412 : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> { 413 public: 414 void runOnOperation() override; 415 }; 416 } // namespace 417 418 void UnifyAliasedResourcePass::runOnOperation() { 419 spirv::ModuleOp moduleOp = getOperation(); 420 MLIRContext *context = &getContext(); 421 422 // Analyze aliased resources first. 423 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>(); 424 425 ConversionTarget target(*context); 426 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp, 427 spirv::AccessChainOp, spirv::LoadOp, 428 spirv::StoreOp>( 429 [&analysis](Operation *op) { return !analysis.shouldUnify(op); }); 430 target.addLegalDialect<spirv::SPIRVDialect>(); 431 432 // Run patterns to rewrite usages of non-canonical resources. 433 RewritePatternSet patterns(context); 434 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain, 435 ConvertLoad, ConvertStore>(analysis, context); 436 if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) 437 return signalPassFailure(); 438 439 // Drop aliased attribute if we only have one single bound resource for a 440 // descriptor. We need to re-collect the map here given in the above the 441 // conversion is best effort; certain sets may not be converted. 442 AliasedResourceMap resourceMap = 443 collectAliasedResources(cast<spirv::ModuleOp>(moduleOp)); 444 for (const auto &dr : resourceMap) { 445 const auto &resources = dr.second; 446 if (resources.size() == 1) 447 resources.front()->removeAttr("aliased"); 448 } 449 } 450 451 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>> 452 spirv::createUnifyAliasedResourcePass() { 453 return std::make_unique<UnifyAliasedResourcePass>(); 454 } 455