1f1f05a91SKrzysztof Drewniak //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2f1f05a91SKrzysztof Drewniak //
3f1f05a91SKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f1f05a91SKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
5f1f05a91SKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f1f05a91SKrzysztof Drewniak //
7f1f05a91SKrzysztof Drewniak //===----------------------------------------------------------------------===//
8f1f05a91SKrzysztof Drewniak 
9f1f05a91SKrzysztof Drewniak #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
10f1f05a91SKrzysztof Drewniak #include "../PassDetail.h"
11f1f05a91SKrzysztof Drewniak #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12f1f05a91SKrzysztof Drewniak #include "mlir/Conversion/LLVMCommon/Pattern.h"
13f1f05a91SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
14f1f05a91SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
15f1f05a91SKrzysztof Drewniak 
16f1f05a91SKrzysztof Drewniak using namespace mlir;
17cab44c51SKrzysztof Drewniak using namespace mlir::amdgpu;
18f1f05a91SKrzysztof Drewniak 
createI32Constant(ConversionPatternRewriter & rewriter,Location loc,int32_t value)19f1f05a91SKrzysztof Drewniak static Value createI32Constant(ConversionPatternRewriter &rewriter,
20f1f05a91SKrzysztof Drewniak                                Location loc, int32_t value) {
21f1f05a91SKrzysztof Drewniak   IntegerAttr valAttr = rewriter.getI32IntegerAttr(value);
22f1f05a91SKrzysztof Drewniak   Type llvmI32 = rewriter.getI32Type();
23f1f05a91SKrzysztof Drewniak   return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, valAttr);
24f1f05a91SKrzysztof Drewniak }
25f1f05a91SKrzysztof Drewniak 
26f1f05a91SKrzysztof Drewniak namespace {
27f1f05a91SKrzysztof Drewniak /// Define lowering patterns for raw buffer ops
28f1f05a91SKrzysztof Drewniak template <typename GpuOp, typename Intrinsic>
29f1f05a91SKrzysztof Drewniak struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
RawBufferOpLowering__anon3714f0670111::RawBufferOpLowering30cab44c51SKrzysztof Drewniak   RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset)
31cab44c51SKrzysztof Drewniak       : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
32f1f05a91SKrzysztof Drewniak 
33cab44c51SKrzysztof Drewniak   Chipset chipset;
34f1f05a91SKrzysztof Drewniak   static constexpr uint32_t maxVectorOpWidth = 128;
35f1f05a91SKrzysztof Drewniak 
36f1f05a91SKrzysztof Drewniak   LogicalResult
matchAndRewrite__anon3714f0670111::RawBufferOpLowering37f1f05a91SKrzysztof Drewniak   matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
38f1f05a91SKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const override {
39f1f05a91SKrzysztof Drewniak     Location loc = gpuOp.getLoc();
408df54a6aSJacques Pienaar     Value memref = adaptor.getMemref();
418df54a6aSJacques Pienaar     Value unconvertedMemref = gpuOp.getMemref();
42f1f05a91SKrzysztof Drewniak     MemRefType memrefType = unconvertedMemref.getType().cast<MemRefType>();
43f1f05a91SKrzysztof Drewniak 
44cab44c51SKrzysztof Drewniak     if (chipset.majorVersion < 9)
45cab44c51SKrzysztof Drewniak       return gpuOp.emitOpError("Raw buffer ops require GCN or higher");
46cab44c51SKrzysztof Drewniak 
47f1f05a91SKrzysztof Drewniak     Value storeData = adaptor.getODSOperands(0)[0];
48f1f05a91SKrzysztof Drewniak     if (storeData == memref) // no write component to this op
49f1f05a91SKrzysztof Drewniak       storeData = Value();
50f1f05a91SKrzysztof Drewniak     Type wantedDataType;
51f1f05a91SKrzysztof Drewniak     if (storeData)
52f1f05a91SKrzysztof Drewniak       wantedDataType = storeData.getType();
53f1f05a91SKrzysztof Drewniak     else
54f1f05a91SKrzysztof Drewniak       wantedDataType = gpuOp.getODSResults(0)[0].getType();
55f1f05a91SKrzysztof Drewniak 
56f1f05a91SKrzysztof Drewniak     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
57f1f05a91SKrzysztof Drewniak 
58f1f05a91SKrzysztof Drewniak     Type i32 = rewriter.getI32Type();
59f1f05a91SKrzysztof Drewniak     Type llvmI32 = this->typeConverter->convertType(i32);
60f1f05a91SKrzysztof Drewniak 
61f1f05a91SKrzysztof Drewniak     int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
62f1f05a91SKrzysztof Drewniak     Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
63f1f05a91SKrzysztof Drewniak 
64f1f05a91SKrzysztof Drewniak     // If we want to load a vector<NxT> with total size <= 32
65f1f05a91SKrzysztof Drewniak     // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
66cab44c51SKrzysztof Drewniak     // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
67cab44c51SKrzysztof Drewniak     // 32) x i32 and bitcast.
68f1f05a91SKrzysztof Drewniak     Type llvmBufferValType = llvmWantedDataType;
69f1f05a91SKrzysztof Drewniak     if (auto dataVector = wantedDataType.dyn_cast<VectorType>()) {
70f1f05a91SKrzysztof Drewniak       uint32_t elemBits = dataVector.getElementTypeBitWidth();
71f1f05a91SKrzysztof Drewniak       uint32_t totalBits = elemBits * dataVector.getNumElements();
72f1f05a91SKrzysztof Drewniak       if (totalBits > maxVectorOpWidth)
73f1f05a91SKrzysztof Drewniak         return gpuOp.emitOpError(
74f1f05a91SKrzysztof Drewniak             "Total width of loads or stores must be no more than " +
75f1f05a91SKrzysztof Drewniak             Twine(maxVectorOpWidth) + " bits, but we call for " +
76f1f05a91SKrzysztof Drewniak             Twine(totalBits) +
77f1f05a91SKrzysztof Drewniak             " bits. This should've been caught in validation");
78f1f05a91SKrzysztof Drewniak       if (elemBits < 32) {
79f1f05a91SKrzysztof Drewniak         if (totalBits > 32) {
80f1f05a91SKrzysztof Drewniak           if (totalBits % 32 != 0)
81f1f05a91SKrzysztof Drewniak             return gpuOp.emitOpError("Load or store of more than 32-bits that "
82f1f05a91SKrzysztof Drewniak                                      "doesn't fit into words. Can't happen\n");
83f1f05a91SKrzysztof Drewniak           llvmBufferValType = this->typeConverter->convertType(
84f1f05a91SKrzysztof Drewniak               VectorType::get(totalBits / 32, i32));
85f1f05a91SKrzysztof Drewniak         } else {
86f1f05a91SKrzysztof Drewniak           llvmBufferValType = this->typeConverter->convertType(
87f1f05a91SKrzysztof Drewniak               rewriter.getIntegerType(totalBits));
88f1f05a91SKrzysztof Drewniak         }
89f1f05a91SKrzysztof Drewniak       }
90f1f05a91SKrzysztof Drewniak     }
91f1f05a91SKrzysztof Drewniak 
92f1f05a91SKrzysztof Drewniak     SmallVector<Value, 6> args;
93f1f05a91SKrzysztof Drewniak     if (storeData) {
94f1f05a91SKrzysztof Drewniak       if (llvmBufferValType != llvmWantedDataType) {
95f1f05a91SKrzysztof Drewniak         Value castForStore =
96f1f05a91SKrzysztof Drewniak             rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
97f1f05a91SKrzysztof Drewniak         args.push_back(castForStore);
98f1f05a91SKrzysztof Drewniak       } else {
99f1f05a91SKrzysztof Drewniak         args.push_back(storeData);
100f1f05a91SKrzysztof Drewniak       }
101f1f05a91SKrzysztof Drewniak     }
102f1f05a91SKrzysztof Drewniak 
103f1f05a91SKrzysztof Drewniak     // Construct buffer descriptor from memref, attributes
104f1f05a91SKrzysztof Drewniak     int64_t offset = 0;
105f1f05a91SKrzysztof Drewniak     SmallVector<int64_t, 5> strides;
106f1f05a91SKrzysztof Drewniak     if (failed(getStridesAndOffset(memrefType, strides, offset)))
107f1f05a91SKrzysztof Drewniak       return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
108f1f05a91SKrzysztof Drewniak 
109f1f05a91SKrzysztof Drewniak     // Resource descriptor
110f1f05a91SKrzysztof Drewniak     // bits 0-47: base address
111f1f05a91SKrzysztof Drewniak     // bits 48-61: stride (0 for raw buffers)
112f1f05a91SKrzysztof Drewniak     // bit 62: texture cache coherency (always 0)
113f1f05a91SKrzysztof Drewniak     // bit 63: enable swizzles (always off for raw buffers)
114f1f05a91SKrzysztof Drewniak     // bits 64-95 (word 2): Number of records, units of stride
115f1f05a91SKrzysztof Drewniak     // bits 96-127 (word 3): See below
116f1f05a91SKrzysztof Drewniak 
117f1f05a91SKrzysztof Drewniak     Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32));
118f1f05a91SKrzysztof Drewniak     MemRefDescriptor memrefDescriptor(memref);
119f1f05a91SKrzysztof Drewniak     Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
120f1f05a91SKrzysztof Drewniak     Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32));
121f1f05a91SKrzysztof Drewniak 
122f1f05a91SKrzysztof Drewniak     Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);
123f1f05a91SKrzysztof Drewniak 
124f1f05a91SKrzysztof Drewniak     Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
125f1f05a91SKrzysztof Drewniak     Value ptrAsInt = rewriter.create<LLVM::PtrToIntOp>(loc, llvmI64, ptr);
126f1f05a91SKrzysztof Drewniak     Value ptrAsInts =
127f1f05a91SKrzysztof Drewniak         rewriter.create<LLVM::BitcastOp>(loc, llvm2xI32, ptrAsInt);
128f1f05a91SKrzysztof Drewniak     for (int64_t i = 0; i < 2; ++i) {
129f1f05a91SKrzysztof Drewniak       Value idxConst = this->createIndexConstant(rewriter, loc, i);
130f1f05a91SKrzysztof Drewniak       Value part =
131f1f05a91SKrzysztof Drewniak           rewriter.create<LLVM::ExtractElementOp>(loc, ptrAsInts, idxConst);
132f1f05a91SKrzysztof Drewniak       resource = rewriter.create<LLVM::InsertElementOp>(
133f1f05a91SKrzysztof Drewniak           loc, llvm4xI32, resource, part, idxConst);
134f1f05a91SKrzysztof Drewniak     }
135f1f05a91SKrzysztof Drewniak 
136f1f05a91SKrzysztof Drewniak     Value numRecords;
137f1f05a91SKrzysztof Drewniak     if (memrefType.hasStaticShape()) {
138f1f05a91SKrzysztof Drewniak       numRecords = createI32Constant(
139f1f05a91SKrzysztof Drewniak           rewriter, loc,
140f1f05a91SKrzysztof Drewniak           static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
141f1f05a91SKrzysztof Drewniak     } else {
142f1f05a91SKrzysztof Drewniak       Value maxIndex;
143f1f05a91SKrzysztof Drewniak       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
144f1f05a91SKrzysztof Drewniak         Value size = memrefDescriptor.size(rewriter, loc, i);
145f1f05a91SKrzysztof Drewniak         Value stride = memrefDescriptor.stride(rewriter, loc, i);
146f1f05a91SKrzysztof Drewniak         stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
147f1f05a91SKrzysztof Drewniak         Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
148f1f05a91SKrzysztof Drewniak         maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
149f1f05a91SKrzysztof Drewniak                                                                maxThisDim)
150f1f05a91SKrzysztof Drewniak                             : maxThisDim;
151f1f05a91SKrzysztof Drewniak       }
152f1f05a91SKrzysztof Drewniak       numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
153f1f05a91SKrzysztof Drewniak     }
154f1f05a91SKrzysztof Drewniak     resource = rewriter.create<LLVM::InsertElementOp>(
155f1f05a91SKrzysztof Drewniak         loc, llvm4xI32, resource, numRecords,
156f1f05a91SKrzysztof Drewniak         this->createIndexConstant(rewriter, loc, 2));
157f1f05a91SKrzysztof Drewniak 
158f1f05a91SKrzysztof Drewniak     // Final word:
159f1f05a91SKrzysztof Drewniak     // bits 0-11: dst sel, ignored by these intrinsics
160f1f05a91SKrzysztof Drewniak     // bits 12-14: data format (ignored, must be nonzero, 7=float)
161f1f05a91SKrzysztof Drewniak     // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
162f1f05a91SKrzysztof Drewniak     // bit 19: In nested heap (0 here)
163f1f05a91SKrzysztof Drewniak     // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
164f1f05a91SKrzysztof Drewniak     // bits 21-22: Index stride for swizzles (N/A)
165f1f05a91SKrzysztof Drewniak     // bit 23: Add thread ID (0)
166f1f05a91SKrzysztof Drewniak     // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
167f1f05a91SKrzysztof Drewniak     // bits 25-26: Reserved (0)
168f1f05a91SKrzysztof Drewniak     // bit 27: Buffer is non-volatile (CDNA only)
169db590549SKrzysztof Drewniak     // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
170db590549SKrzysztof Drewniak     //  none, 3 = either swizzles or testing against offset field) RDNA only
171f1f05a91SKrzysztof Drewniak     // bits 30-31: Type (must be 0)
172f1f05a91SKrzysztof Drewniak     uint32_t word3 = (7 << 12) | (4 << 15);
173cab44c51SKrzysztof Drewniak     if (chipset.majorVersion == 10) {
174f1f05a91SKrzysztof Drewniak       word3 |= (1 << 24);
175db590549SKrzysztof Drewniak       uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
176f1f05a91SKrzysztof Drewniak       word3 |= (oob << 28);
177f1f05a91SKrzysztof Drewniak     }
178f1f05a91SKrzysztof Drewniak     Value word3Const = createI32Constant(rewriter, loc, word3);
179f1f05a91SKrzysztof Drewniak     resource = rewriter.create<LLVM::InsertElementOp>(
180f1f05a91SKrzysztof Drewniak         loc, llvm4xI32, resource, word3Const,
181f1f05a91SKrzysztof Drewniak         this->createIndexConstant(rewriter, loc, 3));
182f1f05a91SKrzysztof Drewniak     args.push_back(resource);
183f1f05a91SKrzysztof Drewniak 
184f1f05a91SKrzysztof Drewniak     // Indexing (voffset)
185f1f05a91SKrzysztof Drewniak     Value voffset;
1868df54a6aSJacques Pienaar     for (auto &pair : llvm::enumerate(adaptor.getIndices())) {
187f1f05a91SKrzysztof Drewniak       size_t i = pair.index();
188f1f05a91SKrzysztof Drewniak       Value index = pair.value();
189f1f05a91SKrzysztof Drewniak       Value strideOp;
190f1f05a91SKrzysztof Drewniak       if (ShapedType::isDynamicStrideOrOffset(strides[i])) {
191f1f05a91SKrzysztof Drewniak         strideOp = rewriter.create<LLVM::MulOp>(
192f1f05a91SKrzysztof Drewniak             loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
193f1f05a91SKrzysztof Drewniak       } else {
194f1f05a91SKrzysztof Drewniak         strideOp =
195f1f05a91SKrzysztof Drewniak             createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
196f1f05a91SKrzysztof Drewniak       }
197f1f05a91SKrzysztof Drewniak       index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
198f1f05a91SKrzysztof Drewniak       voffset =
199f1f05a91SKrzysztof Drewniak           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, index) : index;
200f1f05a91SKrzysztof Drewniak     }
201037f0995SKazu Hirata     if (adaptor.getIndexOffset()) {
2028df54a6aSJacques Pienaar       int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
203f1f05a91SKrzysztof Drewniak       Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
204f1f05a91SKrzysztof Drewniak       voffset =
205f1f05a91SKrzysztof Drewniak           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
206f1f05a91SKrzysztof Drewniak                   : extraOffsetConst;
207f1f05a91SKrzysztof Drewniak     }
208f1f05a91SKrzysztof Drewniak     args.push_back(voffset);
209f1f05a91SKrzysztof Drewniak 
2108df54a6aSJacques Pienaar     Value sgprOffset = adaptor.getSgprOffset();
211f1f05a91SKrzysztof Drewniak     if (!sgprOffset)
212f1f05a91SKrzysztof Drewniak       sgprOffset = createI32Constant(rewriter, loc, 0);
213f1f05a91SKrzysztof Drewniak     if (ShapedType::isDynamicStrideOrOffset(offset))
214f1f05a91SKrzysztof Drewniak       sgprOffset = rewriter.create<LLVM::AddOp>(
215f1f05a91SKrzysztof Drewniak           loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
216f1f05a91SKrzysztof Drewniak     else if (offset > 0)
217f1f05a91SKrzysztof Drewniak       sgprOffset = rewriter.create<LLVM::AddOp>(
218f1f05a91SKrzysztof Drewniak           loc, sgprOffset, createI32Constant(rewriter, loc, offset));
219f1f05a91SKrzysztof Drewniak     args.push_back(sgprOffset);
220f1f05a91SKrzysztof Drewniak 
221f1f05a91SKrzysztof Drewniak     // bit 0: GLC = 0 (atomics drop value, less coherency)
222f1f05a91SKrzysztof Drewniak     // bits 1-2: SLC, DLC = 0 (similarly)
223f1f05a91SKrzysztof Drewniak     // bit 3: swizzled (0 for raw)
224f1f05a91SKrzysztof Drewniak     args.push_back(createI32Constant(rewriter, loc, 0));
225f1f05a91SKrzysztof Drewniak 
226f1f05a91SKrzysztof Drewniak     llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
227f1f05a91SKrzysztof Drewniak                                            llvmBufferValType);
228f1f05a91SKrzysztof Drewniak     Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
229f1f05a91SKrzysztof Drewniak                                                     ArrayRef<NamedAttribute>());
230f1f05a91SKrzysztof Drewniak     if (lowered->getNumResults() == 1) {
231f1f05a91SKrzysztof Drewniak       Value replacement = lowered->getResult(0);
232f1f05a91SKrzysztof Drewniak       if (llvmBufferValType != llvmWantedDataType) {
233f1f05a91SKrzysztof Drewniak         replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
234f1f05a91SKrzysztof Drewniak                                                        replacement);
235f1f05a91SKrzysztof Drewniak       }
236f1f05a91SKrzysztof Drewniak       rewriter.replaceOp(gpuOp, replacement);
237f1f05a91SKrzysztof Drewniak     } else {
238f1f05a91SKrzysztof Drewniak       rewriter.eraseOp(gpuOp);
239f1f05a91SKrzysztof Drewniak     }
240f1f05a91SKrzysztof Drewniak     return success();
241f1f05a91SKrzysztof Drewniak   }
242f1f05a91SKrzysztof Drewniak };
243f1f05a91SKrzysztof Drewniak 
244*bc61cc9aSKrzysztof Drewniak struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
245*bc61cc9aSKrzysztof Drewniak   using ConvertOpToLLVMPattern<LDSBarrierOp>::ConvertOpToLLVMPattern;
246*bc61cc9aSKrzysztof Drewniak 
247*bc61cc9aSKrzysztof Drewniak   LogicalResult
matchAndRewrite__anon3714f0670111::LDSBarrierOpLowering248*bc61cc9aSKrzysztof Drewniak   matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
249*bc61cc9aSKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const override {
250*bc61cc9aSKrzysztof Drewniak     auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
251*bc61cc9aSKrzysztof Drewniak                                                     LLVM::AsmDialect::AD_ATT);
252*bc61cc9aSKrzysztof Drewniak     const char *asmStr = "s_waitcnt lgkmcnt(0)\ns_barrier";
253*bc61cc9aSKrzysztof Drewniak     const char *constraints = "";
254*bc61cc9aSKrzysztof Drewniak     rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
255*bc61cc9aSKrzysztof Drewniak         op,
256*bc61cc9aSKrzysztof Drewniak         /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
257*bc61cc9aSKrzysztof Drewniak         /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
258*bc61cc9aSKrzysztof Drewniak         /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
259*bc61cc9aSKrzysztof Drewniak         /*operand_attrs=*/ArrayAttr());
260*bc61cc9aSKrzysztof Drewniak     return success();
261*bc61cc9aSKrzysztof Drewniak   }
262*bc61cc9aSKrzysztof Drewniak };
263*bc61cc9aSKrzysztof Drewniak 
264f1f05a91SKrzysztof Drewniak struct ConvertAMDGPUToROCDLPass
265f1f05a91SKrzysztof Drewniak     : public ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
266f1f05a91SKrzysztof Drewniak   ConvertAMDGPUToROCDLPass() = default;
267f1f05a91SKrzysztof Drewniak 
runOnOperation__anon3714f0670111::ConvertAMDGPUToROCDLPass268f1f05a91SKrzysztof Drewniak   void runOnOperation() override {
269cab44c51SKrzysztof Drewniak     MLIRContext *ctx = &getContext();
270cab44c51SKrzysztof Drewniak     FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
271cab44c51SKrzysztof Drewniak     if (failed(maybeChipset)) {
272cab44c51SKrzysztof Drewniak       emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
273cab44c51SKrzysztof Drewniak       return signalPassFailure();
274cab44c51SKrzysztof Drewniak     }
275cab44c51SKrzysztof Drewniak 
276cab44c51SKrzysztof Drewniak     RewritePatternSet patterns(ctx);
277cab44c51SKrzysztof Drewniak     LLVMTypeConverter converter(ctx);
278cab44c51SKrzysztof Drewniak     populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
279f1f05a91SKrzysztof Drewniak     LLVMConversionTarget target(getContext());
280f1f05a91SKrzysztof Drewniak     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
281f1f05a91SKrzysztof Drewniak     target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
282f1f05a91SKrzysztof Drewniak     if (failed(applyPartialConversion(getOperation(), target,
283f1f05a91SKrzysztof Drewniak                                       std::move(patterns))))
284f1f05a91SKrzysztof Drewniak       signalPassFailure();
285f1f05a91SKrzysztof Drewniak   }
286f1f05a91SKrzysztof Drewniak };
287f1f05a91SKrzysztof Drewniak } // namespace
288f1f05a91SKrzysztof Drewniak 
populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns,Chipset chipset)289cab44c51SKrzysztof Drewniak void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
290cab44c51SKrzysztof Drewniak                                                    RewritePatternSet &patterns,
291cab44c51SKrzysztof Drewniak                                                    Chipset chipset) {
292*bc61cc9aSKrzysztof Drewniak   patterns.add<LDSBarrierOpLowering>(converter);
293f1f05a91SKrzysztof Drewniak   patterns.add<
294cab44c51SKrzysztof Drewniak       RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
295cab44c51SKrzysztof Drewniak       RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
296cab44c51SKrzysztof Drewniak       RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>>(
297cab44c51SKrzysztof Drewniak       converter, chipset);
298f1f05a91SKrzysztof Drewniak }
299f1f05a91SKrzysztof Drewniak 
createConvertAMDGPUToROCDLPass()300f1f05a91SKrzysztof Drewniak std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
301f1f05a91SKrzysztof Drewniak   return std::make_unique<ConvertAMDGPUToROCDLPass>();
302f1f05a91SKrzysztof Drewniak }
303