1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
15 
16 using namespace mlir;
17 
18 static Value createI32Constant(ConversionPatternRewriter &rewriter,
19                                Location loc, int32_t value) {
20   IntegerAttr valAttr = rewriter.getI32IntegerAttr(value);
21   Type llvmI32 = rewriter.getI32Type();
22   return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, valAttr);
23 }
24 
25 namespace {
26 /// Define lowering patterns for raw buffer ops
27 template <typename GpuOp, typename Intrinsic>
28 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
29   using ConvertOpToLLVMPattern<GpuOp>::ConvertOpToLLVMPattern;
30 
31   static constexpr uint32_t maxVectorOpWidth = 128;
32 
33   LogicalResult
34   matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
35                   ConversionPatternRewriter &rewriter) const override {
36     Location loc = gpuOp.getLoc();
37     Value memref = adaptor.getMemref();
38     Value unconvertedMemref = gpuOp.getMemref();
39     MemRefType memrefType = unconvertedMemref.getType().cast<MemRefType>();
40 
41     Value storeData = adaptor.getODSOperands(0)[0];
42     if (storeData == memref) // no write component to this op
43       storeData = Value();
44     Type wantedDataType;
45     if (storeData)
46       wantedDataType = storeData.getType();
47     else
48       wantedDataType = gpuOp.getODSResults(0)[0].getType();
49 
50     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
51 
52     Type i32 = rewriter.getI32Type();
53     Type llvmI32 = this->typeConverter->convertType(i32);
54 
55     int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
56     Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
57 
58     // If we want to load a vector<NxT> with total size <= 32
59     // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
60     // and the
61     Type llvmBufferValType = llvmWantedDataType;
62     if (auto dataVector = wantedDataType.dyn_cast<VectorType>()) {
63       uint32_t elemBits = dataVector.getElementTypeBitWidth();
64       uint32_t totalBits = elemBits * dataVector.getNumElements();
65       if (totalBits > maxVectorOpWidth)
66         return gpuOp.emitOpError(
67             "Total width of loads or stores must be no more than " +
68             Twine(maxVectorOpWidth) + " bits, but we call for " +
69             Twine(totalBits) +
70             " bits. This should've been caught in validation");
71       if (elemBits < 32) {
72         if (totalBits > 32) {
73           if (totalBits % 32 != 0)
74             return gpuOp.emitOpError("Load or store of more than 32-bits that "
75                                      "doesn't fit into words. Can't happen\n");
76           llvmBufferValType = this->typeConverter->convertType(
77               VectorType::get(totalBits / 32, i32));
78         } else {
79           llvmBufferValType = this->typeConverter->convertType(
80               rewriter.getIntegerType(totalBits));
81         }
82       }
83     }
84 
85     SmallVector<Value, 6> args;
86     if (storeData) {
87       if (llvmBufferValType != llvmWantedDataType) {
88         Value castForStore =
89             rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
90         args.push_back(castForStore);
91       } else {
92         args.push_back(storeData);
93       }
94     }
95 
96     // Construct buffer descriptor from memref, attributes
97     int64_t offset = 0;
98     SmallVector<int64_t, 5> strides;
99     if (failed(getStridesAndOffset(memrefType, strides, offset)))
100       return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
101 
102     // Resource descriptor
103     // bits 0-47: base address
104     // bits 48-61: stride (0 for raw buffers)
105     // bit 62: texture cache coherency (always 0)
106     // bit 63: enable swizzles (always off for raw buffers)
107     // bits 64-95 (word 2): Number of records, units of stride
108     // bits 96-127 (word 3): See below
109 
110     Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32));
111     MemRefDescriptor memrefDescriptor(memref);
112     Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
113     Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32));
114 
115     Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);
116 
117     Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
118     Value ptrAsInt = rewriter.create<LLVM::PtrToIntOp>(loc, llvmI64, ptr);
119     Value ptrAsInts =
120         rewriter.create<LLVM::BitcastOp>(loc, llvm2xI32, ptrAsInt);
121     for (int64_t i = 0; i < 2; ++i) {
122       Value idxConst = this->createIndexConstant(rewriter, loc, i);
123       Value part =
124           rewriter.create<LLVM::ExtractElementOp>(loc, ptrAsInts, idxConst);
125       resource = rewriter.create<LLVM::InsertElementOp>(
126           loc, llvm4xI32, resource, part, idxConst);
127     }
128 
129     Value numRecords;
130     if (memrefType.hasStaticShape()) {
131       numRecords = createI32Constant(
132           rewriter, loc,
133           static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
134     } else {
135       Value maxIndex;
136       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
137         Value size = memrefDescriptor.size(rewriter, loc, i);
138         Value stride = memrefDescriptor.stride(rewriter, loc, i);
139         stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
140         Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
141         maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
142                                                                maxThisDim)
143                             : maxThisDim;
144       }
145       numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
146     }
147     resource = rewriter.create<LLVM::InsertElementOp>(
148         loc, llvm4xI32, resource, numRecords,
149         this->createIndexConstant(rewriter, loc, 2));
150 
151     // Final word:
152     // bits 0-11: dst sel, ignored by these intrinsics
153     // bits 12-14: data format (ignored, must be nonzero, 7=float)
154     // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
155     // bit 19: In nested heap (0 here)
156     // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
157     // bits 21-22: Index stride for swizzles (N/A)
158     // bit 23: Add thread ID (0)
159     // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
160     // bits 25-26: Reserved (0)
161     // bit 27: Buffer is non-volatile (CDNA only)
162     // bits 28-29: Out of bounds select (0 = structured, 1 = raw, 2 = none, 3 =
163     // swizzles) RDNA only
164     // bits 30-31: Type (must be 0)
165     uint32_t word3 = (7 << 12) | (4 << 15);
166     if (adaptor.getTargetIsRDNA()) {
167       word3 |= (1 << 24);
168       uint32_t oob = adaptor.getBoundsCheck() ? 1 : 2;
169       word3 |= (oob << 28);
170     }
171     Value word3Const = createI32Constant(rewriter, loc, word3);
172     resource = rewriter.create<LLVM::InsertElementOp>(
173         loc, llvm4xI32, resource, word3Const,
174         this->createIndexConstant(rewriter, loc, 3));
175     args.push_back(resource);
176 
177     // Indexing (voffset)
178     Value voffset;
179     for (auto &pair : llvm::enumerate(adaptor.getIndices())) {
180       size_t i = pair.index();
181       Value index = pair.value();
182       Value strideOp;
183       if (ShapedType::isDynamicStrideOrOffset(strides[i])) {
184         strideOp = rewriter.create<LLVM::MulOp>(
185             loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
186       } else {
187         strideOp =
188             createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
189       }
190       index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
191       voffset =
192           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, index) : index;
193     }
194     if (adaptor.getIndexOffset()) {
195       int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
196       Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
197       voffset =
198           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
199                   : extraOffsetConst;
200     }
201     args.push_back(voffset);
202 
203     Value sgprOffset = adaptor.getSgprOffset();
204     if (!sgprOffset)
205       sgprOffset = createI32Constant(rewriter, loc, 0);
206     if (ShapedType::isDynamicStrideOrOffset(offset))
207       sgprOffset = rewriter.create<LLVM::AddOp>(
208           loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
209     else if (offset > 0)
210       sgprOffset = rewriter.create<LLVM::AddOp>(
211           loc, sgprOffset, createI32Constant(rewriter, loc, offset));
212     args.push_back(sgprOffset);
213 
214     // bit 0: GLC = 0 (atomics drop value, less coherency)
215     // bits 1-2: SLC, DLC = 0 (similarly)
216     // bit 3: swizzled (0 for raw)
217     args.push_back(createI32Constant(rewriter, loc, 0));
218 
219     llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
220                                            llvmBufferValType);
221     Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
222                                                     ArrayRef<NamedAttribute>());
223     if (lowered->getNumResults() == 1) {
224       Value replacement = lowered->getResult(0);
225       if (llvmBufferValType != llvmWantedDataType) {
226         replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
227                                                        replacement);
228       }
229       rewriter.replaceOp(gpuOp, replacement);
230     } else {
231       rewriter.eraseOp(gpuOp);
232     }
233     return success();
234   }
235 };
236 
237 struct ConvertAMDGPUToROCDLPass
238     : public ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
239   ConvertAMDGPUToROCDLPass() = default;
240 
241   void runOnOperation() override {
242     RewritePatternSet patterns(&getContext());
243     LLVMTypeConverter converter(&getContext());
244     populateAMDGPUToROCDLConversionPatterns(converter, patterns);
245     LLVMConversionTarget target(getContext());
246     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
247     target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
248     if (failed(applyPartialConversion(getOperation(), target,
249                                       std::move(patterns))))
250       signalPassFailure();
251   }
252 };
253 } // namespace
254 
255 void mlir::populateAMDGPUToROCDLConversionPatterns(
256     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
257   patterns.add<
258       RawBufferOpLowering<amdgpu::RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
259       RawBufferOpLowering<amdgpu::RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
260       RawBufferOpLowering<amdgpu::RawBufferAtomicFaddOp,
261                           ROCDL::RawBufferAtomicFAddOp>>(converter);
262 }
263 
264 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
265   return std::make_unique<ConvertAMDGPUToROCDLPass>();
266 }
267