//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "../PassDetail.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" using namespace mlir; using namespace mlir::amdgpu; static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { IntegerAttr valAttr = rewriter.getI32IntegerAttr(value); Type llvmI32 = rewriter.getI32Type(); return rewriter.create(loc, llvmI32, valAttr); } namespace { /// Define lowering patterns for raw buffer ops template struct RawBufferOpLowering : public ConvertOpToLLVMPattern { RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; static constexpr uint32_t maxVectorOpWidth = 128; LogicalResult matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = gpuOp.getLoc(); Value memref = adaptor.getMemref(); Value unconvertedMemref = gpuOp.getMemref(); MemRefType memrefType = unconvertedMemref.getType().cast(); if (chipset.majorVersion < 9) return gpuOp.emitOpError("Raw buffer ops require GCN or higher"); Value storeData = adaptor.getODSOperands(0)[0]; if (storeData == memref) // no write component to this op storeData = Value(); Type wantedDataType; if (storeData) wantedDataType = storeData.getType(); else wantedDataType = gpuOp.getODSResults(0)[0].getType(); Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); Type i32 = rewriter.getI32Type(); Type llvmI32 = this->typeConverter->convertType(i32); int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); // If we want to load a vector with total size <= 32 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 // and the total load size is >= 32, use a vector load of N / (bitsize(T) / // 32) x i32 and bitcast. Type llvmBufferValType = llvmWantedDataType; if (auto dataVector = wantedDataType.dyn_cast()) { uint32_t elemBits = dataVector.getElementTypeBitWidth(); uint32_t totalBits = elemBits * dataVector.getNumElements(); if (totalBits > maxVectorOpWidth) return gpuOp.emitOpError( "Total width of loads or stores must be no more than " + Twine(maxVectorOpWidth) + " bits, but we call for " + Twine(totalBits) + " bits. This should've been caught in validation"); if (elemBits < 32) { if (totalBits > 32) { if (totalBits % 32 != 0) return gpuOp.emitOpError("Load or store of more than 32-bits that " "doesn't fit into words. Can't happen\n"); llvmBufferValType = this->typeConverter->convertType( VectorType::get(totalBits / 32, i32)); } else { llvmBufferValType = this->typeConverter->convertType( rewriter.getIntegerType(totalBits)); } } } SmallVector args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { Value castForStore = rewriter.create(loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); } } // Construct buffer descriptor from memref, attributes int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefType, strides, offset))) return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); // Resource descriptor // bits 0-47: base address // bits 48-61: stride (0 for raw buffers) // bit 62: texture cache coherency (always 0) // bit 63: enable swizzles (always off for raw buffers) // bits 64-95 (word 2): Number of records, units of stride // bits 96-127 (word 3): See below Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32)); MemRefDescriptor memrefDescriptor(memref); Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type()); Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32)); Value resource = rewriter.create(loc, llvm4xI32); Value ptr = memrefDescriptor.alignedPtr(rewriter, loc); Value ptrAsInt = rewriter.create(loc, llvmI64, ptr); Value ptrAsInts = rewriter.create(loc, llvm2xI32, ptrAsInt); for (int64_t i = 0; i < 2; ++i) { Value idxConst = this->createIndexConstant(rewriter, loc, i); Value part = rewriter.create(loc, ptrAsInts, idxConst); resource = rewriter.create( loc, llvm4xI32, resource, part, idxConst); } Value numRecords; if (memrefType.hasStaticShape()) { numRecords = createI32Constant( rewriter, loc, static_cast(memrefType.getNumElements() * elementByteWidth)); } else { Value maxIndex; for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); stride = rewriter.create(loc, stride, byteWidthConst); Value maxThisDim = rewriter.create(loc, size, stride); maxIndex = maxIndex ? rewriter.create(loc, maxIndex, maxThisDim) : maxThisDim; } numRecords = rewriter.create(loc, llvmI32, maxIndex); } resource = rewriter.create( loc, llvm4xI32, resource, numRecords, this->createIndexConstant(rewriter, loc, 2)); // Final word: // bits 0-11: dst sel, ignored by these intrinsics // bits 12-14: data format (ignored, must be nonzero, 7=float) // bits 15-18: data format (ignored, must be nonzero, 4=32bit) // bit 19: In nested heap (0 here) // bit 20: Behavior on unmap (0 means "return 0 / ignore") // bits 21-22: Index stride for swizzles (N/A) // bit 23: Add thread ID (0) // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) // bits 25-26: Reserved (0) // bit 27: Buffer is non-volatile (CDNA only) // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = // none, 3 = either swizzles or testing against offset field) RDNA only // bits 30-31: Type (must be 0) uint32_t word3 = (7 << 12) | (4 << 15); if (chipset.majorVersion == 10) { word3 |= (1 << 24); uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2; word3 |= (oob << 28); } Value word3Const = createI32Constant(rewriter, loc, word3); resource = rewriter.create( loc, llvm4xI32, resource, word3Const, this->createIndexConstant(rewriter, loc, 3)); args.push_back(resource); // Indexing (voffset) Value voffset; for (auto &pair : llvm::enumerate(adaptor.getIndices())) { size_t i = pair.index(); Value index = pair.value(); Value strideOp; if (ShapedType::isDynamicStrideOrOffset(strides[i])) { strideOp = rewriter.create( loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); } else { strideOp = createI32Constant(rewriter, loc, strides[i] * elementByteWidth); } index = rewriter.create(loc, index, strideOp); voffset = voffset ? rewriter.create(loc, voffset, index) : index; } if (adaptor.getIndexOffset()) { int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth; Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset); voffset = voffset ? rewriter.create(loc, voffset, extraOffsetConst) : extraOffsetConst; } args.push_back(voffset); Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); if (ShapedType::isDynamicStrideOrOffset(offset)) sgprOffset = rewriter.create( loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); else if (offset > 0) sgprOffset = rewriter.create( loc, sgprOffset, createI32Constant(rewriter, loc, offset)); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) // bits 1-2: SLC, DLC = 0 (similarly) // bit 3: swizzled (0 for raw) args.push_back(createI32Constant(rewriter, loc, 0)); llvm::SmallVector resultTypes(gpuOp->getNumResults(), llvmBufferValType); Operation *lowered = rewriter.create(loc, resultTypes, args, ArrayRef()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { replacement = rewriter.create(loc, llvmWantedDataType, replacement); } rewriter.replaceOp(gpuOp, replacement); } else { rewriter.eraseOp(gpuOp); } return success(); } }; struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); const char *asmStr = "s_waitcnt lgkmcnt(0)\ns_barrier"; const char *constraints = ""; rewriter.replaceOpWithNewOp( op, /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); return success(); } }; struct ConvertAMDGPUToROCDLPass : public ConvertAMDGPUToROCDLBase { ConvertAMDGPUToROCDLPass() = default; void runOnOperation() override { MLIRContext *ctx = &getContext(); FailureOr maybeChipset = Chipset::parse(chipset); if (failed(maybeChipset)) { emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); return signalPassFailure(); } RewritePatternSet patterns(ctx); LLVMTypeConverter converter(ctx); populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { patterns.add(converter); patterns.add< RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering>( converter, chipset); } std::unique_ptr mlir::createConvertAMDGPUToROCDLPass() { return std::make_unique(); }