1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/AMX/Transforms.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::amx; 20 21 namespace { 22 23 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first 24 /// dimension directly translates into the number of rows of the tiles. 25 /// The second dimensions needs to be scaled by the number of bytes. 26 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter, 27 LLVMTypeConverter &typeConverter, 28 VectorType vType, Location loc) { 29 Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); 30 unsigned width = vType.getElementType().getIntOrFloatBitWidth(); 31 assert(llvm::isPowerOf2_64(width) && width >= 8); 32 unsigned bytes = width >> 3; 33 auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0)); 34 auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes); 35 return std::make_pair( 36 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), 37 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)); 38 } 39 40 /// Verifies if the stride matches proper tile access. 41 LogicalResult verifyStride(MemRefType mType) { 42 if (mType.getRank() < 2) 43 return failure(); 44 int64_t last = mType.getRank() - 1; 45 int64_t offset; 46 SmallVector<int64_t, 4> strides; 47 if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1) 48 return failure(); 49 return success(); 50 } 51 52 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer 53 /// shape may "envelop" the actual tile shape, and may be dynamically sized. 54 Value getStride(ConversionPatternRewriter &rewriter, 55 LLVMTypeConverter &typeConverter, MemRefType mType, Value base, 56 Location loc) { 57 assert(mType.getRank() >= 2); 58 int64_t last = mType.getRank() - 1; 59 Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); 60 unsigned width = mType.getElementType().getIntOrFloatBitWidth(); 61 assert(llvm::isPowerOf2_64(width) && width >= 8); 62 unsigned bytes = width >> 3; 63 if (mType.isDynamicDim(last)) { 64 // Dynamic size needs code to compute the stride at runtime. 65 MemRefDescriptor memrefDescriptor(base); 66 auto attr = rewriter.getI64IntegerAttr(bytes); 67 Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 68 return rewriter.create<LLVM::MulOp>( 69 loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last)); 70 } 71 // Use direct constant for static size. 72 auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes); 73 return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 74 } 75 76 /// Cast any pointer to the !llvm.ptr<i8> pointer type. 77 Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) { 78 auto i8Ptr = 79 LLVM::LLVMPointerType::get(IntegerType::get(ptr.getContext(), 8)); 80 return rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, ptr); 81 } 82 83 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { 84 using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern; 85 LogicalResult 86 matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, 87 ConversionPatternRewriter &rewriter) const override { 88 VectorType vType = op.getVectorType(); 89 // Determine m x n tile sizes. 90 std::pair<Value, Value> tsz = 91 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 92 // Replace operation with intrinsic. 93 Type resType = typeConverter->convertType(vType); 94 rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first, 95 tsz.second); 96 return success(); 97 } 98 }; 99 100 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { 101 using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern; 102 103 LogicalResult 104 matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, 105 ConversionPatternRewriter &rewriter) const override { 106 MemRefType mType = op.getMemRefType(); 107 VectorType vType = op.getVectorType(); 108 // Determine m x n tile sizes. 109 std::pair<Value, Value> tsz = 110 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 111 // Determine stride. 112 if (failed(verifyStride(mType))) 113 return failure(); 114 Value stride = getStride(rewriter, *getTypeConverter(), mType, 115 adaptor.getBase(), op.getLoc()); 116 // Replace operation with intrinsic. 117 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), 118 adaptor.getIndices(), rewriter); 119 ptr = castPtr(rewriter, op.getLoc(), ptr); 120 Type resType = typeConverter->convertType(vType); 121 rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>( 122 op, resType, tsz.first, tsz.second, ptr, stride); 123 return success(); 124 } 125 }; 126 127 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { 128 using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern; 129 130 LogicalResult 131 matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, 132 ConversionPatternRewriter &rewriter) const override { 133 MemRefType mType = op.getMemRefType(); 134 VectorType vType = op.getVectorType(); 135 // Determine m x n tile sizes. 136 std::pair<Value, Value> tsz = 137 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 138 // Determine stride. 139 if (failed(verifyStride(mType))) 140 return failure(); 141 Value stride = getStride(rewriter, *getTypeConverter(), mType, 142 adaptor.getBase(), op.getLoc()); 143 // Replace operation with intrinsic. 144 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), 145 adaptor.getIndices(), rewriter); 146 ptr = castPtr(rewriter, op.getLoc(), ptr); 147 rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>( 148 op, tsz.first, tsz.second, ptr, stride, adaptor.getVal()); 149 return success(); 150 } 151 }; 152 153 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { 154 using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern; 155 LogicalResult 156 matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, 157 ConversionPatternRewriter &rewriter) const override { 158 VectorType aType = op.getLhsVectorType(); 159 VectorType bType = op.getRhsVectorType(); 160 VectorType cType = op.getVectorType(); 161 // Determine m x n x k tile sizes. 162 std::pair<Value, Value> tsza = 163 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 164 std::pair<Value, Value> tszb = 165 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 166 // Replace operation with intrinsic. 167 Type resType = typeConverter->convertType(cType); 168 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>( 169 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 170 adaptor.getLhs(), adaptor.getRhs()); 171 return success(); 172 } 173 }; 174 175 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { 176 using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern; 177 LogicalResult 178 matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, 179 ConversionPatternRewriter &rewriter) const override { 180 VectorType aType = op.getLhsVectorType(); 181 VectorType bType = op.getRhsVectorType(); 182 VectorType cType = op.getVectorType(); 183 // Determine m x n x k tile sizes. 184 std::pair<Value, Value> tsza = 185 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 186 std::pair<Value, Value> tszb = 187 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 188 // Replace operation with intrinsic. 189 Type resType = typeConverter->convertType(cType); 190 bool zexta = op.getIsZextLhs(); 191 bool zextb = op.getIsZextRhs(); 192 if (zexta && zextb) 193 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>( 194 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 195 adaptor.getLhs(), adaptor.getRhs()); 196 else if (zexta && !zextb) 197 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>( 198 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 199 adaptor.getLhs(), adaptor.getRhs()); 200 else if (!zexta && zextb) 201 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>( 202 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 203 adaptor.getLhs(), adaptor.getRhs()); 204 else 205 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>( 206 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 207 adaptor.getLhs(), adaptor.getRhs()); 208 return success(); 209 } 210 }; 211 212 } // namespace 213 214 void mlir::populateAMXLegalizeForLLVMExportPatterns( 215 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 216 patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion, 217 TileMulFConversion, TileMulIConversion>(converter); 218 } 219 220 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { 221 target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64, 222 x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud, 223 x86_amx_tdpbusd, x86_amx_tdpbuud>(); 224 target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp, 225 TileMulFOp>(); 226 } 227