1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "llvm/Support/Debug.h"
19
20 #define DEBUG_TYPE "memref-to-spirv-pattern"
21
22 using namespace mlir;
23
24 //===----------------------------------------------------------------------===//
25 // Utility functions
26 //===----------------------------------------------------------------------===//
27
28 /// Returns the offset of the value in `targetBits` representation.
29 ///
30 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
31 /// It's assumed to be non-negative.
32 ///
33 /// When accessing an element in the array treating as having elements of
34 /// `targetBits`, multiple values are loaded in the same time. The method
35 /// returns the offset where the `srcIdx` locates in the value. For example, if
36 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
37 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
38 /// element has 8 bits.
getOffsetForBitwidth(Location loc,Value srcIdx,int sourceBits,int targetBits,OpBuilder & builder)39 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
40 int targetBits, OpBuilder &builder) {
41 assert(targetBits % sourceBits == 0);
42 IntegerType targetType = builder.getIntegerType(targetBits);
43 IntegerAttr idxAttr =
44 builder.getIntegerAttr(targetType, targetBits / sourceBits);
45 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
46 IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
47 auto srcBitsValue =
48 builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
49 auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
50 return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
51 }
52
53 /// Returns an adjusted spirv::AccessChainOp. Based on the
54 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
55 /// supported. During conversion if a memref of an unsupported type is used,
56 /// load/stores to this memref need to be modified to use a supported higher
57 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
58 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
59 /// bits needed. The extraction of the actual bits needed are handled
60 /// separately. Note that this only works for a 1-D tensor.
adjustAccessChainForBitwidth(SPIRVTypeConverter & typeConverter,spirv::AccessChainOp op,int sourceBits,int targetBits,OpBuilder & builder)61 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
62 spirv::AccessChainOp op,
63 int sourceBits, int targetBits,
64 OpBuilder &builder) {
65 assert(targetBits % sourceBits == 0);
66 const auto loc = op.getLoc();
67 IntegerType targetType = builder.getIntegerType(targetBits);
68 IntegerAttr attr =
69 builder.getIntegerAttr(targetType, targetBits / sourceBits);
70 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
71 auto lastDim = op->getOperand(op.getNumOperands() - 1);
72 auto indices = llvm::to_vector<4>(op.indices());
73 // There are two elements if this is a 1-D tensor.
74 assert(indices.size() == 2);
75 indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
76 Type t = typeConverter.convertType(op.component_ptr().getType());
77 return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
78 }
79
80 /// Returns the shifted `targetBits`-bit value with the given offset.
shiftValue(Location loc,Value value,Value offset,Value mask,int targetBits,OpBuilder & builder)81 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
82 int targetBits, OpBuilder &builder) {
83 Type targetType = builder.getIntegerType(targetBits);
84 Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
85 return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
86 offset);
87 }
88
89 /// Returns true if the allocations of memref `type` generated from `allocOp`
90 /// can be lowered to SPIR-V.
isAllocationSupported(Operation * allocOp,MemRefType type)91 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
92 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
93 if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
94 spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
95 return false;
96 } else if (isa<memref::AllocaOp>(allocOp)) {
97 if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
98 spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
99 return false;
100 } else {
101 return false;
102 }
103
104 // Currently only support static shape and int or float or vector of int or
105 // float element type.
106 if (!type.hasStaticShape())
107 return false;
108
109 Type elementType = type.getElementType();
110 if (auto vecType = elementType.dyn_cast<VectorType>())
111 elementType = vecType.getElementType();
112 return elementType.isIntOrFloat();
113 }
114
115 /// Returns the scope to use for atomic operations use for emulating store
116 /// operations of unsupported integer bitwidths, based on the memref
117 /// type. Returns None on failure.
getAtomicOpScope(MemRefType type)118 static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
119 Optional<spirv::StorageClass> storageClass =
120 SPIRVTypeConverter::getStorageClassForMemorySpace(
121 type.getMemorySpaceAsInt());
122 if (!storageClass)
123 return {};
124 switch (*storageClass) {
125 case spirv::StorageClass::StorageBuffer:
126 return spirv::Scope::Device;
127 case spirv::StorageClass::Workgroup:
128 return spirv::Scope::Workgroup;
129 default: {
130 }
131 }
132 return {};
133 }
134
135 /// Casts the given `srcInt` into a boolean value.
castIntNToBool(Location loc,Value srcInt,OpBuilder & builder)136 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
137 if (srcInt.getType().isInteger(1))
138 return srcInt;
139
140 auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
141 return builder.create<spirv::IEqualOp>(loc, srcInt, one);
142 }
143
144 /// Casts the given `srcBool` into an integer of `dstType`.
castBoolToIntN(Location loc,Value srcBool,Type dstType,OpBuilder & builder)145 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
146 OpBuilder &builder) {
147 assert(srcBool.getType().isInteger(1));
148 if (dstType.isInteger(1))
149 return srcBool;
150 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
151 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
152 return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
153 }
154
155 //===----------------------------------------------------------------------===//
156 // Operation conversion
157 //===----------------------------------------------------------------------===//
158
159 // Note that DRR cannot be used for the patterns in this file: we may need to
160 // convert type along the way, which requires ConversionPattern. DRR generates
161 // normal RewritePattern.
162
163 namespace {
164
165 /// Converts memref.alloca to SPIR-V Function variables.
166 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
167 public:
168 using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
169
170 LogicalResult
171 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter) const override;
173 };
174
175 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
176 /// to Workgroup memory when the size is constant. Note that this pattern needs
177 /// to be applied in a pass that runs at least at spv.module scope since it wil
178 /// ladd global variables into the spv.module.
179 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
180 public:
181 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
182
183 LogicalResult
184 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const override;
186 };
187
188 /// Removed a deallocation if it is a supported allocation. Currently only
189 /// removes deallocation if the memory space is workgroup memory.
190 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
191 public:
192 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
193
194 LogicalResult
195 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override;
197 };
198
199 /// Converts memref.load to spv.Load.
200 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
201 public:
202 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
203
204 LogicalResult
205 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
206 ConversionPatternRewriter &rewriter) const override;
207 };
208
209 /// Converts memref.load to spv.Load.
210 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
211 public:
212 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
213
214 LogicalResult
215 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override;
217 };
218
219 /// Converts memref.store to spv.Store on integers.
220 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
221 public:
222 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
223
224 LogicalResult
225 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const override;
227 };
228
229 /// Converts memref.store to spv.Store.
230 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
231 public:
232 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
233
234 LogicalResult
235 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const override;
237 };
238
239 } // namespace
240
241 //===----------------------------------------------------------------------===//
242 // AllocaOp
243 //===----------------------------------------------------------------------===//
244
245 LogicalResult
matchAndRewrite(memref::AllocaOp allocaOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const246 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter) const {
248 MemRefType allocType = allocaOp.getType();
249 if (!isAllocationSupported(allocaOp, allocType))
250 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
251
252 // Get the SPIR-V type for the allocation.
253 Type spirvType = getTypeConverter()->convertType(allocType);
254 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
255 spirv::StorageClass::Function,
256 /*initializer=*/nullptr);
257 return success();
258 }
259
260 //===----------------------------------------------------------------------===//
261 // AllocOp
262 //===----------------------------------------------------------------------===//
263
264 LogicalResult
matchAndRewrite(memref::AllocOp operation,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const265 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
266 ConversionPatternRewriter &rewriter) const {
267 MemRefType allocType = operation.getType();
268 if (!isAllocationSupported(operation, allocType))
269 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
270
271 // Get the SPIR-V type for the allocation.
272 Type spirvType = getTypeConverter()->convertType(allocType);
273
274 // Insert spv.GlobalVariable for this allocation.
275 Operation *parent =
276 SymbolTable::getNearestSymbolTable(operation->getParentOp());
277 if (!parent)
278 return failure();
279 Location loc = operation.getLoc();
280 spirv::GlobalVariableOp varOp;
281 {
282 OpBuilder::InsertionGuard guard(rewriter);
283 Block &entryBlock = *parent->getRegion(0).begin();
284 rewriter.setInsertionPointToStart(&entryBlock);
285 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
286 std::string varName =
287 std::string("__workgroup_mem__") +
288 std::to_string(std::distance(varOps.begin(), varOps.end()));
289 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
290 /*initializer=*/nullptr);
291 }
292
293 // Get pointer to global variable at the current scope.
294 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
295 return success();
296 }
297
298 //===----------------------------------------------------------------------===//
299 // DeallocOp
300 //===----------------------------------------------------------------------===//
301
302 LogicalResult
matchAndRewrite(memref::DeallocOp operation,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const303 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
304 OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const {
306 MemRefType deallocType = operation.getMemref().getType().cast<MemRefType>();
307 if (!isAllocationSupported(operation, deallocType))
308 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
309 rewriter.eraseOp(operation);
310 return success();
311 }
312
313 //===----------------------------------------------------------------------===//
314 // LoadOp
315 //===----------------------------------------------------------------------===//
316
317 LogicalResult
matchAndRewrite(memref::LoadOp loadOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const318 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const {
320 auto loc = loadOp.getLoc();
321 auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
322 if (!memrefType.getElementType().isSignlessInteger())
323 return failure();
324
325 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
326 spirv::AccessChainOp accessChainOp =
327 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
328 adaptor.getIndices(), loc, rewriter);
329
330 if (!accessChainOp)
331 return failure();
332
333 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
334 bool isBool = srcBits == 1;
335 if (isBool)
336 srcBits = typeConverter.getOptions().boolNumBits;
337 Type pointeeType = typeConverter.convertType(memrefType)
338 .cast<spirv::PointerType>()
339 .getPointeeType();
340 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
341 Type dstType;
342 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
343 dstType = arrayType.getElementType();
344 else
345 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
346
347 int dstBits = dstType.getIntOrFloatBitWidth();
348 assert(dstBits % srcBits == 0);
349
350 // If the rewrited load op has the same bit width, use the loading value
351 // directly.
352 if (srcBits == dstBits) {
353 Value loadVal =
354 rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
355 if (isBool)
356 loadVal = castIntNToBool(loc, loadVal, rewriter);
357 rewriter.replaceOp(loadOp, loadVal);
358 return success();
359 }
360
361 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
362 // still returns a linearized accessing. If the accessing is not linearized,
363 // there will be offset issues.
364 assert(accessChainOp.indices().size() == 2);
365 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
366 srcBits, dstBits, rewriter);
367 Value spvLoadOp = rewriter.create<spirv::LoadOp>(
368 loc, dstType, adjustedPtr,
369 loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
370 spirv::attributeName<spirv::MemoryAccess>()),
371 loadOp->getAttrOfType<IntegerAttr>("alignment"));
372
373 // Shift the bits to the rightmost.
374 // ____XXXX________ -> ____________XXXX
375 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
376 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
377 Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
378 loc, spvLoadOp.getType(), spvLoadOp, offset);
379
380 // Apply the mask to extract corresponding bits.
381 Value mask = rewriter.create<spirv::ConstantOp>(
382 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
383 result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
384
385 // Apply sign extension on the loading value unconditionally. The signedness
386 // semantic is carried in the operator itself, we relies other pattern to
387 // handle the casting.
388 IntegerAttr shiftValueAttr =
389 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
390 Value shiftValue =
391 rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
392 result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
393 shiftValue);
394 result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
395 shiftValue);
396
397 if (isBool) {
398 dstType = typeConverter.convertType(loadOp.getType());
399 mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
400 result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
401 } else if (result.getType().getIntOrFloatBitWidth() !=
402 static_cast<unsigned>(dstBits)) {
403 result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
404 }
405 rewriter.replaceOp(loadOp, result);
406
407 assert(accessChainOp.use_empty());
408 rewriter.eraseOp(accessChainOp);
409
410 return success();
411 }
412
413 LogicalResult
matchAndRewrite(memref::LoadOp loadOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const414 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter) const {
416 auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
417 if (memrefType.getElementType().isSignlessInteger())
418 return failure();
419 auto loadPtr = spirv::getElementPtr(
420 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
421 adaptor.getIndices(), loadOp.getLoc(), rewriter);
422
423 if (!loadPtr)
424 return failure();
425
426 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
427 return success();
428 }
429
430 LogicalResult
matchAndRewrite(memref::StoreOp storeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const431 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter) const {
433 auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
434 if (!memrefType.getElementType().isSignlessInteger())
435 return failure();
436
437 auto loc = storeOp.getLoc();
438 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
439 spirv::AccessChainOp accessChainOp =
440 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
441 adaptor.getIndices(), loc, rewriter);
442
443 if (!accessChainOp)
444 return failure();
445
446 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
447
448 bool isBool = srcBits == 1;
449 if (isBool)
450 srcBits = typeConverter.getOptions().boolNumBits;
451
452 Type pointeeType = typeConverter.convertType(memrefType)
453 .cast<spirv::PointerType>()
454 .getPointeeType();
455 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
456 Type dstType;
457 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
458 dstType = arrayType.getElementType();
459 else
460 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
461
462 int dstBits = dstType.getIntOrFloatBitWidth();
463 assert(dstBits % srcBits == 0);
464
465 if (srcBits == dstBits) {
466 Value storeVal = adaptor.getValue();
467 if (isBool)
468 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
469 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
470 storeOp, accessChainOp.getResult(), storeVal);
471 return success();
472 }
473
474 // Since there are multi threads in the processing, the emulation will be done
475 // with atomic operations. E.g., if the storing value is i8, rewrite the
476 // StoreOp to
477 // 1) load a 32-bit integer
478 // 2) clear 8 bits in the loading value
479 // 3) store 32-bit value back
480 // 4) load a 32-bit integer
481 // 5) modify 8 bits in the loading value
482 // 6) store 32-bit value back
483 // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
484 // 4 to step 6 are done by AtomicOr as another atomic step.
485 assert(accessChainOp.indices().size() == 2);
486 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
487 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
488
489 // Create a mask to clear the destination. E.g., if it is the second i8 in
490 // i32, 0xFFFF00FF is created.
491 Value mask = rewriter.create<spirv::ConstantOp>(
492 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
493 Value clearBitsMask =
494 rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
495 clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
496
497 Value storeVal = adaptor.getValue();
498 if (isBool)
499 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
500 storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
501 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
502 srcBits, dstBits, rewriter);
503 Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
504 if (!scope)
505 return failure();
506 Value result = rewriter.create<spirv::AtomicAndOp>(
507 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
508 clearBitsMask);
509 result = rewriter.create<spirv::AtomicOrOp>(
510 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
511 storeVal);
512
513 // The AtomicOrOp has no side effect. Since it is already inserted, we can
514 // just remove the original StoreOp. Note that rewriter.replaceOp()
515 // doesn't work because it only accepts that the numbers of result are the
516 // same.
517 rewriter.eraseOp(storeOp);
518
519 assert(accessChainOp.use_empty());
520 rewriter.eraseOp(accessChainOp);
521
522 return success();
523 }
524
525 LogicalResult
matchAndRewrite(memref::StoreOp storeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const526 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter) const {
528 auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
529 if (memrefType.getElementType().isSignlessInteger())
530 return failure();
531 auto storePtr = spirv::getElementPtr(
532 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
533 adaptor.getIndices(), storeOp.getLoc(), rewriter);
534
535 if (!storePtr)
536 return failure();
537
538 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
539 adaptor.getValue());
540 return success();
541 }
542
543 //===----------------------------------------------------------------------===//
544 // Pattern population
545 //===----------------------------------------------------------------------===//
546
547 namespace mlir {
populateMemRefToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)548 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
549 RewritePatternSet &patterns) {
550 patterns
551 .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
552 IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
553 typeConverter, patterns.getContext());
554 }
555 } // namespace mlir
556