1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "memref-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Utility functions
25 //===----------------------------------------------------------------------===//
26 
27 /// Returns the offset of the value in `targetBits` representation.
28 ///
29 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
30 /// It's assumed to be non-negative.
31 ///
32 /// When accessing an element in the array treating as having elements of
33 /// `targetBits`, multiple values are loaded in the same time. The method
34 /// returns the offset where the `srcIdx` locates in the value. For example, if
35 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
36 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
37 /// element has 8 bits.
38 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
39                                   int targetBits, OpBuilder &builder) {
40   assert(targetBits % sourceBits == 0);
41   IntegerType targetType = builder.getIntegerType(targetBits);
42   IntegerAttr idxAttr =
43       builder.getIntegerAttr(targetType, targetBits / sourceBits);
44   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
45   IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
46   auto srcBitsValue =
47       builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
48   auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
49   return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
50 }
51 
52 /// Returns an adjusted spirv::AccessChainOp. Based on the
53 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
54 /// supported. During conversion if a memref of an unsupported type is used,
55 /// load/stores to this memref need to be modified to use a supported higher
56 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
57 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
58 /// bits needed. The extraction of the actual bits needed are handled
59 /// separately. Note that this only works for a 1-D tensor.
60 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
61                                           spirv::AccessChainOp op,
62                                           int sourceBits, int targetBits,
63                                           OpBuilder &builder) {
64   assert(targetBits % sourceBits == 0);
65   const auto loc = op.getLoc();
66   IntegerType targetType = builder.getIntegerType(targetBits);
67   IntegerAttr attr =
68       builder.getIntegerAttr(targetType, targetBits / sourceBits);
69   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
70   auto lastDim = op->getOperand(op.getNumOperands() - 1);
71   auto indices = llvm::to_vector<4>(op.indices());
72   // There are two elements if this is a 1-D tensor.
73   assert(indices.size() == 2);
74   indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
75   Type t = typeConverter.convertType(op.component_ptr().getType());
76   return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
77 }
78 
79 /// Returns the shifted `targetBits`-bit value with the given offset.
80 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
81                         int targetBits, OpBuilder &builder) {
82   Type targetType = builder.getIntegerType(targetBits);
83   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
84   return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
85                                                    offset);
86 }
87 
88 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
89 static bool isAllocationSupported(MemRefType t) {
90   // Currently only support workgroup local memory allocations with static
91   // shape and int or float or vector of int or float element type.
92   if (!(t.hasStaticShape() &&
93         SPIRVTypeConverter::getMemorySpaceForStorageClass(
94             spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
95     return false;
96   Type elementType = t.getElementType();
97   if (auto vecType = elementType.dyn_cast<VectorType>())
98     elementType = vecType.getElementType();
99   return elementType.isIntOrFloat();
100 }
101 
102 /// Returns the scope to use for atomic operations use for emulating store
103 /// operations of unsupported integer bitwidths, based on the memref
104 /// type. Returns None on failure.
105 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
106   Optional<spirv::StorageClass> storageClass =
107       SPIRVTypeConverter::getStorageClassForMemorySpace(
108           t.getMemorySpaceAsInt());
109   if (!storageClass)
110     return {};
111   switch (*storageClass) {
112   case spirv::StorageClass::StorageBuffer:
113     return spirv::Scope::Device;
114   case spirv::StorageClass::Workgroup:
115     return spirv::Scope::Workgroup;
116   default: {
117   }
118   }
119   return {};
120 }
121 
122 /// Casts the given `srcInt` into a boolean value.
123 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
124   if (srcInt.getType().isInteger(1))
125     return srcInt;
126 
127   auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
128   return builder.create<spirv::IEqualOp>(loc, srcInt, one);
129 }
130 
131 /// Casts the given `srcBool` into an integer of `dstType`.
132 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
133                             OpBuilder &builder) {
134   assert(srcBool.getType().isInteger(1));
135   if (dstType.isInteger(1))
136     return srcBool;
137   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
138   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
139   return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Operation conversion
144 //===----------------------------------------------------------------------===//
145 
146 // Note that DRR cannot be used for the patterns in this file: we may need to
147 // convert type along the way, which requires ConversionPattern. DRR generates
148 // normal RewritePattern.
149 
150 namespace {
151 
152 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
153 /// to Workgroup memory when the size is constant.  Note that this pattern needs
154 /// to be applied in a pass that runs at least at spv.module scope since it wil
155 /// ladd global variables into the spv.module.
156 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
157 public:
158   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
159 
160   LogicalResult
161   matchAndRewrite(memref::AllocOp operation, ArrayRef<Value> operands,
162                   ConversionPatternRewriter &rewriter) const override;
163 };
164 
165 /// Removed a deallocation if it is a supported allocation. Currently only
166 /// removes deallocation if the memory space is workgroup memory.
167 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
168 public:
169   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
170 
171   LogicalResult
172   matchAndRewrite(memref::DeallocOp operation, ArrayRef<Value> operands,
173                   ConversionPatternRewriter &rewriter) const override;
174 };
175 
176 /// Converts memref.load to spv.Load.
177 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
178 public:
179   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
180 
181   LogicalResult
182   matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
183                   ConversionPatternRewriter &rewriter) const override;
184 };
185 
186 /// Converts memref.load to spv.Load.
187 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
188 public:
189   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
190 
191   LogicalResult
192   matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
193                   ConversionPatternRewriter &rewriter) const override;
194 };
195 
196 /// Converts memref.store to spv.Store on integers.
197 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
198 public:
199   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
200 
201   LogicalResult
202   matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
203                   ConversionPatternRewriter &rewriter) const override;
204 };
205 
206 /// Converts memref.store to spv.Store.
207 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
208 public:
209   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
210 
211   LogicalResult
212   matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
213                   ConversionPatternRewriter &rewriter) const override;
214 };
215 
216 } // namespace
217 
218 //===----------------------------------------------------------------------===//
219 // AllocOp
220 //===----------------------------------------------------------------------===//
221 
222 LogicalResult
223 AllocOpPattern::matchAndRewrite(memref::AllocOp operation,
224                                 ArrayRef<Value> operands,
225                                 ConversionPatternRewriter &rewriter) const {
226   MemRefType allocType = operation.getType();
227   if (!isAllocationSupported(allocType))
228     return operation.emitError("unhandled allocation type");
229 
230   // Get the SPIR-V type for the allocation.
231   Type spirvType = getTypeConverter()->convertType(allocType);
232 
233   // Insert spv.GlobalVariable for this allocation.
234   Operation *parent =
235       SymbolTable::getNearestSymbolTable(operation->getParentOp());
236   if (!parent)
237     return failure();
238   Location loc = operation.getLoc();
239   spirv::GlobalVariableOp varOp;
240   {
241     OpBuilder::InsertionGuard guard(rewriter);
242     Block &entryBlock = *parent->getRegion(0).begin();
243     rewriter.setInsertionPointToStart(&entryBlock);
244     auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
245     std::string varName =
246         std::string("__workgroup_mem__") +
247         std::to_string(std::distance(varOps.begin(), varOps.end()));
248     varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
249                                                      /*initializer=*/nullptr);
250   }
251 
252   // Get pointer to global variable at the current scope.
253   rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
254   return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // DeallocOp
259 //===----------------------------------------------------------------------===//
260 
261 LogicalResult
262 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
263                                   ArrayRef<Value> operands,
264                                   ConversionPatternRewriter &rewriter) const {
265   MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
266   if (!isAllocationSupported(deallocType))
267     return operation.emitError("unhandled deallocation type");
268   rewriter.eraseOp(operation);
269   return success();
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // LoadOp
274 //===----------------------------------------------------------------------===//
275 
276 LogicalResult
277 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
278                                   ArrayRef<Value> operands,
279                                   ConversionPatternRewriter &rewriter) const {
280   memref::LoadOpAdaptor loadOperands(operands);
281   auto loc = loadOp.getLoc();
282   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
283   if (!memrefType.getElementType().isSignlessInteger())
284     return failure();
285 
286   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
287   spirv::AccessChainOp accessChainOp =
288       spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
289                            loadOperands.indices(), loc, rewriter);
290 
291   if (!accessChainOp)
292     return failure();
293 
294   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
295   bool isBool = srcBits == 1;
296   if (isBool)
297     srcBits = typeConverter.getOptions().boolNumBits;
298   Type pointeeType = typeConverter.convertType(memrefType)
299                          .cast<spirv::PointerType>()
300                          .getPointeeType();
301   Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
302   Type dstType;
303   if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
304     dstType = arrayType.getElementType();
305   else
306     dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
307 
308   int dstBits = dstType.getIntOrFloatBitWidth();
309   assert(dstBits % srcBits == 0);
310 
311   // If the rewrited load op has the same bit width, use the loading value
312   // directly.
313   if (srcBits == dstBits) {
314     Value loadVal =
315         rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
316     if (isBool)
317       loadVal = castIntNToBool(loc, loadVal, rewriter);
318     rewriter.replaceOp(loadOp, loadVal);
319     return success();
320   }
321 
322   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
323   // still returns a linearized accessing. If the accessing is not linearized,
324   // there will be offset issues.
325   assert(accessChainOp.indices().size() == 2);
326   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
327                                                    srcBits, dstBits, rewriter);
328   Value spvLoadOp = rewriter.create<spirv::LoadOp>(
329       loc, dstType, adjustedPtr,
330       loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
331           spirv::attributeName<spirv::MemoryAccess>()),
332       loadOp->getAttrOfType<IntegerAttr>("alignment"));
333 
334   // Shift the bits to the rightmost.
335   // ____XXXX________ -> ____________XXXX
336   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
337   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
338   Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
339       loc, spvLoadOp.getType(), spvLoadOp, offset);
340 
341   // Apply the mask to extract corresponding bits.
342   Value mask = rewriter.create<spirv::ConstantOp>(
343       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
344   result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
345 
346   // Apply sign extension on the loading value unconditionally. The signedness
347   // semantic is carried in the operator itself, we relies other pattern to
348   // handle the casting.
349   IntegerAttr shiftValueAttr =
350       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
351   Value shiftValue =
352       rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
353   result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
354                                                       shiftValue);
355   result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
356                                                           shiftValue);
357 
358   if (isBool) {
359     dstType = typeConverter.convertType(loadOp.getType());
360     mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
361     result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
362   } else if (result.getType().getIntOrFloatBitWidth() !=
363              static_cast<unsigned>(dstBits)) {
364     result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
365   }
366   rewriter.replaceOp(loadOp, result);
367 
368   assert(accessChainOp.use_empty());
369   rewriter.eraseOp(accessChainOp);
370 
371   return success();
372 }
373 
374 LogicalResult
375 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
376                                ConversionPatternRewriter &rewriter) const {
377   memref::LoadOpAdaptor loadOperands(operands);
378   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
379   if (memrefType.getElementType().isSignlessInteger())
380     return failure();
381   auto loadPtr = spirv::getElementPtr(
382       *getTypeConverter<SPIRVTypeConverter>(), memrefType,
383       loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
384 
385   if (!loadPtr)
386     return failure();
387 
388   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
389   return success();
390 }
391 
392 LogicalResult
393 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
394                                    ArrayRef<Value> operands,
395                                    ConversionPatternRewriter &rewriter) const {
396   memref::StoreOpAdaptor storeOperands(operands);
397   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
398   if (!memrefType.getElementType().isSignlessInteger())
399     return failure();
400 
401   auto loc = storeOp.getLoc();
402   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
403   spirv::AccessChainOp accessChainOp =
404       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
405                            storeOperands.indices(), loc, rewriter);
406 
407   if (!accessChainOp)
408     return failure();
409 
410   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
411 
412   bool isBool = srcBits == 1;
413   if (isBool)
414     srcBits = typeConverter.getOptions().boolNumBits;
415 
416   Type pointeeType = typeConverter.convertType(memrefType)
417                          .cast<spirv::PointerType>()
418                          .getPointeeType();
419   Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
420   Type dstType;
421   if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
422     dstType = arrayType.getElementType();
423   else
424     dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
425 
426   int dstBits = dstType.getIntOrFloatBitWidth();
427   assert(dstBits % srcBits == 0);
428 
429   if (srcBits == dstBits) {
430     Value storeVal = storeOperands.value();
431     if (isBool)
432       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
433     rewriter.replaceOpWithNewOp<spirv::StoreOp>(
434         storeOp, accessChainOp.getResult(), storeVal);
435     return success();
436   }
437 
438   // Since there are multi threads in the processing, the emulation will be done
439   // with atomic operations. E.g., if the storing value is i8, rewrite the
440   // StoreOp to
441   // 1) load a 32-bit integer
442   // 2) clear 8 bits in the loading value
443   // 3) store 32-bit value back
444   // 4) load a 32-bit integer
445   // 5) modify 8 bits in the loading value
446   // 6) store 32-bit value back
447   // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
448   // 4 to step 6 are done by AtomicOr as another atomic step.
449   assert(accessChainOp.indices().size() == 2);
450   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
451   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
452 
453   // Create a mask to clear the destination. E.g., if it is the second i8 in
454   // i32, 0xFFFF00FF is created.
455   Value mask = rewriter.create<spirv::ConstantOp>(
456       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
457   Value clearBitsMask =
458       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
459   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
460 
461   Value storeVal = storeOperands.value();
462   if (isBool)
463     storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
464   storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
465   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
466                                                    srcBits, dstBits, rewriter);
467   Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
468   if (!scope)
469     return failure();
470   Value result = rewriter.create<spirv::AtomicAndOp>(
471       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
472       clearBitsMask);
473   result = rewriter.create<spirv::AtomicOrOp>(
474       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
475       storeVal);
476 
477   // The AtomicOrOp has no side effect. Since it is already inserted, we can
478   // just remove the original StoreOp. Note that rewriter.replaceOp()
479   // doesn't work because it only accepts that the numbers of result are the
480   // same.
481   rewriter.eraseOp(storeOp);
482 
483   assert(accessChainOp.use_empty());
484   rewriter.eraseOp(accessChainOp);
485 
486   return success();
487 }
488 
489 LogicalResult
490 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
491                                 ArrayRef<Value> operands,
492                                 ConversionPatternRewriter &rewriter) const {
493   memref::StoreOpAdaptor storeOperands(operands);
494   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
495   if (memrefType.getElementType().isSignlessInteger())
496     return failure();
497   auto storePtr =
498       spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
499                            storeOperands.memref(), storeOperands.indices(),
500                            storeOp.getLoc(), rewriter);
501 
502   if (!storePtr)
503     return failure();
504 
505   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
506                                               storeOperands.value());
507   return success();
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // Pattern population
512 //===----------------------------------------------------------------------===//
513 
514 namespace mlir {
515 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
516                                    RewritePatternSet &patterns) {
517   patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
518                IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
519       typeConverter, patterns.getContext());
520 }
521 } // namespace mlir
522