1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" 14 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 15 16 using namespace mlir; 17 using namespace mlir::amdgpu; 18 19 static Value createI32Constant(ConversionPatternRewriter &rewriter, 20 Location loc, int32_t value) { 21 IntegerAttr valAttr = rewriter.getI32IntegerAttr(value); 22 Type llvmI32 = rewriter.getI32Type(); 23 return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, valAttr); 24 } 25 26 namespace { 27 /// Define lowering patterns for raw buffer ops 28 template <typename GpuOp, typename Intrinsic> 29 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { 30 RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset) 31 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {} 32 33 Chipset chipset; 34 static constexpr uint32_t maxVectorOpWidth = 128; 35 36 LogicalResult 37 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, 38 ConversionPatternRewriter &rewriter) const override { 39 Location loc = gpuOp.getLoc(); 40 Value memref = adaptor.getMemref(); 41 Value unconvertedMemref = gpuOp.getMemref(); 42 MemRefType memrefType = unconvertedMemref.getType().cast<MemRefType>(); 43 44 if (chipset.majorVersion < 9) 45 return gpuOp.emitOpError("Raw buffer ops require GCN or higher"); 46 47 Value storeData = adaptor.getODSOperands(0)[0]; 48 if (storeData == memref) // no write component to this op 49 storeData = Value(); 50 Type wantedDataType; 51 if (storeData) 52 wantedDataType = storeData.getType(); 53 else 54 wantedDataType = gpuOp.getODSResults(0)[0].getType(); 55 56 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); 57 58 Type i32 = rewriter.getI32Type(); 59 Type llvmI32 = this->typeConverter->convertType(i32); 60 61 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; 62 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); 63 64 // If we want to load a vector<NxT> with total size <= 32 65 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 66 // and the total load size is >= 32, use a vector load of N / (bitsize(T) / 67 // 32) x i32 and bitcast. 68 Type llvmBufferValType = llvmWantedDataType; 69 if (auto dataVector = wantedDataType.dyn_cast<VectorType>()) { 70 uint32_t elemBits = dataVector.getElementTypeBitWidth(); 71 uint32_t totalBits = elemBits * dataVector.getNumElements(); 72 if (totalBits > maxVectorOpWidth) 73 return gpuOp.emitOpError( 74 "Total width of loads or stores must be no more than " + 75 Twine(maxVectorOpWidth) + " bits, but we call for " + 76 Twine(totalBits) + 77 " bits. This should've been caught in validation"); 78 if (elemBits < 32) { 79 if (totalBits > 32) { 80 if (totalBits % 32 != 0) 81 return gpuOp.emitOpError("Load or store of more than 32-bits that " 82 "doesn't fit into words. Can't happen\n"); 83 llvmBufferValType = this->typeConverter->convertType( 84 VectorType::get(totalBits / 32, i32)); 85 } else { 86 llvmBufferValType = this->typeConverter->convertType( 87 rewriter.getIntegerType(totalBits)); 88 } 89 } 90 } 91 92 SmallVector<Value, 6> args; 93 if (storeData) { 94 if (llvmBufferValType != llvmWantedDataType) { 95 Value castForStore = 96 rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); 97 args.push_back(castForStore); 98 } else { 99 args.push_back(storeData); 100 } 101 } 102 103 // Construct buffer descriptor from memref, attributes 104 int64_t offset = 0; 105 SmallVector<int64_t, 5> strides; 106 if (failed(getStridesAndOffset(memrefType, strides, offset))) 107 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); 108 109 // Resource descriptor 110 // bits 0-47: base address 111 // bits 48-61: stride (0 for raw buffers) 112 // bit 62: texture cache coherency (always 0) 113 // bit 63: enable swizzles (always off for raw buffers) 114 // bits 64-95 (word 2): Number of records, units of stride 115 // bits 96-127 (word 3): See below 116 117 Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32)); 118 MemRefDescriptor memrefDescriptor(memref); 119 Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type()); 120 Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32)); 121 122 Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32); 123 124 Value ptr = memrefDescriptor.alignedPtr(rewriter, loc); 125 Value ptrAsInt = rewriter.create<LLVM::PtrToIntOp>(loc, llvmI64, ptr); 126 Value ptrAsInts = 127 rewriter.create<LLVM::BitcastOp>(loc, llvm2xI32, ptrAsInt); 128 for (int64_t i = 0; i < 2; ++i) { 129 Value idxConst = this->createIndexConstant(rewriter, loc, i); 130 Value part = 131 rewriter.create<LLVM::ExtractElementOp>(loc, ptrAsInts, idxConst); 132 resource = rewriter.create<LLVM::InsertElementOp>( 133 loc, llvm4xI32, resource, part, idxConst); 134 } 135 136 Value numRecords; 137 if (memrefType.hasStaticShape()) { 138 numRecords = createI32Constant( 139 rewriter, loc, 140 static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth)); 141 } else { 142 Value maxIndex; 143 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { 144 Value size = memrefDescriptor.size(rewriter, loc, i); 145 Value stride = memrefDescriptor.stride(rewriter, loc, i); 146 stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst); 147 Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); 148 maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex, 149 maxThisDim) 150 : maxThisDim; 151 } 152 numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex); 153 } 154 resource = rewriter.create<LLVM::InsertElementOp>( 155 loc, llvm4xI32, resource, numRecords, 156 this->createIndexConstant(rewriter, loc, 2)); 157 158 // Final word: 159 // bits 0-11: dst sel, ignored by these intrinsics 160 // bits 12-14: data format (ignored, must be nonzero, 7=float) 161 // bits 15-18: data format (ignored, must be nonzero, 4=32bit) 162 // bit 19: In nested heap (0 here) 163 // bit 20: Behavior on unmap (0 means "return 0 / ignore") 164 // bits 21-22: Index stride for swizzles (N/A) 165 // bit 23: Add thread ID (0) 166 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) 167 // bits 25-26: Reserved (0) 168 // bit 27: Buffer is non-volatile (CDNA only) 169 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = 170 // none, 3 = either swizzles or testing against offset field) RDNA only 171 // bits 30-31: Type (must be 0) 172 uint32_t word3 = (7 << 12) | (4 << 15); 173 if (chipset.majorVersion == 10) { 174 word3 |= (1 << 24); 175 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2; 176 word3 |= (oob << 28); 177 } 178 Value word3Const = createI32Constant(rewriter, loc, word3); 179 resource = rewriter.create<LLVM::InsertElementOp>( 180 loc, llvm4xI32, resource, word3Const, 181 this->createIndexConstant(rewriter, loc, 3)); 182 args.push_back(resource); 183 184 // Indexing (voffset) 185 Value voffset; 186 for (auto &pair : llvm::enumerate(adaptor.getIndices())) { 187 size_t i = pair.index(); 188 Value index = pair.value(); 189 Value strideOp; 190 if (ShapedType::isDynamicStrideOrOffset(strides[i])) { 191 strideOp = rewriter.create<LLVM::MulOp>( 192 loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); 193 } else { 194 strideOp = 195 createI32Constant(rewriter, loc, strides[i] * elementByteWidth); 196 } 197 index = rewriter.create<LLVM::MulOp>(loc, index, strideOp); 198 voffset = 199 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, index) : index; 200 } 201 if (adaptor.getIndexOffset()) { 202 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth; 203 Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset); 204 voffset = 205 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) 206 : extraOffsetConst; 207 } 208 args.push_back(voffset); 209 210 Value sgprOffset = adaptor.getSgprOffset(); 211 if (!sgprOffset) 212 sgprOffset = createI32Constant(rewriter, loc, 0); 213 if (ShapedType::isDynamicStrideOrOffset(offset)) 214 sgprOffset = rewriter.create<LLVM::AddOp>( 215 loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); 216 else if (offset > 0) 217 sgprOffset = rewriter.create<LLVM::AddOp>( 218 loc, sgprOffset, createI32Constant(rewriter, loc, offset)); 219 args.push_back(sgprOffset); 220 221 // bit 0: GLC = 0 (atomics drop value, less coherency) 222 // bits 1-2: SLC, DLC = 0 (similarly) 223 // bit 3: swizzled (0 for raw) 224 args.push_back(createI32Constant(rewriter, loc, 0)); 225 226 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), 227 llvmBufferValType); 228 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, 229 ArrayRef<NamedAttribute>()); 230 if (lowered->getNumResults() == 1) { 231 Value replacement = lowered->getResult(0); 232 if (llvmBufferValType != llvmWantedDataType) { 233 replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, 234 replacement); 235 } 236 rewriter.replaceOp(gpuOp, replacement); 237 } else { 238 rewriter.eraseOp(gpuOp); 239 } 240 return success(); 241 } 242 }; 243 244 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { 245 using ConvertOpToLLVMPattern<LDSBarrierOp>::ConvertOpToLLVMPattern; 246 247 LogicalResult 248 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, 249 ConversionPatternRewriter &rewriter) const override { 250 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), 251 LLVM::AsmDialect::AD_ATT); 252 const char *asmStr = "s_waitcnt lgkmcnt(0)\ns_barrier"; 253 const char *constraints = ""; 254 rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>( 255 op, 256 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), 257 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, 258 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, 259 /*operand_attrs=*/ArrayAttr()); 260 return success(); 261 } 262 }; 263 264 struct ConvertAMDGPUToROCDLPass 265 : public ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> { 266 ConvertAMDGPUToROCDLPass() = default; 267 268 void runOnOperation() override { 269 MLIRContext *ctx = &getContext(); 270 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset); 271 if (failed(maybeChipset)) { 272 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); 273 return signalPassFailure(); 274 } 275 276 RewritePatternSet patterns(ctx); 277 LLVMTypeConverter converter(ctx); 278 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); 279 LLVMConversionTarget target(getContext()); 280 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 281 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); 282 if (failed(applyPartialConversion(getOperation(), target, 283 std::move(patterns)))) 284 signalPassFailure(); 285 } 286 }; 287 } // namespace 288 289 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, 290 RewritePatternSet &patterns, 291 Chipset chipset) { 292 patterns.add<LDSBarrierOpLowering>(converter); 293 patterns.add< 294 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>, 295 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>, 296 RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>>( 297 converter, chipset); 298 } 299 300 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() { 301 return std::make_unique<ConvertAMDGPUToROCDLPass>(); 302 } 303