1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "memref-to-spirv-pattern" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Utility functions 25 //===----------------------------------------------------------------------===// 26 27 /// Returns the offset of the value in `targetBits` representation. 28 /// 29 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. 30 /// It's assumed to be non-negative. 31 /// 32 /// When accessing an element in the array treating as having elements of 33 /// `targetBits`, multiple values are loaded in the same time. The method 34 /// returns the offset where the `srcIdx` locates in the value. For example, if 35 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is 36 /// located at (x % 4) * 8. Because there are four elements in one i32, and one 37 /// element has 8 bits. 38 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, 39 int targetBits, OpBuilder &builder) { 40 assert(targetBits % sourceBits == 0); 41 IntegerType targetType = builder.getIntegerType(targetBits); 42 IntegerAttr idxAttr = 43 builder.getIntegerAttr(targetType, targetBits / sourceBits); 44 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr); 45 IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); 46 auto srcBitsValue = 47 builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr); 48 auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx); 49 return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue); 50 } 51 52 /// Returns an adjusted spirv::AccessChainOp. Based on the 53 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be 54 /// supported. During conversion if a memref of an unsupported type is used, 55 /// load/stores to this memref need to be modified to use a supported higher 56 /// bitwidth `targetBits` and extracting the required bits. For an accessing a 57 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the 58 /// bits needed. The extraction of the actual bits needed are handled 59 /// separately. Note that this only works for a 1-D tensor. 60 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, 61 spirv::AccessChainOp op, 62 int sourceBits, int targetBits, 63 OpBuilder &builder) { 64 assert(targetBits % sourceBits == 0); 65 const auto loc = op.getLoc(); 66 IntegerType targetType = builder.getIntegerType(targetBits); 67 IntegerAttr attr = 68 builder.getIntegerAttr(targetType, targetBits / sourceBits); 69 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr); 70 auto lastDim = op->getOperand(op.getNumOperands() - 1); 71 auto indices = llvm::to_vector<4>(op.indices()); 72 // There are two elements if this is a 1-D tensor. 73 assert(indices.size() == 2); 74 indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx); 75 Type t = typeConverter.convertType(op.component_ptr().getType()); 76 return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices); 77 } 78 79 /// Returns the shifted `targetBits`-bit value with the given offset. 80 static Value shiftValue(Location loc, Value value, Value offset, Value mask, 81 int targetBits, OpBuilder &builder) { 82 Type targetType = builder.getIntegerType(targetBits); 83 Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask); 84 return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result, 85 offset); 86 } 87 88 /// Returns true if the allocations of type `t` can be lowered to SPIR-V. 89 static bool isAllocationSupported(MemRefType t) { 90 // Currently only support workgroup local memory allocations with static 91 // shape and int or float or vector of int or float element type. 92 if (!(t.hasStaticShape() && 93 SPIRVTypeConverter::getMemorySpaceForStorageClass( 94 spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) 95 return false; 96 Type elementType = t.getElementType(); 97 if (auto vecType = elementType.dyn_cast<VectorType>()) 98 elementType = vecType.getElementType(); 99 return elementType.isIntOrFloat(); 100 } 101 102 /// Returns the scope to use for atomic operations use for emulating store 103 /// operations of unsupported integer bitwidths, based on the memref 104 /// type. Returns None on failure. 105 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) { 106 Optional<spirv::StorageClass> storageClass = 107 SPIRVTypeConverter::getStorageClassForMemorySpace( 108 t.getMemorySpaceAsInt()); 109 if (!storageClass) 110 return {}; 111 switch (*storageClass) { 112 case spirv::StorageClass::StorageBuffer: 113 return spirv::Scope::Device; 114 case spirv::StorageClass::Workgroup: 115 return spirv::Scope::Workgroup; 116 default: { 117 } 118 } 119 return {}; 120 } 121 122 /// Casts the given `srcInt` into a boolean value. 123 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { 124 if (srcInt.getType().isInteger(1)) 125 return srcInt; 126 127 auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); 128 return builder.create<spirv::IEqualOp>(loc, srcInt, one); 129 } 130 131 /// Casts the given `srcBool` into an integer of `dstType`. 132 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, 133 OpBuilder &builder) { 134 assert(srcBool.getType().isInteger(1)); 135 if (dstType.isInteger(1)) 136 return srcBool; 137 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); 138 Value one = spirv::ConstantOp::getOne(dstType, loc, builder); 139 return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Operation conversion 144 //===----------------------------------------------------------------------===// 145 146 // Note that DRR cannot be used for the patterns in this file: we may need to 147 // convert type along the way, which requires ConversionPattern. DRR generates 148 // normal RewritePattern. 149 150 namespace { 151 152 /// Converts an allocation operation to SPIR-V. Currently only supports lowering 153 /// to Workgroup memory when the size is constant. Note that this pattern needs 154 /// to be applied in a pass that runs at least at spv.module scope since it wil 155 /// ladd global variables into the spv.module. 156 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { 157 public: 158 using OpConversionPattern<memref::AllocOp>::OpConversionPattern; 159 160 LogicalResult 161 matchAndRewrite(memref::AllocOp operation, ArrayRef<Value> operands, 162 ConversionPatternRewriter &rewriter) const override; 163 }; 164 165 /// Removed a deallocation if it is a supported allocation. Currently only 166 /// removes deallocation if the memory space is workgroup memory. 167 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { 168 public: 169 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; 170 171 LogicalResult 172 matchAndRewrite(memref::DeallocOp operation, ArrayRef<Value> operands, 173 ConversionPatternRewriter &rewriter) const override; 174 }; 175 176 /// Converts memref.load to spv.Load. 177 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 178 public: 179 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 180 181 LogicalResult 182 matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands, 183 ConversionPatternRewriter &rewriter) const override; 184 }; 185 186 /// Converts memref.load to spv.Load. 187 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 188 public: 189 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 190 191 LogicalResult 192 matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands, 193 ConversionPatternRewriter &rewriter) const override; 194 }; 195 196 /// Converts memref.store to spv.Store on integers. 197 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 198 public: 199 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 200 201 LogicalResult 202 matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands, 203 ConversionPatternRewriter &rewriter) const override; 204 }; 205 206 /// Converts memref.store to spv.Store. 207 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 208 public: 209 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 210 211 LogicalResult 212 matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands, 213 ConversionPatternRewriter &rewriter) const override; 214 }; 215 216 } // namespace 217 218 //===----------------------------------------------------------------------===// 219 // AllocOp 220 //===----------------------------------------------------------------------===// 221 222 LogicalResult 223 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, 224 ArrayRef<Value> operands, 225 ConversionPatternRewriter &rewriter) const { 226 MemRefType allocType = operation.getType(); 227 if (!isAllocationSupported(allocType)) 228 return operation.emitError("unhandled allocation type"); 229 230 // Get the SPIR-V type for the allocation. 231 Type spirvType = getTypeConverter()->convertType(allocType); 232 233 // Insert spv.GlobalVariable for this allocation. 234 Operation *parent = 235 SymbolTable::getNearestSymbolTable(operation->getParentOp()); 236 if (!parent) 237 return failure(); 238 Location loc = operation.getLoc(); 239 spirv::GlobalVariableOp varOp; 240 { 241 OpBuilder::InsertionGuard guard(rewriter); 242 Block &entryBlock = *parent->getRegion(0).begin(); 243 rewriter.setInsertionPointToStart(&entryBlock); 244 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); 245 std::string varName = 246 std::string("__workgroup_mem__") + 247 std::to_string(std::distance(varOps.begin(), varOps.end())); 248 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, 249 /*initializer=*/nullptr); 250 } 251 252 // Get pointer to global variable at the current scope. 253 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); 254 return success(); 255 } 256 257 //===----------------------------------------------------------------------===// 258 // DeallocOp 259 //===----------------------------------------------------------------------===// 260 261 LogicalResult 262 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, 263 ArrayRef<Value> operands, 264 ConversionPatternRewriter &rewriter) const { 265 MemRefType deallocType = operation.memref().getType().cast<MemRefType>(); 266 if (!isAllocationSupported(deallocType)) 267 return operation.emitError("unhandled deallocation type"); 268 rewriter.eraseOp(operation); 269 return success(); 270 } 271 272 //===----------------------------------------------------------------------===// 273 // LoadOp 274 //===----------------------------------------------------------------------===// 275 276 LogicalResult 277 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, 278 ArrayRef<Value> operands, 279 ConversionPatternRewriter &rewriter) const { 280 memref::LoadOpAdaptor loadOperands(operands); 281 auto loc = loadOp.getLoc(); 282 auto memrefType = loadOp.memref().getType().cast<MemRefType>(); 283 if (!memrefType.getElementType().isSignlessInteger()) 284 return failure(); 285 286 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 287 spirv::AccessChainOp accessChainOp = 288 spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), 289 loadOperands.indices(), loc, rewriter); 290 291 if (!accessChainOp) 292 return failure(); 293 294 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 295 bool isBool = srcBits == 1; 296 if (isBool) 297 srcBits = typeConverter.getOptions().boolNumBits; 298 Type pointeeType = typeConverter.convertType(memrefType) 299 .cast<spirv::PointerType>() 300 .getPointeeType(); 301 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); 302 Type dstType; 303 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) 304 dstType = arrayType.getElementType(); 305 else 306 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); 307 308 int dstBits = dstType.getIntOrFloatBitWidth(); 309 assert(dstBits % srcBits == 0); 310 311 // If the rewrited load op has the same bit width, use the loading value 312 // directly. 313 if (srcBits == dstBits) { 314 Value loadVal = 315 rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult()); 316 if (isBool) 317 loadVal = castIntNToBool(loc, loadVal, rewriter); 318 rewriter.replaceOp(loadOp, loadVal); 319 return success(); 320 } 321 322 // Assume that getElementPtr() works linearizely. If it's a scalar, the method 323 // still returns a linearized accessing. If the accessing is not linearized, 324 // there will be offset issues. 325 assert(accessChainOp.indices().size() == 2); 326 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 327 srcBits, dstBits, rewriter); 328 Value spvLoadOp = rewriter.create<spirv::LoadOp>( 329 loc, dstType, adjustedPtr, 330 loadOp->getAttrOfType<spirv::MemoryAccessAttr>( 331 spirv::attributeName<spirv::MemoryAccess>()), 332 loadOp->getAttrOfType<IntegerAttr>("alignment")); 333 334 // Shift the bits to the rightmost. 335 // ____XXXX________ -> ____________XXXX 336 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 337 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 338 Value result = rewriter.create<spirv::ShiftRightArithmeticOp>( 339 loc, spvLoadOp.getType(), spvLoadOp, offset); 340 341 // Apply the mask to extract corresponding bits. 342 Value mask = rewriter.create<spirv::ConstantOp>( 343 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 344 result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask); 345 346 // Apply sign extension on the loading value unconditionally. The signedness 347 // semantic is carried in the operator itself, we relies other pattern to 348 // handle the casting. 349 IntegerAttr shiftValueAttr = 350 rewriter.getIntegerAttr(dstType, dstBits - srcBits); 351 Value shiftValue = 352 rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr); 353 result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result, 354 shiftValue); 355 result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result, 356 shiftValue); 357 358 if (isBool) { 359 dstType = typeConverter.convertType(loadOp.getType()); 360 mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); 361 result = rewriter.create<spirv::IEqualOp>(loc, result, mask); 362 } else if (result.getType().getIntOrFloatBitWidth() != 363 static_cast<unsigned>(dstBits)) { 364 result = rewriter.create<spirv::SConvertOp>(loc, dstType, result); 365 } 366 rewriter.replaceOp(loadOp, result); 367 368 assert(accessChainOp.use_empty()); 369 rewriter.eraseOp(accessChainOp); 370 371 return success(); 372 } 373 374 LogicalResult 375 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands, 376 ConversionPatternRewriter &rewriter) const { 377 memref::LoadOpAdaptor loadOperands(operands); 378 auto memrefType = loadOp.memref().getType().cast<MemRefType>(); 379 if (memrefType.getElementType().isSignlessInteger()) 380 return failure(); 381 auto loadPtr = spirv::getElementPtr( 382 *getTypeConverter<SPIRVTypeConverter>(), memrefType, 383 loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); 384 385 if (!loadPtr) 386 return failure(); 387 388 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr); 389 return success(); 390 } 391 392 LogicalResult 393 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, 394 ArrayRef<Value> operands, 395 ConversionPatternRewriter &rewriter) const { 396 memref::StoreOpAdaptor storeOperands(operands); 397 auto memrefType = storeOp.memref().getType().cast<MemRefType>(); 398 if (!memrefType.getElementType().isSignlessInteger()) 399 return failure(); 400 401 auto loc = storeOp.getLoc(); 402 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 403 spirv::AccessChainOp accessChainOp = 404 spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), 405 storeOperands.indices(), loc, rewriter); 406 407 if (!accessChainOp) 408 return failure(); 409 410 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 411 412 bool isBool = srcBits == 1; 413 if (isBool) 414 srcBits = typeConverter.getOptions().boolNumBits; 415 416 Type pointeeType = typeConverter.convertType(memrefType) 417 .cast<spirv::PointerType>() 418 .getPointeeType(); 419 Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); 420 Type dstType; 421 if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) 422 dstType = arrayType.getElementType(); 423 else 424 dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); 425 426 int dstBits = dstType.getIntOrFloatBitWidth(); 427 assert(dstBits % srcBits == 0); 428 429 if (srcBits == dstBits) { 430 Value storeVal = storeOperands.value(); 431 if (isBool) 432 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 433 rewriter.replaceOpWithNewOp<spirv::StoreOp>( 434 storeOp, accessChainOp.getResult(), storeVal); 435 return success(); 436 } 437 438 // Since there are multi threads in the processing, the emulation will be done 439 // with atomic operations. E.g., if the storing value is i8, rewrite the 440 // StoreOp to 441 // 1) load a 32-bit integer 442 // 2) clear 8 bits in the loading value 443 // 3) store 32-bit value back 444 // 4) load a 32-bit integer 445 // 5) modify 8 bits in the loading value 446 // 6) store 32-bit value back 447 // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step 448 // 4 to step 6 are done by AtomicOr as another atomic step. 449 assert(accessChainOp.indices().size() == 2); 450 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 451 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 452 453 // Create a mask to clear the destination. E.g., if it is the second i8 in 454 // i32, 0xFFFF00FF is created. 455 Value mask = rewriter.create<spirv::ConstantOp>( 456 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 457 Value clearBitsMask = 458 rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset); 459 clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask); 460 461 Value storeVal = storeOperands.value(); 462 if (isBool) 463 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 464 storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); 465 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 466 srcBits, dstBits, rewriter); 467 Optional<spirv::Scope> scope = getAtomicOpScope(memrefType); 468 if (!scope) 469 return failure(); 470 Value result = rewriter.create<spirv::AtomicAndOp>( 471 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 472 clearBitsMask); 473 result = rewriter.create<spirv::AtomicOrOp>( 474 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 475 storeVal); 476 477 // The AtomicOrOp has no side effect. Since it is already inserted, we can 478 // just remove the original StoreOp. Note that rewriter.replaceOp() 479 // doesn't work because it only accepts that the numbers of result are the 480 // same. 481 rewriter.eraseOp(storeOp); 482 483 assert(accessChainOp.use_empty()); 484 rewriter.eraseOp(accessChainOp); 485 486 return success(); 487 } 488 489 LogicalResult 490 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, 491 ArrayRef<Value> operands, 492 ConversionPatternRewriter &rewriter) const { 493 memref::StoreOpAdaptor storeOperands(operands); 494 auto memrefType = storeOp.memref().getType().cast<MemRefType>(); 495 if (memrefType.getElementType().isSignlessInteger()) 496 return failure(); 497 auto storePtr = 498 spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType, 499 storeOperands.memref(), storeOperands.indices(), 500 storeOp.getLoc(), rewriter); 501 502 if (!storePtr) 503 return failure(); 504 505 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, 506 storeOperands.value()); 507 return success(); 508 } 509 510 //===----------------------------------------------------------------------===// 511 // Pattern population 512 //===----------------------------------------------------------------------===// 513 514 namespace mlir { 515 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 516 RewritePatternSet &patterns) { 517 patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, 518 IntStoreOpPattern, LoadOpPattern, StoreOpPattern>( 519 typeConverter, patterns.getContext()); 520 } 521 } // namespace mlir 522