1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AMX dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 using namespace mlir; 20 21 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" 22 23 void amx::AMXDialect::initialize() { 24 addOperations< 25 #define GET_OP_LIST 26 #include "mlir/Dialect/AMX/AMX.cpp.inc" 27 >(); 28 } 29 30 /// Verify that AMX supports the implied tile shape. 31 static LogicalResult verifyTileSize(Operation *op, VectorType tp) { 32 const unsigned kMaxRows = 16; 33 const unsigned kBitsPerRow = 64 * 8; 34 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); 35 if (tp.getDimSize(0) > kMaxRows) 36 return op->emitOpError("bad row height: ") << tp.getDimSize(0); 37 if (col > kBitsPerRow || col & 0x1f) 38 return op->emitOpError("bad column width: ") << (col >> 3); 39 return success(); 40 } 41 42 /// Verify that AMX supports the multiplication. 43 static LogicalResult verifyMultShape(Operation *op, VectorType atp, 44 VectorType btp, VectorType ctp, 45 unsigned scale) { 46 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; 47 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; 48 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); 49 if (cm != am || cn != bn || ak != bk) 50 return op->emitOpError("bad mult shape: ") 51 << cm << " x " << cn << " x " << ak; 52 return success(); 53 } 54 55 LogicalResult amx::TileZeroOp::verify() { 56 return verifyTileSize(*this, getVectorType()); 57 } 58 59 LogicalResult amx::TileLoadOp::verify() { 60 unsigned rank = getMemRefType().getRank(); 61 if (indices().size() != rank) 62 return emitOpError("requires ") << rank << " indices"; 63 return verifyTileSize(*this, getVectorType()); 64 } 65 66 LogicalResult amx::TileStoreOp::verify() { 67 unsigned rank = getMemRefType().getRank(); 68 if (indices().size() != rank) 69 return emitOpError("requires ") << rank << " indices"; 70 return verifyTileSize(*this, getVectorType()); 71 } 72 73 LogicalResult amx::TileMulFOp::verify() { 74 VectorType aType = getLhsVectorType(); 75 VectorType bType = getRhsVectorType(); 76 VectorType cType = getVectorType(); 77 if (failed(verifyTileSize(*this, aType)) || 78 failed(verifyTileSize(*this, bType)) || 79 failed(verifyTileSize(*this, cType)) || 80 failed(verifyMultShape(*this, aType, bType, cType, 1))) 81 return failure(); 82 Type ta = aType.getElementType(); 83 Type tb = bType.getElementType(); 84 Type tc = cType.getElementType(); 85 if (!ta.isBF16() || !tb.isBF16() || !tc.isF32()) 86 return emitOpError("unsupported type combination"); 87 return success(); 88 } 89 90 LogicalResult amx::TileMulIOp::verify() { 91 VectorType aType = getLhsVectorType(); 92 VectorType bType = getRhsVectorType(); 93 VectorType cType = getVectorType(); 94 if (failed(verifyTileSize(*this, aType)) || 95 failed(verifyTileSize(*this, bType)) || 96 failed(verifyTileSize(*this, cType)) || 97 failed(verifyMultShape(*this, aType, bType, cType, 2))) 98 return failure(); 99 Type ta = aType.getElementType(); 100 Type tb = bType.getElementType(); 101 Type tc = cType.getElementType(); 102 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) 103 return emitOpError("unsupported type combination"); 104 return success(); 105 } 106 107 #define GET_OP_CLASSES 108 #include "mlir/Dialect/AMX/AMX.cpp.inc" 109