1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/LoopInfo.h" 27 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 28 #include "llvm/Analysis/TargetTransformInfo.h" 29 #include "llvm/Analysis/ValueTracking.h" 30 #include "llvm/Analysis/VectorUtils.h" 31 #include "llvm/IR/CFG.h" 32 #include "llvm/IR/DataLayout.h" 33 #include "llvm/IR/DebugInfoMetadata.h" 34 #include "llvm/IR/Function.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/Instructions.h" 37 #include "llvm/IR/IntrinsicInst.h" 38 #include "llvm/IR/MatrixBuilder.h" 39 #include "llvm/IR/PatternMatch.h" 40 #include "llvm/InitializePasses.h" 41 #include "llvm/Pass.h" 42 #include "llvm/Support/Alignment.h" 43 #include "llvm/Support/CommandLine.h" 44 #include "llvm/Support/Debug.h" 45 #include "llvm/Transforms/Scalar.h" 46 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 47 #include "llvm/Transforms/Utils/LoopUtils.h" 48 #include "llvm/Transforms/Utils/MatrixUtils.h" 49 50 using namespace llvm; 51 using namespace PatternMatch; 52 53 #define DEBUG_TYPE "lower-matrix-intrinsics" 54 55 static cl::opt<bool> 56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 57 cl::desc("Enable/disable fusing matrix instructions.")); 58 // TODO: Allow and use non-square tiles. 59 static cl::opt<unsigned> TileSize( 60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 61 cl::desc( 62 "Tile size for matrix instruction fusion using square-shaped tiles.")); 63 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 64 cl::Hidden, 65 cl::desc("Generate loop nest for tiling.")); 66 static cl::opt<bool> ForceFusion( 67 "force-fuse-matrix", cl::init(false), cl::Hidden, 68 cl::desc("Force matrix instruction fusion even if not profitable.")); 69 static cl::opt<bool> AllowContractEnabled( 70 "matrix-allow-contract", cl::init(false), cl::Hidden, 71 cl::desc("Allow the use of FMAs if available and profitable. This may " 72 "result in different results, due to less rounding error.")); 73 74 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 75 76 static cl::opt<MatrixLayoutTy> MatrixLayout( 77 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 78 cl::desc("Sets the default matrix layout"), 79 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 80 "Use column-major layout"), 81 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 82 "Use row-major layout"))); 83 84 /// Helper function to either return Scope, if it is a subprogram or the 85 /// attached subprogram for a local scope. 86 static DISubprogram *getSubprogram(DIScope *Scope) { 87 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 88 return Subprogram; 89 return cast<DILocalScope>(Scope)->getSubprogram(); 90 } 91 92 namespace { 93 94 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 95 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 96 // assuming \p Stride elements between start two consecutive vectors. 97 // \p Stride must be >= \p NumElements. 98 // For column-major matrixes, the function computes the address of a column 99 // vectors and \p NumElements must be set to the number of elements in a column 100 // (= number of rows of the matrix). For row-major matrixes, the function 101 // computes the address of a row vector and \p NumElements must be set to the 102 // number of elements in a column (= number of columns of the matrix). 103 // 104 // Consider a 4x4 matrix in column-mjaor layout like below 105 // 106 // 0 1 2 3 107 // 0 v_0_0 v_0_1 v_0_2 v_0_3 108 // 1 v_1_0 v_1_1 v_1_2 v_1_3 109 // 2 v_2_0 v_2_1 v_2_2 v_2_3 110 // 3 v_3_0 v_3_1 v_3_2 v_3_3 111 112 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 113 // we need a pointer to the first element of the submatrix as base pointer. 114 // Then we can use computeVectorAddr to compute the addresses for the columns 115 // of the sub-matrix. 116 // 117 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 118 // -> just returns Base 119 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 120 // -> returns Base + (1 * 4) 121 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 122 // -> returns Base + (2 * 4) 123 // 124 // The graphic below illustrates the number of elements in a column (marked 125 // with |) and the number of skipped elements (marked with }). 126 // 127 // v_0_0 v_0_1 {v_0_2 {v_0_3 128 // Base Col 1 Col 2 129 // | | | 130 // v_1_0 |v_1_1 |v_1_2 |v_1_3 131 // v_2_0 |v_2_1 |v_2_2 |v_2_3 132 // v_3_0 {v_3_1 {v_3_2 v_3_3 133 // 134 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 135 unsigned NumElements, Type *EltType, 136 IRBuilder<> &Builder) { 137 138 assert((!isa<ConstantInt>(Stride) || 139 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 140 "Stride must be >= the number of elements in the result vector."); 141 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 142 143 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 144 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 145 146 // Get pointer to the start of the selected vector. Skip GEP creation, 147 // if we select vector 0. 148 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 149 VecStart = BasePtr; 150 else 151 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 152 153 // Cast elementwise vector start pointer to a pointer to a vector 154 // (EltType x NumElements)*. 155 auto *VecType = FixedVectorType::get(EltType, NumElements); 156 Type *VecPtrType = PointerType::get(VecType, AS); 157 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 158 } 159 160 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 161 /// 162 /// Currently, the lowering for each matrix intrinsic is done as follows: 163 /// 1. Propagate the shape information from intrinsics to connected 164 /// instructions. 165 /// 2. Lower instructions with shape information (assuming column-major layout). 166 /// The lowering works similarly using row-major layout. 167 /// 2.1. Get column vectors for each argument. If we already lowered the 168 /// definition of an argument, use the produced column vectors directly. 169 /// If not, split the operand vector containing an embedded matrix into 170 /// a set of column vectors, 171 /// 2.2. Lower the instruction in terms of column major operations, which 172 /// yields a set of column vectors containing result matrix. Note that we 173 /// lower all instructions that have shape information. Besides the 174 /// intrinsics, this includes stores for example. 175 /// 2.3. Update uses of the lowered instruction. If we have shape information 176 /// for a user, there is nothing to do, as we will look up the result 177 /// column matrix when lowering the user. For other uses, we embed the 178 /// result matrix in a flat vector and update the use. 179 /// 2.4. Cache the result column matrix for the instruction we lowered 180 /// 3. After we lowered all instructions in a function, remove the now 181 /// obsolete instructions. 182 /// 183 class LowerMatrixIntrinsics { 184 Function &Func; 185 const DataLayout &DL; 186 const TargetTransformInfo &TTI; 187 AliasAnalysis *AA; 188 DominatorTree *DT; 189 LoopInfo *LI; 190 OptimizationRemarkEmitter *ORE; 191 192 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 193 struct OpInfoTy { 194 /// Number of stores emitted to generate this matrix. 195 unsigned NumStores = 0; 196 /// Number of loads emitted to generate this matrix. 197 unsigned NumLoads = 0; 198 /// Number of compute operations emitted to generate this matrix. 199 unsigned NumComputeOps = 0; 200 /// Most of the time transposes can be fused with matrix multiplies or can 201 /// be folded away via algebraic simplifications. This is the number of 202 /// transposes that we failed to make "free" via such optimizations. 203 unsigned NumExposedTransposes = 0; 204 205 OpInfoTy &operator+=(const OpInfoTy &RHS) { 206 NumStores += RHS.NumStores; 207 NumLoads += RHS.NumLoads; 208 NumComputeOps += RHS.NumComputeOps; 209 NumExposedTransposes += RHS.NumExposedTransposes; 210 return *this; 211 } 212 }; 213 214 /// Wrapper class representing a matrix as a set of vectors, either in row or 215 /// column major layout. All vectors must have the same vector type. 216 class MatrixTy { 217 SmallVector<Value *, 16> Vectors; 218 219 OpInfoTy OpInfo; 220 221 bool IsColumnMajor = true; 222 223 public: 224 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 225 MatrixTy(ArrayRef<Value *> Vectors) 226 : Vectors(Vectors.begin(), Vectors.end()), 227 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 228 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 229 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 230 231 unsigned D = isColumnMajor() ? NumColumns : NumRows; 232 for (unsigned J = 0; J < D; ++J) 233 addVector(UndefValue::get(FixedVectorType::get( 234 EltTy, isColumnMajor() ? NumRows : NumColumns))); 235 } 236 237 Value *getVector(unsigned i) const { return Vectors[i]; } 238 Value *getColumn(unsigned i) const { 239 assert(isColumnMajor() && "only supported for column-major matrixes"); 240 return Vectors[i]; 241 } 242 Value *getRow(unsigned i) const { 243 assert(!isColumnMajor() && "only supported for row-major matrixes"); 244 return Vectors[i]; 245 } 246 247 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 248 249 Type *getElementType() const { return getVectorTy()->getElementType(); } 250 251 unsigned getNumVectors() const { 252 if (isColumnMajor()) 253 return getNumColumns(); 254 return getNumRows(); 255 } 256 257 unsigned getNumColumns() const { 258 if (isColumnMajor()) 259 return Vectors.size(); 260 else { 261 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 262 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 263 } 264 } 265 unsigned getNumRows() const { 266 if (isColumnMajor()) { 267 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 268 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 269 } else 270 return Vectors.size(); 271 } 272 273 void addVector(Value *V) { Vectors.push_back(V); } 274 VectorType *getColumnTy() { 275 assert(isColumnMajor() && "only supported for column-major matrixes"); 276 return getVectorTy(); 277 } 278 279 VectorType *getVectorTy() const { 280 return cast<VectorType>(Vectors[0]->getType()); 281 } 282 283 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 284 assert(isColumnMajor() && 285 "columns() only supported for column-major matrixes"); 286 return make_range(Vectors.begin(), Vectors.end()); 287 } 288 289 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 290 return make_range(Vectors.begin(), Vectors.end()); 291 } 292 293 /// Embed the vectors of the matrix into a flat vector by concatenating 294 /// them. 295 Value *embedInVector(IRBuilder<> &Builder) const { 296 return Vectors.size() == 1 ? Vectors[0] 297 : concatenateVectors(Builder, Vectors); 298 } 299 300 MatrixTy &addNumLoads(unsigned N) { 301 OpInfo.NumLoads += N; 302 return *this; 303 } 304 305 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 306 307 MatrixTy &addNumStores(unsigned N) { 308 OpInfo.NumStores += N; 309 return *this; 310 } 311 312 MatrixTy &addNumExposedTransposes(unsigned N) { 313 OpInfo.NumExposedTransposes += N; 314 return *this; 315 } 316 317 MatrixTy &addNumComputeOps(unsigned N) { 318 OpInfo.NumComputeOps += N; 319 return *this; 320 } 321 322 unsigned getNumStores() const { return OpInfo.NumStores; } 323 unsigned getNumLoads() const { return OpInfo.NumLoads; } 324 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 325 326 const OpInfoTy &getOpInfo() const { return OpInfo; } 327 328 bool isColumnMajor() const { return IsColumnMajor; } 329 330 unsigned getStride() const { 331 if (isColumnMajor()) 332 return getNumRows(); 333 return getNumColumns(); 334 } 335 336 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 337 /// matrix is column-major, the result vector is extracted from a column 338 /// vector, otherwise from a row vector. 339 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 340 IRBuilder<> &Builder) const { 341 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 342 return Builder.CreateShuffleVector( 343 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 344 "block"); 345 } 346 }; 347 348 struct ShapeInfo { 349 unsigned NumRows; 350 unsigned NumColumns; 351 352 bool IsColumnMajor; 353 354 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 355 : NumRows(NumRows), NumColumns(NumColumns), 356 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 357 358 ShapeInfo(Value *NumRows, Value *NumColumns) 359 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 360 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 361 362 bool operator==(const ShapeInfo &other) { 363 return NumRows == other.NumRows && NumColumns == other.NumColumns; 364 } 365 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 366 367 /// Returns true if shape-information is defined, meaning both dimensions 368 /// are != 0. 369 operator bool() const { 370 assert(NumRows == 0 || NumColumns != 0); 371 return NumRows != 0; 372 } 373 374 unsigned getStride() const { 375 if (IsColumnMajor) 376 return NumRows; 377 return NumColumns; 378 } 379 380 unsigned getNumVectors() const { 381 if (IsColumnMajor) 382 return NumColumns; 383 return NumRows; 384 } 385 }; 386 387 /// Maps instructions to their shape information. The shape information 388 /// describes the shape to be used while lowering. This matches the shape of 389 /// the result value of the instruction, with the only exceptions being store 390 /// instructions and the matrix_column_major_store intrinsics. For those, the 391 /// shape information indicates that those instructions should be lowered 392 /// using shape information as well. A ValueMap is used so that when 393 /// sub-passes like optimizeTransposes performs RAUW the map stays 394 /// up-to-date. 395 ValueMap<Value *, ShapeInfo> ShapeMap; 396 397 /// List of instructions to remove. While lowering, we are not replacing all 398 /// users of a lowered instruction, if shape information is available and 399 /// those need to be removed after we finished lowering. 400 SmallVector<Instruction *, 16> ToRemove; 401 402 /// Map from instructions to their produced column matrix. 403 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 404 405 private: 406 static FastMathFlags getFastMathFlags(Instruction *Inst) { 407 FastMathFlags FMF; 408 409 if (isa<FPMathOperator>(*Inst)) 410 FMF = Inst->getFastMathFlags(); 411 412 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 413 414 return FMF; 415 } 416 417 public: 418 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 419 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 420 OptimizationRemarkEmitter *ORE) 421 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 422 LI(LI), ORE(ORE) {} 423 424 unsigned getNumOps(Type *VT) { 425 assert(isa<VectorType>(VT) && "Expected vector type"); 426 return getNumOps(VT->getScalarType(), 427 cast<FixedVectorType>(VT)->getNumElements()); 428 } 429 430 /// Is this the minimal version executed in the backend pipelines. 431 bool isMinimal() const { 432 return !DT; 433 } 434 435 /// Return the estimated number of vector ops required for an operation on 436 /// \p VT * N. 437 unsigned getNumOps(Type *ST, unsigned N) { 438 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 439 double(TTI.getRegisterBitWidth( 440 TargetTransformInfo::RGK_FixedWidthVector) 441 .getFixedSize())); 442 } 443 444 /// Return the set of vectors that a matrix value is lowered to. 445 /// 446 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 447 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 448 /// into vectors. 449 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 450 IRBuilder<> &Builder) { 451 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 452 assert(VType && "MatrixVal must be a vector type"); 453 assert(cast<FixedVectorType>(VType)->getNumElements() == 454 SI.NumRows * SI.NumColumns && 455 "The vector size must match the number of matrix elements"); 456 457 // Check if we lowered MatrixVal using shape information. In that case, 458 // return the existing matrix, if it matches the requested shape 459 // information. If there is a mis-match, embed the result in a flat 460 // vector and split it later. 461 auto Found = Inst2ColumnMatrix.find(MatrixVal); 462 if (Found != Inst2ColumnMatrix.end()) { 463 MatrixTy &M = Found->second; 464 // Return the found matrix, if its shape matches the requested shape 465 // information 466 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 467 return M; 468 469 MatrixVal = M.embedInVector(Builder); 470 } 471 472 // Otherwise split MatrixVal. 473 SmallVector<Value *, 16> SplitVecs; 474 for (unsigned MaskStart = 0; 475 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 476 MaskStart += SI.getStride()) { 477 Value *V = Builder.CreateShuffleVector( 478 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 479 "split"); 480 SplitVecs.push_back(V); 481 } 482 483 return {SplitVecs}; 484 } 485 486 /// If \p V already has a known shape return false. Otherwise set the shape 487 /// for instructions that support it. 488 bool setShapeInfo(Value *V, ShapeInfo Shape) { 489 assert(Shape && "Shape not set"); 490 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 491 return false; 492 493 auto SIter = ShapeMap.find(V); 494 if (SIter != ShapeMap.end()) { 495 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 496 << SIter->second.NumRows << " " 497 << SIter->second.NumColumns << " for " << *V << "\n"); 498 return false; 499 } 500 501 ShapeMap.insert({V, Shape}); 502 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 503 << " for " << *V << "\n"); 504 return true; 505 } 506 507 bool isUniformShape(Value *V) { 508 Instruction *I = dyn_cast<Instruction>(V); 509 if (!I) 510 return true; 511 512 switch (I->getOpcode()) { 513 case Instruction::FAdd: 514 case Instruction::FSub: 515 case Instruction::FMul: // Scalar multiply. 516 case Instruction::FNeg: 517 case Instruction::Add: 518 case Instruction::Mul: 519 case Instruction::Sub: 520 return true; 521 default: 522 return false; 523 } 524 } 525 526 /// Returns true if shape information can be used for \p V. The supported 527 /// instructions must match the instructions that can be lowered by this pass. 528 bool supportsShapeInfo(Value *V) { 529 Instruction *Inst = dyn_cast<Instruction>(V); 530 if (!Inst) 531 return false; 532 533 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 534 if (II) 535 switch (II->getIntrinsicID()) { 536 case Intrinsic::matrix_multiply: 537 case Intrinsic::matrix_transpose: 538 case Intrinsic::matrix_column_major_load: 539 case Intrinsic::matrix_column_major_store: 540 return true; 541 default: 542 return false; 543 } 544 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 545 } 546 547 /// Propagate the shape information of instructions to their users. 548 /// The work list contains instructions for which we can compute the shape, 549 /// either based on the information provided by matrix intrinsics or known 550 /// shapes of operands. 551 SmallVector<Instruction *, 32> 552 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 553 SmallVector<Instruction *, 32> NewWorkList; 554 // Pop an element for which we guaranteed to have at least one of the 555 // operand shapes. Add the shape for this and then add users to the work 556 // list. 557 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 558 while (!WorkList.empty()) { 559 Instruction *Inst = WorkList.pop_back_val(); 560 561 // New entry, set the value and insert operands 562 bool Propagate = false; 563 564 Value *MatrixA; 565 Value *MatrixB; 566 Value *M; 567 Value *N; 568 Value *K; 569 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 570 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 571 m_Value(N), m_Value(K)))) { 572 Propagate = setShapeInfo(Inst, {M, K}); 573 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 574 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 575 // Flip dimensions. 576 Propagate = setShapeInfo(Inst, {N, M}); 577 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 578 m_Value(MatrixA), m_Value(), m_Value(), 579 m_Value(), m_Value(M), m_Value(N)))) { 580 Propagate = setShapeInfo(Inst, {N, M}); 581 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 582 m_Value(), m_Value(), m_Value(), m_Value(M), 583 m_Value(N)))) { 584 Propagate = setShapeInfo(Inst, {M, N}); 585 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 586 auto OpShape = ShapeMap.find(MatrixA); 587 if (OpShape != ShapeMap.end()) 588 setShapeInfo(Inst, OpShape->second); 589 continue; 590 } else if (isUniformShape(Inst)) { 591 // Find the first operand that has a known shape and use that. 592 for (auto &Op : Inst->operands()) { 593 auto OpShape = ShapeMap.find(Op.get()); 594 if (OpShape != ShapeMap.end()) { 595 Propagate |= setShapeInfo(Inst, OpShape->second); 596 break; 597 } 598 } 599 } 600 601 if (Propagate) { 602 NewWorkList.push_back(Inst); 603 for (auto *User : Inst->users()) 604 if (ShapeMap.count(User) == 0) 605 WorkList.push_back(cast<Instruction>(User)); 606 } 607 } 608 609 return NewWorkList; 610 } 611 612 /// Propagate the shape to operands of instructions with shape information. 613 /// \p Worklist contains the instruction for which we already know the shape. 614 SmallVector<Instruction *, 32> 615 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 616 SmallVector<Instruction *, 32> NewWorkList; 617 618 auto pushInstruction = [](Value *V, 619 SmallVectorImpl<Instruction *> &WorkList) { 620 Instruction *I = dyn_cast<Instruction>(V); 621 if (I) 622 WorkList.push_back(I); 623 }; 624 // Pop an element with known shape. Traverse the operands, if their shape 625 // derives from the result shape and is unknown, add it and add them to the 626 // worklist. 627 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 628 while (!WorkList.empty()) { 629 Value *V = WorkList.pop_back_val(); 630 631 size_t BeforeProcessingV = WorkList.size(); 632 if (!isa<Instruction>(V)) 633 continue; 634 635 Value *MatrixA; 636 Value *MatrixB; 637 Value *M; 638 Value *N; 639 Value *K; 640 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 641 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 642 m_Value(N), m_Value(K)))) { 643 if (setShapeInfo(MatrixA, {M, N})) 644 pushInstruction(MatrixA, WorkList); 645 646 if (setShapeInfo(MatrixB, {N, K})) 647 pushInstruction(MatrixB, WorkList); 648 649 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 650 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 651 // Flip dimensions. 652 if (setShapeInfo(MatrixA, {M, N})) 653 pushInstruction(MatrixA, WorkList); 654 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 655 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 656 m_Value(M), m_Value(N)))) { 657 if (setShapeInfo(MatrixA, {M, N})) { 658 pushInstruction(MatrixA, WorkList); 659 } 660 } else if (isa<LoadInst>(V) || 661 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 662 // Nothing to do, no matrix input. 663 } else if (isa<StoreInst>(V)) { 664 // Nothing to do. We forward-propagated to this so we would just 665 // backward propagate to an instruction with an already known shape. 666 } else if (isUniformShape(V)) { 667 // Propagate to all operands. 668 ShapeInfo Shape = ShapeMap[V]; 669 for (Use &U : cast<Instruction>(V)->operands()) { 670 if (setShapeInfo(U.get(), Shape)) 671 pushInstruction(U.get(), WorkList); 672 } 673 } 674 // After we discovered new shape info for new instructions in the 675 // worklist, we use their users as seeds for the next round of forward 676 // propagation. 677 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 678 for (User *U : WorkList[I]->users()) 679 if (isa<Instruction>(U) && V != U) 680 NewWorkList.push_back(cast<Instruction>(U)); 681 } 682 return NewWorkList; 683 } 684 685 /// Try moving transposes in order to fold them away or into multiplies. 686 void optimizeTransposes() { 687 auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { 688 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 689 // with New. We should only add New it it supportsShapeInfo so we insert 690 // it conditionally instead. 691 auto S = ShapeMap.find(&Old); 692 if (S != ShapeMap.end()) { 693 ShapeMap.erase(S); 694 if (supportsShapeInfo(New)) 695 ShapeMap.insert({New, S->second}); 696 } 697 Old.replaceAllUsesWith(New); 698 }; 699 700 // First sink all transposes inside matmuls, hoping that we end up with NN, 701 // NT or TN variants. 702 for (BasicBlock &BB : reverse(Func)) { 703 for (auto II = BB.rbegin(); II != BB.rend();) { 704 Instruction &I = *II; 705 // We may remove II. By default continue on the next/prev instruction. 706 ++II; 707 // If we were to erase II, move again. 708 auto EraseFromParent = [&II](Value *V) { 709 auto *Inst = cast<Instruction>(V); 710 if (Inst->use_empty()) { 711 if (Inst == &*II) { 712 ++II; 713 } 714 Inst->eraseFromParent(); 715 } 716 }; 717 718 // If we're creating a new instruction, continue from there. 719 Instruction *NewInst = nullptr; 720 721 IRBuilder<> IB(&I); 722 MatrixBuilder Builder(IB); 723 724 Value *TA, *TAMA, *TAMB; 725 ConstantInt *R, *K, *C; 726 if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) { 727 728 // Transpose of a transpose is a nop 729 Value *TATA; 730 if (match(TA, 731 m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 732 ReplaceAllUsesWith(I, TATA); 733 EraseFromParent(&I); 734 EraseFromParent(TA); 735 } 736 737 // (A * B)^t -> B^t * A^t 738 // RxK KxC CxK KxR 739 else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 740 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 741 m_ConstantInt(K), m_ConstantInt(C)))) { 742 Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), 743 C->getZExtValue(), 744 TAMB->getName() + "_t"); 745 // We are being run after shape prop, add shape for newly created 746 // instructions so that we lower them later. 747 setShapeInfo(T0, {C, K}); 748 Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), 749 K->getZExtValue(), 750 TAMA->getName() + "_t"); 751 setShapeInfo(T1, {K, R}); 752 NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), 753 K->getZExtValue(), 754 R->getZExtValue(), "mmul"); 755 ReplaceAllUsesWith(I, NewInst); 756 EraseFromParent(&I); 757 EraseFromParent(TA); 758 } 759 } 760 761 // If we replaced I with a new instruction, continue from there. 762 if (NewInst) 763 II = std::next(BasicBlock::reverse_iterator(NewInst)); 764 } 765 } 766 767 // If we have a TT matmul, lift the transpose. We may be able to fold into 768 // consuming multiply. 769 for (BasicBlock &BB : Func) { 770 for (BasicBlock::iterator II = BB.begin(); II != BB.end();) { 771 Instruction *I = &*II; 772 // We may remove I. 773 ++II; 774 Value *A, *B, *AT, *BT; 775 ConstantInt *R, *K, *C; 776 // A^t * B ^t -> (B * A)^t 777 if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>( 778 m_Value(A), m_Value(B), m_ConstantInt(R), 779 m_ConstantInt(K), m_ConstantInt(C))) && 780 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 781 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 782 IRBuilder<> IB(&*I); 783 MatrixBuilder Builder(IB); 784 Value *M = Builder.CreateMatrixMultiply( 785 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 786 setShapeInfo(M, {C, R}); 787 Instruction *NewInst = Builder.CreateMatrixTranspose( 788 M, C->getZExtValue(), R->getZExtValue()); 789 ReplaceAllUsesWith(*I, NewInst); 790 if (I->use_empty()) 791 I->eraseFromParent(); 792 if (A->use_empty()) 793 cast<Instruction>(A)->eraseFromParent(); 794 if (A != B && B->use_empty()) 795 cast<Instruction>(B)->eraseFromParent(); 796 } 797 } 798 } 799 } 800 801 bool Visit() { 802 SmallVector<Instruction *, 32> WorkList; 803 804 // Initially only the shape of matrix intrinsics is known. 805 // Initialize the work list with ops carrying shape information. 806 for (BasicBlock &BB : Func) 807 for (Instruction &Inst : BB) { 808 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 809 if (!II) 810 continue; 811 812 switch (II->getIntrinsicID()) { 813 case Intrinsic::matrix_multiply: 814 case Intrinsic::matrix_transpose: 815 case Intrinsic::matrix_column_major_load: 816 case Intrinsic::matrix_column_major_store: 817 WorkList.push_back(&Inst); 818 break; 819 default: 820 break; 821 } 822 } 823 824 // Avoid unnecessary work if there are no matrix intrinsics in the function. 825 if (WorkList.empty()) 826 return false; 827 828 // Propagate shapes until nothing changes any longer. 829 while (!WorkList.empty()) { 830 WorkList = propagateShapeForward(WorkList); 831 WorkList = propagateShapeBackward(WorkList); 832 } 833 834 if (!isMinimal()) { 835 optimizeTransposes(); 836 LLVM_DEBUG({ 837 dbgs() << "Dump after matrix transpose optimization:\n"; 838 Func.dump(); 839 }); 840 } 841 842 bool Changed = false; 843 SmallVector<CallInst *, 16> MaybeFusableInsts; 844 SmallVector<Instruction *, 16> MatrixInsts; 845 846 // First, collect all instructions with shape information and candidates for 847 // fusion (currently only matrix multiplies). 848 ReversePostOrderTraversal<Function *> RPOT(&Func); 849 for (auto *BB : RPOT) 850 for (Instruction &I : *BB) { 851 if (ShapeMap.find(&I) == ShapeMap.end()) 852 continue; 853 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 854 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 855 MatrixInsts.push_back(&I); 856 } 857 858 // Second, try to fuse candidates. 859 SmallPtrSet<Instruction *, 16> FusedInsts; 860 for (CallInst *CI : MaybeFusableInsts) 861 LowerMatrixMultiplyFused(CI, FusedInsts); 862 Changed = !FusedInsts.empty(); 863 864 // Third, lower remaining instructions with shape information. 865 for (Instruction *Inst : MatrixInsts) { 866 if (FusedInsts.count(Inst)) 867 continue; 868 869 IRBuilder<> Builder(Inst); 870 871 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 872 Changed |= VisitCallInst(CInst); 873 874 Value *Op1; 875 Value *Op2; 876 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 877 Changed |= VisitBinaryOperator(BinOp); 878 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 879 Changed |= VisitUnaryOperator(UnOp); 880 if (match(Inst, m_Load(m_Value(Op1)))) 881 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 882 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 883 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 884 } 885 886 if (ORE) { 887 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 888 RemarkGen.emitRemarks(); 889 } 890 891 // Delete the instructions backwards, as it has a reduced likelihood of 892 // having to update as many def-use and use-def chains. 893 // 894 // Because we add to ToRemove during fusion we can't guarantee that defs 895 // are before uses. Change uses to undef temporarily as these should get 896 // removed as well. 897 // 898 // For verification, we keep track of where we changed uses to undefs in 899 // UndefedInsts and then check that we in fact remove them. 900 SmallSet<Instruction *, 16> UndefedInsts; 901 for (auto *Inst : reverse(ToRemove)) { 902 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 903 if (auto *Undefed = dyn_cast<Instruction>(U.getUser())) 904 UndefedInsts.insert(Undefed); 905 U.set(UndefValue::get(Inst->getType())); 906 } 907 Inst->eraseFromParent(); 908 UndefedInsts.erase(Inst); 909 } 910 if (!UndefedInsts.empty()) { 911 // If we didn't remove all undefed instructions, it's a hard error. 912 dbgs() << "Undefed but present instructions:\n"; 913 for (auto *I : UndefedInsts) 914 dbgs() << *I << "\n"; 915 llvm_unreachable("Undefed but instruction not removed"); 916 } 917 918 return Changed; 919 } 920 921 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 922 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 923 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 924 Type *EltPtrType = PointerType::get(EltType, AS); 925 return Builder.CreatePointerCast(BasePtr, EltPtrType); 926 } 927 928 /// Replace intrinsic calls 929 bool VisitCallInst(CallInst *Inst) { 930 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 931 return false; 932 933 switch (Inst->getCalledFunction()->getIntrinsicID()) { 934 case Intrinsic::matrix_multiply: 935 LowerMultiply(Inst); 936 break; 937 case Intrinsic::matrix_transpose: 938 LowerTranspose(Inst); 939 break; 940 case Intrinsic::matrix_column_major_load: 941 LowerColumnMajorLoad(Inst); 942 break; 943 case Intrinsic::matrix_column_major_store: 944 LowerColumnMajorStore(Inst); 945 break; 946 default: 947 return false; 948 } 949 return true; 950 } 951 952 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 953 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 954 /// ConstantInt, reduce the initial alignment based on the byte offset. For 955 /// non-ConstantInt strides, return the common alignment of the initial 956 /// alignment and the element size in bytes. 957 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 958 MaybeAlign A) const { 959 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 960 if (Idx == 0) 961 return InitialAlign; 962 963 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 964 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 965 uint64_t StrideInBytes = 966 ConstStride->getZExtValue() * ElementSizeInBits / 8; 967 return commonAlignment(InitialAlign, Idx * StrideInBytes); 968 } 969 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 970 } 971 972 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 973 /// vectors. 974 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 975 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 976 auto *VType = cast<VectorType>(Ty); 977 Type *EltTy = VType->getElementType(); 978 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 979 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 980 MatrixTy Result; 981 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 982 Value *GEP = computeVectorAddr( 983 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 984 Stride, Shape.getStride(), EltTy, Builder); 985 Value *Vector = Builder.CreateAlignedLoad( 986 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 987 IsVolatile, "col.load"); 988 989 Result.addVector(Vector); 990 } 991 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 992 Result.getNumVectors()); 993 } 994 995 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 996 /// starting at \p MatrixPtr[I][J]. 997 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 998 ShapeInfo MatrixShape, Value *I, Value *J, 999 ShapeInfo ResultShape, Type *EltTy, 1000 IRBuilder<> &Builder) { 1001 1002 Value *Offset = Builder.CreateAdd( 1003 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1004 1005 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1006 Value *EltPtr = 1007 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1008 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1009 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1010 ResultShape.NumColumns); 1011 Type *TilePtrTy = PointerType::get(TileTy, AS); 1012 Value *TilePtr = 1013 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1014 1015 return loadMatrix(TileTy, TilePtr, Align, 1016 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1017 ResultShape, Builder); 1018 } 1019 1020 /// Lower a load instruction with shape information. 1021 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 1022 bool IsVolatile, ShapeInfo Shape) { 1023 IRBuilder<> Builder(Inst); 1024 finalizeLowering(Inst, 1025 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 1026 Shape, Builder), 1027 Builder); 1028 } 1029 1030 /// Lowers llvm.matrix.column.major.load. 1031 /// 1032 /// The intrinsic loads a matrix from memory using a stride between columns. 1033 void LowerColumnMajorLoad(CallInst *Inst) { 1034 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1035 "Intrinsic only supports column-major layout!"); 1036 Value *Ptr = Inst->getArgOperand(0); 1037 Value *Stride = Inst->getArgOperand(1); 1038 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1039 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1040 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1041 } 1042 1043 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1044 /// MatrixPtr[I][J]. 1045 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1046 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1047 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1048 Value *Offset = Builder.CreateAdd( 1049 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1050 1051 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1052 Value *EltPtr = 1053 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1054 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1055 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1056 StoreVal.getNumColumns()); 1057 Type *TilePtrTy = PointerType::get(TileTy, AS); 1058 Value *TilePtr = 1059 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1060 1061 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 1062 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1063 } 1064 1065 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1066 /// vectors. 1067 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1068 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1069 IRBuilder<> &Builder) { 1070 auto VType = cast<VectorType>(Ty); 1071 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 1072 for (auto Vec : enumerate(StoreVal.vectors())) { 1073 Value *GEP = computeVectorAddr( 1074 EltPtr, 1075 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1076 Vec.index()), 1077 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1078 Builder.CreateAlignedStore(Vec.value(), GEP, 1079 getAlignForIndex(Vec.index(), Stride, 1080 VType->getElementType(), 1081 MAlign), 1082 IsVolatile); 1083 } 1084 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1085 StoreVal.getNumVectors()); 1086 } 1087 1088 /// Lower a store instruction with shape information. 1089 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1090 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1091 IRBuilder<> Builder(Inst); 1092 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1093 finalizeLowering(Inst, 1094 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1095 IsVolatile, Builder), 1096 Builder); 1097 } 1098 1099 /// Lowers llvm.matrix.column.major.store. 1100 /// 1101 /// The intrinsic store a matrix back memory using a stride between columns. 1102 void LowerColumnMajorStore(CallInst *Inst) { 1103 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1104 "Intrinsic only supports column-major layout!"); 1105 Value *Matrix = Inst->getArgOperand(0); 1106 Value *Ptr = Inst->getArgOperand(1); 1107 Value *Stride = Inst->getArgOperand(2); 1108 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1109 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1110 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1111 } 1112 1113 // Set elements I..I+NumElts-1 to Block 1114 Value *insertVector(Value *Col, unsigned I, Value *Block, 1115 IRBuilder<> &Builder) { 1116 1117 // First, bring Block to the same size as Col 1118 unsigned BlockNumElts = 1119 cast<FixedVectorType>(Block->getType())->getNumElements(); 1120 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1121 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1122 1123 Block = Builder.CreateShuffleVector( 1124 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1125 1126 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1127 // 8, 4, 5, 6 1128 SmallVector<int, 16> Mask; 1129 unsigned i; 1130 for (i = 0; i < I; i++) 1131 Mask.push_back(i); 1132 1133 unsigned VecNumElts = 1134 cast<FixedVectorType>(Col->getType())->getNumElements(); 1135 for (; i < I + BlockNumElts; i++) 1136 Mask.push_back(i - I + VecNumElts); 1137 1138 for (; i < VecNumElts; i++) 1139 Mask.push_back(i); 1140 1141 return Builder.CreateShuffleVector(Col, Block, Mask); 1142 } 1143 1144 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1145 IRBuilder<> &Builder, bool AllowContraction, 1146 unsigned &NumComputeOps) { 1147 NumComputeOps += getNumOps(A->getType()); 1148 if (!Sum) 1149 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1150 1151 if (UseFPOp) { 1152 if (AllowContraction) { 1153 // Use fmuladd for floating point operations and let the backend decide 1154 // if that's profitable. 1155 Function *FMulAdd = Intrinsic::getDeclaration( 1156 Func.getParent(), Intrinsic::fmuladd, A->getType()); 1157 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1158 } 1159 NumComputeOps += getNumOps(A->getType()); 1160 Value *Mul = Builder.CreateFMul(A, B); 1161 return Builder.CreateFAdd(Sum, Mul); 1162 } 1163 1164 NumComputeOps += getNumOps(A->getType()); 1165 Value *Mul = Builder.CreateMul(A, B); 1166 return Builder.CreateAdd(Sum, Mul); 1167 } 1168 1169 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1170 /// users with shape information, there's nothing to do: they will use the 1171 /// cached value when they are lowered. For other users, \p Matrix is 1172 /// flattened and the uses are updated to use it. Also marks \p Inst for 1173 /// deletion. 1174 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1175 IRBuilder<> &Builder) { 1176 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1177 (void)inserted; 1178 assert(inserted.second && "multiple matrix lowering mapping"); 1179 1180 ToRemove.push_back(Inst); 1181 Value *Flattened = nullptr; 1182 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1183 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1184 if (!Flattened) 1185 Flattened = Matrix.embedInVector(Builder); 1186 U.set(Flattened); 1187 } 1188 } 1189 } 1190 1191 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1192 /// addition. 1193 /// 1194 /// We can fold a transpose into the operand that is used to extract scalars. 1195 /// This is the first operands with row-major and the second with 1196 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1197 /// operand is transposed. 1198 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1199 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1200 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1201 const unsigned VF = std::max<unsigned>( 1202 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1203 .getFixedSize() / 1204 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1205 1U); 1206 unsigned R = Result.getNumRows(); 1207 unsigned C = Result.getNumColumns(); 1208 unsigned M = A.getNumColumns(); 1209 1210 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1211 assert(A.isColumnMajor() == B.isColumnMajor() && 1212 Result.isColumnMajor() == A.isColumnMajor() && 1213 "operands must agree on matrix layout"); 1214 unsigned NumComputeOps = 0; 1215 1216 Builder.setFastMathFlags(FMF); 1217 1218 if (A.isColumnMajor()) { 1219 // Multiply columns from the first operand with scalars from the second 1220 // operand. Then move along the K axes and accumulate the columns. With 1221 // this the adds can be vectorized without reassociation. 1222 for (unsigned J = 0; J < C; ++J) { 1223 unsigned BlockSize = VF; 1224 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1225 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1226 1227 for (unsigned I = 0; I < R; I += BlockSize) { 1228 // Gradually lower the vectorization factor to cover the remainder. 1229 while (I + BlockSize > R) 1230 BlockSize /= 2; 1231 1232 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1233 : nullptr; 1234 for (unsigned K = 0; K < M; ++K) { 1235 Value *L = A.extractVector(I, K, BlockSize, Builder); 1236 Value *RH = Builder.CreateExtractElement( 1237 B.getColumn(IsScalarMatrixTransposed ? K : J), 1238 IsScalarMatrixTransposed ? J : K); 1239 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1240 Sum = 1241 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1242 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1243 } 1244 Result.setVector(J, 1245 insertVector(Result.getVector(J), I, Sum, Builder)); 1246 } 1247 } 1248 } else { 1249 // Multiply rows from the second operand with scalars from the first 1250 // operand. Then move along the K axes and accumulate the rows. With this 1251 // the adds can be vectorized without reassociation. 1252 for (unsigned I = 0; I < R; ++I) { 1253 unsigned BlockSize = VF; 1254 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1255 for (unsigned J = 0; J < C; J += BlockSize) { 1256 // Gradually lower the vectorization factor to cover the remainder. 1257 while (J + BlockSize > C) 1258 BlockSize /= 2; 1259 1260 Value *Sum = nullptr; 1261 for (unsigned K = 0; K < M; ++K) { 1262 Value *R = B.extractVector(K, J, BlockSize, Builder); 1263 Value *LH = Builder.CreateExtractElement( 1264 A.getVector(IsScalarMatrixTransposed ? K : I), 1265 IsScalarMatrixTransposed ? I : K); 1266 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1267 Sum = 1268 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1269 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1270 } 1271 Result.setVector(I, 1272 insertVector(Result.getVector(I), J, Sum, Builder)); 1273 } 1274 } 1275 } 1276 Result.addNumComputeOps(NumComputeOps); 1277 } 1278 1279 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1280 /// copying it to a new location. This new or otherwise the original location 1281 /// is returned. 1282 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1283 CallInst *MatMul) { 1284 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1285 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1286 1287 // If we can statically determine noalias we're good. 1288 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1289 return Load->getPointerOperand(); 1290 1291 // Create code to check if the memory locations of the Load and Store 1292 // overlap and if they do, copy Load's operand to a new buffer. 1293 1294 // First, create new blocks for 2n part of the check and the copy. 1295 BasicBlock *Check0 = MatMul->getParent(); 1296 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1297 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1298 // as we adjust Check0 and Check1's branches. 1299 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1300 for (BasicBlock *Succ : successors(Check0)) 1301 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1302 1303 BasicBlock *Check1 = 1304 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1305 nullptr, "alias_cont"); 1306 BasicBlock *Copy = 1307 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1308 nullptr, "copy"); 1309 BasicBlock *Fusion = 1310 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1311 nullptr, "no_alias"); 1312 1313 // Check if the loaded memory location begins before the end of the store 1314 // location. If the condition holds, they might overlap, otherwise they are 1315 // guaranteed to not overlap. 1316 IRBuilder<> Builder(MatMul); 1317 Check0->getTerminator()->eraseFromParent(); 1318 Builder.SetInsertPoint(Check0); 1319 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1320 Value *StoreBegin = Builder.CreatePtrToInt( 1321 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1322 Value *StoreEnd = Builder.CreateAdd( 1323 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1324 "store.end", true, true); 1325 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1326 IntPtrTy, "load.begin"); 1327 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1328 Fusion); 1329 1330 // Check if the store begins before the end of the load location. If the 1331 // condition holds, they alias, otherwise they are guaranteed to not 1332 // overlap. 1333 Check1->getTerminator()->eraseFromParent(); 1334 Builder.SetInsertPoint(Check1, Check1->begin()); 1335 Value *LoadEnd = Builder.CreateAdd( 1336 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1337 "load.end", true, true); 1338 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1339 Fusion); 1340 1341 // Copy load operand to new alloca. 1342 Builder.SetInsertPoint(Copy, Copy->begin()); 1343 auto *VT = cast<FixedVectorType>(Load->getType()); 1344 // Use an array type for the alloca, to avoid potentially huge alignment 1345 // requirements for large vector types. 1346 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1347 AllocaInst *Alloca = 1348 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1349 Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo()); 1350 1351 Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(), 1352 Load->getAlign(), LoadLoc.Size.getValue()); 1353 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1354 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1355 PHI->addIncoming(Load->getPointerOperand(), Check0); 1356 PHI->addIncoming(Load->getPointerOperand(), Check1); 1357 PHI->addIncoming(BC, Copy); 1358 1359 // Adjust DT. 1360 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1361 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1362 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1363 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1364 DT->applyUpdates(DTUpdates); 1365 return PHI; 1366 } 1367 1368 bool isFusionProfitable(CallInst *MatMul) { 1369 if (ForceFusion) 1370 return true; 1371 1372 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1373 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1374 1375 const unsigned R = LShape.NumRows; 1376 const unsigned C = RShape.NumColumns; 1377 const unsigned M = LShape.NumColumns; 1378 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1379 1380 const unsigned VF = std::max<unsigned>( 1381 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1382 .getFixedSize() / 1383 EltType->getPrimitiveSizeInBits().getFixedSize(), 1384 1U); 1385 1386 // Cost model for tiling 1387 // 1388 // For tiling to be beneficial, we need reuse either along the R or 1389 // the C axis. We vectorize along the R axis so that means at least 1390 // 3 elements. 1391 // TODO: Also consider cost of copying if operands alias. 1392 if (R <= VF && C == 1) 1393 return false; 1394 // Then we need enough elements to exceed the number of vector 1395 // registers we have. Note that this is an oversimplification since 1396 // fusing also takes some extra loads which may exceed the number of 1397 // reloads necessary. 1398 unsigned Op0Regs = (R + VF - 1) / VF * M; 1399 unsigned Op1Regs = (M + VF - 1) / VF * C; 1400 return Op0Regs + Op1Regs > 1401 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1402 } 1403 1404 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1405 MatrixTy Res; 1406 auto *ColumType = FixedVectorType::get(EltType, R); 1407 for (unsigned I = 0; I < C; ++I) 1408 Res.addVector(ConstantAggregateZero::get(ColumType)); 1409 return Res; 1410 } 1411 1412 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1413 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1414 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1415 1416 // Create the main tiling loop nest. 1417 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1418 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1419 Instruction *InsertI = cast<Instruction>(MatMul); 1420 BasicBlock *Start = InsertI->getParent(); 1421 BasicBlock *End = 1422 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1423 IRBuilder<> Builder(MatMul); 1424 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1425 1426 Type *TileVecTy = 1427 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1428 MatrixTy TileResult; 1429 // Insert in the inner loop header. 1430 Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); 1431 // Create PHI nodes for the result columns to accumulate across iterations. 1432 SmallVector<PHINode *, 4> ColumnPhis; 1433 for (unsigned I = 0; I < TileSize; I++) { 1434 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1435 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1436 TI.RowLoopHeader->getSingleSuccessor()); 1437 TileResult.addVector(Phi); 1438 ColumnPhis.push_back(Phi); 1439 } 1440 1441 // Insert in the inner loop body, which computes 1442 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1443 Builder.SetInsertPoint(InnerBody->getTerminator()); 1444 // Load tiles of the operands. 1445 MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, 1446 {TileSize, TileSize}, EltType, Builder); 1447 MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, 1448 {TileSize, TileSize}, EltType, Builder); 1449 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1450 getFastMathFlags(MatMul)); 1451 // Store result after the inner loop is done. 1452 Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); 1453 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1454 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1455 TI.CurrentRow, TI.CurrentCol, EltType, Builder); 1456 1457 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1458 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); 1459 1460 // Force unrolling of a few iterations of the inner loop, to make sure there 1461 // is enough work per iteration. 1462 // FIXME: The unroller should make this decision directly instead, but 1463 // currently the cost-model is not up to the task. 1464 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1465 addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), 1466 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1467 } 1468 1469 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1470 StoreInst *Store, 1471 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1472 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1473 "Tiling only supported for column-major matrixes at the moment!"); 1474 if (!isFusionProfitable(MatMul)) 1475 return; 1476 1477 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1478 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1479 1480 const unsigned R = LShape.NumRows; 1481 const unsigned C = RShape.NumColumns; 1482 const unsigned M = LShape.NumColumns; 1483 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1484 1485 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1486 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1487 Value *CPtr = Store->getPointerOperand(); 1488 1489 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1490 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1491 else { 1492 IRBuilder<> Builder(Store); 1493 for (unsigned J = 0; J < C; J += TileSize) 1494 for (unsigned I = 0; I < R; I += TileSize) { 1495 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1496 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1497 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1498 1499 for (unsigned K = 0; K < M; K += TileSize) { 1500 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1501 MatrixTy A = 1502 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1503 LShape, Builder.getInt64(I), Builder.getInt64(K), 1504 {TileR, TileM}, EltType, Builder); 1505 MatrixTy B = 1506 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1507 RShape, Builder.getInt64(K), Builder.getInt64(J), 1508 {TileM, TileC}, EltType, Builder); 1509 emitMatrixMultiply(Res, A, B, Builder, true, false, 1510 getFastMathFlags(MatMul)); 1511 } 1512 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1513 Builder.getInt64(I), Builder.getInt64(J), EltType, 1514 Builder); 1515 } 1516 } 1517 1518 // Mark eliminated instructions as fused and remove them. 1519 FusedInsts.insert(Store); 1520 FusedInsts.insert(MatMul); 1521 Store->eraseFromParent(); 1522 MatMul->eraseFromParent(); 1523 if (LoadOp0->hasNUses(0)) { 1524 FusedInsts.insert(LoadOp0); 1525 LoadOp0->eraseFromParent(); 1526 } 1527 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 1528 FusedInsts.insert(LoadOp1); 1529 LoadOp1->eraseFromParent(); 1530 } 1531 } 1532 1533 /// Try to lower matrix multiply chains by fusing operations. 1534 /// 1535 /// Call finalizeLowering on lowered instructions. Instructions that are 1536 /// completely eliminated by fusion are added to \p FusedInsts. 1537 void LowerMatrixMultiplyFused(CallInst *MatMul, 1538 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1539 if (!FuseMatrix || !DT) 1540 return; 1541 1542 assert(AA && LI && "Analyses should be available"); 1543 1544 Value *A = MatMul->getArgOperand(0); 1545 Value *B = MatMul->getArgOperand(1); 1546 1547 // We can fold the transpose into the operand that is used to fetch scalars. 1548 Value *T; 1549 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1550 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1551 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1552 IRBuilder<> Builder(MatMul); 1553 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1554 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1555 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1556 const unsigned R = LShape.NumRows; 1557 const unsigned M = LShape.NumColumns; 1558 const unsigned C = RShape.NumColumns; 1559 1560 MatrixTy MA; 1561 MatrixTy MB; 1562 1563 Value *Transpose; 1564 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1565 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1566 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1567 Transpose = B; 1568 } else { 1569 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1570 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1571 Transpose = A; 1572 } 1573 1574 // Initialize the output 1575 MatrixTy Result(R, C, EltType); 1576 1577 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1578 getFastMathFlags(MatMul)); 1579 1580 FusedInsts.insert(MatMul); 1581 if (Transpose->hasOneUse()) { 1582 FusedInsts.insert(cast<Instruction>(Transpose)); 1583 ToRemove.push_back(cast<Instruction>(Transpose)); 1584 // TODO: add a fake entry for the folded instruction so that this is 1585 // included in the expression in the remark. 1586 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1587 } 1588 finalizeLowering(MatMul, Result, Builder); 1589 return; 1590 } 1591 1592 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1593 return; 1594 1595 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1596 // since the single store user will be lowered as part of this. 1597 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1598 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1599 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1600 if (LoadOp0 && LoadOp1 && Store) { 1601 // The store address must dominate the MatMul instruction, otherwise 1602 // we create invalid IR. 1603 SetVector<Value *> WorkList; 1604 WorkList.insert(Store->getOperand(1)); 1605 SmallVector<Instruction *> ToHoist; 1606 for (unsigned I = 0; I != WorkList.size(); ++I) { 1607 Value *Current = WorkList[I]; 1608 auto *CurrI = dyn_cast<Instruction>(Current); 1609 if (!CurrI) 1610 continue; 1611 if (isa<PHINode>(CurrI)) 1612 return; 1613 if (DT->dominates(CurrI, MatMul)) 1614 continue; 1615 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1616 return; 1617 ToHoist.push_back(CurrI); 1618 WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1619 } 1620 1621 sort(ToHoist, [this](Instruction *A, Instruction *B) { 1622 return DT->dominates(A, B); 1623 }); 1624 for (Instruction *I : ToHoist) 1625 I->moveBefore(MatMul); 1626 1627 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1628 return; 1629 } 1630 } 1631 1632 /// Lowers llvm.matrix.multiply. 1633 void LowerMultiply(CallInst *MatMul) { 1634 IRBuilder<> Builder(MatMul); 1635 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1636 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1637 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1638 1639 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1640 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1641 assert(Lhs.getElementType() == Rhs.getElementType() && 1642 "Matrix multiply argument element types do not match."); 1643 1644 const unsigned R = LShape.NumRows; 1645 const unsigned C = RShape.NumColumns; 1646 assert(LShape.NumColumns == RShape.NumRows); 1647 1648 // Initialize the output 1649 MatrixTy Result(R, C, EltType); 1650 assert(Lhs.getElementType() == Result.getElementType() && 1651 "Matrix multiply result element type does not match arguments."); 1652 1653 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1654 getFastMathFlags(MatMul)); 1655 finalizeLowering(MatMul, Result, Builder); 1656 } 1657 1658 /// Lowers llvm.matrix.transpose. 1659 void LowerTranspose(CallInst *Inst) { 1660 MatrixTy Result; 1661 IRBuilder<> Builder(Inst); 1662 Value *InputVal = Inst->getArgOperand(0); 1663 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1664 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1665 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1666 1667 const unsigned NewNumVecs = 1668 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1669 const unsigned NewNumElts = 1670 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1671 1672 for (unsigned I = 0; I < NewNumVecs; ++I) { 1673 // Build a single result vector. First initialize it. 1674 Value *ResultVector = UndefValue::get( 1675 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1676 // Go through the old elements and insert it into the resulting vector. 1677 for (auto J : enumerate(InputMatrix.vectors())) { 1678 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1679 // Row and column indices are transposed. 1680 ResultVector = 1681 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1682 } 1683 Result.addVector(ResultVector); 1684 } 1685 1686 // TODO: Improve estimate of operations needed for transposes. Currently we 1687 // just count the insertelement/extractelement instructions, but do not 1688 // account for later simplifications/combines. 1689 finalizeLowering( 1690 Inst, 1691 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1692 .addNumExposedTransposes(1), 1693 Builder); 1694 } 1695 1696 /// Lower load instructions, if shape information is available. 1697 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1698 auto I = ShapeMap.find(Inst); 1699 if (I == ShapeMap.end()) 1700 return false; 1701 1702 LowerLoad(Inst, Ptr, Inst->getAlign(), 1703 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1704 I->second); 1705 return true; 1706 } 1707 1708 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1709 IRBuilder<> &Builder) { 1710 auto I = ShapeMap.find(StoredVal); 1711 if (I == ShapeMap.end()) 1712 return false; 1713 1714 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1715 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1716 I->second); 1717 return true; 1718 } 1719 1720 /// Lower binary operators, if shape information is available. 1721 bool VisitBinaryOperator(BinaryOperator *Inst) { 1722 auto I = ShapeMap.find(Inst); 1723 if (I == ShapeMap.end()) 1724 return false; 1725 1726 Value *Lhs = Inst->getOperand(0); 1727 Value *Rhs = Inst->getOperand(1); 1728 1729 IRBuilder<> Builder(Inst); 1730 ShapeInfo &Shape = I->second; 1731 1732 MatrixTy Result; 1733 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1734 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1735 assert(A.isColumnMajor() == B.isColumnMajor() && 1736 Result.isColumnMajor() == A.isColumnMajor() && 1737 "operands must agree on matrix layout"); 1738 1739 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1740 1741 // Helper to perform binary op on vectors. 1742 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1743 switch (Inst->getOpcode()) { 1744 case Instruction::Add: 1745 return Builder.CreateAdd(LHS, RHS); 1746 case Instruction::Mul: 1747 return Builder.CreateMul(LHS, RHS); 1748 case Instruction::Sub: 1749 return Builder.CreateSub(LHS, RHS); 1750 case Instruction::FAdd: 1751 return Builder.CreateFAdd(LHS, RHS); 1752 case Instruction::FMul: 1753 return Builder.CreateFMul(LHS, RHS); 1754 case Instruction::FSub: 1755 return Builder.CreateFSub(LHS, RHS); 1756 default: 1757 llvm_unreachable("Unsupported binary operator for matrix"); 1758 } 1759 }; 1760 1761 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1762 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1763 1764 finalizeLowering(Inst, 1765 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1766 Result.getNumVectors()), 1767 Builder); 1768 return true; 1769 } 1770 1771 /// Lower unary operators, if shape information is available. 1772 bool VisitUnaryOperator(UnaryOperator *Inst) { 1773 auto I = ShapeMap.find(Inst); 1774 if (I == ShapeMap.end()) 1775 return false; 1776 1777 Value *Op = Inst->getOperand(0); 1778 1779 IRBuilder<> Builder(Inst); 1780 ShapeInfo &Shape = I->second; 1781 1782 MatrixTy Result; 1783 MatrixTy M = getMatrix(Op, Shape, Builder); 1784 1785 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1786 1787 // Helper to perform unary op on vectors. 1788 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1789 switch (Inst->getOpcode()) { 1790 case Instruction::FNeg: 1791 return Builder.CreateFNeg(Op); 1792 default: 1793 llvm_unreachable("Unsupported unary operator for matrix"); 1794 } 1795 }; 1796 1797 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1798 Result.addVector(BuildVectorOp(M.getVector(I))); 1799 1800 finalizeLowering(Inst, 1801 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1802 Result.getNumVectors()), 1803 Builder); 1804 return true; 1805 } 1806 1807 /// Helper to linearize a matrix expression tree into a string. Currently 1808 /// matrix expressions are linarized by starting at an expression leaf and 1809 /// linearizing bottom up. 1810 struct ExprLinearizer { 1811 unsigned LengthToBreak = 100; 1812 std::string Str; 1813 raw_string_ostream Stream; 1814 unsigned LineLength = 0; 1815 const DataLayout &DL; 1816 1817 /// Mapping from instructions to matrixes. It is used to identify 1818 /// matrix instructions. 1819 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1820 1821 /// Mapping from values to the leaves of all expressions that the value is 1822 /// part of. 1823 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1824 1825 /// Set of matrix expressions in the scope of a given DISubprogram. 1826 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1827 1828 /// Leaf node of the expression to linearize. 1829 Value *Leaf; 1830 1831 /// Used to keep track of sub-expressions that get reused while linearizing 1832 /// the expression. Re-used sub-expressions are marked as (reused). 1833 SmallPtrSet<Value *, 8> ReusedExprs; 1834 1835 ExprLinearizer(const DataLayout &DL, 1836 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1837 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1838 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1839 Value *Leaf) 1840 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1841 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1842 1843 void indent(unsigned N) { 1844 LineLength += N; 1845 for (unsigned i = 0; i < N; i++) 1846 Stream << " "; 1847 } 1848 1849 void lineBreak() { 1850 Stream << "\n"; 1851 LineLength = 0; 1852 } 1853 1854 void maybeIndent(unsigned Indent) { 1855 if (LineLength >= LengthToBreak) 1856 lineBreak(); 1857 1858 if (LineLength == 0) 1859 indent(Indent); 1860 } 1861 1862 void write(StringRef S) { 1863 LineLength += S.size(); 1864 Stream << S; 1865 } 1866 1867 Value *getUnderlyingObjectThroughLoads(Value *V) { 1868 if (Value *Ptr = getPointerOperand(V)) 1869 return getUnderlyingObjectThroughLoads(Ptr); 1870 else if (V->getType()->isPointerTy()) 1871 return getUnderlyingObject(V); 1872 return V; 1873 } 1874 1875 /// Returns true if \p V is a matrix value in the given subprogram. 1876 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1877 1878 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1879 /// \p SS. 1880 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1881 auto M = Inst2Matrix.find(V); 1882 if (M == Inst2Matrix.end()) 1883 SS << "unknown"; 1884 else { 1885 SS << M->second.getNumRows(); 1886 SS << "x"; 1887 SS << M->second.getNumColumns(); 1888 } 1889 } 1890 1891 /// Write the called function name. Handles calls to llvm.matrix.* 1892 /// specially: we write the name, followed by the dimensions of the input 1893 /// matrixes, followed by the scalar type name. 1894 void writeFnName(CallInst *CI) { 1895 if (!CI->getCalledFunction()) 1896 write("<no called fn>"); 1897 else { 1898 StringRef Name = CI->getCalledFunction()->getName(); 1899 if (!Name.startswith("llvm.matrix")) { 1900 write(Name); 1901 return; 1902 } 1903 auto *II = cast<IntrinsicInst>(CI); 1904 write(Intrinsic::getBaseName(II->getIntrinsicID()) 1905 .drop_front(StringRef("llvm.matrix.").size())); 1906 write("."); 1907 std::string Tmp; 1908 raw_string_ostream SS(Tmp); 1909 1910 switch (II->getIntrinsicID()) { 1911 case Intrinsic::matrix_multiply: 1912 prettyPrintMatrixType(II->getOperand(0), SS); 1913 SS << "."; 1914 prettyPrintMatrixType(II->getOperand(1), SS); 1915 SS << "." << *II->getType()->getScalarType(); 1916 break; 1917 case Intrinsic::matrix_transpose: 1918 prettyPrintMatrixType(II->getOperand(0), SS); 1919 SS << "." << *II->getType()->getScalarType(); 1920 break; 1921 case Intrinsic::matrix_column_major_load: 1922 prettyPrintMatrixType(II, SS); 1923 SS << "." << *II->getType()->getScalarType(); 1924 break; 1925 case Intrinsic::matrix_column_major_store: 1926 prettyPrintMatrixType(II->getOperand(0), SS); 1927 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1928 break; 1929 default: 1930 llvm_unreachable("Unhandled case"); 1931 } 1932 SS.flush(); 1933 write(Tmp); 1934 } 1935 } 1936 1937 unsigned getNumShapeArgs(CallInst *CI) const { 1938 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1939 switch (II->getIntrinsicID()) { 1940 case Intrinsic::matrix_multiply: 1941 return 3; 1942 case Intrinsic::matrix_transpose: 1943 return 2; 1944 case Intrinsic::matrix_column_major_load: 1945 case Intrinsic::matrix_column_major_store: 1946 return 3; 1947 default: 1948 return 0; 1949 } 1950 } 1951 return 0; 1952 } 1953 1954 /// Special printing for values: for pointers, we print if they refer to an 1955 /// (function) external address or a stack address, for other values we 1956 /// either print the constant or "scalar"/"matrix" for other values. 1957 void write(Value *V) { 1958 V = getUnderlyingObjectThroughLoads(V); 1959 if (V->getType()->isPointerTy()) { 1960 if (isa<AllocaInst>(V)) { 1961 Stream << "stack addr"; 1962 LineLength += StringRef("stack addr").size(); 1963 } else { 1964 Stream << "addr"; 1965 LineLength += StringRef("addr").size(); 1966 } 1967 if (!V->getName().empty()) { 1968 Stream << " %" << V->getName() << ""; 1969 LineLength += V->getName().size() + 2; 1970 } 1971 return; 1972 } 1973 1974 std::string Tmp; 1975 raw_string_ostream TmpStream(Tmp); 1976 1977 if (auto *CI = dyn_cast<ConstantInt>(V)) 1978 TmpStream << CI->getValue(); 1979 else if (isa<Constant>(V)) 1980 TmpStream << "constant"; 1981 else { 1982 if (isMatrix(V)) 1983 TmpStream << "matrix"; 1984 else 1985 TmpStream << "scalar"; 1986 } 1987 TmpStream.flush(); 1988 Tmp = std::string(StringRef(Tmp).trim()); 1989 LineLength += Tmp.size(); 1990 Stream << Tmp; 1991 } 1992 1993 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1994 /// Expressions that are re-used multiple times are prefixed with (reused) 1995 /// at the re-used root instruction. 1996 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1997 bool ParentShared) { 1998 auto *I = cast<Instruction>(Expr); 1999 maybeIndent(Indent); 2000 SmallVector<Value *, 8> Ops; 2001 2002 // Is Expr shared with other expression leaves? 2003 bool ExprShared = false; 2004 2005 // Deal with shared subtrees. Mark them as shared, if required. 2006 if (!ParentShared) { 2007 auto SI = Shared.find(Expr); 2008 assert(SI != Shared.end() && SI->second.count(Leaf)); 2009 2010 for (Value *S : SI->second) { 2011 if (S == Leaf) 2012 continue; 2013 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2014 write("shared with remark at line " + std::to_string(DL.getLine()) + 2015 " column " + std::to_string(DL.getCol()) + " ("); 2016 } 2017 ExprShared = SI->second.size() > 1; 2018 } 2019 2020 bool Reused = !ReusedExprs.insert(Expr).second; 2021 if (Reused && !ParentReused) 2022 write("(reused) "); 2023 2024 if (auto *CI = dyn_cast<CallInst>(I)) { 2025 writeFnName(CI); 2026 2027 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2028 } else if (isa<BitCastInst>(Expr)) { 2029 // Special case bitcasts, which are used to materialize matrixes from 2030 // non-matrix ops. 2031 write("matrix"); 2032 return; 2033 } else { 2034 Ops.append(I->value_op_begin(), I->value_op_end()); 2035 write(std::string(I->getOpcodeName())); 2036 } 2037 2038 write(std::string("(")); 2039 2040 unsigned NumOpsToBreak = 1; 2041 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2042 NumOpsToBreak = 2; 2043 2044 for (Value *Op : Ops) { 2045 if (Ops.size() > NumOpsToBreak) 2046 lineBreak(); 2047 2048 maybeIndent(Indent + 1); 2049 if (isMatrix(Op)) 2050 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2051 else 2052 write(Op); 2053 if (Op != Ops.back()) 2054 write(", "); 2055 } 2056 2057 write(")"); 2058 } 2059 2060 const std::string &getResult() { 2061 Stream.flush(); 2062 return Str; 2063 } 2064 }; 2065 2066 /// Generate remarks for matrix operations in a function. To generate remarks 2067 /// for matrix expressions, the following approach is used: 2068 /// 1. Use the inlined-at debug information to group matrix operations to the 2069 /// DISubprograms they are contained in. 2070 /// 2. Collect leaves of matrix expressions (done in 2071 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2072 // mapping. Leaves are lowered matrix instructions without other matrix 2073 // users (like stores) in the current subprogram. 2074 /// 3. For each leaf, create a remark containing a linearizied version of the 2075 /// matrix expression. The expression is linearized by a recursive 2076 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2077 /// that multiple leaves can share sub-expressions. Shared subexpressions 2078 /// are explicitly marked as shared(). 2079 struct RemarkGenerator { 2080 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2081 OptimizationRemarkEmitter &ORE; 2082 Function &Func; 2083 const DataLayout &DL; 2084 2085 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2086 OptimizationRemarkEmitter &ORE, Function &Func) 2087 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2088 DL(Func.getParent()->getDataLayout()) {} 2089 2090 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2091 /// instructions in Inst2Matrix returning void or without any users in 2092 /// \p ExprsInSubprogram. Currently that should only include stores. 2093 SmallVector<Value *, 4> 2094 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2095 SmallVector<Value *, 4> Leaves; 2096 for (auto *Expr : ExprsInSubprogram) 2097 if (Expr->getType()->isVoidTy() || 2098 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2099 return ExprsInSubprogram.count(U); 2100 })) 2101 Leaves.push_back(Expr); 2102 return Leaves; 2103 } 2104 2105 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2106 /// to all visited expressions in \p Shared. Limit the matrix operations to 2107 /// the ones in \p ExprsInSubprogram. 2108 void collectSharedInfo(Value *Leaf, Value *V, 2109 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2110 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2111 2112 if (!ExprsInSubprogram.count(V)) 2113 return; 2114 2115 auto I = Shared.insert({V, {}}); 2116 I.first->second.insert(Leaf); 2117 2118 for (Value *Op : cast<Instruction>(V)->operand_values()) 2119 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2120 } 2121 2122 /// Calculate the number of exclusive and shared op counts for expression 2123 /// starting at \p V. Expressions used multiple times are counted once. 2124 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2125 std::pair<OpInfoTy, OpInfoTy> 2126 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2127 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2128 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2129 if (!ExprsInSubprogram.count(Root)) 2130 return {}; 2131 2132 // Already counted this expression. Stop. 2133 if (!ReusedExprs.insert(Root).second) 2134 return {}; 2135 2136 OpInfoTy SharedCount; 2137 OpInfoTy Count; 2138 2139 auto I = Shared.find(Root); 2140 auto CM = Inst2Matrix.find(Root); 2141 if (I->second.size() == 1) 2142 Count = CM->second.getOpInfo(); 2143 else 2144 SharedCount = CM->second.getOpInfo(); 2145 2146 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2147 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2148 Count += C.first; 2149 SharedCount += C.second; 2150 } 2151 return {Count, SharedCount}; 2152 } 2153 2154 void emitRemarks() { 2155 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2156 return; 2157 2158 // Map matrix operations to their containting subprograms, by traversing 2159 // the inlinedAt chain. If the function does not have a DISubprogram, we 2160 // only map them to the containing function. 2161 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2162 for (auto &KV : Inst2Matrix) { 2163 if (Func.getSubprogram()) { 2164 auto *I = cast<Instruction>(KV.first); 2165 DILocation *Context = I->getDebugLoc(); 2166 while (Context) { 2167 auto I = 2168 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 2169 I.first->second.push_back(KV.first); 2170 Context = DebugLoc(Context).getInlinedAt(); 2171 } 2172 } else { 2173 auto I = Subprog2Exprs.insert({nullptr, {}}); 2174 I.first->second.push_back(KV.first); 2175 } 2176 } 2177 for (auto &KV : Subprog2Exprs) { 2178 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2179 KV.second.end()); 2180 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2181 2182 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2183 for (Value *Leaf : Leaves) 2184 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2185 2186 // Generate remarks for each leaf. 2187 for (auto *L : Leaves) { 2188 2189 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2190 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2191 while (Context) { 2192 if (getSubprogram(Context->getScope()) == KV.first) { 2193 Loc = Context; 2194 break; 2195 } 2196 Context = DebugLoc(Context).getInlinedAt(); 2197 } 2198 2199 SmallPtrSet<Value *, 8> ReusedExprs; 2200 OpInfoTy Counts, SharedCounts; 2201 std::tie(Counts, SharedCounts) = 2202 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2203 2204 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2205 cast<Instruction>(L)->getParent()); 2206 2207 Rem << "Lowered with "; 2208 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2209 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2210 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2211 << " compute ops, " 2212 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2213 << " exposed transposes"; 2214 2215 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2216 SharedCounts.NumComputeOps > 0) { 2217 Rem << ",\nadditionally " 2218 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2219 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2220 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2221 << " compute ops" 2222 << " are shared with other expressions"; 2223 } 2224 2225 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2226 ORE.emit(Rem); 2227 } 2228 } 2229 } 2230 2231 std::string 2232 linearize(Value *L, 2233 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2234 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2235 const DataLayout &DL) { 2236 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2237 Lin.linearizeExpr(L, 0, false, false); 2238 return Lin.getResult(); 2239 } 2240 }; 2241 }; 2242 } // namespace 2243 2244 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2245 FunctionAnalysisManager &AM) { 2246 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2247 OptimizationRemarkEmitter *ORE = nullptr; 2248 AAResults *AA = nullptr; 2249 DominatorTree *DT = nullptr; 2250 LoopInfo *LI = nullptr; 2251 2252 if (!Minimal) { 2253 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2254 AA = &AM.getResult<AAManager>(F); 2255 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2256 LI = &AM.getResult<LoopAnalysis>(F); 2257 } 2258 2259 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2260 if (LMT.Visit()) { 2261 PreservedAnalyses PA; 2262 if (!Minimal) { 2263 PA.preserve<LoopAnalysis>(); 2264 PA.preserve<DominatorTreeAnalysis>(); 2265 } 2266 return PA; 2267 } 2268 return PreservedAnalyses::all(); 2269 } 2270 2271 void LowerMatrixIntrinsicsPass::printPipeline( 2272 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2273 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2274 OS, MapClassName2PassName); 2275 OS << "<"; 2276 if (Minimal) 2277 OS << "minimal"; 2278 OS << ">"; 2279 } 2280 2281 namespace { 2282 2283 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2284 public: 2285 static char ID; 2286 2287 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2288 initializeLowerMatrixIntrinsicsLegacyPassPass( 2289 *PassRegistry::getPassRegistry()); 2290 } 2291 2292 bool runOnFunction(Function &F) override { 2293 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2294 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 2295 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 2296 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2297 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2298 LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2299 bool C = LMT.Visit(); 2300 return C; 2301 } 2302 2303 void getAnalysisUsage(AnalysisUsage &AU) const override { 2304 AU.addRequired<TargetTransformInfoWrapperPass>(); 2305 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 2306 AU.addRequired<AAResultsWrapperPass>(); 2307 AU.addRequired<DominatorTreeWrapperPass>(); 2308 AU.addPreserved<DominatorTreeWrapperPass>(); 2309 AU.addRequired<LoopInfoWrapperPass>(); 2310 AU.addPreserved<LoopInfoWrapperPass>(); 2311 } 2312 }; 2313 } // namespace 2314 2315 static const char pass_name[] = "Lower the matrix intrinsics"; 2316 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2317 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2318 false, false) 2319 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 2320 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 2321 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 2322 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2323 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2324 false, false) 2325 2326 Pass *llvm::createLowerMatrixIntrinsicsPass() { 2327 return new LowerMatrixIntrinsicsLegacyPass(); 2328 } 2329 2330 namespace { 2331 2332 /// A lightweight version of the matrix lowering pass that only requires TTI. 2333 /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2334 /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2335 /// example with -O0. 2336 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2337 public: 2338 static char ID; 2339 2340 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2341 initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2342 *PassRegistry::getPassRegistry()); 2343 } 2344 2345 bool runOnFunction(Function &F) override { 2346 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2347 LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2348 bool C = LMT.Visit(); 2349 return C; 2350 } 2351 2352 void getAnalysisUsage(AnalysisUsage &AU) const override { 2353 AU.addRequired<TargetTransformInfoWrapperPass>(); 2354 AU.setPreservesCFG(); 2355 } 2356 }; 2357 } // namespace 2358 2359 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2360 char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2361 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2362 "lower-matrix-intrinsics-minimal", pass_name_minimal, 2363 false, false) 2364 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2365 "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2366 false) 2367 2368 Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2369 return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2370 } 2371