1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" 14 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 15 16 using namespace mlir; 17 18 static Value createI32Constant(ConversionPatternRewriter &rewriter, 19 Location loc, int32_t value) { 20 IntegerAttr valAttr = rewriter.getI32IntegerAttr(value); 21 Type llvmI32 = rewriter.getI32Type(); 22 return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, valAttr); 23 } 24 25 namespace { 26 /// Define lowering patterns for raw buffer ops 27 template <typename GpuOp, typename Intrinsic> 28 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { 29 using ConvertOpToLLVMPattern<GpuOp>::ConvertOpToLLVMPattern; 30 31 static constexpr uint32_t maxVectorOpWidth = 128; 32 33 LogicalResult 34 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override { 36 Location loc = gpuOp.getLoc(); 37 Value memref = adaptor.getMemref(); 38 Value unconvertedMemref = gpuOp.getMemref(); 39 MemRefType memrefType = unconvertedMemref.getType().cast<MemRefType>(); 40 41 Value storeData = adaptor.getODSOperands(0)[0]; 42 if (storeData == memref) // no write component to this op 43 storeData = Value(); 44 Type wantedDataType; 45 if (storeData) 46 wantedDataType = storeData.getType(); 47 else 48 wantedDataType = gpuOp.getODSResults(0)[0].getType(); 49 50 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); 51 52 Type i32 = rewriter.getI32Type(); 53 Type llvmI32 = this->typeConverter->convertType(i32); 54 55 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; 56 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); 57 58 // If we want to load a vector<NxT> with total size <= 32 59 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 60 // and the 61 Type llvmBufferValType = llvmWantedDataType; 62 if (auto dataVector = wantedDataType.dyn_cast<VectorType>()) { 63 uint32_t elemBits = dataVector.getElementTypeBitWidth(); 64 uint32_t totalBits = elemBits * dataVector.getNumElements(); 65 if (totalBits > maxVectorOpWidth) 66 return gpuOp.emitOpError( 67 "Total width of loads or stores must be no more than " + 68 Twine(maxVectorOpWidth) + " bits, but we call for " + 69 Twine(totalBits) + 70 " bits. This should've been caught in validation"); 71 if (elemBits < 32) { 72 if (totalBits > 32) { 73 if (totalBits % 32 != 0) 74 return gpuOp.emitOpError("Load or store of more than 32-bits that " 75 "doesn't fit into words. Can't happen\n"); 76 llvmBufferValType = this->typeConverter->convertType( 77 VectorType::get(totalBits / 32, i32)); 78 } else { 79 llvmBufferValType = this->typeConverter->convertType( 80 rewriter.getIntegerType(totalBits)); 81 } 82 } 83 } 84 85 SmallVector<Value, 6> args; 86 if (storeData) { 87 if (llvmBufferValType != llvmWantedDataType) { 88 Value castForStore = 89 rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); 90 args.push_back(castForStore); 91 } else { 92 args.push_back(storeData); 93 } 94 } 95 96 // Construct buffer descriptor from memref, attributes 97 int64_t offset = 0; 98 SmallVector<int64_t, 5> strides; 99 if (failed(getStridesAndOffset(memrefType, strides, offset))) 100 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); 101 102 // Resource descriptor 103 // bits 0-47: base address 104 // bits 48-61: stride (0 for raw buffers) 105 // bit 62: texture cache coherency (always 0) 106 // bit 63: enable swizzles (always off for raw buffers) 107 // bits 64-95 (word 2): Number of records, units of stride 108 // bits 96-127 (word 3): See below 109 110 Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32)); 111 MemRefDescriptor memrefDescriptor(memref); 112 Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type()); 113 Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32)); 114 115 Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32); 116 117 Value ptr = memrefDescriptor.alignedPtr(rewriter, loc); 118 Value ptrAsInt = rewriter.create<LLVM::PtrToIntOp>(loc, llvmI64, ptr); 119 Value ptrAsInts = 120 rewriter.create<LLVM::BitcastOp>(loc, llvm2xI32, ptrAsInt); 121 for (int64_t i = 0; i < 2; ++i) { 122 Value idxConst = this->createIndexConstant(rewriter, loc, i); 123 Value part = 124 rewriter.create<LLVM::ExtractElementOp>(loc, ptrAsInts, idxConst); 125 resource = rewriter.create<LLVM::InsertElementOp>( 126 loc, llvm4xI32, resource, part, idxConst); 127 } 128 129 Value numRecords; 130 if (memrefType.hasStaticShape()) { 131 numRecords = createI32Constant( 132 rewriter, loc, 133 static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth)); 134 } else { 135 Value maxIndex; 136 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { 137 Value size = memrefDescriptor.size(rewriter, loc, i); 138 Value stride = memrefDescriptor.stride(rewriter, loc, i); 139 stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst); 140 Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); 141 maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex, 142 maxThisDim) 143 : maxThisDim; 144 } 145 numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex); 146 } 147 resource = rewriter.create<LLVM::InsertElementOp>( 148 loc, llvm4xI32, resource, numRecords, 149 this->createIndexConstant(rewriter, loc, 2)); 150 151 // Final word: 152 // bits 0-11: dst sel, ignored by these intrinsics 153 // bits 12-14: data format (ignored, must be nonzero, 7=float) 154 // bits 15-18: data format (ignored, must be nonzero, 4=32bit) 155 // bit 19: In nested heap (0 here) 156 // bit 20: Behavior on unmap (0 means "return 0 / ignore") 157 // bits 21-22: Index stride for swizzles (N/A) 158 // bit 23: Add thread ID (0) 159 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) 160 // bits 25-26: Reserved (0) 161 // bit 27: Buffer is non-volatile (CDNA only) 162 // bits 28-29: Out of bounds select (0 = structured, 1 = raw, 2 = none, 3 = 163 // swizzles) RDNA only 164 // bits 30-31: Type (must be 0) 165 uint32_t word3 = (7 << 12) | (4 << 15); 166 if (adaptor.getTargetIsRDNA()) { 167 word3 |= (1 << 24); 168 uint32_t oob = adaptor.getBoundsCheck() ? 1 : 2; 169 word3 |= (oob << 28); 170 } 171 Value word3Const = createI32Constant(rewriter, loc, word3); 172 resource = rewriter.create<LLVM::InsertElementOp>( 173 loc, llvm4xI32, resource, word3Const, 174 this->createIndexConstant(rewriter, loc, 3)); 175 args.push_back(resource); 176 177 // Indexing (voffset) 178 Value voffset; 179 for (auto &pair : llvm::enumerate(adaptor.getIndices())) { 180 size_t i = pair.index(); 181 Value index = pair.value(); 182 Value strideOp; 183 if (ShapedType::isDynamicStrideOrOffset(strides[i])) { 184 strideOp = rewriter.create<LLVM::MulOp>( 185 loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); 186 } else { 187 strideOp = 188 createI32Constant(rewriter, loc, strides[i] * elementByteWidth); 189 } 190 index = rewriter.create<LLVM::MulOp>(loc, index, strideOp); 191 voffset = 192 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, index) : index; 193 } 194 if (adaptor.getIndexOffset()) { 195 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth; 196 Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset); 197 voffset = 198 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) 199 : extraOffsetConst; 200 } 201 args.push_back(voffset); 202 203 Value sgprOffset = adaptor.getSgprOffset(); 204 if (!sgprOffset) 205 sgprOffset = createI32Constant(rewriter, loc, 0); 206 if (ShapedType::isDynamicStrideOrOffset(offset)) 207 sgprOffset = rewriter.create<LLVM::AddOp>( 208 loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); 209 else if (offset > 0) 210 sgprOffset = rewriter.create<LLVM::AddOp>( 211 loc, sgprOffset, createI32Constant(rewriter, loc, offset)); 212 args.push_back(sgprOffset); 213 214 // bit 0: GLC = 0 (atomics drop value, less coherency) 215 // bits 1-2: SLC, DLC = 0 (similarly) 216 // bit 3: swizzled (0 for raw) 217 args.push_back(createI32Constant(rewriter, loc, 0)); 218 219 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), 220 llvmBufferValType); 221 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, 222 ArrayRef<NamedAttribute>()); 223 if (lowered->getNumResults() == 1) { 224 Value replacement = lowered->getResult(0); 225 if (llvmBufferValType != llvmWantedDataType) { 226 replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, 227 replacement); 228 } 229 rewriter.replaceOp(gpuOp, replacement); 230 } else { 231 rewriter.eraseOp(gpuOp); 232 } 233 return success(); 234 } 235 }; 236 237 struct ConvertAMDGPUToROCDLPass 238 : public ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> { 239 ConvertAMDGPUToROCDLPass() = default; 240 241 void runOnOperation() override { 242 RewritePatternSet patterns(&getContext()); 243 LLVMTypeConverter converter(&getContext()); 244 populateAMDGPUToROCDLConversionPatterns(converter, patterns); 245 LLVMConversionTarget target(getContext()); 246 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 247 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); 248 if (failed(applyPartialConversion(getOperation(), target, 249 std::move(patterns)))) 250 signalPassFailure(); 251 } 252 }; 253 } // namespace 254 255 void mlir::populateAMDGPUToROCDLConversionPatterns( 256 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 257 patterns.add< 258 RawBufferOpLowering<amdgpu::RawBufferLoadOp, ROCDL::RawBufferLoadOp>, 259 RawBufferOpLowering<amdgpu::RawBufferStoreOp, ROCDL::RawBufferStoreOp>, 260 RawBufferOpLowering<amdgpu::RawBufferAtomicFaddOp, 261 ROCDL::RawBufferAtomicFAddOp>>(converter); 262 } 263 264 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() { 265 return std::make_unique<ConvertAMDGPUToROCDLPass>(); 266 } 267