1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "memref-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Utility functions
26 //===----------------------------------------------------------------------===//
27 
28 /// Returns the offset of the value in `targetBits` representation.
29 ///
30 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
31 /// It's assumed to be non-negative.
32 ///
33 /// When accessing an element in the array treating as having elements of
34 /// `targetBits`, multiple values are loaded in the same time. The method
35 /// returns the offset where the `srcIdx` locates in the value. For example, if
36 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
37 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
38 /// element has 8 bits.
getOffsetForBitwidth(Location loc,Value srcIdx,int sourceBits,int targetBits,OpBuilder & builder)39 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
40                                   int targetBits, OpBuilder &builder) {
41   assert(targetBits % sourceBits == 0);
42   IntegerType targetType = builder.getIntegerType(targetBits);
43   IntegerAttr idxAttr =
44       builder.getIntegerAttr(targetType, targetBits / sourceBits);
45   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
46   IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
47   auto srcBitsValue =
48       builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
49   auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
50   return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
51 }
52 
53 /// Returns an adjusted spirv::AccessChainOp. Based on the
54 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
55 /// supported. During conversion if a memref of an unsupported type is used,
56 /// load/stores to this memref need to be modified to use a supported higher
57 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
58 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
59 /// bits needed. The extraction of the actual bits needed are handled
60 /// separately. Note that this only works for a 1-D tensor.
adjustAccessChainForBitwidth(SPIRVTypeConverter & typeConverter,spirv::AccessChainOp op,int sourceBits,int targetBits,OpBuilder & builder)61 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
62                                           spirv::AccessChainOp op,
63                                           int sourceBits, int targetBits,
64                                           OpBuilder &builder) {
65   assert(targetBits % sourceBits == 0);
66   const auto loc = op.getLoc();
67   IntegerType targetType = builder.getIntegerType(targetBits);
68   IntegerAttr attr =
69       builder.getIntegerAttr(targetType, targetBits / sourceBits);
70   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
71   auto lastDim = op->getOperand(op.getNumOperands() - 1);
72   auto indices = llvm::to_vector<4>(op.indices());
73   // There are two elements if this is a 1-D tensor.
74   assert(indices.size() == 2);
75   indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
76   Type t = typeConverter.convertType(op.component_ptr().getType());
77   return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
78 }
79 
80 /// Returns the shifted `targetBits`-bit value with the given offset.
shiftValue(Location loc,Value value,Value offset,Value mask,int targetBits,OpBuilder & builder)81 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
82                         int targetBits, OpBuilder &builder) {
83   Type targetType = builder.getIntegerType(targetBits);
84   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
85   return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
86                                                    offset);
87 }
88 
89 /// Returns true if the allocations of memref `type` generated from `allocOp`
90 /// can be lowered to SPIR-V.
isAllocationSupported(Operation * allocOp,MemRefType type)91 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
92   if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
93     if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
94             spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
95       return false;
96   } else if (isa<memref::AllocaOp>(allocOp)) {
97     if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
98             spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
99       return false;
100   } else {
101     return false;
102   }
103 
104   // Currently only support static shape and int or float or vector of int or
105   // float element type.
106   if (!type.hasStaticShape())
107     return false;
108 
109   Type elementType = type.getElementType();
110   if (auto vecType = elementType.dyn_cast<VectorType>())
111     elementType = vecType.getElementType();
112   return elementType.isIntOrFloat();
113 }
114 
115 /// Returns the scope to use for atomic operations use for emulating store
116 /// operations of unsupported integer bitwidths, based on the memref
117 /// type. Returns None on failure.
getAtomicOpScope(MemRefType type)118 static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
119   Optional<spirv::StorageClass> storageClass =
120       SPIRVTypeConverter::getStorageClassForMemorySpace(
121           type.getMemorySpaceAsInt());
122   if (!storageClass)
123     return {};
124   switch (*storageClass) {
125   case spirv::StorageClass::StorageBuffer:
126     return spirv::Scope::Device;
127   case spirv::StorageClass::Workgroup:
128     return spirv::Scope::Workgroup;
129   default: {
130   }
131   }
132   return {};
133 }
134 
135 /// Casts the given `srcInt` into a boolean value.
castIntNToBool(Location loc,Value srcInt,OpBuilder & builder)136 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
137   if (srcInt.getType().isInteger(1))
138     return srcInt;
139 
140   auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
141   return builder.create<spirv::IEqualOp>(loc, srcInt, one);
142 }
143 
144 /// Casts the given `srcBool` into an integer of `dstType`.
castBoolToIntN(Location loc,Value srcBool,Type dstType,OpBuilder & builder)145 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
146                             OpBuilder &builder) {
147   assert(srcBool.getType().isInteger(1));
148   if (dstType.isInteger(1))
149     return srcBool;
150   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
151   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
152   return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // Operation conversion
157 //===----------------------------------------------------------------------===//
158 
159 // Note that DRR cannot be used for the patterns in this file: we may need to
160 // convert type along the way, which requires ConversionPattern. DRR generates
161 // normal RewritePattern.
162 
163 namespace {
164 
165 /// Converts memref.alloca to SPIR-V Function variables.
166 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
167 public:
168   using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
169 
170   LogicalResult
171   matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
172                   ConversionPatternRewriter &rewriter) const override;
173 };
174 
175 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
176 /// to Workgroup memory when the size is constant.  Note that this pattern needs
177 /// to be applied in a pass that runs at least at spv.module scope since it wil
178 /// ladd global variables into the spv.module.
179 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
180 public:
181   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
182 
183   LogicalResult
184   matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
185                   ConversionPatternRewriter &rewriter) const override;
186 };
187 
188 /// Removed a deallocation if it is a supported allocation. Currently only
189 /// removes deallocation if the memory space is workgroup memory.
190 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
191 public:
192   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
193 
194   LogicalResult
195   matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
196                   ConversionPatternRewriter &rewriter) const override;
197 };
198 
199 /// Converts memref.load to spv.Load.
200 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
201 public:
202   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
203 
204   LogicalResult
205   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
206                   ConversionPatternRewriter &rewriter) const override;
207 };
208 
209 /// Converts memref.load to spv.Load.
210 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
211 public:
212   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
213 
214   LogicalResult
215   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
216                   ConversionPatternRewriter &rewriter) const override;
217 };
218 
219 /// Converts memref.store to spv.Store on integers.
220 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
221 public:
222   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
223 
224   LogicalResult
225   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
226                   ConversionPatternRewriter &rewriter) const override;
227 };
228 
229 /// Converts memref.store to spv.Store.
230 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
231 public:
232   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
233 
234   LogicalResult
235   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
236                   ConversionPatternRewriter &rewriter) const override;
237 };
238 
239 } // namespace
240 
241 //===----------------------------------------------------------------------===//
242 // AllocaOp
243 //===----------------------------------------------------------------------===//
244 
245 LogicalResult
matchAndRewrite(memref::AllocaOp allocaOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const246 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
247                                  ConversionPatternRewriter &rewriter) const {
248   MemRefType allocType = allocaOp.getType();
249   if (!isAllocationSupported(allocaOp, allocType))
250     return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
251 
252   // Get the SPIR-V type for the allocation.
253   Type spirvType = getTypeConverter()->convertType(allocType);
254   rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
255                                                  spirv::StorageClass::Function,
256                                                  /*initializer=*/nullptr);
257   return success();
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // AllocOp
262 //===----------------------------------------------------------------------===//
263 
264 LogicalResult
matchAndRewrite(memref::AllocOp operation,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const265 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
266                                 ConversionPatternRewriter &rewriter) const {
267   MemRefType allocType = operation.getType();
268   if (!isAllocationSupported(operation, allocType))
269     return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
270 
271   // Get the SPIR-V type for the allocation.
272   Type spirvType = getTypeConverter()->convertType(allocType);
273 
274   // Insert spv.GlobalVariable for this allocation.
275   Operation *parent =
276       SymbolTable::getNearestSymbolTable(operation->getParentOp());
277   if (!parent)
278     return failure();
279   Location loc = operation.getLoc();
280   spirv::GlobalVariableOp varOp;
281   {
282     OpBuilder::InsertionGuard guard(rewriter);
283     Block &entryBlock = *parent->getRegion(0).begin();
284     rewriter.setInsertionPointToStart(&entryBlock);
285     auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
286     std::string varName =
287         std::string("__workgroup_mem__") +
288         std::to_string(std::distance(varOps.begin(), varOps.end()));
289     varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
290                                                      /*initializer=*/nullptr);
291   }
292 
293   // Get pointer to global variable at the current scope.
294   rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
295   return success();
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // DeallocOp
300 //===----------------------------------------------------------------------===//
301 
302 LogicalResult
matchAndRewrite(memref::DeallocOp operation,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const303 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
304                                   OpAdaptor adaptor,
305                                   ConversionPatternRewriter &rewriter) const {
306   MemRefType deallocType = operation.getMemref().getType().cast<MemRefType>();
307   if (!isAllocationSupported(operation, deallocType))
308     return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
309   rewriter.eraseOp(operation);
310   return success();
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // LoadOp
315 //===----------------------------------------------------------------------===//
316 
317 LogicalResult
matchAndRewrite(memref::LoadOp loadOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const318 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
319                                   ConversionPatternRewriter &rewriter) const {
320   auto loc = loadOp.getLoc();
321   auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
322   if (!memrefType.getElementType().isSignlessInteger())
323     return failure();
324 
325   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
326   spirv::AccessChainOp accessChainOp =
327       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
328                            adaptor.getIndices(), loc, rewriter);
329 
330   if (!accessChainOp)
331     return failure();
332 
333   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
334   bool isBool = srcBits == 1;
335   if (isBool)
336     srcBits = typeConverter.getOptions().boolNumBits;
337   Type pointeeType = typeConverter.convertType(memrefType)
338                          .cast<spirv::PointerType>()
339                          .getPointeeType();
340   Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
341   Type dstType;
342   if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
343     dstType = arrayType.getElementType();
344   else
345     dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
346 
347   int dstBits = dstType.getIntOrFloatBitWidth();
348   assert(dstBits % srcBits == 0);
349 
350   // If the rewrited load op has the same bit width, use the loading value
351   // directly.
352   if (srcBits == dstBits) {
353     Value loadVal =
354         rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
355     if (isBool)
356       loadVal = castIntNToBool(loc, loadVal, rewriter);
357     rewriter.replaceOp(loadOp, loadVal);
358     return success();
359   }
360 
361   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
362   // still returns a linearized accessing. If the accessing is not linearized,
363   // there will be offset issues.
364   assert(accessChainOp.indices().size() == 2);
365   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
366                                                    srcBits, dstBits, rewriter);
367   Value spvLoadOp = rewriter.create<spirv::LoadOp>(
368       loc, dstType, adjustedPtr,
369       loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
370           spirv::attributeName<spirv::MemoryAccess>()),
371       loadOp->getAttrOfType<IntegerAttr>("alignment"));
372 
373   // Shift the bits to the rightmost.
374   // ____XXXX________ -> ____________XXXX
375   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
376   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
377   Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
378       loc, spvLoadOp.getType(), spvLoadOp, offset);
379 
380   // Apply the mask to extract corresponding bits.
381   Value mask = rewriter.create<spirv::ConstantOp>(
382       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
383   result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
384 
385   // Apply sign extension on the loading value unconditionally. The signedness
386   // semantic is carried in the operator itself, we relies other pattern to
387   // handle the casting.
388   IntegerAttr shiftValueAttr =
389       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
390   Value shiftValue =
391       rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
392   result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
393                                                       shiftValue);
394   result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
395                                                           shiftValue);
396 
397   if (isBool) {
398     dstType = typeConverter.convertType(loadOp.getType());
399     mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
400     result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
401   } else if (result.getType().getIntOrFloatBitWidth() !=
402              static_cast<unsigned>(dstBits)) {
403     result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
404   }
405   rewriter.replaceOp(loadOp, result);
406 
407   assert(accessChainOp.use_empty());
408   rewriter.eraseOp(accessChainOp);
409 
410   return success();
411 }
412 
413 LogicalResult
matchAndRewrite(memref::LoadOp loadOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const414 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
415                                ConversionPatternRewriter &rewriter) const {
416   auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
417   if (memrefType.getElementType().isSignlessInteger())
418     return failure();
419   auto loadPtr = spirv::getElementPtr(
420       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
421       adaptor.getIndices(), loadOp.getLoc(), rewriter);
422 
423   if (!loadPtr)
424     return failure();
425 
426   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
427   return success();
428 }
429 
430 LogicalResult
matchAndRewrite(memref::StoreOp storeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const431 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
432                                    ConversionPatternRewriter &rewriter) const {
433   auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
434   if (!memrefType.getElementType().isSignlessInteger())
435     return failure();
436 
437   auto loc = storeOp.getLoc();
438   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
439   spirv::AccessChainOp accessChainOp =
440       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
441                            adaptor.getIndices(), loc, rewriter);
442 
443   if (!accessChainOp)
444     return failure();
445 
446   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
447 
448   bool isBool = srcBits == 1;
449   if (isBool)
450     srcBits = typeConverter.getOptions().boolNumBits;
451 
452   Type pointeeType = typeConverter.convertType(memrefType)
453                          .cast<spirv::PointerType>()
454                          .getPointeeType();
455   Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
456   Type dstType;
457   if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
458     dstType = arrayType.getElementType();
459   else
460     dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
461 
462   int dstBits = dstType.getIntOrFloatBitWidth();
463   assert(dstBits % srcBits == 0);
464 
465   if (srcBits == dstBits) {
466     Value storeVal = adaptor.getValue();
467     if (isBool)
468       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
469     rewriter.replaceOpWithNewOp<spirv::StoreOp>(
470         storeOp, accessChainOp.getResult(), storeVal);
471     return success();
472   }
473 
474   // Since there are multi threads in the processing, the emulation will be done
475   // with atomic operations. E.g., if the storing value is i8, rewrite the
476   // StoreOp to
477   // 1) load a 32-bit integer
478   // 2) clear 8 bits in the loading value
479   // 3) store 32-bit value back
480   // 4) load a 32-bit integer
481   // 5) modify 8 bits in the loading value
482   // 6) store 32-bit value back
483   // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
484   // 4 to step 6 are done by AtomicOr as another atomic step.
485   assert(accessChainOp.indices().size() == 2);
486   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
487   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
488 
489   // Create a mask to clear the destination. E.g., if it is the second i8 in
490   // i32, 0xFFFF00FF is created.
491   Value mask = rewriter.create<spirv::ConstantOp>(
492       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
493   Value clearBitsMask =
494       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
495   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
496 
497   Value storeVal = adaptor.getValue();
498   if (isBool)
499     storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
500   storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
501   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
502                                                    srcBits, dstBits, rewriter);
503   Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
504   if (!scope)
505     return failure();
506   Value result = rewriter.create<spirv::AtomicAndOp>(
507       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
508       clearBitsMask);
509   result = rewriter.create<spirv::AtomicOrOp>(
510       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
511       storeVal);
512 
513   // The AtomicOrOp has no side effect. Since it is already inserted, we can
514   // just remove the original StoreOp. Note that rewriter.replaceOp()
515   // doesn't work because it only accepts that the numbers of result are the
516   // same.
517   rewriter.eraseOp(storeOp);
518 
519   assert(accessChainOp.use_empty());
520   rewriter.eraseOp(accessChainOp);
521 
522   return success();
523 }
524 
525 LogicalResult
matchAndRewrite(memref::StoreOp storeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const526 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
527                                 ConversionPatternRewriter &rewriter) const {
528   auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
529   if (memrefType.getElementType().isSignlessInteger())
530     return failure();
531   auto storePtr = spirv::getElementPtr(
532       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
533       adaptor.getIndices(), storeOp.getLoc(), rewriter);
534 
535   if (!storePtr)
536     return failure();
537 
538   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
539                                               adaptor.getValue());
540   return success();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // Pattern population
545 //===----------------------------------------------------------------------===//
546 
547 namespace mlir {
populateMemRefToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)548 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
549                                    RewritePatternSet &patterns) {
550   patterns
551       .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
552            IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
553           typeConverter, patterns.getContext());
554 }
555 } // namespace mlir
556