1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/AMX/AMXDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 using namespace mlir;
20 
21 void amx::AMXDialect::initialize() {
22   addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/AMX/AMX.cpp.inc"
25       >();
26 }
27 
28 /// Verify that AMX supports the implied tile shape.
29 static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
30   const unsigned kMaxRows = 16;
31   const unsigned kBitsPerRow = 64 * 8;
32   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
33   if (tp.getDimSize(0) > kMaxRows)
34     return op->emitOpError("bad row height: ") << tp.getDimSize(0);
35   if (col > kBitsPerRow || col & 0x1f)
36     return op->emitOpError("bad column width: ") << (col >> 3);
37   return success();
38 }
39 
40 /// Verify that AMX supports the multiplication.
41 static LogicalResult verifyMultShape(Operation *op, VectorType atp,
42                                      VectorType btp, VectorType ctp,
43                                      unsigned scale) {
44   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
45   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
46   unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
47   if (cm != am || cn != bn || ak != bk)
48     return op->emitOpError("bad mult shape: ")
49            << cm << " x " << cn << " x " << ak;
50   return success();
51 }
52 
53 static LogicalResult verify(amx::TileZeroOp op) {
54   return verifyTileSize(op, op.getVectorType());
55 }
56 
57 static LogicalResult verify(amx::TileLoadOp op) {
58   unsigned rank = op.getMemRefType().getRank();
59   if (llvm::size(op.indices()) != rank)
60     return op.emitOpError("requires ") << rank << " indices";
61   return verifyTileSize(op, op.getVectorType());
62 }
63 
64 static LogicalResult verify(amx::TileStoreOp op) {
65   unsigned rank = op.getMemRefType().getRank();
66   if (llvm::size(op.indices()) != rank)
67     return op.emitOpError("requires ") << rank << " indices";
68   return verifyTileSize(op, op.getVectorType());
69 }
70 
71 static LogicalResult verify(amx::TileMulFOp op) {
72   VectorType aType = op.getLhsVectorType();
73   VectorType bType = op.getRhsVectorType();
74   VectorType cType = op.getVectorType();
75   if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
76       failed(verifyTileSize(op, cType)) ||
77       failed(verifyMultShape(op, aType, bType, cType, 1)))
78     return failure();
79   Type ta = aType.getElementType();
80   Type tb = bType.getElementType();
81   Type tc = cType.getElementType();
82   if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
83     return op.emitOpError("unsupported type combination");
84   return success();
85 }
86 
87 static LogicalResult verify(amx::TileMulIOp op) {
88   VectorType aType = op.getLhsVectorType();
89   VectorType bType = op.getRhsVectorType();
90   VectorType cType = op.getVectorType();
91   if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
92       failed(verifyTileSize(op, cType)) ||
93       failed(verifyMultShape(op, aType, bType, cType, 2)))
94     return failure();
95   Type ta = aType.getElementType();
96   Type tb = bType.getElementType();
97   Type tc = cType.getElementType();
98   if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
99     return op.emitOpError("unsupported type combination");
100   return success();
101 }
102 
103 #define GET_OP_CLASSES
104 #include "mlir/Dialect/AMX/AMX.cpp.inc"
105