1 //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// Insert tilecfg for each area of key AMX intrinsic. 10 /// All the key AMX intrinsic's tile operand must come from tileload. And the 11 /// def tile of key AMX intrinsic must be tilestored. 12 /// take tdpbssd for example: 13 /// -------------------------------------------------------------------------- 14 /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key 15 /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) | 16 /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx 17 /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) | 18 /// call void @llvm.x86.tilestored64.internal(... td) area 19 /// -------------------------------------------------------------------------- 20 /// This pass will insert tilecfg before every key-amx-area, some like: 21 /// -------------------------------------------------------------------------- 22 /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem 23 /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init 24 /// ... 25 /// ... pre-config shape of %t1 * 26 /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 27 /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 28 /// ... * 29 /// ... pre-config shape of %t2 * shapes 30 /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * 31 /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 32 /// ... 33 /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config 34 // 35 //===----------------------------------------------------------------------===// 36 // 37 #include "X86.h" 38 #include "llvm/ADT/SmallSet.h" 39 #include "llvm/Analysis/TargetTransformInfo.h" 40 #include "llvm/CodeGen/Passes.h" 41 #include "llvm/CodeGen/TargetPassConfig.h" 42 #include "llvm/CodeGen/ValueTypes.h" 43 #include "llvm/IR/DataLayout.h" 44 #include "llvm/IR/Function.h" 45 #include "llvm/IR/IRBuilder.h" 46 #include "llvm/IR/Instructions.h" 47 #include "llvm/IR/IntrinsicInst.h" 48 #include "llvm/IR/IntrinsicsX86.h" 49 #include "llvm/IR/PatternMatch.h" 50 #include "llvm/InitializePasses.h" 51 #include "llvm/Pass.h" 52 #include "llvm/Support/raw_ostream.h" 53 #include "llvm/Target/TargetMachine.h" 54 55 using namespace llvm; 56 using namespace PatternMatch; 57 58 #define DEBUG_TYPE "pre-amx-config" 59 60 static bool isAMXIntrinsic(IntrinsicInst *II) { 61 for (Value *Operand : II->operands()) 62 if (Operand->getType()->isX86_AMXTy()) 63 return true; 64 return II->getType()->isX86_AMXTy(); 65 } 66 67 static bool isTileLoad(IntrinsicInst *II) { 68 return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal; 69 } 70 71 static bool isTileStore(IntrinsicInst *II) { 72 return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal; 73 } 74 75 #ifndef NDEBUG 76 static bool onlyTileDef(IntrinsicInst *II) { 77 for (Value *Operand : II->operands()) 78 if (Operand->getType()->isX86_AMXTy()) 79 return false; 80 return II->getType()->isX86_AMXTy(); 81 } 82 83 static bool brokenVolatile(Instruction *I) { 84 // Todo: it is weak to identify a normal call here. 85 if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator()) 86 return true; 87 return false; 88 } 89 #endif 90 91 namespace { 92 class X86PreAMXConfig { 93 Function &F; 94 95 public: 96 X86PreAMXConfig(Function &Func) : F(Func) {} 97 bool preTileConfig(); 98 bool addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes); 99 bool findConfigShapes( 100 DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes); 101 bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes); 102 bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos, 103 SmallVector<Value *, 8> &Shapes); 104 BasicBlock::iterator 105 getShapesAndConfigPosEnd(BasicBlock::iterator Iter, 106 SmallVector<Value *, 8> &Shapes); 107 bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store, 108 IntrinsicInst *KeyAMX); 109 }; 110 111 // Orderly write the shapes in tilecfg's mem. This maybe not right. 112 // Because the first shape may not corresponding to the first tmm register, 113 // so we need to handle at at X86FastTileConfig::materializeTileCfg() 114 // after register allocation. 115 // For example: 116 // -------------------------------------------------------------------------- 117 // zeroinitialize tilecfg's mem (of ldtilecfg) 118 // -------------------------------------------------------------------------- 119 // ... pre-config shape of %t1 * 120 // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 * 121 // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 * 122 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 123 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 124 // ... * 125 // ... pre-config shape of %t2 * 126 // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 * 127 // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 * 128 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes 129 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 130 // ... * 131 // ... pre-config shape of %t3 * of 132 // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 * 133 // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 * 134 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * 135 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * 136 // ... * tiles 137 // ... pre-config shape of %td * 138 // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 * 139 // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 * 140 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * 141 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * 142 // -------------------------------------------------------------------------- 143 // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config 144 // -------------------------------------------------------------------------- 145 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 146 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 147 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 148 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 149 // call void @llvm.x86.tilestored64.internal(... td) area 150 // -------------------------------------------------------------------------- 151 bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos, 152 SmallVector<Value *, 8> &Shapes) { 153 bool Write = false; 154 LLVMContext &Ctx = Pos->getParent()->getContext(); 155 Type *I8Ty = Type::getInt8Ty(Ctx); 156 Type *I16Ty = Type::getInt16Ty(Ctx); 157 158 // TODO: Currently we defaultly set Palette = 1, it may be assigned to 159 // other value in the future. 160 Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0); 161 Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); 162 Value *PalettePos = 163 GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos); 164 new StoreInst(PaletteValue, PalettePos, Pos); 165 166 for (int I = 0, E = Shapes.size() / 2; I < E; I++) { 167 Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I); 168 Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2); 169 const std::string ShapeName = "amx.tmm." + itostr(I); 170 Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset, 171 ShapeName + ".shape.row", Pos); 172 Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos); 173 ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0), 174 ShapeName + ".shape.col", Pos); 175 Value *Row = Shapes[I * 2]; 176 Value *Col = Shapes[I * 2 + 1]; 177 Row = new TruncInst(Row, I8Ty, "", Pos); 178 new StoreInst(Row, RowPos, Pos); 179 new StoreInst(Col, ColPos, Pos); 180 Write = true; 181 } 182 return Write; 183 } 184 185 bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart, 186 SmallVector<Value *, 8> &Shapes) { 187 Module *M = F.getParent(); 188 IRBuilder<> Builder(ModelStart); 189 const DataLayout &DL = M->getDataLayout(); 190 unsigned AddrSpace = DL.getAllocaAddrSpace(); 191 LLVMContext &Ctx = Builder.getContext(); 192 Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false); 193 Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx)); 194 195 AllocaInst *Addr = 196 new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front()); 197 Addr->setAlignment(Alignment); 198 Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy()); 199 200 std::array<Value *, 1> Args = {I8Ptr}; 201 Instruction *Cfg = 202 Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args); 203 204 Value *Val0 = Constant::getNullValue(V512Ty); 205 Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg); 206 assert(Init0 && "Not Zero initilizate the cfg mem!"); 207 208 preWriteTileCfg(I8Ptr, Cfg, Shapes); 209 210 return Init0; 211 } 212 213 // Todo: We may need to handle "more than one store" case in the future. 214 bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads, 215 IntrinsicInst *Store, 216 IntrinsicInst *KeyAMX) { 217 Value *ST = Store->getOperand(4); 218 219 // Only has tileload and tilestore. 220 if (!KeyAMX) 221 return (Loads.size() == 1) && Loads.contains(ST); 222 223 // All Loads should be operands of KeyAMX. 224 // All tile operands of KeyAMX should come from Loads. 225 for (Value *Op : KeyAMX->operands()) { 226 if (Op->getType()->isX86_AMXTy()) 227 if (!Loads.erase(Op)) 228 return false; 229 } 230 231 // The def of KeyAMX should be stored into mem. 232 // Todo: is it key amx can be no def? 233 return Loads.empty() && (ST == cast<Value>(KeyAMX)); 234 } 235 236 bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX, 237 SmallVector<Value *, 8> &Shapes) { 238 for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) { 239 Value *Op = KeyAMX->getOperand(I); 240 if (!Op->getType()->isX86_AMXTy()) 241 continue; 242 IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op); 243 assert((TileDef && isTileLoad(TileDef)) && 244 "All KeyAMX's tile definiation should comes from TileLoad!"); 245 Shapes.push_back(TileDef->getOperand(0)); 246 Shapes.push_back(TileDef->getOperand(1)); 247 } 248 if (!isTileStore(KeyAMX)) { 249 Shapes.push_back(KeyAMX->getOperand(0)); 250 Shapes.push_back(KeyAMX->getOperand(1)); 251 } 252 return Shapes.size() != 0; 253 } 254 255 // Collect the shapes and skip the area of current key amx intrinsic. 256 // 257 // For example: 258 // ... 259 // -------------------------------------------------------------------------- 260 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k) 261 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k) 262 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k) 263 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) 264 // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k) 265 // -------------------------------------------------------------------------- 266 BasicBlock::iterator 267 X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter, 268 SmallVector<Value *, 8> &Shapes) { 269 IntrinsicInst *KeyAMX = nullptr; 270 BasicBlock *BB = Iter->getParent(); 271 BasicBlock::iterator PosEnd = BB->end(); 272 SmallSet<Value *, 4> Loads; 273 274 // See TileStore as "Config Position End" and check volatile model. 275 for (auto I = Iter, E = BB->end(); I != E; ++I) { 276 assert(!brokenVolatile(&*I) && "Not reach tile store!"); 277 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); 278 if (!II || !isAMXIntrinsic(II)) 279 continue; 280 281 if (isTileLoad(II)) { 282 Loads.insert(II); 283 } else if (isTileStore(II)) { 284 if (!checkVolatileModel(Loads, II, KeyAMX)) 285 report_fatal_error("Not Volatile AMX Model!"); 286 PosEnd = I; 287 break; 288 } else { 289 assert(!KeyAMX && "Too many key amx intrinsic!"); 290 KeyAMX = II; 291 } 292 } 293 assert(PosEnd != BB->end() && "Not find TileStore!"); 294 295 // See KeyAMX as TileStore if only TileLoad and TileStore. 296 if (!KeyAMX) 297 KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd); 298 299 // Get Shapes in order. 300 assert(Shapes.empty() && "Shapes should be clean."); 301 getKeyAMXShapes(KeyAMX, Shapes); 302 303 return PosEnd; 304 } 305 306 // Record a key amx area's shapes with its position. 307 // Use the first tileload as its position. 308 // For example: 309 // ... 310 // -------------------------------------------------------------------------- 311 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos 312 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) / 313 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes: 314 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n) 315 // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n) 316 // -------------------------------------------------------------------------- 317 bool X86PreAMXConfig::findConfigShapes( 318 DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes) { 319 bool Find = false; 320 for (BasicBlock &BB : F) { 321 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) { 322 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); 323 if (!II) 324 continue; 325 if (!isAMXIntrinsic(II)) 326 continue; 327 assert(onlyTileDef(II) && "Not volatile model for AMX at O0!"); 328 329 I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]); 330 Find = true; 331 } 332 } 333 return Find; 334 } 335 336 // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic. 337 // e.g. (key amx = tdpbssd) 338 // -------------------------------------------------------------------------- 339 // %cfgmem = alloca <16 x i32>, align 4 * allocate mem 340 // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init 341 // ... 342 // ... pre-config shape of %t1 * 343 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * 344 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config 345 // ... * 346 // ... pre-config shape of %t2 * 347 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes 348 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * 349 // ... * 350 // ... pre-config shape of %t3 * of 351 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * 352 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * 353 // ... * tiles 354 // ... pre-config shape of %td * 355 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * 356 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * 357 // 358 // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config 359 // -------------------------------------------------------------------------- 360 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 361 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 362 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 363 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 364 // call void @llvm.x86.tilestored64.internal(... td) area 365 // -------------------------------------------------------------------------- 366 bool X86PreAMXConfig::preTileConfig() { 367 DenseMap<Instruction *, SmallVector<Value *, 8>> PosAndShapes; 368 bool NeedCfg = findConfigShapes(PosAndShapes); 369 if (!NeedCfg) 370 return false; 371 for (auto &IPAndShapes : PosAndShapes) 372 addTileConfig(IPAndShapes.first, IPAndShapes.second); 373 374 return true; 375 } 376 } // anonymous namespace 377 378 namespace { 379 380 class X86PreAMXConfigPass : public FunctionPass { 381 public: 382 static char ID; 383 384 X86PreAMXConfigPass() : FunctionPass(ID) { 385 initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry()); 386 } 387 388 bool runOnFunction(Function &F) override { 389 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 390 bool C = false; 391 392 // Prepare for fast register allocation at O0. 393 if (TM->getOptLevel() == CodeGenOpt::None) { 394 395 // We pre-config each key AMX intrinsic at O0. 396 // In theory, one tile config can cover several AMX intrinsics, but 397 // it is very diffcult to classify the tile shapes at O0. So here we 398 // let thing be easy, pre-config every key AMX intrinsic. 399 X86PreAMXConfig PCFG(F); 400 C = PCFG.preTileConfig(); 401 } 402 403 return C; 404 } 405 406 void getAnalysisUsage(AnalysisUsage &AU) const override { 407 AU.setPreservesCFG(); 408 AU.addRequired<TargetPassConfig>(); 409 } 410 }; 411 412 } // anonymous namespace 413 414 static const char PassName[] = "Pre AMX Tile Config"; 415 char X86PreAMXConfigPass::ID = 0; 416 INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) 417 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 418 INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) 419 420 FunctionPass *llvm::createX86PreAMXConfigPass() { 421 return new X86PreAMXConfigPass(); 422 } 423