1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "memref-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Utility functions
25 //===----------------------------------------------------------------------===//
26 
27 /// Returns the offset of the value in `targetBits` representation.
28 ///
29 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
30 /// It's assumed to be non-negative.
31 ///
32 /// When accessing an element in the array treating as having elements of
33 /// `targetBits`, multiple values are loaded in the same time. The method
34 /// returns the offset where the `srcIdx` locates in the value. For example, if
35 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
36 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
37 /// element has 8 bits.
38 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
39                                   int targetBits, OpBuilder &builder) {
40   assert(targetBits % sourceBits == 0);
41   IntegerType targetType = builder.getIntegerType(targetBits);
42   IntegerAttr idxAttr =
43       builder.getIntegerAttr(targetType, targetBits / sourceBits);
44   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
45   IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
46   auto srcBitsValue =
47       builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
48   auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
49   return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
50 }
51 
52 /// Returns an adjusted spirv::AccessChainOp. Based on the
53 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
54 /// supported. During conversion if a memref of an unsupported type is used,
55 /// load/stores to this memref need to be modified to use a supported higher
56 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
57 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
58 /// bits needed. The extraction of the actual bits needed are handled
59 /// separately. Note that this only works for a 1-D tensor.
60 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
61                                           spirv::AccessChainOp op,
62                                           int sourceBits, int targetBits,
63                                           OpBuilder &builder) {
64   assert(targetBits % sourceBits == 0);
65   const auto loc = op.getLoc();
66   IntegerType targetType = builder.getIntegerType(targetBits);
67   IntegerAttr attr =
68       builder.getIntegerAttr(targetType, targetBits / sourceBits);
69   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
70   auto lastDim = op->getOperand(op.getNumOperands() - 1);
71   auto indices = llvm::to_vector<4>(op.indices());
72   // There are two elements if this is a 1-D tensor.
73   assert(indices.size() == 2);
74   indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
75   Type t = typeConverter.convertType(op.component_ptr().getType());
76   return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
77 }
78 
79 /// Returns the shifted `targetBits`-bit value with the given offset.
80 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
81                         int targetBits, OpBuilder &builder) {
82   Type targetType = builder.getIntegerType(targetBits);
83   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
84   return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
85                                                    offset);
86 }
87 
88 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
89 static bool isAllocationSupported(MemRefType t) {
90   // Currently only support workgroup local memory allocations with static
91   // shape and int or float or vector of int or float element type.
92   if (!(t.hasStaticShape() &&
93         SPIRVTypeConverter::getMemorySpaceForStorageClass(
94             spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
95     return false;
96   Type elementType = t.getElementType();
97   if (auto vecType = elementType.dyn_cast<VectorType>())
98     elementType = vecType.getElementType();
99   return elementType.isIntOrFloat();
100 }
101 
102 /// Returns the scope to use for atomic operations use for emulating store
103 /// operations of unsupported integer bitwidths, based on the memref
104 /// type. Returns None on failure.
105 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
106   Optional<spirv::StorageClass> storageClass =
107       SPIRVTypeConverter::getStorageClassForMemorySpace(
108           t.getMemorySpaceAsInt());
109   if (!storageClass)
110     return {};
111   switch (*storageClass) {
112   case spirv::StorageClass::StorageBuffer:
113     return spirv::Scope::Device;
114   case spirv::StorageClass::Workgroup:
115     return spirv::Scope::Workgroup;
116   default: {
117   }
118   }
119   return {};
120 }
121 
122 /// Casts the given `srcInt` into a boolean value.
123 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
124   if (srcInt.getType().isInteger(1))
125     return srcInt;
126 
127   auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
128   return builder.create<spirv::IEqualOp>(loc, srcInt, one);
129 }
130 
131 /// Casts the given `srcBool` into an integer of `dstType`.
132 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
133                             OpBuilder &builder) {
134   assert(srcBool.getType().isInteger(1));
135   if (dstType.isInteger(1))
136     return srcBool;
137   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
138   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
139   return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Operation conversion
144 //===----------------------------------------------------------------------===//
145 
146 // Note that DRR cannot be used for the patterns in this file: we may need to
147 // convert type along the way, which requires ConversionPattern. DRR generates
148 // normal RewritePattern.
149 
150 namespace {
151 
152 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
153 /// to Workgroup memory when the size is constant.  Note that this pattern needs
154 /// to be applied in a pass that runs at least at spv.module scope since it wil
155 /// ladd global variables into the spv.module.
156 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
157 public:
158   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
159 
160   LogicalResult
161   matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
162                   ConversionPatternRewriter &rewriter) const override;
163 };
164 
165 /// Removed a deallocation if it is a supported allocation. Currently only
166 /// removes deallocation if the memory space is workgroup memory.
167 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
168 public:
169   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
170 
171   LogicalResult
172   matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
173                   ConversionPatternRewriter &rewriter) const override;
174 };
175 
176 /// Converts memref.load to spv.Load.
177 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
178 public:
179   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
180 
181   LogicalResult
182   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
183                   ConversionPatternRewriter &rewriter) const override;
184 };
185 
186 /// Converts memref.load to spv.Load.
187 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
188 public:
189   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
190 
191   LogicalResult
192   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
193                   ConversionPatternRewriter &rewriter) const override;
194 };
195 
196 /// Converts memref.store to spv.Store on integers.
197 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
198 public:
199   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
200 
201   LogicalResult
202   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
203                   ConversionPatternRewriter &rewriter) const override;
204 };
205 
206 /// Converts memref.store to spv.Store.
207 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
208 public:
209   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
210 
211   LogicalResult
212   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
213                   ConversionPatternRewriter &rewriter) const override;
214 };
215 
216 } // namespace
217 
218 //===----------------------------------------------------------------------===//
219 // AllocOp
220 //===----------------------------------------------------------------------===//
221 
222 LogicalResult
223 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
224                                 ConversionPatternRewriter &rewriter) const {
225   MemRefType allocType = operation.getType();
226   if (!isAllocationSupported(allocType))
227     return operation.emitError("unhandled allocation type");
228 
229   // Get the SPIR-V type for the allocation.
230   Type spirvType = getTypeConverter()->convertType(allocType);
231 
232   // Insert spv.GlobalVariable for this allocation.
233   Operation *parent =
234       SymbolTable::getNearestSymbolTable(operation->getParentOp());
235   if (!parent)
236     return failure();
237   Location loc = operation.getLoc();
238   spirv::GlobalVariableOp varOp;
239   {
240     OpBuilder::InsertionGuard guard(rewriter);
241     Block &entryBlock = *parent->getRegion(0).begin();
242     rewriter.setInsertionPointToStart(&entryBlock);
243     auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
244     std::string varName =
245         std::string("__workgroup_mem__") +
246         std::to_string(std::distance(varOps.begin(), varOps.end()));
247     varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
248                                                      /*initializer=*/nullptr);
249   }
250 
251   // Get pointer to global variable at the current scope.
252   rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
253   return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // DeallocOp
258 //===----------------------------------------------------------------------===//
259 
260 LogicalResult
261 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
262                                   OpAdaptor adaptor,
263                                   ConversionPatternRewriter &rewriter) const {
264   MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
265   if (!isAllocationSupported(deallocType))
266     return operation.emitError("unhandled deallocation type");
267   rewriter.eraseOp(operation);
268   return success();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // LoadOp
273 //===----------------------------------------------------------------------===//
274 
275 LogicalResult
276 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
277                                   ConversionPatternRewriter &rewriter) const {
278   auto loc = loadOp.getLoc();
279   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
280   if (!memrefType.getElementType().isSignlessInteger())
281     return failure();
282 
283   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
284   spirv::AccessChainOp accessChainOp =
285       spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
286                            adaptor.indices(), loc, rewriter);
287 
288   if (!accessChainOp)
289     return failure();
290 
291   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
292   bool isBool = srcBits == 1;
293   if (isBool)
294     srcBits = typeConverter.getOptions().boolNumBits;
295   Type pointeeType = typeConverter.convertType(memrefType)
296                          .cast<spirv::PointerType>()
297                          .getPointeeType();
298   Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
299   Type dstType;
300   if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
301     dstType = arrayType.getElementType();
302   else
303     dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
304 
305   int dstBits = dstType.getIntOrFloatBitWidth();
306   assert(dstBits % srcBits == 0);
307 
308   // If the rewrited load op has the same bit width, use the loading value
309   // directly.
310   if (srcBits == dstBits) {
311     Value loadVal =
312         rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
313     if (isBool)
314       loadVal = castIntNToBool(loc, loadVal, rewriter);
315     rewriter.replaceOp(loadOp, loadVal);
316     return success();
317   }
318 
319   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
320   // still returns a linearized accessing. If the accessing is not linearized,
321   // there will be offset issues.
322   assert(accessChainOp.indices().size() == 2);
323   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
324                                                    srcBits, dstBits, rewriter);
325   Value spvLoadOp = rewriter.create<spirv::LoadOp>(
326       loc, dstType, adjustedPtr,
327       loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
328           spirv::attributeName<spirv::MemoryAccess>()),
329       loadOp->getAttrOfType<IntegerAttr>("alignment"));
330 
331   // Shift the bits to the rightmost.
332   // ____XXXX________ -> ____________XXXX
333   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
334   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
335   Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
336       loc, spvLoadOp.getType(), spvLoadOp, offset);
337 
338   // Apply the mask to extract corresponding bits.
339   Value mask = rewriter.create<spirv::ConstantOp>(
340       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
341   result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
342 
343   // Apply sign extension on the loading value unconditionally. The signedness
344   // semantic is carried in the operator itself, we relies other pattern to
345   // handle the casting.
346   IntegerAttr shiftValueAttr =
347       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
348   Value shiftValue =
349       rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
350   result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
351                                                       shiftValue);
352   result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
353                                                           shiftValue);
354 
355   if (isBool) {
356     dstType = typeConverter.convertType(loadOp.getType());
357     mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
358     result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
359   } else if (result.getType().getIntOrFloatBitWidth() !=
360              static_cast<unsigned>(dstBits)) {
361     result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
362   }
363   rewriter.replaceOp(loadOp, result);
364 
365   assert(accessChainOp.use_empty());
366   rewriter.eraseOp(accessChainOp);
367 
368   return success();
369 }
370 
371 LogicalResult
372 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
373                                ConversionPatternRewriter &rewriter) const {
374   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
375   if (memrefType.getElementType().isSignlessInteger())
376     return failure();
377   auto loadPtr = spirv::getElementPtr(
378       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
379       adaptor.indices(), loadOp.getLoc(), rewriter);
380 
381   if (!loadPtr)
382     return failure();
383 
384   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
385   return success();
386 }
387 
388 LogicalResult
389 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
390                                    ConversionPatternRewriter &rewriter) const {
391   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
392   if (!memrefType.getElementType().isSignlessInteger())
393     return failure();
394 
395   auto loc = storeOp.getLoc();
396   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
397   spirv::AccessChainOp accessChainOp =
398       spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
399                            adaptor.indices(), loc, rewriter);
400 
401   if (!accessChainOp)
402     return failure();
403 
404   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
405 
406   bool isBool = srcBits == 1;
407   if (isBool)
408     srcBits = typeConverter.getOptions().boolNumBits;
409 
410   Type pointeeType = typeConverter.convertType(memrefType)
411                          .cast<spirv::PointerType>()
412                          .getPointeeType();
413   Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
414   Type dstType;
415   if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
416     dstType = arrayType.getElementType();
417   else
418     dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
419 
420   int dstBits = dstType.getIntOrFloatBitWidth();
421   assert(dstBits % srcBits == 0);
422 
423   if (srcBits == dstBits) {
424     Value storeVal = adaptor.value();
425     if (isBool)
426       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
427     rewriter.replaceOpWithNewOp<spirv::StoreOp>(
428         storeOp, accessChainOp.getResult(), storeVal);
429     return success();
430   }
431 
432   // Since there are multi threads in the processing, the emulation will be done
433   // with atomic operations. E.g., if the storing value is i8, rewrite the
434   // StoreOp to
435   // 1) load a 32-bit integer
436   // 2) clear 8 bits in the loading value
437   // 3) store 32-bit value back
438   // 4) load a 32-bit integer
439   // 5) modify 8 bits in the loading value
440   // 6) store 32-bit value back
441   // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
442   // 4 to step 6 are done by AtomicOr as another atomic step.
443   assert(accessChainOp.indices().size() == 2);
444   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
445   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
446 
447   // Create a mask to clear the destination. E.g., if it is the second i8 in
448   // i32, 0xFFFF00FF is created.
449   Value mask = rewriter.create<spirv::ConstantOp>(
450       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
451   Value clearBitsMask =
452       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
453   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
454 
455   Value storeVal = adaptor.value();
456   if (isBool)
457     storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
458   storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
459   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
460                                                    srcBits, dstBits, rewriter);
461   Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
462   if (!scope)
463     return failure();
464   Value result = rewriter.create<spirv::AtomicAndOp>(
465       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
466       clearBitsMask);
467   result = rewriter.create<spirv::AtomicOrOp>(
468       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
469       storeVal);
470 
471   // The AtomicOrOp has no side effect. Since it is already inserted, we can
472   // just remove the original StoreOp. Note that rewriter.replaceOp()
473   // doesn't work because it only accepts that the numbers of result are the
474   // same.
475   rewriter.eraseOp(storeOp);
476 
477   assert(accessChainOp.use_empty());
478   rewriter.eraseOp(accessChainOp);
479 
480   return success();
481 }
482 
483 LogicalResult
484 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
485                                 ConversionPatternRewriter &rewriter) const {
486   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
487   if (memrefType.getElementType().isSignlessInteger())
488     return failure();
489   auto storePtr = spirv::getElementPtr(
490       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
491       adaptor.indices(), storeOp.getLoc(), rewriter);
492 
493   if (!storePtr)
494     return failure();
495 
496   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
497                                               adaptor.value());
498   return success();
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // Pattern population
503 //===----------------------------------------------------------------------===//
504 
505 namespace mlir {
506 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
507                                    RewritePatternSet &patterns) {
508   patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
509                IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
510       typeConverter, patterns.getContext());
511 }
512 } // namespace mlir
513