1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AMX dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 using namespace mlir; 20 21 void amx::AMXDialect::initialize() { 22 addOperations< 23 #define GET_OP_LIST 24 #include "mlir/Dialect/AMX/AMX.cpp.inc" 25 >(); 26 } 27 28 /// Verify that AMX supports the implied tile shape. 29 static LogicalResult verifyTileSize(Operation *op, VectorType tp) { 30 const unsigned kMaxRows = 16; 31 const unsigned kBitsPerRow = 64 * 8; 32 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); 33 if (tp.getDimSize(0) > kMaxRows) 34 return op->emitOpError("bad row height: ") << tp.getDimSize(0); 35 if (col > kBitsPerRow || col & 0x1f) 36 return op->emitOpError("bad column width: ") << (col >> 3); 37 return success(); 38 } 39 40 /// Verify that AMX supports the multiplication. 41 static LogicalResult verifyMultShape(Operation *op, VectorType atp, 42 VectorType btp, VectorType ctp, 43 unsigned scale) { 44 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; 45 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; 46 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); 47 if (cm != am || cn != bn || ak != bk) 48 return op->emitOpError("bad mult shape: ") 49 << cm << " x " << cn << " x " << ak; 50 return success(); 51 } 52 53 static LogicalResult verify(amx::TileZeroOp op) { 54 return verifyTileSize(op, op.getVectorType()); 55 } 56 57 static LogicalResult verify(amx::TileLoadOp op) { 58 unsigned rank = op.getMemRefType().getRank(); 59 if (llvm::size(op.indices()) != rank) 60 return op.emitOpError("requires ") << rank << " indices"; 61 return verifyTileSize(op, op.getVectorType()); 62 } 63 64 static LogicalResult verify(amx::TileStoreOp op) { 65 unsigned rank = op.getMemRefType().getRank(); 66 if (llvm::size(op.indices()) != rank) 67 return op.emitOpError("requires ") << rank << " indices"; 68 return verifyTileSize(op, op.getVectorType()); 69 } 70 71 static LogicalResult verify(amx::TileMulFOp op) { 72 VectorType aType = op.getLhsVectorType(); 73 VectorType bType = op.getRhsVectorType(); 74 VectorType cType = op.getVectorType(); 75 if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || 76 failed(verifyTileSize(op, cType)) || 77 failed(verifyMultShape(op, aType, bType, cType, 1))) 78 return failure(); 79 Type ta = aType.getElementType(); 80 Type tb = bType.getElementType(); 81 Type tc = cType.getElementType(); 82 if (!ta.isBF16() || !tb.isBF16() || !tc.isF32()) 83 return op.emitOpError("unsupported type combination"); 84 return success(); 85 } 86 87 static LogicalResult verify(amx::TileMulIOp op) { 88 VectorType aType = op.getLhsVectorType(); 89 VectorType bType = op.getRhsVectorType(); 90 VectorType cType = op.getVectorType(); 91 if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || 92 failed(verifyTileSize(op, cType)) || 93 failed(verifyMultShape(op, aType, bType, cType, 2))) 94 return failure(); 95 Type ta = aType.getElementType(); 96 Type tb = bType.getElementType(); 97 Type tc = cType.getElementType(); 98 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) 99 return op.emitOpError("unsupported type combination"); 100 return success(); 101 } 102 103 #define GET_OP_CLASSES 104 #include "mlir/Dialect/AMX/AMX.cpp.inc" 105