1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "llvm/Support/Debug.h" 19 20 #define DEBUG_TYPE "memref-to-spirv-pattern" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Utility functions 26 //===----------------------------------------------------------------------===// 27 28 /// Returns the offset of the value in `targetBits` representation. 29 /// 30 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. 31 /// It's assumed to be non-negative. 32 /// 33 /// When accessing an element in the array treating as having elements of 34 /// `targetBits`, multiple values are loaded in the same time. The method 35 /// returns the offset where the `srcIdx` locates in the value. For example, if 36 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is 37 /// located at (x % 4) * 8. Because there are four elements in one i32, and one 38 /// element has 8 bits. 39 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, 40 int targetBits, OpBuilder &builder) { 41 assert(targetBits % sourceBits == 0); 42 IntegerType targetType = builder.getIntegerType(targetBits); 43 IntegerAttr idxAttr = 44 builder.getIntegerAttr(targetType, targetBits / sourceBits); 45 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr); 46 IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); 47 auto srcBitsValue = 48 builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr); 49 auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx); 50 return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue); 51 } 52 53 /// Returns an adjusted spirv::AccessChainOp. Based on the 54 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be 55 /// supported. During conversion if a memref of an unsupported type is used, 56 /// load/stores to this memref need to be modified to use a supported higher 57 /// bitwidth `targetBits` and extracting the required bits. For an accessing a 58 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the 59 /// bits needed. The extraction of the actual bits needed are handled 60 /// separately. Note that this only works for a 1-D tensor. 61 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, 62 spirv::AccessChainOp op, 63 int sourceBits, int targetBits, 64 OpBuilder &builder) { 65 assert(targetBits % sourceBits == 0); 66 const auto loc = op.getLoc(); 67 IntegerType targetType = builder.getIntegerType(targetBits); 68 IntegerAttr attr = 69 builder.getIntegerAttr(targetType, targetBits / sourceBits); 70 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr); 71 auto lastDim = op->getOperand(op.getNumOperands() - 1); 72 auto indices = llvm::to_vector<4>(op.indices()); 73 // There are two elements if this is a 1-D tensor. 74 assert(indices.size() == 2); 75 indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx); 76 Type t = typeConverter.convertType(op.component_ptr().getType()); 77 return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices); 78 } 79 80 /// Returns the shifted `targetBits`-bit value with the given offset. 81 static Value shiftValue(Location loc, Value value, Value offset, Value mask, 82 int targetBits, OpBuilder &builder) { 83 Type targetType = builder.getIntegerType(targetBits); 84 Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask); 85 return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result, 86 offset); 87 } 88 89 /// Returns true if the allocations of memref `type` generated from `allocOp` 90 /// can be lowered to SPIR-V. 91 static bool isAllocationSupported(Operation *allocOp, MemRefType type) { 92 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) { 93 if (SPIRVTypeConverter::getMemorySpaceForStorageClass( 94 spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt()) 95 return false; 96 } else if (isa<memref::AllocaOp>(allocOp)) { 97 if (SPIRVTypeConverter::getMemorySpaceForStorageClass( 98 spirv::StorageClass::Function) != type.getMemorySpaceAsInt()) 99 return false; 100 } else { 101 return false; 102 } 103 104 // Currently only support static shape and int or float or vector of int or 105 // float element type. 106 if (!type.hasStaticShape()) 107 return false; 108 109 Type elementType = type.getElementType(); 110 if (auto vecType = elementType.dyn_cast<VectorType>()) 111 elementType = vecType.getElementType(); 112 return elementType.isIntOrFloat(); 113 } 114 115 /// Returns the scope to use for atomic operations use for emulating store 116 /// operations of unsupported integer bitwidths, based on the memref 117 /// type. Returns None on failure. 118 static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) { 119 Optional<spirv::StorageClass> storageClass = 120 SPIRVTypeConverter::getStorageClassForMemorySpace( 121 type.getMemorySpaceAsInt()); 122 if (!storageClass) 123 return {}; 124 switch (*storageClass) { 125 case spirv::StorageClass::StorageBuffer: 126 return spirv::Scope::Device; 127 case spirv::StorageClass::Workgroup: 128 return spirv::Scope::Workgroup; 129 default: { 130 } 131 } 132 return {}; 133 } 134 135 /// Casts the given `srcInt` into a boolean value. 136 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { 137 if (srcInt.getType().isInteger(1)) 138 return srcInt; 139 140 auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); 141 return builder.create<spirv::IEqualOp>(loc, srcInt, one); 142 } 143 144 /// Casts the given `srcBool` into an integer of `dstType`. 145 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, 146 OpBuilder &builder) { 147 assert(srcBool.getType().isInteger(1)); 148 if (dstType.isInteger(1)) 149 return srcBool; 150 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); 151 Value one = spirv::ConstantOp::getOne(dstType, loc, builder); 152 return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero); 153 } 154 155 //===----------------------------------------------------------------------===// 156 // Operation conversion 157 //===----------------------------------------------------------------------===// 158 159 // Note that DRR cannot be used for the patterns in this file: we may need to 160 // convert type along the way, which requires ConversionPattern. DRR generates 161 // normal RewritePattern. 162 163 namespace { 164 165 /// Converts memref.alloca to SPIR-V Function variables. 166 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { 167 public: 168 using OpConversionPattern<memref::AllocaOp>::OpConversionPattern; 169 170 LogicalResult 171 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, 172 ConversionPatternRewriter &rewriter) const override; 173 }; 174 175 /// Converts an allocation operation to SPIR-V. Currently only supports lowering 176 /// to Workgroup memory when the size is constant. Note that this pattern needs 177 /// to be applied in a pass that runs at least at spv.module scope since it wil 178 /// ladd global variables into the spv.module. 179 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { 180 public: 181 using OpConversionPattern<memref::AllocOp>::OpConversionPattern; 182 183 LogicalResult 184 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 185 ConversionPatternRewriter &rewriter) const override; 186 }; 187 188 /// Removed a deallocation if it is a supported allocation. Currently only 189 /// removes deallocation if the memory space is workgroup memory. 190 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { 191 public: 192 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; 193 194 LogicalResult 195 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, 196 ConversionPatternRewriter &rewriter) const override; 197 }; 198 199 /// Converts memref.load to spv.Load. 200 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 201 public: 202 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 203 204 LogicalResult 205 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 206 ConversionPatternRewriter &rewriter) const override; 207 }; 208 209 /// Converts memref.load to spv.Load. 210 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 211 public: 212 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 213 214 LogicalResult 215 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 216 ConversionPatternRewriter &rewriter) const override; 217 }; 218 219 /// Converts memref.store to spv.Store on integers. 220 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 221 public: 222 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 223 224 LogicalResult 225 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 226 ConversionPatternRewriter &rewriter) const override; 227 }; 228 229 /// Converts memref.store to spv.Store. 230 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 231 public: 232 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 233 234 LogicalResult 235 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 236 ConversionPatternRewriter &rewriter) const override; 237 }; 238 239 } // namespace 240 241 //===----------------------------------------------------------------------===// 242 // AllocaOp 243 //===----------------------------------------------------------------------===// 244 245 LogicalResult 246 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, 247 ConversionPatternRewriter &rewriter) const { 248 MemRefType allocType = allocaOp.getType(); 249 if (!isAllocationSupported(allocaOp, allocType)) 250 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type"); 251 252 // Get the SPIR-V type for the allocation. 253 Type spirvType = getTypeConverter()->convertType(allocType); 254 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType, 255 spirv::StorageClass::Function, 256 /*initializer=*/nullptr); 257 return success(); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // AllocOp 262 //===----------------------------------------------------------------------===// 263 264 LogicalResult 265 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 266 ConversionPatternRewriter &rewriter) const { 267 MemRefType allocType = operation.getType(); 268 if (!isAllocationSupported(operation, allocType)) 269 return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); 270 271 // Get the SPIR-V type for the allocation. 272 Type spirvType = getTypeConverter()->convertType(allocType); 273 274 // Insert spv.GlobalVariable for this allocation. 275 Operation *parent = 276 SymbolTable::getNearestSymbolTable(operation->getParentOp()); 277 if (!parent) 278 return failure(); 279 Location loc = operation.getLoc(); 280 spirv::GlobalVariableOp varOp; 281 { 282 OpBuilder::InsertionGuard guard(rewriter); 283 Block &entryBlock = *parent->getRegion(0).begin(); 284 rewriter.setInsertionPointToStart(&entryBlock); 285 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); 286 std::string varName = 287 std::string("__workgroup_mem__") + 288 std::to_string(std::distance(varOps.begin(), varOps.end())); 289 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, 290 /*initializer=*/nullptr); 291 } 292 293 // Get pointer to global variable at the current scope. 294 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); 295 return success(); 296 } 297 298 //===----------------------------------------------------------------------===// 299 // DeallocOp 300 //===----------------------------------------------------------------------===// 301 302 LogicalResult 303 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, 304 OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const { 306 MemRefType deallocType = operation.memref().getType().cast<MemRefType>(); 307 if (!isAllocationSupported(operation, deallocType)) 308 return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); 309 rewriter.eraseOp(operation); 310 return success(); 311 } 312 313 //===----------------------------------------------------------------------===// 314 // LoadOp 315 //===----------------------------------------------------------------------===// 316 317 LogicalResult 318 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 319 ConversionPatternRewriter &rewriter) const { 320 auto loc = loadOp.getLoc(); 321 auto memrefType = loadOp.memref().getType().cast<MemRefType>(); 322 if (!memrefType.getElementType().isSignlessInteger()) 323 return failure(); 324 325 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 326 spirv::AccessChainOp accessChainOp = 327 spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), 328 adaptor.indices(), loc, rewriter); 329 330 if (!accessChainOp) 331 return failure(); 332 333 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 334 bool isBool = srcBits == 1; 335 if (isBool) 336 srcBits = typeConverter.getOptions().boolNumBits; 337 Type pointeeType = typeConverter.convertType(memrefType) 338 .cast<spirv::PointerType>() 339 .getPointeeType(); 340 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); 341 Type dstType; 342 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) 343 dstType = arrayType.getElementType(); 344 else 345 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); 346 347 int dstBits = dstType.getIntOrFloatBitWidth(); 348 assert(dstBits % srcBits == 0); 349 350 // If the rewrited load op has the same bit width, use the loading value 351 // directly. 352 if (srcBits == dstBits) { 353 Value loadVal = 354 rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult()); 355 if (isBool) 356 loadVal = castIntNToBool(loc, loadVal, rewriter); 357 rewriter.replaceOp(loadOp, loadVal); 358 return success(); 359 } 360 361 // Assume that getElementPtr() works linearizely. If it's a scalar, the method 362 // still returns a linearized accessing. If the accessing is not linearized, 363 // there will be offset issues. 364 assert(accessChainOp.indices().size() == 2); 365 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 366 srcBits, dstBits, rewriter); 367 Value spvLoadOp = rewriter.create<spirv::LoadOp>( 368 loc, dstType, adjustedPtr, 369 loadOp->getAttrOfType<spirv::MemoryAccessAttr>( 370 spirv::attributeName<spirv::MemoryAccess>()), 371 loadOp->getAttrOfType<IntegerAttr>("alignment")); 372 373 // Shift the bits to the rightmost. 374 // ____XXXX________ -> ____________XXXX 375 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 376 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 377 Value result = rewriter.create<spirv::ShiftRightArithmeticOp>( 378 loc, spvLoadOp.getType(), spvLoadOp, offset); 379 380 // Apply the mask to extract corresponding bits. 381 Value mask = rewriter.create<spirv::ConstantOp>( 382 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 383 result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask); 384 385 // Apply sign extension on the loading value unconditionally. The signedness 386 // semantic is carried in the operator itself, we relies other pattern to 387 // handle the casting. 388 IntegerAttr shiftValueAttr = 389 rewriter.getIntegerAttr(dstType, dstBits - srcBits); 390 Value shiftValue = 391 rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr); 392 result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result, 393 shiftValue); 394 result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result, 395 shiftValue); 396 397 if (isBool) { 398 dstType = typeConverter.convertType(loadOp.getType()); 399 mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); 400 result = rewriter.create<spirv::IEqualOp>(loc, result, mask); 401 } else if (result.getType().getIntOrFloatBitWidth() != 402 static_cast<unsigned>(dstBits)) { 403 result = rewriter.create<spirv::SConvertOp>(loc, dstType, result); 404 } 405 rewriter.replaceOp(loadOp, result); 406 407 assert(accessChainOp.use_empty()); 408 rewriter.eraseOp(accessChainOp); 409 410 return success(); 411 } 412 413 LogicalResult 414 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 415 ConversionPatternRewriter &rewriter) const { 416 auto memrefType = loadOp.memref().getType().cast<MemRefType>(); 417 if (memrefType.getElementType().isSignlessInteger()) 418 return failure(); 419 auto loadPtr = spirv::getElementPtr( 420 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(), 421 adaptor.indices(), loadOp.getLoc(), rewriter); 422 423 if (!loadPtr) 424 return failure(); 425 426 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr); 427 return success(); 428 } 429 430 LogicalResult 431 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 432 ConversionPatternRewriter &rewriter) const { 433 auto memrefType = storeOp.memref().getType().cast<MemRefType>(); 434 if (!memrefType.getElementType().isSignlessInteger()) 435 return failure(); 436 437 auto loc = storeOp.getLoc(); 438 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 439 spirv::AccessChainOp accessChainOp = 440 spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), 441 adaptor.indices(), loc, rewriter); 442 443 if (!accessChainOp) 444 return failure(); 445 446 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 447 448 bool isBool = srcBits == 1; 449 if (isBool) 450 srcBits = typeConverter.getOptions().boolNumBits; 451 452 Type pointeeType = typeConverter.convertType(memrefType) 453 .cast<spirv::PointerType>() 454 .getPointeeType(); 455 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); 456 Type dstType; 457 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) 458 dstType = arrayType.getElementType(); 459 else 460 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); 461 462 int dstBits = dstType.getIntOrFloatBitWidth(); 463 assert(dstBits % srcBits == 0); 464 465 if (srcBits == dstBits) { 466 Value storeVal = adaptor.value(); 467 if (isBool) 468 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 469 rewriter.replaceOpWithNewOp<spirv::StoreOp>( 470 storeOp, accessChainOp.getResult(), storeVal); 471 return success(); 472 } 473 474 // Since there are multi threads in the processing, the emulation will be done 475 // with atomic operations. E.g., if the storing value is i8, rewrite the 476 // StoreOp to 477 // 1) load a 32-bit integer 478 // 2) clear 8 bits in the loading value 479 // 3) store 32-bit value back 480 // 4) load a 32-bit integer 481 // 5) modify 8 bits in the loading value 482 // 6) store 32-bit value back 483 // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step 484 // 4 to step 6 are done by AtomicOr as another atomic step. 485 assert(accessChainOp.indices().size() == 2); 486 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 487 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 488 489 // Create a mask to clear the destination. E.g., if it is the second i8 in 490 // i32, 0xFFFF00FF is created. 491 Value mask = rewriter.create<spirv::ConstantOp>( 492 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 493 Value clearBitsMask = 494 rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset); 495 clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask); 496 497 Value storeVal = adaptor.value(); 498 if (isBool) 499 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 500 storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); 501 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 502 srcBits, dstBits, rewriter); 503 Optional<spirv::Scope> scope = getAtomicOpScope(memrefType); 504 if (!scope) 505 return failure(); 506 Value result = rewriter.create<spirv::AtomicAndOp>( 507 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 508 clearBitsMask); 509 result = rewriter.create<spirv::AtomicOrOp>( 510 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 511 storeVal); 512 513 // The AtomicOrOp has no side effect. Since it is already inserted, we can 514 // just remove the original StoreOp. Note that rewriter.replaceOp() 515 // doesn't work because it only accepts that the numbers of result are the 516 // same. 517 rewriter.eraseOp(storeOp); 518 519 assert(accessChainOp.use_empty()); 520 rewriter.eraseOp(accessChainOp); 521 522 return success(); 523 } 524 525 LogicalResult 526 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 527 ConversionPatternRewriter &rewriter) const { 528 auto memrefType = storeOp.memref().getType().cast<MemRefType>(); 529 if (memrefType.getElementType().isSignlessInteger()) 530 return failure(); 531 auto storePtr = spirv::getElementPtr( 532 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(), 533 adaptor.indices(), storeOp.getLoc(), rewriter); 534 535 if (!storePtr) 536 return failure(); 537 538 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, 539 adaptor.value()); 540 return success(); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // Pattern population 545 //===----------------------------------------------------------------------===// 546 547 namespace mlir { 548 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 549 RewritePatternSet &patterns) { 550 patterns 551 .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, 552 IntStoreOpPattern, LoadOpPattern, StoreOpPattern>( 553 typeConverter, patterns.getContext()); 554 } 555 } // namespace mlir 556