1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
15 
16 using namespace mlir;
17 using namespace mlir::amdgpu;
18 
createI32Constant(ConversionPatternRewriter & rewriter,Location loc,int32_t value)19 static Value createI32Constant(ConversionPatternRewriter &rewriter,
20                                Location loc, int32_t value) {
21   IntegerAttr valAttr = rewriter.getI32IntegerAttr(value);
22   Type llvmI32 = rewriter.getI32Type();
23   return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, valAttr);
24 }
25 
26 namespace {
27 /// Define lowering patterns for raw buffer ops
28 template <typename GpuOp, typename Intrinsic>
29 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
RawBufferOpLowering__anon3714f0670111::RawBufferOpLowering30   RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset)
31       : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
32 
33   Chipset chipset;
34   static constexpr uint32_t maxVectorOpWidth = 128;
35 
36   LogicalResult
matchAndRewrite__anon3714f0670111::RawBufferOpLowering37   matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
38                   ConversionPatternRewriter &rewriter) const override {
39     Location loc = gpuOp.getLoc();
40     Value memref = adaptor.getMemref();
41     Value unconvertedMemref = gpuOp.getMemref();
42     MemRefType memrefType = unconvertedMemref.getType().cast<MemRefType>();
43 
44     if (chipset.majorVersion < 9)
45       return gpuOp.emitOpError("Raw buffer ops require GCN or higher");
46 
47     Value storeData = adaptor.getODSOperands(0)[0];
48     if (storeData == memref) // no write component to this op
49       storeData = Value();
50     Type wantedDataType;
51     if (storeData)
52       wantedDataType = storeData.getType();
53     else
54       wantedDataType = gpuOp.getODSResults(0)[0].getType();
55 
56     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
57 
58     Type i32 = rewriter.getI32Type();
59     Type llvmI32 = this->typeConverter->convertType(i32);
60 
61     int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
62     Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
63 
64     // If we want to load a vector<NxT> with total size <= 32
65     // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
66     // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
67     // 32) x i32 and bitcast.
68     Type llvmBufferValType = llvmWantedDataType;
69     if (auto dataVector = wantedDataType.dyn_cast<VectorType>()) {
70       uint32_t elemBits = dataVector.getElementTypeBitWidth();
71       uint32_t totalBits = elemBits * dataVector.getNumElements();
72       if (totalBits > maxVectorOpWidth)
73         return gpuOp.emitOpError(
74             "Total width of loads or stores must be no more than " +
75             Twine(maxVectorOpWidth) + " bits, but we call for " +
76             Twine(totalBits) +
77             " bits. This should've been caught in validation");
78       if (elemBits < 32) {
79         if (totalBits > 32) {
80           if (totalBits % 32 != 0)
81             return gpuOp.emitOpError("Load or store of more than 32-bits that "
82                                      "doesn't fit into words. Can't happen\n");
83           llvmBufferValType = this->typeConverter->convertType(
84               VectorType::get(totalBits / 32, i32));
85         } else {
86           llvmBufferValType = this->typeConverter->convertType(
87               rewriter.getIntegerType(totalBits));
88         }
89       }
90     }
91 
92     SmallVector<Value, 6> args;
93     if (storeData) {
94       if (llvmBufferValType != llvmWantedDataType) {
95         Value castForStore =
96             rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
97         args.push_back(castForStore);
98       } else {
99         args.push_back(storeData);
100       }
101     }
102 
103     // Construct buffer descriptor from memref, attributes
104     int64_t offset = 0;
105     SmallVector<int64_t, 5> strides;
106     if (failed(getStridesAndOffset(memrefType, strides, offset)))
107       return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
108 
109     // Resource descriptor
110     // bits 0-47: base address
111     // bits 48-61: stride (0 for raw buffers)
112     // bit 62: texture cache coherency (always 0)
113     // bit 63: enable swizzles (always off for raw buffers)
114     // bits 64-95 (word 2): Number of records, units of stride
115     // bits 96-127 (word 3): See below
116 
117     Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32));
118     MemRefDescriptor memrefDescriptor(memref);
119     Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
120     Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32));
121 
122     Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);
123 
124     Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
125     Value ptrAsInt = rewriter.create<LLVM::PtrToIntOp>(loc, llvmI64, ptr);
126     Value ptrAsInts =
127         rewriter.create<LLVM::BitcastOp>(loc, llvm2xI32, ptrAsInt);
128     for (int64_t i = 0; i < 2; ++i) {
129       Value idxConst = this->createIndexConstant(rewriter, loc, i);
130       Value part =
131           rewriter.create<LLVM::ExtractElementOp>(loc, ptrAsInts, idxConst);
132       resource = rewriter.create<LLVM::InsertElementOp>(
133           loc, llvm4xI32, resource, part, idxConst);
134     }
135 
136     Value numRecords;
137     if (memrefType.hasStaticShape()) {
138       numRecords = createI32Constant(
139           rewriter, loc,
140           static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
141     } else {
142       Value maxIndex;
143       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
144         Value size = memrefDescriptor.size(rewriter, loc, i);
145         Value stride = memrefDescriptor.stride(rewriter, loc, i);
146         stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
147         Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
148         maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
149                                                                maxThisDim)
150                             : maxThisDim;
151       }
152       numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
153     }
154     resource = rewriter.create<LLVM::InsertElementOp>(
155         loc, llvm4xI32, resource, numRecords,
156         this->createIndexConstant(rewriter, loc, 2));
157 
158     // Final word:
159     // bits 0-11: dst sel, ignored by these intrinsics
160     // bits 12-14: data format (ignored, must be nonzero, 7=float)
161     // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
162     // bit 19: In nested heap (0 here)
163     // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
164     // bits 21-22: Index stride for swizzles (N/A)
165     // bit 23: Add thread ID (0)
166     // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
167     // bits 25-26: Reserved (0)
168     // bit 27: Buffer is non-volatile (CDNA only)
169     // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
170     //  none, 3 = either swizzles or testing against offset field) RDNA only
171     // bits 30-31: Type (must be 0)
172     uint32_t word3 = (7 << 12) | (4 << 15);
173     if (chipset.majorVersion == 10) {
174       word3 |= (1 << 24);
175       uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
176       word3 |= (oob << 28);
177     }
178     Value word3Const = createI32Constant(rewriter, loc, word3);
179     resource = rewriter.create<LLVM::InsertElementOp>(
180         loc, llvm4xI32, resource, word3Const,
181         this->createIndexConstant(rewriter, loc, 3));
182     args.push_back(resource);
183 
184     // Indexing (voffset)
185     Value voffset;
186     for (auto &pair : llvm::enumerate(adaptor.getIndices())) {
187       size_t i = pair.index();
188       Value index = pair.value();
189       Value strideOp;
190       if (ShapedType::isDynamicStrideOrOffset(strides[i])) {
191         strideOp = rewriter.create<LLVM::MulOp>(
192             loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
193       } else {
194         strideOp =
195             createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
196       }
197       index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
198       voffset =
199           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, index) : index;
200     }
201     if (adaptor.getIndexOffset()) {
202       int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
203       Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
204       voffset =
205           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
206                   : extraOffsetConst;
207     }
208     args.push_back(voffset);
209 
210     Value sgprOffset = adaptor.getSgprOffset();
211     if (!sgprOffset)
212       sgprOffset = createI32Constant(rewriter, loc, 0);
213     if (ShapedType::isDynamicStrideOrOffset(offset))
214       sgprOffset = rewriter.create<LLVM::AddOp>(
215           loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
216     else if (offset > 0)
217       sgprOffset = rewriter.create<LLVM::AddOp>(
218           loc, sgprOffset, createI32Constant(rewriter, loc, offset));
219     args.push_back(sgprOffset);
220 
221     // bit 0: GLC = 0 (atomics drop value, less coherency)
222     // bits 1-2: SLC, DLC = 0 (similarly)
223     // bit 3: swizzled (0 for raw)
224     args.push_back(createI32Constant(rewriter, loc, 0));
225 
226     llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
227                                            llvmBufferValType);
228     Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
229                                                     ArrayRef<NamedAttribute>());
230     if (lowered->getNumResults() == 1) {
231       Value replacement = lowered->getResult(0);
232       if (llvmBufferValType != llvmWantedDataType) {
233         replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
234                                                        replacement);
235       }
236       rewriter.replaceOp(gpuOp, replacement);
237     } else {
238       rewriter.eraseOp(gpuOp);
239     }
240     return success();
241   }
242 };
243 
244 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
245   using ConvertOpToLLVMPattern<LDSBarrierOp>::ConvertOpToLLVMPattern;
246 
247   LogicalResult
matchAndRewrite__anon3714f0670111::LDSBarrierOpLowering248   matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
249                   ConversionPatternRewriter &rewriter) const override {
250     auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
251                                                     LLVM::AsmDialect::AD_ATT);
252     const char *asmStr = "s_waitcnt lgkmcnt(0)\ns_barrier";
253     const char *constraints = "";
254     rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
255         op,
256         /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
257         /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
258         /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
259         /*operand_attrs=*/ArrayAttr());
260     return success();
261   }
262 };
263 
264 struct ConvertAMDGPUToROCDLPass
265     : public ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
266   ConvertAMDGPUToROCDLPass() = default;
267 
runOnOperation__anon3714f0670111::ConvertAMDGPUToROCDLPass268   void runOnOperation() override {
269     MLIRContext *ctx = &getContext();
270     FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
271     if (failed(maybeChipset)) {
272       emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
273       return signalPassFailure();
274     }
275 
276     RewritePatternSet patterns(ctx);
277     LLVMTypeConverter converter(ctx);
278     populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
279     LLVMConversionTarget target(getContext());
280     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
281     target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
282     if (failed(applyPartialConversion(getOperation(), target,
283                                       std::move(patterns))))
284       signalPassFailure();
285   }
286 };
287 } // namespace
288 
populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns,Chipset chipset)289 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
290                                                    RewritePatternSet &patterns,
291                                                    Chipset chipset) {
292   patterns.add<LDSBarrierOpLowering>(converter);
293   patterns.add<
294       RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
295       RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
296       RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>>(
297       converter, chipset);
298 }
299 
createConvertAMDGPUToROCDLPass()300 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
301   return std::make_unique<ConvertAMDGPUToROCDLPass>();
302 }
303