1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/AMX/Transforms.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 using namespace mlir::amx; 21 22 namespace { 23 24 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first 25 /// dimension directly translates into the number of rows of the tiles. 26 /// The second dimensions needs to be scaled by the number of bytes. 27 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter, 28 LLVMTypeConverter &typeConverter, 29 VectorType vType, Location loc) { 30 Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); 31 unsigned width = vType.getElementType().getIntOrFloatBitWidth(); 32 assert(llvm::isPowerOf2_64(width) && width >= 8); 33 unsigned bytes = width >> 3; 34 auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0)); 35 auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes); 36 return std::make_pair( 37 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), 38 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)); 39 } 40 41 /// Verifies if the stride matches proper tile access. 42 LogicalResult verifyStride(MemRefType mType) { 43 if (mType.getRank() < 2) 44 return failure(); 45 int64_t last = mType.getRank() - 1; 46 int64_t offset; 47 SmallVector<int64_t, 4> strides; 48 if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1) 49 return failure(); 50 return success(); 51 } 52 53 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer 54 /// shape may "envelop" the actual tile shape, and may be dynamically sized. 55 Value getStride(ConversionPatternRewriter &rewriter, 56 LLVMTypeConverter &typeConverter, MemRefType mType, Value base, 57 Location loc) { 58 assert(mType.getRank() >= 2); 59 int64_t last = mType.getRank() - 1; 60 Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); 61 unsigned width = mType.getElementType().getIntOrFloatBitWidth(); 62 assert(llvm::isPowerOf2_64(width) && width >= 8); 63 unsigned bytes = width >> 3; 64 if (mType.isDynamicDim(last)) { 65 // Dynamic size needs code to compute the stride at runtime. 66 MemRefDescriptor memrefDescriptor(base); 67 auto attr = rewriter.getI64IntegerAttr(bytes); 68 Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 69 return rewriter.create<LLVM::MulOp>( 70 loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last)); 71 } 72 // Use direct constant for static size. 73 auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes); 74 return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 75 } 76 77 /// Cast any pointer to the !llvm.ptr<i8> pointer type. 78 Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) { 79 auto i8Ptr = 80 LLVM::LLVMPointerType::get(IntegerType::get(ptr.getContext(), 8)); 81 return rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, ptr); 82 } 83 84 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { 85 using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern; 86 LogicalResult 87 matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override { 89 VectorType vType = op.getVectorType(); 90 // Determine m x n tile sizes. 91 std::pair<Value, Value> tsz = 92 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 93 // Replace operation with intrinsic. 94 Type resType = typeConverter->convertType(vType); 95 rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first, 96 tsz.second); 97 return success(); 98 } 99 }; 100 101 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { 102 using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern; 103 104 LogicalResult 105 matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, 106 ConversionPatternRewriter &rewriter) const override { 107 MemRefType mType = op.getMemRefType(); 108 VectorType vType = op.getVectorType(); 109 // Determine m x n tile sizes. 110 std::pair<Value, Value> tsz = 111 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 112 // Determine stride. 113 if (failed(verifyStride(mType))) 114 return failure(); 115 Value stride = getStride(rewriter, *getTypeConverter(), mType, 116 adaptor.base(), op.getLoc()); 117 // Replace operation with intrinsic. 118 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), 119 adaptor.indices(), rewriter); 120 ptr = castPtr(rewriter, op.getLoc(), ptr); 121 Type resType = typeConverter->convertType(vType); 122 rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>( 123 op, resType, tsz.first, tsz.second, ptr, stride); 124 return success(); 125 } 126 }; 127 128 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { 129 using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern; 130 131 LogicalResult 132 matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, 133 ConversionPatternRewriter &rewriter) const override { 134 MemRefType mType = op.getMemRefType(); 135 VectorType vType = op.getVectorType(); 136 // Determine m x n tile sizes. 137 std::pair<Value, Value> tsz = 138 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 139 // Determine stride. 140 if (failed(verifyStride(mType))) 141 return failure(); 142 Value stride = getStride(rewriter, *getTypeConverter(), mType, 143 adaptor.base(), op.getLoc()); 144 // Replace operation with intrinsic. 145 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), 146 adaptor.indices(), rewriter); 147 ptr = castPtr(rewriter, op.getLoc(), ptr); 148 rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>( 149 op, tsz.first, tsz.second, ptr, stride, adaptor.val()); 150 return success(); 151 } 152 }; 153 154 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { 155 using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern; 156 LogicalResult 157 matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, 158 ConversionPatternRewriter &rewriter) const override { 159 VectorType aType = op.getLhsVectorType(); 160 VectorType bType = op.getRhsVectorType(); 161 VectorType cType = op.getVectorType(); 162 // Determine m x n x k tile sizes. 163 std::pair<Value, Value> tsza = 164 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 165 std::pair<Value, Value> tszb = 166 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 167 // Replace operation with intrinsic. 168 Type resType = typeConverter->convertType(cType); 169 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>( 170 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 171 adaptor.lhs(), adaptor.rhs()); 172 return success(); 173 } 174 }; 175 176 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { 177 using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern; 178 LogicalResult 179 matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, 180 ConversionPatternRewriter &rewriter) const override { 181 VectorType aType = op.getLhsVectorType(); 182 VectorType bType = op.getRhsVectorType(); 183 VectorType cType = op.getVectorType(); 184 // Determine m x n x k tile sizes. 185 std::pair<Value, Value> tsza = 186 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 187 std::pair<Value, Value> tszb = 188 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 189 // Replace operation with intrinsic. 190 Type resType = typeConverter->convertType(cType); 191 bool zexta = op.isZextLhs(); 192 bool zextb = op.isZextRhs(); 193 if (zexta && zextb) 194 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>( 195 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 196 adaptor.lhs(), adaptor.rhs()); 197 else if (zexta && !zextb) 198 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>( 199 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 200 adaptor.lhs(), adaptor.rhs()); 201 else if (!zexta && zextb) 202 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>( 203 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 204 adaptor.lhs(), adaptor.rhs()); 205 else 206 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>( 207 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 208 adaptor.lhs(), adaptor.rhs()); 209 return success(); 210 } 211 }; 212 213 } // namespace 214 215 void mlir::populateAMXLegalizeForLLVMExportPatterns( 216 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 217 patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion, 218 TileMulFConversion, TileMulIConversion>(converter); 219 } 220 221 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { 222 target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64, 223 x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud, 224 x86_amx_tdpbusd, x86_amx_tdpbuud>(); 225 target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp, 226 TileMulFOp>(); 227 } 228