1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AMX dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 using namespace mlir; 20 21 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" 22 23 void amx::AMXDialect::initialize() { 24 addOperations< 25 #define GET_OP_LIST 26 #include "mlir/Dialect/AMX/AMX.cpp.inc" 27 >(); 28 } 29 30 /// Verify that AMX supports the implied tile shape. 31 static LogicalResult verifyTileSize(Operation *op, VectorType tp) { 32 const unsigned kMaxRows = 16; 33 const unsigned kBitsPerRow = 64 * 8; 34 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); 35 if (tp.getDimSize(0) > kMaxRows) 36 return op->emitOpError("bad row height: ") << tp.getDimSize(0); 37 if (col > kBitsPerRow || col & 0x1f) 38 return op->emitOpError("bad column width: ") << (col >> 3); 39 return success(); 40 } 41 42 /// Verify that AMX supports the multiplication. 43 static LogicalResult verifyMultShape(Operation *op, VectorType atp, 44 VectorType btp, VectorType ctp, 45 unsigned scale) { 46 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; 47 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; 48 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); 49 if (cm != am || cn != bn || ak != bk) 50 return op->emitOpError("bad mult shape: ") 51 << cm << " x " << cn << " x " << ak; 52 return success(); 53 } 54 55 static LogicalResult verify(amx::TileZeroOp op) { 56 return verifyTileSize(op, op.getVectorType()); 57 } 58 59 static LogicalResult verify(amx::TileLoadOp op) { 60 unsigned rank = op.getMemRefType().getRank(); 61 if (llvm::size(op.indices()) != rank) 62 return op.emitOpError("requires ") << rank << " indices"; 63 return verifyTileSize(op, op.getVectorType()); 64 } 65 66 static LogicalResult verify(amx::TileStoreOp op) { 67 unsigned rank = op.getMemRefType().getRank(); 68 if (llvm::size(op.indices()) != rank) 69 return op.emitOpError("requires ") << rank << " indices"; 70 return verifyTileSize(op, op.getVectorType()); 71 } 72 73 static LogicalResult verify(amx::TileMulFOp op) { 74 VectorType aType = op.getLhsVectorType(); 75 VectorType bType = op.getRhsVectorType(); 76 VectorType cType = op.getVectorType(); 77 if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || 78 failed(verifyTileSize(op, cType)) || 79 failed(verifyMultShape(op, aType, bType, cType, 1))) 80 return failure(); 81 Type ta = aType.getElementType(); 82 Type tb = bType.getElementType(); 83 Type tc = cType.getElementType(); 84 if (!ta.isBF16() || !tb.isBF16() || !tc.isF32()) 85 return op.emitOpError("unsupported type combination"); 86 return success(); 87 } 88 89 static LogicalResult verify(amx::TileMulIOp op) { 90 VectorType aType = op.getLhsVectorType(); 91 VectorType bType = op.getRhsVectorType(); 92 VectorType cType = op.getVectorType(); 93 if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || 94 failed(verifyTileSize(op, cType)) || 95 failed(verifyMultShape(op, aType, bType, cType, 2))) 96 return failure(); 97 Type ta = aType.getElementType(); 98 Type tb = bType.getElementType(); 99 Type tc = cType.getElementType(); 100 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) 101 return op.emitOpError("unsupported type combination"); 102 return success(); 103 } 104 105 #define GET_OP_CLASSES 106 #include "mlir/Dialect/AMX/AMX.cpp.inc" 107