1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/AMX/Transforms.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 using namespace mlir::amx; 21 22 namespace { 23 24 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first 25 /// dimension directly translates into the number of rows of the tiles. 26 /// The second dimensions needs to be scaled by the number of bytes. 27 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter, 28 LLVMTypeConverter &typeConverter, 29 VectorType vType, Location loc) { 30 Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); 31 unsigned width = vType.getElementType().getIntOrFloatBitWidth(); 32 assert(llvm::isPowerOf2_64(width) && width >= 8); 33 unsigned bytes = width >> 3; 34 auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0)); 35 auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes); 36 return std::make_pair( 37 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), 38 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)); 39 } 40 41 /// Verifies if the stride matches proper tile access. 42 LogicalResult verifyStride(MemRefType mType) { 43 if (mType.getRank() < 2) 44 return failure(); 45 int64_t last = mType.getRank() - 1; 46 int64_t offset; 47 SmallVector<int64_t, 4> strides; 48 if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1) 49 return failure(); 50 return success(); 51 } 52 53 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer 54 /// shape may "envelop" the actual tile shape, and may be dynamically sized. 55 Value getStride(ConversionPatternRewriter &rewriter, 56 LLVMTypeConverter &typeConverter, MemRefType mType, Value base, 57 Location loc) { 58 assert(mType.getRank() >= 2); 59 int64_t last = mType.getRank() - 1; 60 Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); 61 unsigned width = mType.getElementType().getIntOrFloatBitWidth(); 62 assert(llvm::isPowerOf2_64(width) && width >= 8); 63 unsigned bytes = width >> 3; 64 if (mType.isDynamicDim(last)) { 65 // Dynamic size needs code to compute the stride at runtime. 66 MemRefDescriptor memrefDescriptor(base); 67 auto attr = rewriter.getI64IntegerAttr(bytes); 68 Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 69 return rewriter.create<LLVM::MulOp>( 70 loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last)); 71 } 72 // Use direct constant for static size. 73 auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes); 74 return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 75 } 76 77 /// Cast any pointer to the !llvm.ptr<i8> pointer type. 78 Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) { 79 auto i8Ptr = 80 LLVM::LLVMPointerType::get(IntegerType::get(ptr.getContext(), 8)); 81 return rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, ptr); 82 } 83 84 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { 85 using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern; 86 LogicalResult 87 matchAndRewrite(TileZeroOp op, ArrayRef<Value> operands, 88 ConversionPatternRewriter &rewriter) const override { 89 VectorType vType = op.getVectorType(); 90 // Determine m x n tile sizes. 91 std::pair<Value, Value> tsz = 92 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 93 // Replace operation with intrinsic. 94 Type resType = typeConverter->convertType(vType); 95 rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first, 96 tsz.second); 97 return success(); 98 } 99 }; 100 101 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { 102 using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern; 103 104 LogicalResult 105 matchAndRewrite(TileLoadOp op, ArrayRef<Value> operands, 106 ConversionPatternRewriter &rewriter) const override { 107 TileLoadOp::Adaptor adaptor(operands); 108 MemRefType mType = op.getMemRefType(); 109 VectorType vType = op.getVectorType(); 110 // Determine m x n tile sizes. 111 std::pair<Value, Value> tsz = 112 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 113 // Determine stride. 114 if (failed(verifyStride(mType))) 115 return failure(); 116 Value stride = getStride(rewriter, *getTypeConverter(), mType, 117 adaptor.base(), op.getLoc()); 118 // Replace operation with intrinsic. 119 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), 120 adaptor.indices(), rewriter); 121 ptr = castPtr(rewriter, op.getLoc(), ptr); 122 Type resType = typeConverter->convertType(vType); 123 rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>( 124 op, resType, tsz.first, tsz.second, ptr, stride); 125 return success(); 126 } 127 }; 128 129 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { 130 using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern; 131 132 LogicalResult 133 matchAndRewrite(TileStoreOp op, ArrayRef<Value> operands, 134 ConversionPatternRewriter &rewriter) const override { 135 TileStoreOp::Adaptor adaptor(operands); 136 MemRefType mType = op.getMemRefType(); 137 VectorType vType = op.getVectorType(); 138 // Determine m x n tile sizes. 139 std::pair<Value, Value> tsz = 140 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 141 // Determine stride. 142 if (failed(verifyStride(mType))) 143 return failure(); 144 Value stride = getStride(rewriter, *getTypeConverter(), mType, 145 adaptor.base(), op.getLoc()); 146 // Replace operation with intrinsic. 147 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), 148 adaptor.indices(), rewriter); 149 ptr = castPtr(rewriter, op.getLoc(), ptr); 150 rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>( 151 op, tsz.first, tsz.second, ptr, stride, adaptor.val()); 152 return success(); 153 } 154 }; 155 156 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { 157 using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern; 158 LogicalResult 159 matchAndRewrite(TileMulFOp op, ArrayRef<Value> operands, 160 ConversionPatternRewriter &rewriter) const override { 161 TileMulFOp::Adaptor adaptor(operands); 162 VectorType aType = op.getLhsVectorType(); 163 VectorType bType = op.getRhsVectorType(); 164 VectorType cType = op.getVectorType(); 165 // Determine m x n x k tile sizes. 166 std::pair<Value, Value> tsza = 167 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 168 std::pair<Value, Value> tszb = 169 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 170 // Replace operation with intrinsic. 171 Type resType = typeConverter->convertType(cType); 172 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>( 173 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 174 adaptor.lhs(), adaptor.rhs()); 175 return success(); 176 } 177 }; 178 179 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { 180 using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern; 181 LogicalResult 182 matchAndRewrite(TileMulIOp op, ArrayRef<Value> operands, 183 ConversionPatternRewriter &rewriter) const override { 184 TileMulIOp::Adaptor adaptor(operands); 185 VectorType aType = op.getLhsVectorType(); 186 VectorType bType = op.getRhsVectorType(); 187 VectorType cType = op.getVectorType(); 188 // Determine m x n x k tile sizes. 189 std::pair<Value, Value> tsza = 190 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 191 std::pair<Value, Value> tszb = 192 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 193 // Replace operation with intrinsic. 194 Type resType = typeConverter->convertType(cType); 195 bool zexta = op.isZextLhs(); 196 bool zextb = op.isZextRhs(); 197 if (zexta && zextb) 198 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>( 199 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 200 adaptor.lhs(), adaptor.rhs()); 201 else if (zexta && !zextb) 202 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>( 203 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 204 adaptor.lhs(), adaptor.rhs()); 205 else if (!zexta && zextb) 206 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>( 207 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 208 adaptor.lhs(), adaptor.rhs()); 209 else 210 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>( 211 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 212 adaptor.lhs(), adaptor.rhs()); 213 return success(); 214 } 215 }; 216 217 } // namespace 218 219 void mlir::populateAMXLegalizeForLLVMExportPatterns( 220 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 221 patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion, 222 TileMulFConversion, TileMulIConversion>(converter); 223 } 224 225 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { 226 target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64, 227 x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud, 228 x86_amx_tdpbusd, x86_amx_tdpbuud>(); 229 target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp, 230 TileMulFOp>(); 231 } 232