1 //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// Insert tilecfg for each area of key AMX intrinsic. 10 /// All the key AMX intrinsic's tile operand must come from tileload. And the 11 /// def tile of key AMX intrinsic must be tilestored. 12 /// take tdpbssd for example: 13 /// -------------------------------------------------------------------------- 14 /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key 15 /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) | 16 /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx 17 /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) | 18 /// call void @llvm.x86.tilestored64.internal(... td) area 19 /// -------------------------------------------------------------------------- 20 /// This pass will insert tilecfg before every key-amx-area, some like: 21 /// -------------------------------------------------------------------------- 22 /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem 23 /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init 24 /// ... 25 /// ... pre-config shape of %t1 * 26 /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 27 /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 28 /// ... * 29 /// ... pre-config shape of %t2 * shapes 30 /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * 31 /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 32 /// ... 33 /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config 34 // 35 //===----------------------------------------------------------------------===// 36 // 37 #include "X86.h" 38 #include "llvm/ADT/SmallSet.h" 39 #include "llvm/Analysis/TargetTransformInfo.h" 40 #include "llvm/CodeGen/Passes.h" 41 #include "llvm/CodeGen/TargetPassConfig.h" 42 #include "llvm/CodeGen/ValueTypes.h" 43 #include "llvm/IR/DataLayout.h" 44 #include "llvm/IR/Function.h" 45 #include "llvm/IR/IRBuilder.h" 46 #include "llvm/IR/Instructions.h" 47 #include "llvm/IR/IntrinsicInst.h" 48 #include "llvm/IR/IntrinsicsX86.h" 49 #include "llvm/IR/PatternMatch.h" 50 #include "llvm/InitializePasses.h" 51 #include "llvm/Pass.h" 52 #include "llvm/Support/raw_ostream.h" 53 #include "llvm/Target/TargetMachine.h" 54 55 using namespace llvm; 56 using namespace PatternMatch; 57 58 #define DEBUG_TYPE "pre-amx-config" 59 60 static bool isAMXIntrinsic(IntrinsicInst *II) { 61 for (Value *Operand : II->operands()) 62 if (Operand->getType()->isX86_AMXTy()) 63 return true; 64 return II->getType()->isX86_AMXTy(); 65 } 66 67 static bool isTileLoad(IntrinsicInst *II) { 68 return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal || 69 II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal; 70 } 71 72 static bool isTileStore(IntrinsicInst *II) { 73 return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal; 74 } 75 76 #ifndef NDEBUG 77 static bool onlyTileDef(IntrinsicInst *II) { 78 for (Value *Operand : II->operands()) 79 if (Operand->getType()->isX86_AMXTy()) 80 return false; 81 return II->getType()->isX86_AMXTy(); 82 } 83 84 static bool brokenVolatile(Instruction *I) { 85 // Todo: it is weak to identify a normal call here. 86 if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator()) 87 return true; 88 return false; 89 } 90 #endif 91 92 namespace { 93 class X86PreAMXConfig { 94 Function &F; 95 96 public: 97 X86PreAMXConfig(Function &Func) : F(Func) {} 98 bool preTileConfig(); 99 bool addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes); 100 bool findConfigShapes( 101 DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes); 102 bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes); 103 bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos, 104 SmallVector<Value *, 8> &Shapes); 105 BasicBlock::iterator 106 getShapesAndConfigPosEnd(BasicBlock::iterator Iter, 107 SmallVector<Value *, 8> &Shapes); 108 bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store, 109 IntrinsicInst *KeyAMX); 110 }; 111 112 // Orderly write the shapes in tilecfg's mem. This maybe not right. 113 // Because the first shape may not corresponding to the first tmm register, 114 // so we need to handle at at X86FastTileConfig::materializeTileCfg() 115 // after register allocation. 116 // For example: 117 // -------------------------------------------------------------------------- 118 // zeroinitialize tilecfg's mem (of ldtilecfg) 119 // -------------------------------------------------------------------------- 120 // ... pre-config shape of %t1 * 121 // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 * 122 // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 * 123 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 124 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 125 // ... * 126 // ... pre-config shape of %t2 * 127 // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 * 128 // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 * 129 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes 130 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 131 // ... * 132 // ... pre-config shape of %t3 * of 133 // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 * 134 // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 * 135 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * 136 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * 137 // ... * tiles 138 // ... pre-config shape of %td * 139 // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 * 140 // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 * 141 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * 142 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * 143 // -------------------------------------------------------------------------- 144 // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config 145 // -------------------------------------------------------------------------- 146 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 147 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 148 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 149 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 150 // call void @llvm.x86.tilestored64.internal(... td) area 151 // -------------------------------------------------------------------------- 152 bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos, 153 SmallVector<Value *, 8> &Shapes) { 154 bool Write = false; 155 LLVMContext &Ctx = Pos->getParent()->getContext(); 156 Type *I8Ty = Type::getInt8Ty(Ctx); 157 Type *I16Ty = Type::getInt16Ty(Ctx); 158 159 // TODO: Currently we defaultly set Palette = 1, it may be assigned to 160 // other value in the future. 161 Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0); 162 Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); 163 Value *PalettePos = 164 GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos); 165 new StoreInst(PaletteValue, PalettePos, Pos); 166 167 for (int I = 0, E = Shapes.size() / 2; I < E; I++) { 168 Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I); 169 Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2); 170 const std::string ShapeName = "amx.tmm." + itostr(I); 171 Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset, 172 ShapeName + ".shape.row", Pos); 173 Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos); 174 ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0), 175 ShapeName + ".shape.col", Pos); 176 Value *Row = Shapes[I * 2]; 177 Value *Col = Shapes[I * 2 + 1]; 178 Row = new TruncInst(Row, I8Ty, "", Pos); 179 new StoreInst(Row, RowPos, Pos); 180 new StoreInst(Col, ColPos, Pos); 181 Write = true; 182 } 183 return Write; 184 } 185 186 bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart, 187 SmallVector<Value *, 8> &Shapes) { 188 Module *M = F.getParent(); 189 IRBuilder<> Builder(ModelStart); 190 const DataLayout &DL = M->getDataLayout(); 191 unsigned AddrSpace = DL.getAllocaAddrSpace(); 192 LLVMContext &Ctx = Builder.getContext(); 193 Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false); 194 Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx)); 195 196 AllocaInst *Addr = 197 new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front()); 198 Addr->setAlignment(Alignment); 199 Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy()); 200 201 std::array<Value *, 1> Args = {I8Ptr}; 202 Instruction *Cfg = 203 Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args); 204 205 Value *Val0 = Constant::getNullValue(V512Ty); 206 Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg); 207 assert(Init0 && "Not Zero initilizate the cfg mem!"); 208 209 preWriteTileCfg(I8Ptr, Cfg, Shapes); 210 211 return Init0; 212 } 213 214 // Todo: We may need to handle "more than one store" case in the future. 215 bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads, 216 IntrinsicInst *Store, 217 IntrinsicInst *KeyAMX) { 218 Value *ST = Store->getOperand(4); 219 220 // Only has tileload and tilestore. 221 if (!KeyAMX) 222 return (Loads.size() == 1) && Loads.contains(ST); 223 224 // All Loads should be operands of KeyAMX. 225 // All tile operands of KeyAMX should come from Loads. 226 for (Value *Op : KeyAMX->operands()) { 227 if (Op->getType()->isX86_AMXTy()) 228 if (!Loads.erase(Op)) 229 return false; 230 } 231 232 // The def of KeyAMX should be stored into mem. 233 // Todo: is it key amx can be no def? 234 return Loads.empty() && (ST == cast<Value>(KeyAMX)); 235 } 236 237 bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX, 238 SmallVector<Value *, 8> &Shapes) { 239 for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) { 240 Value *Op = KeyAMX->getOperand(I); 241 if (!Op->getType()->isX86_AMXTy()) 242 continue; 243 IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op); 244 assert((TileDef && isTileLoad(TileDef)) && 245 "All KeyAMX's tile definiation should comes from TileLoad!"); 246 Shapes.push_back(TileDef->getOperand(0)); 247 Shapes.push_back(TileDef->getOperand(1)); 248 } 249 if (!isTileStore(KeyAMX)) { 250 Shapes.push_back(KeyAMX->getOperand(0)); 251 Shapes.push_back(KeyAMX->getOperand(1)); 252 } 253 return Shapes.size() != 0; 254 } 255 256 // Collect the shapes and skip the area of current key amx intrinsic. 257 // 258 // For example: 259 // ... 260 // -------------------------------------------------------------------------- 261 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k) 262 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k) 263 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k) 264 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) 265 // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k) 266 // -------------------------------------------------------------------------- 267 BasicBlock::iterator 268 X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter, 269 SmallVector<Value *, 8> &Shapes) { 270 IntrinsicInst *KeyAMX = nullptr; 271 BasicBlock *BB = Iter->getParent(); 272 BasicBlock::iterator PosEnd = BB->end(); 273 SmallSet<Value *, 4> Loads; 274 275 // See TileStore as "Config Position End" and check volatile model. 276 for (auto I = Iter, E = BB->end(); I != E; ++I) { 277 assert(!brokenVolatile(&*I) && "Not reach tile store!"); 278 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); 279 if (!II || !isAMXIntrinsic(II)) 280 continue; 281 282 if (isTileLoad(II)) { 283 Loads.insert(II); 284 } else if (isTileStore(II)) { 285 if (!checkVolatileModel(Loads, II, KeyAMX)) 286 report_fatal_error("Not Volatile AMX Model!"); 287 PosEnd = I; 288 break; 289 } else { 290 assert(!KeyAMX && "Too many key amx intrinsic!"); 291 KeyAMX = II; 292 } 293 } 294 assert(PosEnd != BB->end() && "Not find TileStore!"); 295 296 // See KeyAMX as TileStore if only TileLoad and TileStore. 297 if (!KeyAMX) 298 KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd); 299 300 // Get Shapes in order. 301 assert(Shapes.empty() && "Shapes should be clean."); 302 getKeyAMXShapes(KeyAMX, Shapes); 303 304 return PosEnd; 305 } 306 307 // Record a key amx area's shapes with its position. 308 // Use the first tileload as its position. 309 // For example: 310 // ... 311 // -------------------------------------------------------------------------- 312 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos 313 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) / 314 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes: 315 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n) 316 // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n) 317 // -------------------------------------------------------------------------- 318 bool X86PreAMXConfig::findConfigShapes( 319 DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes) { 320 bool Find = false; 321 for (BasicBlock &BB : F) { 322 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) { 323 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); 324 if (!II) 325 continue; 326 if (!isAMXIntrinsic(II)) 327 continue; 328 assert(onlyTileDef(II) && "Not volatile model for AMX at O0!"); 329 330 I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]); 331 Find = true; 332 } 333 } 334 return Find; 335 } 336 337 // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic. 338 // e.g. (key amx = tdpbssd) 339 // -------------------------------------------------------------------------- 340 // %cfgmem = alloca <16 x i32>, align 4 * allocate mem 341 // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init 342 // ... 343 // ... pre-config shape of %t1 * 344 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 345 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 346 // ... * 347 // ... pre-config shape of %t2 * 348 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes 349 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 350 // ... * 351 // ... pre-config shape of %t3 * of 352 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * 353 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * 354 // ... * tiles 355 // ... pre-config shape of %td * 356 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * 357 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * 358 // 359 // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config 360 // -------------------------------------------------------------------------- 361 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 362 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 363 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 364 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 365 // call void @llvm.x86.tilestored64.internal(... td) area 366 // -------------------------------------------------------------------------- 367 bool X86PreAMXConfig::preTileConfig() { 368 DenseMap<Instruction *, SmallVector<Value *, 8>> PosAndShapes; 369 bool NeedCfg = findConfigShapes(PosAndShapes); 370 if (!NeedCfg) 371 return false; 372 for (auto &IPAndShapes : PosAndShapes) 373 addTileConfig(IPAndShapes.first, IPAndShapes.second); 374 375 return true; 376 } 377 } // anonymous namespace 378 379 namespace { 380 381 class X86PreAMXConfigPass : public FunctionPass { 382 public: 383 static char ID; 384 385 X86PreAMXConfigPass() : FunctionPass(ID) { 386 initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry()); 387 } 388 389 bool runOnFunction(Function &F) override { 390 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 391 bool C = false; 392 393 // Prepare for fast register allocation at O0. 394 if (TM->getOptLevel() == CodeGenOpt::None) { 395 396 // We pre-config each key AMX intrinsic at O0. 397 // In theory, one tile config can cover several AMX intrinsics, but 398 // it is very diffcult to classify the tile shapes at O0. So here we 399 // let thing be easy, pre-config every key AMX intrinsic. 400 X86PreAMXConfig PCFG(F); 401 C = PCFG.preTileConfig(); 402 } 403 404 return C; 405 } 406 407 void getAnalysisUsage(AnalysisUsage &AU) const override { 408 AU.setPreservesCFG(); 409 AU.addRequired<TargetPassConfig>(); 410 } 411 }; 412 413 } // anonymous namespace 414 415 static const char PassName[] = "Pre AMX Tile Config"; 416 char X86PreAMXConfigPass::ID = 0; 417 INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) 418 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 419 INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) 420 421 FunctionPass *llvm::createX86PreAMXConfigPass() { 422 return new X86PreAMXConfigPass(); 423 } 424