1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/AMX/AMXDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 using namespace mlir;
20 
21 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
22 
initialize()23 void amx::AMXDialect::initialize() {
24   addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/AMX/AMX.cpp.inc"
27       >();
28 }
29 
30 /// Verify that AMX supports the implied tile shape.
verifyTileSize(Operation * op,VectorType tp)31 static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
32   const unsigned kMaxRows = 16;
33   const unsigned kBitsPerRow = 64 * 8;
34   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
35   if (tp.getDimSize(0) > kMaxRows)
36     return op->emitOpError("bad row height: ") << tp.getDimSize(0);
37   if (col > kBitsPerRow || col & 0x1f)
38     return op->emitOpError("bad column width: ") << (col >> 3);
39   return success();
40 }
41 
42 /// Verify that AMX supports the multiplication.
verifyMultShape(Operation * op,VectorType atp,VectorType btp,VectorType ctp,unsigned scale)43 static LogicalResult verifyMultShape(Operation *op, VectorType atp,
44                                      VectorType btp, VectorType ctp,
45                                      unsigned scale) {
46   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
47   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
48   unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
49   if (cm != am || cn != bn || ak != bk)
50     return op->emitOpError("bad mult shape: ")
51            << cm << " x " << cn << " x " << ak;
52   return success();
53 }
54 
verify()55 LogicalResult amx::TileZeroOp::verify() {
56   return verifyTileSize(*this, getVectorType());
57 }
58 
verify()59 LogicalResult amx::TileLoadOp::verify() {
60   unsigned rank = getMemRefType().getRank();
61   if (getIndices().size() != rank)
62     return emitOpError("requires ") << rank << " indices";
63   return verifyTileSize(*this, getVectorType());
64 }
65 
verify()66 LogicalResult amx::TileStoreOp::verify() {
67   unsigned rank = getMemRefType().getRank();
68   if (getIndices().size() != rank)
69     return emitOpError("requires ") << rank << " indices";
70   return verifyTileSize(*this, getVectorType());
71 }
72 
verify()73 LogicalResult amx::TileMulFOp::verify() {
74   VectorType aType = getLhsVectorType();
75   VectorType bType = getRhsVectorType();
76   VectorType cType = getVectorType();
77   if (failed(verifyTileSize(*this, aType)) ||
78       failed(verifyTileSize(*this, bType)) ||
79       failed(verifyTileSize(*this, cType)) ||
80       failed(verifyMultShape(*this, aType, bType, cType, 1)))
81     return failure();
82   Type ta = aType.getElementType();
83   Type tb = bType.getElementType();
84   Type tc = cType.getElementType();
85   if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
86     return emitOpError("unsupported type combination");
87   return success();
88 }
89 
verify()90 LogicalResult amx::TileMulIOp::verify() {
91   VectorType aType = getLhsVectorType();
92   VectorType bType = getRhsVectorType();
93   VectorType cType = getVectorType();
94   if (failed(verifyTileSize(*this, aType)) ||
95       failed(verifyTileSize(*this, bType)) ||
96       failed(verifyTileSize(*this, cType)) ||
97       failed(verifyMultShape(*this, aType, bType, cType, 2)))
98     return failure();
99   Type ta = aType.getElementType();
100   Type tb = bType.getElementType();
101   Type tc = cType.getElementType();
102   if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
103     return emitOpError("unsupported type combination");
104   return success();
105 }
106 
107 #define GET_OP_CLASSES
108 #include "mlir/Dialect/AMX/AMX.cpp.inc"
109