1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "memref-to-spirv-pattern" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Utility functions 25 //===----------------------------------------------------------------------===// 26 27 /// Returns the offset of the value in `targetBits` representation. 28 /// 29 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. 30 /// It's assumed to be non-negative. 31 /// 32 /// When accessing an element in the array treating as having elements of 33 /// `targetBits`, multiple values are loaded in the same time. The method 34 /// returns the offset where the `srcIdx` locates in the value. For example, if 35 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is 36 /// located at (x % 4) * 8. Because there are four elements in one i32, and one 37 /// element has 8 bits. 38 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, 39 int targetBits, OpBuilder &builder) { 40 assert(targetBits % sourceBits == 0); 41 IntegerType targetType = builder.getIntegerType(targetBits); 42 IntegerAttr idxAttr = 43 builder.getIntegerAttr(targetType, targetBits / sourceBits); 44 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr); 45 IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); 46 auto srcBitsValue = 47 builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr); 48 auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx); 49 return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue); 50 } 51 52 /// Returns an adjusted spirv::AccessChainOp. Based on the 53 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be 54 /// supported. During conversion if a memref of an unsupported type is used, 55 /// load/stores to this memref need to be modified to use a supported higher 56 /// bitwidth `targetBits` and extracting the required bits. For an accessing a 57 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the 58 /// bits needed. The extraction of the actual bits needed are handled 59 /// separately. Note that this only works for a 1-D tensor. 60 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, 61 spirv::AccessChainOp op, 62 int sourceBits, int targetBits, 63 OpBuilder &builder) { 64 assert(targetBits % sourceBits == 0); 65 const auto loc = op.getLoc(); 66 IntegerType targetType = builder.getIntegerType(targetBits); 67 IntegerAttr attr = 68 builder.getIntegerAttr(targetType, targetBits / sourceBits); 69 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr); 70 auto lastDim = op->getOperand(op.getNumOperands() - 1); 71 auto indices = llvm::to_vector<4>(op.indices()); 72 // There are two elements if this is a 1-D tensor. 73 assert(indices.size() == 2); 74 indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx); 75 Type t = typeConverter.convertType(op.component_ptr().getType()); 76 return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices); 77 } 78 79 /// Returns the shifted `targetBits`-bit value with the given offset. 80 static Value shiftValue(Location loc, Value value, Value offset, Value mask, 81 int targetBits, OpBuilder &builder) { 82 Type targetType = builder.getIntegerType(targetBits); 83 Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask); 84 return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result, 85 offset); 86 } 87 88 /// Returns true if the allocations of type `t` can be lowered to SPIR-V. 89 static bool isAllocationSupported(MemRefType t) { 90 // Currently only support workgroup local memory allocations with static 91 // shape and int or float or vector of int or float element type. 92 if (!(t.hasStaticShape() && 93 SPIRVTypeConverter::getMemorySpaceForStorageClass( 94 spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) 95 return false; 96 Type elementType = t.getElementType(); 97 if (auto vecType = elementType.dyn_cast<VectorType>()) 98 elementType = vecType.getElementType(); 99 return elementType.isIntOrFloat(); 100 } 101 102 /// Returns the scope to use for atomic operations use for emulating store 103 /// operations of unsupported integer bitwidths, based on the memref 104 /// type. Returns None on failure. 105 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) { 106 Optional<spirv::StorageClass> storageClass = 107 SPIRVTypeConverter::getStorageClassForMemorySpace( 108 t.getMemorySpaceAsInt()); 109 if (!storageClass) 110 return {}; 111 switch (*storageClass) { 112 case spirv::StorageClass::StorageBuffer: 113 return spirv::Scope::Device; 114 case spirv::StorageClass::Workgroup: 115 return spirv::Scope::Workgroup; 116 default: { 117 } 118 } 119 return {}; 120 } 121 122 /// Casts the given `srcInt` into a boolean value. 123 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { 124 if (srcInt.getType().isInteger(1)) 125 return srcInt; 126 127 auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); 128 return builder.create<spirv::IEqualOp>(loc, srcInt, one); 129 } 130 131 /// Casts the given `srcBool` into an integer of `dstType`. 132 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, 133 OpBuilder &builder) { 134 assert(srcBool.getType().isInteger(1)); 135 if (dstType.isInteger(1)) 136 return srcBool; 137 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); 138 Value one = spirv::ConstantOp::getOne(dstType, loc, builder); 139 return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Operation conversion 144 //===----------------------------------------------------------------------===// 145 146 // Note that DRR cannot be used for the patterns in this file: we may need to 147 // convert type along the way, which requires ConversionPattern. DRR generates 148 // normal RewritePattern. 149 150 namespace { 151 152 /// Converts an allocation operation to SPIR-V. Currently only supports lowering 153 /// to Workgroup memory when the size is constant. Note that this pattern needs 154 /// to be applied in a pass that runs at least at spv.module scope since it wil 155 /// ladd global variables into the spv.module. 156 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { 157 public: 158 using OpConversionPattern<memref::AllocOp>::OpConversionPattern; 159 160 LogicalResult 161 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 162 ConversionPatternRewriter &rewriter) const override; 163 }; 164 165 /// Removed a deallocation if it is a supported allocation. Currently only 166 /// removes deallocation if the memory space is workgroup memory. 167 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { 168 public: 169 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; 170 171 LogicalResult 172 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, 173 ConversionPatternRewriter &rewriter) const override; 174 }; 175 176 /// Converts memref.load to spv.Load. 177 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 178 public: 179 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 180 181 LogicalResult 182 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 183 ConversionPatternRewriter &rewriter) const override; 184 }; 185 186 /// Converts memref.load to spv.Load. 187 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 188 public: 189 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 190 191 LogicalResult 192 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 193 ConversionPatternRewriter &rewriter) const override; 194 }; 195 196 /// Converts memref.store to spv.Store on integers. 197 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 198 public: 199 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 200 201 LogicalResult 202 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 203 ConversionPatternRewriter &rewriter) const override; 204 }; 205 206 /// Converts memref.store to spv.Store. 207 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 208 public: 209 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 210 211 LogicalResult 212 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 213 ConversionPatternRewriter &rewriter) const override; 214 }; 215 216 } // namespace 217 218 //===----------------------------------------------------------------------===// 219 // AllocOp 220 //===----------------------------------------------------------------------===// 221 222 LogicalResult 223 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 224 ConversionPatternRewriter &rewriter) const { 225 MemRefType allocType = operation.getType(); 226 if (!isAllocationSupported(allocType)) 227 return operation.emitError("unhandled allocation type"); 228 229 // Get the SPIR-V type for the allocation. 230 Type spirvType = getTypeConverter()->convertType(allocType); 231 232 // Insert spv.GlobalVariable for this allocation. 233 Operation *parent = 234 SymbolTable::getNearestSymbolTable(operation->getParentOp()); 235 if (!parent) 236 return failure(); 237 Location loc = operation.getLoc(); 238 spirv::GlobalVariableOp varOp; 239 { 240 OpBuilder::InsertionGuard guard(rewriter); 241 Block &entryBlock = *parent->getRegion(0).begin(); 242 rewriter.setInsertionPointToStart(&entryBlock); 243 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); 244 std::string varName = 245 std::string("__workgroup_mem__") + 246 std::to_string(std::distance(varOps.begin(), varOps.end())); 247 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, 248 /*initializer=*/nullptr); 249 } 250 251 // Get pointer to global variable at the current scope. 252 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); 253 return success(); 254 } 255 256 //===----------------------------------------------------------------------===// 257 // DeallocOp 258 //===----------------------------------------------------------------------===// 259 260 LogicalResult 261 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, 262 OpAdaptor adaptor, 263 ConversionPatternRewriter &rewriter) const { 264 MemRefType deallocType = operation.memref().getType().cast<MemRefType>(); 265 if (!isAllocationSupported(deallocType)) 266 return operation.emitError("unhandled deallocation type"); 267 rewriter.eraseOp(operation); 268 return success(); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // LoadOp 273 //===----------------------------------------------------------------------===// 274 275 LogicalResult 276 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 277 ConversionPatternRewriter &rewriter) const { 278 auto loc = loadOp.getLoc(); 279 auto memrefType = loadOp.memref().getType().cast<MemRefType>(); 280 if (!memrefType.getElementType().isSignlessInteger()) 281 return failure(); 282 283 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 284 spirv::AccessChainOp accessChainOp = 285 spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), 286 adaptor.indices(), loc, rewriter); 287 288 if (!accessChainOp) 289 return failure(); 290 291 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 292 bool isBool = srcBits == 1; 293 if (isBool) 294 srcBits = typeConverter.getOptions().boolNumBits; 295 Type pointeeType = typeConverter.convertType(memrefType) 296 .cast<spirv::PointerType>() 297 .getPointeeType(); 298 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); 299 Type dstType; 300 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) 301 dstType = arrayType.getElementType(); 302 else 303 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); 304 305 int dstBits = dstType.getIntOrFloatBitWidth(); 306 assert(dstBits % srcBits == 0); 307 308 // If the rewrited load op has the same bit width, use the loading value 309 // directly. 310 if (srcBits == dstBits) { 311 Value loadVal = 312 rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult()); 313 if (isBool) 314 loadVal = castIntNToBool(loc, loadVal, rewriter); 315 rewriter.replaceOp(loadOp, loadVal); 316 return success(); 317 } 318 319 // Assume that getElementPtr() works linearizely. If it's a scalar, the method 320 // still returns a linearized accessing. If the accessing is not linearized, 321 // there will be offset issues. 322 assert(accessChainOp.indices().size() == 2); 323 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 324 srcBits, dstBits, rewriter); 325 Value spvLoadOp = rewriter.create<spirv::LoadOp>( 326 loc, dstType, adjustedPtr, 327 loadOp->getAttrOfType<spirv::MemoryAccessAttr>( 328 spirv::attributeName<spirv::MemoryAccess>()), 329 loadOp->getAttrOfType<IntegerAttr>("alignment")); 330 331 // Shift the bits to the rightmost. 332 // ____XXXX________ -> ____________XXXX 333 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 334 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 335 Value result = rewriter.create<spirv::ShiftRightArithmeticOp>( 336 loc, spvLoadOp.getType(), spvLoadOp, offset); 337 338 // Apply the mask to extract corresponding bits. 339 Value mask = rewriter.create<spirv::ConstantOp>( 340 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 341 result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask); 342 343 // Apply sign extension on the loading value unconditionally. The signedness 344 // semantic is carried in the operator itself, we relies other pattern to 345 // handle the casting. 346 IntegerAttr shiftValueAttr = 347 rewriter.getIntegerAttr(dstType, dstBits - srcBits); 348 Value shiftValue = 349 rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr); 350 result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result, 351 shiftValue); 352 result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result, 353 shiftValue); 354 355 if (isBool) { 356 dstType = typeConverter.convertType(loadOp.getType()); 357 mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); 358 result = rewriter.create<spirv::IEqualOp>(loc, result, mask); 359 } else if (result.getType().getIntOrFloatBitWidth() != 360 static_cast<unsigned>(dstBits)) { 361 result = rewriter.create<spirv::SConvertOp>(loc, dstType, result); 362 } 363 rewriter.replaceOp(loadOp, result); 364 365 assert(accessChainOp.use_empty()); 366 rewriter.eraseOp(accessChainOp); 367 368 return success(); 369 } 370 371 LogicalResult 372 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 373 ConversionPatternRewriter &rewriter) const { 374 auto memrefType = loadOp.memref().getType().cast<MemRefType>(); 375 if (memrefType.getElementType().isSignlessInteger()) 376 return failure(); 377 auto loadPtr = spirv::getElementPtr( 378 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(), 379 adaptor.indices(), loadOp.getLoc(), rewriter); 380 381 if (!loadPtr) 382 return failure(); 383 384 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr); 385 return success(); 386 } 387 388 LogicalResult 389 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 390 ConversionPatternRewriter &rewriter) const { 391 auto memrefType = storeOp.memref().getType().cast<MemRefType>(); 392 if (!memrefType.getElementType().isSignlessInteger()) 393 return failure(); 394 395 auto loc = storeOp.getLoc(); 396 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 397 spirv::AccessChainOp accessChainOp = 398 spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), 399 adaptor.indices(), loc, rewriter); 400 401 if (!accessChainOp) 402 return failure(); 403 404 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 405 406 bool isBool = srcBits == 1; 407 if (isBool) 408 srcBits = typeConverter.getOptions().boolNumBits; 409 410 Type pointeeType = typeConverter.convertType(memrefType) 411 .cast<spirv::PointerType>() 412 .getPointeeType(); 413 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); 414 Type dstType; 415 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) 416 dstType = arrayType.getElementType(); 417 else 418 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); 419 420 int dstBits = dstType.getIntOrFloatBitWidth(); 421 assert(dstBits % srcBits == 0); 422 423 if (srcBits == dstBits) { 424 Value storeVal = adaptor.value(); 425 if (isBool) 426 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 427 rewriter.replaceOpWithNewOp<spirv::StoreOp>( 428 storeOp, accessChainOp.getResult(), storeVal); 429 return success(); 430 } 431 432 // Since there are multi threads in the processing, the emulation will be done 433 // with atomic operations. E.g., if the storing value is i8, rewrite the 434 // StoreOp to 435 // 1) load a 32-bit integer 436 // 2) clear 8 bits in the loading value 437 // 3) store 32-bit value back 438 // 4) load a 32-bit integer 439 // 5) modify 8 bits in the loading value 440 // 6) store 32-bit value back 441 // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step 442 // 4 to step 6 are done by AtomicOr as another atomic step. 443 assert(accessChainOp.indices().size() == 2); 444 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 445 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 446 447 // Create a mask to clear the destination. E.g., if it is the second i8 in 448 // i32, 0xFFFF00FF is created. 449 Value mask = rewriter.create<spirv::ConstantOp>( 450 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 451 Value clearBitsMask = 452 rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset); 453 clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask); 454 455 Value storeVal = adaptor.value(); 456 if (isBool) 457 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 458 storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); 459 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 460 srcBits, dstBits, rewriter); 461 Optional<spirv::Scope> scope = getAtomicOpScope(memrefType); 462 if (!scope) 463 return failure(); 464 Value result = rewriter.create<spirv::AtomicAndOp>( 465 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 466 clearBitsMask); 467 result = rewriter.create<spirv::AtomicOrOp>( 468 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 469 storeVal); 470 471 // The AtomicOrOp has no side effect. Since it is already inserted, we can 472 // just remove the original StoreOp. Note that rewriter.replaceOp() 473 // doesn't work because it only accepts that the numbers of result are the 474 // same. 475 rewriter.eraseOp(storeOp); 476 477 assert(accessChainOp.use_empty()); 478 rewriter.eraseOp(accessChainOp); 479 480 return success(); 481 } 482 483 LogicalResult 484 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 485 ConversionPatternRewriter &rewriter) const { 486 auto memrefType = storeOp.memref().getType().cast<MemRefType>(); 487 if (memrefType.getElementType().isSignlessInteger()) 488 return failure(); 489 auto storePtr = spirv::getElementPtr( 490 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(), 491 adaptor.indices(), storeOp.getLoc(), rewriter); 492 493 if (!storePtr) 494 return failure(); 495 496 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, 497 adaptor.value()); 498 return success(); 499 } 500 501 //===----------------------------------------------------------------------===// 502 // Pattern population 503 //===----------------------------------------------------------------------===// 504 505 namespace mlir { 506 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 507 RewritePatternSet &patterns) { 508 patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, 509 IntStoreOpPattern, LoadOpPattern, StoreOpPattern>( 510 typeConverter, patterns.getContext()); 511 } 512 } // namespace mlir 513