1d4bdeca5SXiang1 Zhang //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
2d4bdeca5SXiang1 Zhang //
3d4bdeca5SXiang1 Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d4bdeca5SXiang1 Zhang // See https://llvm.org/LICENSE.txt for license information.
5d4bdeca5SXiang1 Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d4bdeca5SXiang1 Zhang //
7d4bdeca5SXiang1 Zhang //===----------------------------------------------------------------------===//
8d4bdeca5SXiang1 Zhang //
9d4bdeca5SXiang1 Zhang /// Insert tilecfg for each area of key AMX intrinsic.
10d4bdeca5SXiang1 Zhang /// All the key AMX intrinsic's tile operand must come from tileload. And the
11d4bdeca5SXiang1 Zhang /// def tile of key AMX intrinsic must be tilestored.
12d4bdeca5SXiang1 Zhang /// take tdpbssd for example:
13d4bdeca5SXiang1 Zhang /// --------------------------------------------------------------------------
14d4bdeca5SXiang1 Zhang /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...)                key
15d4bdeca5SXiang1 Zhang /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...)                 |
16d4bdeca5SXiang1 Zhang /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...)                amx
17d4bdeca5SXiang1 Zhang /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3)         |
18d4bdeca5SXiang1 Zhang /// call void @llvm.x86.tilestored64.internal(... td)                     area
19d4bdeca5SXiang1 Zhang /// --------------------------------------------------------------------------
20d4bdeca5SXiang1 Zhang /// This pass will insert tilecfg before every key-amx-area, some like:
21d4bdeca5SXiang1 Zhang /// --------------------------------------------------------------------------
22d4bdeca5SXiang1 Zhang /// %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
23d4bdeca5SXiang1 Zhang /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
24d4bdeca5SXiang1 Zhang /// ...
25d4bdeca5SXiang1 Zhang /// ... pre-config shape of %t1                                 *
26d4bdeca5SXiang1 Zhang /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
27d4bdeca5SXiang1 Zhang /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
28d4bdeca5SXiang1 Zhang /// ...                                                         *
29d4bdeca5SXiang1 Zhang /// ... pre-config shape of %t2                                 * shapes
30d4bdeca5SXiang1 Zhang /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     *
31d4bdeca5SXiang1 Zhang /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
32d4bdeca5SXiang1 Zhang /// ...
33d4bdeca5SXiang1 Zhang /// call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * tile config
34d4bdeca5SXiang1 Zhang //
35d4bdeca5SXiang1 Zhang //===----------------------------------------------------------------------===//
36d4bdeca5SXiang1 Zhang //
37d4bdeca5SXiang1 Zhang #include "X86.h"
38d4bdeca5SXiang1 Zhang #include "llvm/ADT/SmallSet.h"
39d4bdeca5SXiang1 Zhang #include "llvm/Analysis/TargetTransformInfo.h"
40d4bdeca5SXiang1 Zhang #include "llvm/CodeGen/Passes.h"
41d4bdeca5SXiang1 Zhang #include "llvm/CodeGen/TargetPassConfig.h"
42d4bdeca5SXiang1 Zhang #include "llvm/CodeGen/ValueTypes.h"
43d4bdeca5SXiang1 Zhang #include "llvm/IR/DataLayout.h"
44d4bdeca5SXiang1 Zhang #include "llvm/IR/Function.h"
45d4bdeca5SXiang1 Zhang #include "llvm/IR/IRBuilder.h"
46d4bdeca5SXiang1 Zhang #include "llvm/IR/Instructions.h"
47d4bdeca5SXiang1 Zhang #include "llvm/IR/IntrinsicInst.h"
48d4bdeca5SXiang1 Zhang #include "llvm/IR/IntrinsicsX86.h"
49d4bdeca5SXiang1 Zhang #include "llvm/IR/PatternMatch.h"
50d4bdeca5SXiang1 Zhang #include "llvm/InitializePasses.h"
51d4bdeca5SXiang1 Zhang #include "llvm/Pass.h"
52d4bdeca5SXiang1 Zhang #include "llvm/Support/raw_ostream.h"
53d4bdeca5SXiang1 Zhang #include "llvm/Target/TargetMachine.h"
54d4bdeca5SXiang1 Zhang 
55d4bdeca5SXiang1 Zhang using namespace llvm;
56d4bdeca5SXiang1 Zhang using namespace PatternMatch;
57d4bdeca5SXiang1 Zhang 
58d4bdeca5SXiang1 Zhang #define DEBUG_TYPE "pre-amx-config"
59d4bdeca5SXiang1 Zhang 
isAMXIntrinsic(IntrinsicInst * II)60d4bdeca5SXiang1 Zhang static bool isAMXIntrinsic(IntrinsicInst *II) {
61d4bdeca5SXiang1 Zhang   for (Value *Operand : II->operands())
62d4bdeca5SXiang1 Zhang     if (Operand->getType()->isX86_AMXTy())
63d4bdeca5SXiang1 Zhang       return true;
64d4bdeca5SXiang1 Zhang   return II->getType()->isX86_AMXTy();
65d4bdeca5SXiang1 Zhang }
66d4bdeca5SXiang1 Zhang 
isTileLoad(IntrinsicInst * II)67d4bdeca5SXiang1 Zhang static bool isTileLoad(IntrinsicInst *II) {
6856d5c46bSBing1 Yu   return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
6956d5c46bSBing1 Yu          II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
70d4bdeca5SXiang1 Zhang }
71d4bdeca5SXiang1 Zhang 
isTileStore(IntrinsicInst * II)72d4bdeca5SXiang1 Zhang static bool isTileStore(IntrinsicInst *II) {
73d4bdeca5SXiang1 Zhang   return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
74d4bdeca5SXiang1 Zhang }
75d4bdeca5SXiang1 Zhang 
76d4bdeca5SXiang1 Zhang #ifndef NDEBUG
onlyTileDef(IntrinsicInst * II)77d4bdeca5SXiang1 Zhang static bool onlyTileDef(IntrinsicInst *II) {
78d4bdeca5SXiang1 Zhang   for (Value *Operand : II->operands())
79d4bdeca5SXiang1 Zhang     if (Operand->getType()->isX86_AMXTy())
80d4bdeca5SXiang1 Zhang       return false;
81d4bdeca5SXiang1 Zhang   return II->getType()->isX86_AMXTy();
82d4bdeca5SXiang1 Zhang }
83d4bdeca5SXiang1 Zhang 
brokenVolatile(Instruction * I)84d4bdeca5SXiang1 Zhang static bool brokenVolatile(Instruction *I) {
85d4bdeca5SXiang1 Zhang   // Todo: it is weak to identify a normal call here.
86d4bdeca5SXiang1 Zhang   if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
87d4bdeca5SXiang1 Zhang     return true;
88d4bdeca5SXiang1 Zhang   return false;
89d4bdeca5SXiang1 Zhang }
90d4bdeca5SXiang1 Zhang #endif
91d4bdeca5SXiang1 Zhang 
92d4bdeca5SXiang1 Zhang namespace {
93d4bdeca5SXiang1 Zhang class X86PreAMXConfig {
94fbb72530SNikita Popov   using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>;
95fbb72530SNikita Popov 
96d4bdeca5SXiang1 Zhang   Function &F;
97d4bdeca5SXiang1 Zhang 
98d4bdeca5SXiang1 Zhang public:
X86PreAMXConfig(Function & Func)99d4bdeca5SXiang1 Zhang   X86PreAMXConfig(Function &Func) : F(Func) {}
100d4bdeca5SXiang1 Zhang   bool preTileConfig();
101*10615110SNikita Popov   void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
102fbb72530SNikita Popov   bool findConfigShapes(PosAndShapesMap &PosAndShapes);
103d4bdeca5SXiang1 Zhang   bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
104*10615110SNikita Popov   void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
105d4bdeca5SXiang1 Zhang                        SmallVector<Value *, 8> &Shapes);
106d4bdeca5SXiang1 Zhang   BasicBlock::iterator
107d4bdeca5SXiang1 Zhang   getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
108d4bdeca5SXiang1 Zhang                            SmallVector<Value *, 8> &Shapes);
109d4bdeca5SXiang1 Zhang   bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
110d4bdeca5SXiang1 Zhang                           IntrinsicInst *KeyAMX);
111d4bdeca5SXiang1 Zhang };
112d4bdeca5SXiang1 Zhang 
113d4bdeca5SXiang1 Zhang // Orderly write the shapes in tilecfg's mem. This maybe not right.
114d4bdeca5SXiang1 Zhang // Because the first shape may not corresponding to the first tmm register,
115d4bdeca5SXiang1 Zhang // so we need to handle at at X86FastTileConfig::materializeTileCfg()
116d4bdeca5SXiang1 Zhang // after register allocation.
117d4bdeca5SXiang1 Zhang // For example:
118d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
119d4bdeca5SXiang1 Zhang // zeroinitialize tilecfg's mem (of ldtilecfg)
120d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
121d4bdeca5SXiang1 Zhang // ... pre-config shape of %t1                                 *
122d4bdeca5SXiang1 Zhang // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48   *
123d4bdeca5SXiang1 Zhang // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
124d4bdeca5SXiang1 Zhang // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
125d4bdeca5SXiang1 Zhang // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
126d4bdeca5SXiang1 Zhang // ...                                                         *
127d4bdeca5SXiang1 Zhang // ... pre-config shape of %t2                                 *
128d4bdeca5SXiang1 Zhang // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49   *
129d4bdeca5SXiang1 Zhang // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
130d4bdeca5SXiang1 Zhang // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
131d4bdeca5SXiang1 Zhang // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
132d4bdeca5SXiang1 Zhang // ...                                                         *
133d4bdeca5SXiang1 Zhang // ... pre-config shape of %t3                                 * of
134d4bdeca5SXiang1 Zhang // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50   *
135d4bdeca5SXiang1 Zhang // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
136d4bdeca5SXiang1 Zhang // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
137d4bdeca5SXiang1 Zhang // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
138d4bdeca5SXiang1 Zhang // ...                                                         * tiles
139d4bdeca5SXiang1 Zhang // ... pre-config shape of %td                                 *
140d4bdeca5SXiang1 Zhang // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51   *
141d4bdeca5SXiang1 Zhang // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
142d4bdeca5SXiang1 Zhang // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
143d4bdeca5SXiang1 Zhang // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
144d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
145d4bdeca5SXiang1 Zhang // call void @llvm.x86.ldtilecfg(i8* %mem)                     * tile config
146d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
147d4bdeca5SXiang1 Zhang // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
148d4bdeca5SXiang1 Zhang // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
149d4bdeca5SXiang1 Zhang // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
150d4bdeca5SXiang1 Zhang // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
151d4bdeca5SXiang1 Zhang // call void @llvm.x86.tilestored64.internal(... td)                     area
152d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
preWriteTileCfg(Value * I8Ptr,IRBuilderBase & Builder,SmallVector<Value *,8> & Shapes)153*10615110SNikita Popov void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
154d4bdeca5SXiang1 Zhang                                       SmallVector<Value *, 8> &Shapes) {
155*10615110SNikita Popov   LLVMContext &Ctx = Builder.getContext();
156d4bdeca5SXiang1 Zhang   Type *I8Ty = Type::getInt8Ty(Ctx);
157d4bdeca5SXiang1 Zhang   Type *I16Ty = Type::getInt16Ty(Ctx);
158d4bdeca5SXiang1 Zhang 
159d4bdeca5SXiang1 Zhang   // TODO: Currently we defaultly set Palette = 1, it may be assigned to
160d4bdeca5SXiang1 Zhang   // other value in the future.
161d4bdeca5SXiang1 Zhang   Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
162d4bdeca5SXiang1 Zhang   Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
163*10615110SNikita Popov   Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset);
164*10615110SNikita Popov   Builder.CreateStore(PaletteValue, PalettePos);
165d4bdeca5SXiang1 Zhang 
166d4bdeca5SXiang1 Zhang   for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
167d4bdeca5SXiang1 Zhang     Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
168d4bdeca5SXiang1 Zhang     Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
169d4bdeca5SXiang1 Zhang     const std::string ShapeName = "amx.tmm." + itostr(I);
170*10615110SNikita Popov     Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset,
171*10615110SNikita Popov                                       ShapeName + ".shape.row");
172*10615110SNikita Popov     Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset);
173*10615110SNikita Popov     ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0),
174*10615110SNikita Popov                                    ShapeName + ".shape.col");
175d4bdeca5SXiang1 Zhang     Value *Row = Shapes[I * 2];
176d4bdeca5SXiang1 Zhang     Value *Col = Shapes[I * 2 + 1];
177*10615110SNikita Popov     Row = Builder.CreateTrunc(Row, I8Ty);
178*10615110SNikita Popov     Builder.CreateStore(Row, RowPos);
179*10615110SNikita Popov     Builder.CreateStore(Col, ColPos);
180d4bdeca5SXiang1 Zhang   }
181d4bdeca5SXiang1 Zhang }
182d4bdeca5SXiang1 Zhang 
addTileConfig(Instruction * ModelStart,SmallVector<Value *,8> & Shapes)183*10615110SNikita Popov void X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
184d4bdeca5SXiang1 Zhang                                     SmallVector<Value *, 8> &Shapes) {
185d4bdeca5SXiang1 Zhang   Module *M = F.getParent();
186d4bdeca5SXiang1 Zhang   IRBuilder<> Builder(ModelStart);
187d4bdeca5SXiang1 Zhang   const DataLayout &DL = M->getDataLayout();
188d4bdeca5SXiang1 Zhang   unsigned AddrSpace = DL.getAllocaAddrSpace();
189d4bdeca5SXiang1 Zhang   LLVMContext &Ctx = Builder.getContext();
190d4bdeca5SXiang1 Zhang   Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
191d4bdeca5SXiang1 Zhang   Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
192d4bdeca5SXiang1 Zhang 
193d4bdeca5SXiang1 Zhang   AllocaInst *Addr =
194d4bdeca5SXiang1 Zhang       new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
195d4bdeca5SXiang1 Zhang   Addr->setAlignment(Alignment);
196d4bdeca5SXiang1 Zhang   Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
197d4bdeca5SXiang1 Zhang 
198*10615110SNikita Popov   Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment);
199d4bdeca5SXiang1 Zhang 
200*10615110SNikita Popov   preWriteTileCfg(I8Ptr, Builder, Shapes);
201d4bdeca5SXiang1 Zhang 
202*10615110SNikita Popov   Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, {I8Ptr});
203d4bdeca5SXiang1 Zhang }
204d4bdeca5SXiang1 Zhang 
205d4bdeca5SXiang1 Zhang // Todo: We may need to handle "more than one store" case in the future.
checkVolatileModel(SmallSet<Value *,4> & Loads,IntrinsicInst * Store,IntrinsicInst * KeyAMX)206d4bdeca5SXiang1 Zhang bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
207d4bdeca5SXiang1 Zhang                                          IntrinsicInst *Store,
208d4bdeca5SXiang1 Zhang                                          IntrinsicInst *KeyAMX) {
209d4bdeca5SXiang1 Zhang   Value *ST = Store->getOperand(4);
210d4bdeca5SXiang1 Zhang 
211d4bdeca5SXiang1 Zhang   // Only has tileload and tilestore.
212d4bdeca5SXiang1 Zhang   if (!KeyAMX)
213d4bdeca5SXiang1 Zhang     return (Loads.size() == 1) && Loads.contains(ST);
214d4bdeca5SXiang1 Zhang 
215d4bdeca5SXiang1 Zhang   // All Loads should be operands of KeyAMX.
216d4bdeca5SXiang1 Zhang   // All tile operands of KeyAMX should come from Loads.
217d4bdeca5SXiang1 Zhang   for (Value *Op : KeyAMX->operands()) {
218d4bdeca5SXiang1 Zhang     if (Op->getType()->isX86_AMXTy())
219d4bdeca5SXiang1 Zhang       if (!Loads.erase(Op))
220d4bdeca5SXiang1 Zhang         return false;
221d4bdeca5SXiang1 Zhang   }
222d4bdeca5SXiang1 Zhang 
223d4bdeca5SXiang1 Zhang   // The def of KeyAMX should be stored into mem.
224d4bdeca5SXiang1 Zhang   // Todo: is it key amx can be no def?
225d4bdeca5SXiang1 Zhang   return Loads.empty() && (ST == cast<Value>(KeyAMX));
226d4bdeca5SXiang1 Zhang }
227d4bdeca5SXiang1 Zhang 
getKeyAMXShapes(IntrinsicInst * KeyAMX,SmallVector<Value *,8> & Shapes)228d4bdeca5SXiang1 Zhang bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
229d4bdeca5SXiang1 Zhang                                       SmallVector<Value *, 8> &Shapes) {
230d4bdeca5SXiang1 Zhang   for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
231d4bdeca5SXiang1 Zhang     Value *Op = KeyAMX->getOperand(I);
232d4bdeca5SXiang1 Zhang     if (!Op->getType()->isX86_AMXTy())
233d4bdeca5SXiang1 Zhang       continue;
234d4bdeca5SXiang1 Zhang     IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
235d4bdeca5SXiang1 Zhang     assert((TileDef && isTileLoad(TileDef)) &&
236d4bdeca5SXiang1 Zhang            "All KeyAMX's tile definiation should comes from TileLoad!");
237d4bdeca5SXiang1 Zhang     Shapes.push_back(TileDef->getOperand(0));
238d4bdeca5SXiang1 Zhang     Shapes.push_back(TileDef->getOperand(1));
239d4bdeca5SXiang1 Zhang   }
240d4bdeca5SXiang1 Zhang   if (!isTileStore(KeyAMX)) {
241d4bdeca5SXiang1 Zhang     Shapes.push_back(KeyAMX->getOperand(0));
242d4bdeca5SXiang1 Zhang     Shapes.push_back(KeyAMX->getOperand(1));
243d4bdeca5SXiang1 Zhang   }
244d4bdeca5SXiang1 Zhang   return Shapes.size() != 0;
245d4bdeca5SXiang1 Zhang }
246d4bdeca5SXiang1 Zhang 
247d4bdeca5SXiang1 Zhang // Collect the shapes and skip the area of current key amx intrinsic.
248d4bdeca5SXiang1 Zhang //
249d4bdeca5SXiang1 Zhang // For example:
250d4bdeca5SXiang1 Zhang // ...
251d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
252d4bdeca5SXiang1 Zhang // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)  record (m,k)
253d4bdeca5SXiang1 Zhang // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)  record (m,k)
254d4bdeca5SXiang1 Zhang // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)  record (m,k)
255d4bdeca5SXiang1 Zhang // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
256d4bdeca5SXiang1 Zhang // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
257d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
258d4bdeca5SXiang1 Zhang BasicBlock::iterator
getShapesAndConfigPosEnd(BasicBlock::iterator Iter,SmallVector<Value *,8> & Shapes)259d4bdeca5SXiang1 Zhang X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
260d4bdeca5SXiang1 Zhang                                           SmallVector<Value *, 8> &Shapes) {
261d4bdeca5SXiang1 Zhang   IntrinsicInst *KeyAMX = nullptr;
262d4bdeca5SXiang1 Zhang   BasicBlock *BB = Iter->getParent();
263d4bdeca5SXiang1 Zhang   BasicBlock::iterator PosEnd = BB->end();
264d4bdeca5SXiang1 Zhang   SmallSet<Value *, 4> Loads;
265d4bdeca5SXiang1 Zhang 
266d4bdeca5SXiang1 Zhang   // See TileStore as "Config Position End" and check volatile model.
267d4bdeca5SXiang1 Zhang   for (auto I = Iter, E = BB->end(); I != E; ++I) {
268d4bdeca5SXiang1 Zhang     assert(!brokenVolatile(&*I) && "Not reach tile store!");
269d4bdeca5SXiang1 Zhang     IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
270d4bdeca5SXiang1 Zhang     if (!II || !isAMXIntrinsic(II))
271d4bdeca5SXiang1 Zhang       continue;
272d4bdeca5SXiang1 Zhang 
273d4bdeca5SXiang1 Zhang     if (isTileLoad(II)) {
274d4bdeca5SXiang1 Zhang       Loads.insert(II);
275d4bdeca5SXiang1 Zhang     } else if (isTileStore(II)) {
276d4bdeca5SXiang1 Zhang       if (!checkVolatileModel(Loads, II, KeyAMX))
277d4bdeca5SXiang1 Zhang         report_fatal_error("Not Volatile AMX Model!");
278d4bdeca5SXiang1 Zhang       PosEnd = I;
279d4bdeca5SXiang1 Zhang       break;
280d4bdeca5SXiang1 Zhang     } else {
281d4bdeca5SXiang1 Zhang       assert(!KeyAMX && "Too many key amx intrinsic!");
282d4bdeca5SXiang1 Zhang       KeyAMX = II;
283d4bdeca5SXiang1 Zhang     }
284d4bdeca5SXiang1 Zhang   }
285d4bdeca5SXiang1 Zhang   assert(PosEnd != BB->end() && "Not find TileStore!");
286d4bdeca5SXiang1 Zhang 
287d4bdeca5SXiang1 Zhang   // See KeyAMX as TileStore if only TileLoad and TileStore.
288d4bdeca5SXiang1 Zhang   if (!KeyAMX)
289d4bdeca5SXiang1 Zhang     KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
290d4bdeca5SXiang1 Zhang 
291d4bdeca5SXiang1 Zhang   // Get Shapes in order.
292d4bdeca5SXiang1 Zhang   assert(Shapes.empty() && "Shapes should be clean.");
293d4bdeca5SXiang1 Zhang   getKeyAMXShapes(KeyAMX, Shapes);
294d4bdeca5SXiang1 Zhang 
295d4bdeca5SXiang1 Zhang   return PosEnd;
296d4bdeca5SXiang1 Zhang }
297d4bdeca5SXiang1 Zhang 
298d4bdeca5SXiang1 Zhang // Record a key amx area's shapes with its position.
299d4bdeca5SXiang1 Zhang // Use the first tileload as its position.
300d4bdeca5SXiang1 Zhang // For example:
301d4bdeca5SXiang1 Zhang // ...
302d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
303d4bdeca5SXiang1 Zhang // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)   <--  pos
304d4bdeca5SXiang1 Zhang // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)        /
305d4bdeca5SXiang1 Zhang // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)     shapes:
306d4bdeca5SXiang1 Zhang // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)    (m,k)(k,n)
307d4bdeca5SXiang1 Zhang // call void @llvm.x86.tilestored64.internal(m, n,... td)          (m,n)(m,n)
308d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
findConfigShapes(PosAndShapesMap & PosAndShapes)309fbb72530SNikita Popov bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) {
310d4bdeca5SXiang1 Zhang   bool Find = false;
311d4bdeca5SXiang1 Zhang   for (BasicBlock &BB : F) {
312d4bdeca5SXiang1 Zhang     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
313d4bdeca5SXiang1 Zhang       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
314d4bdeca5SXiang1 Zhang       if (!II)
315d4bdeca5SXiang1 Zhang         continue;
316d4bdeca5SXiang1 Zhang       if (!isAMXIntrinsic(II))
317d4bdeca5SXiang1 Zhang         continue;
318d4bdeca5SXiang1 Zhang       assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
319d4bdeca5SXiang1 Zhang 
320d4bdeca5SXiang1 Zhang       I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
321d4bdeca5SXiang1 Zhang       Find = true;
322d4bdeca5SXiang1 Zhang     }
323d4bdeca5SXiang1 Zhang   }
324d4bdeca5SXiang1 Zhang   return Find;
325d4bdeca5SXiang1 Zhang }
326d4bdeca5SXiang1 Zhang 
327d4bdeca5SXiang1 Zhang // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
328d4bdeca5SXiang1 Zhang // e.g. (key amx = tdpbssd)
329d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
330d4bdeca5SXiang1 Zhang // %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
331d4bdeca5SXiang1 Zhang // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
332d4bdeca5SXiang1 Zhang // ...
333d4bdeca5SXiang1 Zhang // ... pre-config shape of %t1                                 *
334d4bdeca5SXiang1 Zhang // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
335d4bdeca5SXiang1 Zhang // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
336d4bdeca5SXiang1 Zhang // ...                                                         *
337d4bdeca5SXiang1 Zhang // ... pre-config shape of %t2                                 *
338d4bdeca5SXiang1 Zhang // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
339d4bdeca5SXiang1 Zhang // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
340d4bdeca5SXiang1 Zhang // ...                                                         *
341d4bdeca5SXiang1 Zhang // ... pre-config shape of %t3                                 * of
342d4bdeca5SXiang1 Zhang // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
343d4bdeca5SXiang1 Zhang // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
344d4bdeca5SXiang1 Zhang // ...                                                         * tiles
345d4bdeca5SXiang1 Zhang // ... pre-config shape of %td                                 *
346d4bdeca5SXiang1 Zhang // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
347d4bdeca5SXiang1 Zhang // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
348d4bdeca5SXiang1 Zhang //
349d4bdeca5SXiang1 Zhang // call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * pre-config
350d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
351d4bdeca5SXiang1 Zhang // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
352d4bdeca5SXiang1 Zhang // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
353d4bdeca5SXiang1 Zhang // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
354d4bdeca5SXiang1 Zhang // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
355d4bdeca5SXiang1 Zhang // call void @llvm.x86.tilestored64.internal(... td)                     area
356d4bdeca5SXiang1 Zhang // --------------------------------------------------------------------------
preTileConfig()357d4bdeca5SXiang1 Zhang bool X86PreAMXConfig::preTileConfig() {
358fbb72530SNikita Popov   PosAndShapesMap PosAndShapes;
359d4bdeca5SXiang1 Zhang   bool NeedCfg = findConfigShapes(PosAndShapes);
360d4bdeca5SXiang1 Zhang   if (!NeedCfg)
361d4bdeca5SXiang1 Zhang     return false;
362d4bdeca5SXiang1 Zhang   for (auto &IPAndShapes : PosAndShapes)
363d4bdeca5SXiang1 Zhang     addTileConfig(IPAndShapes.first, IPAndShapes.second);
364d4bdeca5SXiang1 Zhang 
365d4bdeca5SXiang1 Zhang   return true;
366d4bdeca5SXiang1 Zhang }
367d4bdeca5SXiang1 Zhang } // anonymous namespace
368d4bdeca5SXiang1 Zhang 
369d4bdeca5SXiang1 Zhang namespace {
370d4bdeca5SXiang1 Zhang 
371d4bdeca5SXiang1 Zhang class X86PreAMXConfigPass : public FunctionPass {
372d4bdeca5SXiang1 Zhang public:
373d4bdeca5SXiang1 Zhang   static char ID;
374d4bdeca5SXiang1 Zhang 
X86PreAMXConfigPass()375d4bdeca5SXiang1 Zhang   X86PreAMXConfigPass() : FunctionPass(ID) {
376d4bdeca5SXiang1 Zhang     initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
377d4bdeca5SXiang1 Zhang   }
378d4bdeca5SXiang1 Zhang 
runOnFunction(Function & F)379d4bdeca5SXiang1 Zhang   bool runOnFunction(Function &F) override {
380d4bdeca5SXiang1 Zhang     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
381d4bdeca5SXiang1 Zhang     bool C = false;
382d4bdeca5SXiang1 Zhang 
383d4bdeca5SXiang1 Zhang     // Prepare for fast register allocation at O0.
384d4bdeca5SXiang1 Zhang     if (TM->getOptLevel() == CodeGenOpt::None) {
385d4bdeca5SXiang1 Zhang 
386d4bdeca5SXiang1 Zhang       // We pre-config each key AMX intrinsic at O0.
387d4bdeca5SXiang1 Zhang       // In theory, one tile config can cover several AMX intrinsics, but
388d4bdeca5SXiang1 Zhang       // it is very diffcult to classify the tile shapes at O0. So here we
389d4bdeca5SXiang1 Zhang       // let thing be easy, pre-config every key AMX intrinsic.
390d4bdeca5SXiang1 Zhang       X86PreAMXConfig PCFG(F);
391d4bdeca5SXiang1 Zhang       C = PCFG.preTileConfig();
392d4bdeca5SXiang1 Zhang     }
393d4bdeca5SXiang1 Zhang 
394d4bdeca5SXiang1 Zhang     return C;
395d4bdeca5SXiang1 Zhang   }
396d4bdeca5SXiang1 Zhang 
getAnalysisUsage(AnalysisUsage & AU) const397d4bdeca5SXiang1 Zhang   void getAnalysisUsage(AnalysisUsage &AU) const override {
398d4bdeca5SXiang1 Zhang     AU.setPreservesCFG();
399d4bdeca5SXiang1 Zhang     AU.addRequired<TargetPassConfig>();
400d4bdeca5SXiang1 Zhang   }
401d4bdeca5SXiang1 Zhang };
402d4bdeca5SXiang1 Zhang 
403d4bdeca5SXiang1 Zhang } // anonymous namespace
404d4bdeca5SXiang1 Zhang 
405d4bdeca5SXiang1 Zhang static const char PassName[] = "Pre AMX Tile Config";
406d4bdeca5SXiang1 Zhang char X86PreAMXConfigPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass,DEBUG_TYPE,PassName,false,false)407d4bdeca5SXiang1 Zhang INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
408d4bdeca5SXiang1 Zhang INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
409d4bdeca5SXiang1 Zhang INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
410d4bdeca5SXiang1 Zhang 
411d4bdeca5SXiang1 Zhang FunctionPass *llvm::createX86PreAMXConfigPass() {
412d4bdeca5SXiang1 Zhang   return new X86PreAMXConfigPass();
413d4bdeca5SXiang1 Zhang }
414