1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/AMX/Transforms.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Dialect/AMX/AMXDialect.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::amx; 20 21 namespace { 22 23 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first 24 /// dimension directly translates into the number of rows of the tiles. 25 /// The second dimensions needs to be scaled by the number of bytes. 26 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter, 27 LLVMTypeConverter &typeConverter, 28 VectorType vType, Location loc) { 29 Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); 30 unsigned width = vType.getElementType().getIntOrFloatBitWidth(); 31 assert(llvm::isPowerOf2_64(width) && width >= 8); 32 unsigned bytes = width >> 3; 33 auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0)); 34 auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes); 35 return std::make_pair( 36 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), 37 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)); 38 } 39 40 /// Verifies if the stride matches proper tile access. 41 LogicalResult verifyStride(MemRefType mType) { 42 if (mType.getRank() < 2) 43 return failure(); 44 int64_t last = mType.getRank() - 1; 45 int64_t offset; 46 SmallVector<int64_t, 4> strides; 47 if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1) 48 return failure(); 49 return success(); 50 } 51 52 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer 53 /// shape may "envelop" the actual tile shape, and may be dynamically sized. 54 Value getStride(ConversionPatternRewriter &rewriter, 55 LLVMTypeConverter &typeConverter, MemRefType mType, Value base, 56 Location loc) { 57 assert(mType.getRank() >= 2); 58 int64_t last = mType.getRank() - 1; 59 Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); 60 unsigned width = mType.getElementType().getIntOrFloatBitWidth(); 61 assert(llvm::isPowerOf2_64(width) && width >= 8); 62 unsigned bytes = width >> 3; 63 if (mType.isDynamicDim(last)) { 64 // Dynamic size needs code to compute the stride at runtime. 65 MemRefDescriptor memrefDescriptor(base); 66 auto attr = rewriter.getI64IntegerAttr(bytes); 67 Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 68 return rewriter.create<LLVM::MulOp>( 69 loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last)); 70 } 71 // Use direct constant for static size. 72 auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes); 73 return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 74 } 75 76 /// Cast any pointer to the !llvm.ptr<i8> pointer type. 77 Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) { 78 auto i8Ptr = 79 LLVM::LLVMPointerType::get(IntegerType::get(ptr.getContext(), 8)); 80 return rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, ptr); 81 } 82 83 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { 84 using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern; 85 LogicalResult 86 matchAndRewrite(TileZeroOp op, ArrayRef<Value> operands, 87 ConversionPatternRewriter &rewriter) const override { 88 VectorType vType = op.getVectorType(); 89 // Determine m x n tile sizes. 90 std::pair<Value, Value> tsz = 91 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 92 // Replace operation with intrinsic. 93 Type resType = typeConverter->convertType(vType); 94 rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first, 95 tsz.second); 96 return success(); 97 } 98 }; 99 100 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { 101 using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern; 102 103 LogicalResult 104 matchAndRewrite(TileLoadOp op, ArrayRef<Value> operands, 105 ConversionPatternRewriter &rewriter) const override { 106 TileLoadOp::Adaptor adaptor(operands); 107 MemRefType mType = op.getMemRefType(); 108 VectorType vType = op.getVectorType(); 109 // Determine m x n tile sizes. 110 std::pair<Value, Value> tsz = 111 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 112 // Determine stride. 113 if (failed(verifyStride(mType))) 114 return failure(); 115 Value stride = getStride(rewriter, *getTypeConverter(), mType, 116 adaptor.base(), op.getLoc()); 117 // Replace operation with intrinsic. 118 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), 119 adaptor.indices(), rewriter); 120 ptr = castPtr(rewriter, op.getLoc(), ptr); 121 Type resType = typeConverter->convertType(vType); 122 rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>( 123 op, resType, tsz.first, tsz.second, ptr, stride); 124 return success(); 125 } 126 }; 127 128 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { 129 using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern; 130 131 LogicalResult 132 matchAndRewrite(TileStoreOp op, ArrayRef<Value> operands, 133 ConversionPatternRewriter &rewriter) const override { 134 TileStoreOp::Adaptor adaptor(operands); 135 MemRefType mType = op.getMemRefType(); 136 VectorType vType = op.getVectorType(); 137 // Determine m x n tile sizes. 138 std::pair<Value, Value> tsz = 139 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); 140 // Determine stride. 141 if (failed(verifyStride(mType))) 142 return failure(); 143 Value stride = getStride(rewriter, *getTypeConverter(), mType, 144 adaptor.base(), op.getLoc()); 145 // Replace operation with intrinsic. 146 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), 147 adaptor.indices(), rewriter); 148 ptr = castPtr(rewriter, op.getLoc(), ptr); 149 rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>( 150 op, tsz.first, tsz.second, ptr, stride, adaptor.val()); 151 return success(); 152 } 153 }; 154 155 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { 156 using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern; 157 LogicalResult 158 matchAndRewrite(TileMulFOp op, ArrayRef<Value> operands, 159 ConversionPatternRewriter &rewriter) const override { 160 TileMulFOp::Adaptor adaptor(operands); 161 VectorType aType = op.getLhsVectorType(); 162 VectorType bType = op.getRhsVectorType(); 163 VectorType cType = op.getVectorType(); 164 // Determine m x n x k tile sizes. 165 std::pair<Value, Value> tsza = 166 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 167 std::pair<Value, Value> tszb = 168 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 169 // Replace operation with intrinsic. 170 Type resType = typeConverter->convertType(cType); 171 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>( 172 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 173 adaptor.lhs(), adaptor.rhs()); 174 return success(); 175 } 176 }; 177 178 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { 179 using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern; 180 LogicalResult 181 matchAndRewrite(TileMulIOp op, ArrayRef<Value> operands, 182 ConversionPatternRewriter &rewriter) const override { 183 TileMulIOp::Adaptor adaptor(operands); 184 VectorType aType = op.getLhsVectorType(); 185 VectorType bType = op.getRhsVectorType(); 186 VectorType cType = op.getVectorType(); 187 // Determine m x n x k tile sizes. 188 std::pair<Value, Value> tsza = 189 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 190 std::pair<Value, Value> tszb = 191 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 192 // Replace operation with intrinsic. 193 Type resType = typeConverter->convertType(cType); 194 bool zexta = op.isZextLhs(); 195 bool zextb = op.isZextRhs(); 196 if (zexta && zextb) 197 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>( 198 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 199 adaptor.lhs(), adaptor.rhs()); 200 else if (zexta && !zextb) 201 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>( 202 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 203 adaptor.lhs(), adaptor.rhs()); 204 else if (!zexta && zextb) 205 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>( 206 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 207 adaptor.lhs(), adaptor.rhs()); 208 else 209 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>( 210 op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), 211 adaptor.lhs(), adaptor.rhs()); 212 return success(); 213 } 214 }; 215 216 } // namespace 217 218 void mlir::populateAMXLegalizeForLLVMExportPatterns( 219 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 220 patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion, 221 TileMulFConversion, TileMulIConversion>(converter); 222 } 223 224 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { 225 target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64, 226 x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud, 227 x86_amx_tdpbusd, x86_amx_tdpbuud>(); 228 target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp, 229 TileMulFOp>(); 230 } 231