1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/MatrixBuilder.h" 38 #include "llvm/IR/PatternMatch.h" 39 #include "llvm/InitializePasses.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Alignment.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Scalar.h" 45 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46 #include "llvm/Transforms/Utils/LoopUtils.h" 47 #include "llvm/Transforms/Utils/MatrixUtils.h" 48 49 using namespace llvm; 50 using namespace PatternMatch; 51 52 #define DEBUG_TYPE "lower-matrix-intrinsics" 53 54 static cl::opt<bool> 55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 56 cl::desc("Enable/disable fusing matrix instructions.")); 57 // TODO: Allow and use non-square tiles. 58 static cl::opt<unsigned> TileSize( 59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 60 cl::desc( 61 "Tile size for matrix instruction fusion using square-shaped tiles.")); 62 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 63 cl::Hidden, 64 cl::desc("Generate loop nest for tiling.")); 65 static cl::opt<bool> ForceFusion( 66 "force-fuse-matrix", cl::init(false), cl::Hidden, 67 cl::desc("Force matrix instruction fusion even if not profitable.")); 68 static cl::opt<bool> AllowContractEnabled( 69 "matrix-allow-contract", cl::init(false), cl::Hidden, 70 cl::desc("Allow the use of FMAs if available and profitable. This may " 71 "result in different results, due to less rounding error.")); 72 73 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 74 75 static cl::opt<MatrixLayoutTy> MatrixLayout( 76 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 77 cl::desc("Sets the default matrix layout"), 78 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 79 "Use column-major layout"), 80 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 81 "Use row-major layout"))); 82 83 /// Helper function to either return Scope, if it is a subprogram or the 84 /// attached subprogram for a local scope. 85 static DISubprogram *getSubprogram(DIScope *Scope) { 86 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 87 return Subprogram; 88 return cast<DILocalScope>(Scope)->getSubprogram(); 89 } 90 91 namespace { 92 93 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 94 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 95 // assuming \p Stride elements between start two consecutive vectors. 96 // \p Stride must be >= \p NumElements. 97 // For column-major matrixes, the function computes the address of a column 98 // vectors and \p NumElements must be set to the number of elements in a column 99 // (= number of rows of the matrix). For row-major matrixes, the function 100 // computes the address of a row vector and \p NumElements must be set to the 101 // number of elements in a column (= number of columns of the matrix). 102 // 103 // Consider a 4x4 matrix in column-mjaor layout like below 104 // 105 // 0 1 2 3 106 // 0 v_0_0 v_0_1 v_0_2 v_0_3 107 // 1 v_1_0 v_1_1 v_1_2 v_1_3 108 // 2 v_2_0 v_2_1 v_2_2 v_2_3 109 // 3 v_3_0 v_3_1 v_3_2 v_3_3 110 111 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 112 // we need a pointer to the first element of the submatrix as base pointer. 113 // Then we can use computeVectorAddr to compute the addresses for the columns 114 // of the sub-matrix. 115 // 116 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 117 // -> just returns Base 118 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 119 // -> returns Base + (1 * 4) 120 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 121 // -> returns Base + (2 * 4) 122 // 123 // The graphic below illustrates the number of elements in a column (marked 124 // with |) and the number of skipped elements (marked with }). 125 // 126 // v_0_0 v_0_1 {v_0_2 {v_0_3 127 // Base Col 1 Col 2 128 // | | | 129 // v_1_0 |v_1_1 |v_1_2 |v_1_3 130 // v_2_0 |v_2_1 |v_2_2 |v_2_3 131 // v_3_0 {v_3_1 {v_3_2 v_3_3 132 // 133 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 134 unsigned NumElements, Type *EltType, 135 IRBuilder<> &Builder) { 136 137 assert((!isa<ConstantInt>(Stride) || 138 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 139 "Stride must be >= the number of elements in the result vector."); 140 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 141 142 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 143 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 144 145 // Get pointer to the start of the selected vector. Skip GEP creation, 146 // if we select vector 0. 147 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 148 VecStart = BasePtr; 149 else 150 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 151 152 // Cast elementwise vector start pointer to a pointer to a vector 153 // (EltType x NumElements)*. 154 auto *VecType = FixedVectorType::get(EltType, NumElements); 155 Type *VecPtrType = PointerType::get(VecType, AS); 156 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 157 } 158 159 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 160 /// 161 /// Currently, the lowering for each matrix intrinsic is done as follows: 162 /// 1. Propagate the shape information from intrinsics to connected 163 /// instructions. 164 /// 2. Lower instructions with shape information (assuming column-major layout). 165 /// The lowering works similarly using row-major layout. 166 /// 2.1. Get column vectors for each argument. If we already lowered the 167 /// definition of an argument, use the produced column vectors directly. 168 /// If not, split the operand vector containing an embedded matrix into 169 /// a set of column vectors, 170 /// 2.2. Lower the instruction in terms of column major operations, which 171 /// yields a set of column vectors containing result matrix. Note that we 172 /// lower all instructions that have shape information. Besides the 173 /// intrinsics, this includes stores for example. 174 /// 2.3. Update uses of the lowered instruction. If we have shape information 175 /// for a user, there is nothing to do, as we will look up the result 176 /// column matrix when lowering the user. For other uses, we embed the 177 /// result matrix in a flat vector and update the use. 178 /// 2.4. Cache the result column matrix for the instruction we lowered 179 /// 3. After we lowered all instructions in a function, remove the now 180 /// obsolete instructions. 181 /// 182 class LowerMatrixIntrinsics { 183 Function &Func; 184 const DataLayout &DL; 185 const TargetTransformInfo &TTI; 186 AliasAnalysis *AA; 187 DominatorTree *DT; 188 LoopInfo *LI; 189 OptimizationRemarkEmitter *ORE; 190 191 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 192 struct OpInfoTy { 193 /// Number of stores emitted to generate this matrix. 194 unsigned NumStores = 0; 195 /// Number of loads emitted to generate this matrix. 196 unsigned NumLoads = 0; 197 /// Number of compute operations emitted to generate this matrix. 198 unsigned NumComputeOps = 0; 199 /// Most of the time transposes can be fused with matrix multiplies or can 200 /// be folded away via algebraic simplifications. This is the number of 201 /// transposes that we failed to make "free" via such optimizations. 202 unsigned NumExposedTransposes = 0; 203 204 OpInfoTy &operator+=(const OpInfoTy &RHS) { 205 NumStores += RHS.NumStores; 206 NumLoads += RHS.NumLoads; 207 NumComputeOps += RHS.NumComputeOps; 208 NumExposedTransposes += RHS.NumExposedTransposes; 209 return *this; 210 } 211 }; 212 213 /// Wrapper class representing a matrix as a set of vectors, either in row or 214 /// column major layout. All vectors must have the same vector type. 215 class MatrixTy { 216 SmallVector<Value *, 16> Vectors; 217 218 OpInfoTy OpInfo; 219 220 bool IsColumnMajor = true; 221 222 public: 223 MatrixTy() 224 : Vectors(), 225 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 226 MatrixTy(ArrayRef<Value *> Vectors) 227 : Vectors(Vectors.begin(), Vectors.end()), 228 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 229 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 230 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 231 232 unsigned D = isColumnMajor() ? NumColumns : NumRows; 233 for (unsigned J = 0; J < D; ++J) 234 addVector(UndefValue::get(FixedVectorType::get( 235 EltTy, isColumnMajor() ? NumRows : NumColumns))); 236 } 237 238 Value *getVector(unsigned i) const { return Vectors[i]; } 239 Value *getColumn(unsigned i) const { 240 assert(isColumnMajor() && "only supported for column-major matrixes"); 241 return Vectors[i]; 242 } 243 Value *getRow(unsigned i) const { 244 assert(!isColumnMajor() && "only supported for row-major matrixes"); 245 return Vectors[i]; 246 } 247 248 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 249 250 Type *getElementType() const { return getVectorTy()->getElementType(); } 251 252 unsigned getNumVectors() const { 253 if (isColumnMajor()) 254 return getNumColumns(); 255 return getNumRows(); 256 } 257 258 unsigned getNumColumns() const { 259 if (isColumnMajor()) 260 return Vectors.size(); 261 else { 262 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 263 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 264 } 265 } 266 unsigned getNumRows() const { 267 if (isColumnMajor()) { 268 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 269 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 270 } else 271 return Vectors.size(); 272 } 273 274 void addVector(Value *V) { Vectors.push_back(V); } 275 VectorType *getColumnTy() { 276 assert(isColumnMajor() && "only supported for column-major matrixes"); 277 return getVectorTy(); 278 } 279 280 VectorType *getVectorTy() const { 281 return cast<VectorType>(Vectors[0]->getType()); 282 } 283 284 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 285 assert(isColumnMajor() && 286 "columns() only supported for column-major matrixes"); 287 return make_range(Vectors.begin(), Vectors.end()); 288 } 289 290 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 291 return make_range(Vectors.begin(), Vectors.end()); 292 } 293 294 /// Embed the vectors of the matrix into a flat vector by concatenating 295 /// them. 296 Value *embedInVector(IRBuilder<> &Builder) const { 297 return Vectors.size() == 1 ? Vectors[0] 298 : concatenateVectors(Builder, Vectors); 299 } 300 301 MatrixTy &addNumLoads(unsigned N) { 302 OpInfo.NumLoads += N; 303 return *this; 304 } 305 306 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 307 308 MatrixTy &addNumStores(unsigned N) { 309 OpInfo.NumStores += N; 310 return *this; 311 } 312 313 MatrixTy &addNumExposedTransposes(unsigned N) { 314 OpInfo.NumExposedTransposes += N; 315 return *this; 316 } 317 318 MatrixTy &addNumComputeOps(unsigned N) { 319 OpInfo.NumComputeOps += N; 320 return *this; 321 } 322 323 unsigned getNumStores() const { return OpInfo.NumStores; } 324 unsigned getNumLoads() const { return OpInfo.NumLoads; } 325 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 326 327 const OpInfoTy &getOpInfo() const { return OpInfo; } 328 329 bool isColumnMajor() const { return IsColumnMajor; } 330 331 unsigned getStride() const { 332 if (isColumnMajor()) 333 return getNumRows(); 334 return getNumColumns(); 335 } 336 337 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 338 /// matrix is column-major, the result vector is extracted from a column 339 /// vector, otherwise from a row vector. 340 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 341 IRBuilder<> &Builder) const { 342 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 343 return Builder.CreateShuffleVector( 344 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 345 "block"); 346 } 347 }; 348 349 struct ShapeInfo { 350 unsigned NumRows; 351 unsigned NumColumns; 352 353 bool IsColumnMajor; 354 355 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 356 : NumRows(NumRows), NumColumns(NumColumns), 357 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 358 359 ShapeInfo(Value *NumRows, Value *NumColumns) 360 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 361 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 362 363 bool operator==(const ShapeInfo &other) { 364 return NumRows == other.NumRows && NumColumns == other.NumColumns; 365 } 366 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 367 368 /// Returns true if shape-information is defined, meaning both dimensions 369 /// are != 0. 370 operator bool() const { 371 assert(NumRows == 0 || NumColumns != 0); 372 return NumRows != 0; 373 } 374 375 unsigned getStride() const { 376 if (IsColumnMajor) 377 return NumRows; 378 return NumColumns; 379 } 380 381 unsigned getNumVectors() const { 382 if (IsColumnMajor) 383 return NumColumns; 384 return NumRows; 385 } 386 }; 387 388 /// Maps instructions to their shape information. The shape information 389 /// describes the shape to be used while lowering. This matches the shape of 390 /// the result value of the instruction, with the only exceptions being store 391 /// instructions and the matrix_column_major_store intrinsics. For those, the 392 /// shape information indicates that those instructions should be lowered 393 /// using shape information as well. A ValueMap is used so that when 394 /// sub-passes like optimizeTransposes performs RAUW the map stays 395 /// up-to-date. 396 ValueMap<Value *, ShapeInfo> ShapeMap; 397 398 /// List of instructions to remove. While lowering, we are not replacing all 399 /// users of a lowered instruction, if shape information is available and 400 /// those need to be removed after we finished lowering. 401 SmallVector<Instruction *, 16> ToRemove; 402 403 /// Map from instructions to their produced column matrix. 404 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 405 406 private: 407 static FastMathFlags getFastMathFlags(Instruction *Inst) { 408 FastMathFlags FMF; 409 410 if (isa<FPMathOperator>(*Inst)) 411 FMF = Inst->getFastMathFlags(); 412 413 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 414 415 return FMF; 416 } 417 418 public: 419 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 420 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 421 OptimizationRemarkEmitter *ORE) 422 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 423 LI(LI), ORE(ORE) {} 424 425 unsigned getNumOps(Type *VT) { 426 assert(isa<VectorType>(VT) && "Expected vector type"); 427 return getNumOps(VT->getScalarType(), 428 cast<FixedVectorType>(VT)->getNumElements()); 429 } 430 431 /// Is this the minimal version executed in the backend pipelines. 432 bool isMinimal() const { 433 return !DT; 434 } 435 436 /// Return the estimated number of vector ops required for an operation on 437 /// \p VT * N. 438 unsigned getNumOps(Type *ST, unsigned N) { 439 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 440 double(TTI.getRegisterBitWidth( 441 TargetTransformInfo::RGK_FixedWidthVector) 442 .getFixedSize())); 443 } 444 445 /// Return the set of vectors that a matrix value is lowered to. 446 /// 447 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 448 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 449 /// into vectors. 450 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 451 IRBuilder<> &Builder) { 452 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 453 assert(VType && "MatrixVal must be a vector type"); 454 assert(cast<FixedVectorType>(VType)->getNumElements() == 455 SI.NumRows * SI.NumColumns && 456 "The vector size must match the number of matrix elements"); 457 458 // Check if we lowered MatrixVal using shape information. In that case, 459 // return the existing matrix, if it matches the requested shape 460 // information. If there is a mis-match, embed the result in a flat 461 // vector and split it later. 462 auto Found = Inst2ColumnMatrix.find(MatrixVal); 463 if (Found != Inst2ColumnMatrix.end()) { 464 MatrixTy &M = Found->second; 465 // Return the found matrix, if its shape matches the requested shape 466 // information 467 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 468 return M; 469 470 MatrixVal = M.embedInVector(Builder); 471 } 472 473 // Otherwise split MatrixVal. 474 SmallVector<Value *, 16> SplitVecs; 475 for (unsigned MaskStart = 0; 476 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 477 MaskStart += SI.getStride()) { 478 Value *V = Builder.CreateShuffleVector( 479 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 480 "split"); 481 SplitVecs.push_back(V); 482 } 483 484 return {SplitVecs}; 485 } 486 487 /// If \p V already has a known shape return false. Otherwise set the shape 488 /// for instructions that support it. 489 bool setShapeInfo(Value *V, ShapeInfo Shape) { 490 assert(Shape && "Shape not set"); 491 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 492 return false; 493 494 auto SIter = ShapeMap.find(V); 495 if (SIter != ShapeMap.end()) { 496 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 497 << SIter->second.NumRows << " " 498 << SIter->second.NumColumns << " for " << *V << "\n"); 499 return false; 500 } 501 502 ShapeMap.insert({V, Shape}); 503 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 504 << " for " << *V << "\n"); 505 return true; 506 } 507 508 bool isUniformShape(Value *V) { 509 Instruction *I = dyn_cast<Instruction>(V); 510 if (!I) 511 return true; 512 513 switch (I->getOpcode()) { 514 case Instruction::FAdd: 515 case Instruction::FSub: 516 case Instruction::FMul: // Scalar multiply. 517 case Instruction::FNeg: 518 case Instruction::Add: 519 case Instruction::Mul: 520 case Instruction::Sub: 521 return true; 522 default: 523 return false; 524 } 525 } 526 527 /// Returns true if shape information can be used for \p V. The supported 528 /// instructions must match the instructions that can be lowered by this pass. 529 bool supportsShapeInfo(Value *V) { 530 Instruction *Inst = dyn_cast<Instruction>(V); 531 if (!Inst) 532 return false; 533 534 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 535 if (II) 536 switch (II->getIntrinsicID()) { 537 case Intrinsic::matrix_multiply: 538 case Intrinsic::matrix_transpose: 539 case Intrinsic::matrix_column_major_load: 540 case Intrinsic::matrix_column_major_store: 541 return true; 542 default: 543 return false; 544 } 545 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 546 } 547 548 /// Propagate the shape information of instructions to their users. 549 /// The work list contains instructions for which we can compute the shape, 550 /// either based on the information provided by matrix intrinsics or known 551 /// shapes of operands. 552 SmallVector<Instruction *, 32> 553 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 554 SmallVector<Instruction *, 32> NewWorkList; 555 // Pop an element for which we guaranteed to have at least one of the 556 // operand shapes. Add the shape for this and then add users to the work 557 // list. 558 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 559 while (!WorkList.empty()) { 560 Instruction *Inst = WorkList.pop_back_val(); 561 562 // New entry, set the value and insert operands 563 bool Propagate = false; 564 565 Value *MatrixA; 566 Value *MatrixB; 567 Value *M; 568 Value *N; 569 Value *K; 570 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 571 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 572 m_Value(N), m_Value(K)))) { 573 Propagate = setShapeInfo(Inst, {M, K}); 574 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 575 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 576 // Flip dimensions. 577 Propagate = setShapeInfo(Inst, {N, M}); 578 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 579 m_Value(MatrixA), m_Value(), m_Value(), 580 m_Value(), m_Value(M), m_Value(N)))) { 581 Propagate = setShapeInfo(Inst, {N, M}); 582 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 583 m_Value(), m_Value(), m_Value(), m_Value(M), 584 m_Value(N)))) { 585 Propagate = setShapeInfo(Inst, {M, N}); 586 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 587 auto OpShape = ShapeMap.find(MatrixA); 588 if (OpShape != ShapeMap.end()) 589 setShapeInfo(Inst, OpShape->second); 590 continue; 591 } else if (isUniformShape(Inst)) { 592 // Find the first operand that has a known shape and use that. 593 for (auto &Op : Inst->operands()) { 594 auto OpShape = ShapeMap.find(Op.get()); 595 if (OpShape != ShapeMap.end()) { 596 Propagate |= setShapeInfo(Inst, OpShape->second); 597 break; 598 } 599 } 600 } 601 602 if (Propagate) { 603 NewWorkList.push_back(Inst); 604 for (auto *User : Inst->users()) 605 if (ShapeMap.count(User) == 0) 606 WorkList.push_back(cast<Instruction>(User)); 607 } 608 } 609 610 return NewWorkList; 611 } 612 613 /// Propagate the shape to operands of instructions with shape information. 614 /// \p Worklist contains the instruction for which we already know the shape. 615 SmallVector<Instruction *, 32> 616 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 617 SmallVector<Instruction *, 32> NewWorkList; 618 619 auto pushInstruction = [](Value *V, 620 SmallVectorImpl<Instruction *> &WorkList) { 621 Instruction *I = dyn_cast<Instruction>(V); 622 if (I) 623 WorkList.push_back(I); 624 }; 625 // Pop an element with known shape. Traverse the operands, if their shape 626 // derives from the result shape and is unknown, add it and add them to the 627 // worklist. 628 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 629 while (!WorkList.empty()) { 630 Value *V = WorkList.pop_back_val(); 631 632 size_t BeforeProcessingV = WorkList.size(); 633 if (!isa<Instruction>(V)) 634 continue; 635 636 Value *MatrixA; 637 Value *MatrixB; 638 Value *M; 639 Value *N; 640 Value *K; 641 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 642 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 643 m_Value(N), m_Value(K)))) { 644 if (setShapeInfo(MatrixA, {M, N})) 645 pushInstruction(MatrixA, WorkList); 646 647 if (setShapeInfo(MatrixB, {N, K})) 648 pushInstruction(MatrixB, WorkList); 649 650 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 651 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 652 // Flip dimensions. 653 if (setShapeInfo(MatrixA, {M, N})) 654 pushInstruction(MatrixA, WorkList); 655 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 656 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 657 m_Value(M), m_Value(N)))) { 658 if (setShapeInfo(MatrixA, {M, N})) { 659 pushInstruction(MatrixA, WorkList); 660 } 661 } else if (isa<LoadInst>(V) || 662 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 663 // Nothing to do, no matrix input. 664 } else if (isa<StoreInst>(V)) { 665 // Nothing to do. We forward-propagated to this so we would just 666 // backward propagate to an instruction with an already known shape. 667 } else if (isUniformShape(V)) { 668 // Propagate to all operands. 669 ShapeInfo Shape = ShapeMap[V]; 670 for (Use &U : cast<Instruction>(V)->operands()) { 671 if (setShapeInfo(U.get(), Shape)) 672 pushInstruction(U.get(), WorkList); 673 } 674 } 675 // After we discovered new shape info for new instructions in the 676 // worklist, we use their users as seeds for the next round of forward 677 // propagation. 678 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 679 for (User *U : WorkList[I]->users()) 680 if (isa<Instruction>(U) && V != U) 681 NewWorkList.push_back(cast<Instruction>(U)); 682 } 683 return NewWorkList; 684 } 685 686 /// Try moving transposes in order to fold them away or into multiplies. 687 void optimizeTransposes() { 688 // First sink all transposes inside matmuls, hoping that we end up with NN, 689 // NT or TN variants. 690 for (BasicBlock &BB : reverse(Func)) { 691 for (auto II = BB.rbegin(); II != BB.rend();) { 692 Instruction &I = *II; 693 // We may remove II. By default continue on the next/prev instruction. 694 ++II; 695 // If we were to erase II, move again. 696 auto EraseFromParent = [&II](Value *V) { 697 auto *Inst = cast<Instruction>(V); 698 if (Inst->use_empty()) { 699 if (Inst == &*II) { 700 ++II; 701 } 702 Inst->eraseFromParent(); 703 } 704 }; 705 706 // If we're creating a new instruction, continue from there. 707 Instruction *NewInst = nullptr; 708 709 IRBuilder<> IB(&I); 710 MatrixBuilder<IRBuilder<>> Builder(IB); 711 712 Value *TA, *TAMA, *TAMB; 713 ConstantInt *R, *K, *C; 714 if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) { 715 716 // Transpose of a transpose is a nop 717 Value *TATA; 718 if (match(TA, 719 m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 720 I.replaceAllUsesWith(TATA); 721 EraseFromParent(&I); 722 EraseFromParent(TA); 723 } 724 725 // (A * B)^t -> B^t * A^t 726 // RxK KxC CxK KxR 727 else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 728 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 729 m_ConstantInt(K), m_ConstantInt(C)))) { 730 Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), 731 C->getZExtValue(), 732 TAMB->getName() + "_t"); 733 // We are being run after shape prop, add shape for newly created 734 // instructions so that we lower them later. 735 setShapeInfo(T0, {C, K}); 736 Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), 737 K->getZExtValue(), 738 TAMA->getName() + "_t"); 739 setShapeInfo(T1, {K, R}); 740 NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), 741 K->getZExtValue(), 742 R->getZExtValue(), "mmul"); 743 setShapeInfo(NewInst, {C, R}); 744 I.replaceAllUsesWith(NewInst); 745 EraseFromParent(&I); 746 EraseFromParent(TA); 747 } 748 } 749 750 // If we replaced I with a new instruction, continue from there. 751 if (NewInst) 752 II = std::next(BasicBlock::reverse_iterator(NewInst)); 753 } 754 } 755 756 // If we have a TT matmul, lift the transpose. We may be able to fold into 757 // consuming multiply. 758 for (BasicBlock &BB : Func) { 759 for (BasicBlock::iterator II = BB.begin(); II != BB.end();) { 760 Instruction *I = &*II; 761 // We may remove I. 762 ++II; 763 Value *A, *B, *AT, *BT; 764 ConstantInt *R, *K, *C; 765 if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>( 766 m_Value(A), m_Value(B), m_ConstantInt(R), 767 m_ConstantInt(K), m_ConstantInt(C))) && 768 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 769 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 770 IRBuilder<> IB(&*I); 771 MatrixBuilder<IRBuilder<>> Builder(IB); 772 Value *M = Builder.CreateMatrixMultiply( 773 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 774 setShapeInfo(M, {C, R}); 775 Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(), 776 C->getZExtValue()); 777 setShapeInfo(NewInst, {C, R}); 778 I->replaceAllUsesWith(NewInst); 779 if (I->use_empty()) 780 I->eraseFromParent(); 781 if (A->use_empty()) 782 cast<Instruction>(A)->eraseFromParent(); 783 if (A != B && B->use_empty()) 784 cast<Instruction>(B)->eraseFromParent(); 785 } 786 } 787 } 788 } 789 790 bool Visit() { 791 SmallVector<Instruction *, 32> WorkList; 792 793 // Initially only the shape of matrix intrinsics is known. 794 // Initialize the work list with ops carrying shape information. 795 for (BasicBlock &BB : Func) 796 for (Instruction &Inst : BB) { 797 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 798 if (!II) 799 continue; 800 801 switch (II->getIntrinsicID()) { 802 case Intrinsic::matrix_multiply: 803 case Intrinsic::matrix_transpose: 804 case Intrinsic::matrix_column_major_load: 805 case Intrinsic::matrix_column_major_store: 806 WorkList.push_back(&Inst); 807 break; 808 default: 809 break; 810 } 811 } 812 813 // Avoid unnecessary work if there are no matrix intrinsics in the function. 814 if (WorkList.empty()) 815 return false; 816 817 // Propagate shapes until nothing changes any longer. 818 while (!WorkList.empty()) { 819 WorkList = propagateShapeForward(WorkList); 820 WorkList = propagateShapeBackward(WorkList); 821 } 822 823 if (!isMinimal()) { 824 optimizeTransposes(); 825 LLVM_DEBUG({ 826 dbgs() << "Dump after matrix transpose optimization:\n"; 827 Func.dump(); 828 }); 829 } 830 831 bool Changed = false; 832 SmallVector<CallInst *, 16> MaybeFusableInsts; 833 SmallVector<Instruction *, 16> MatrixInsts; 834 835 // First, collect all instructions with shape information and candidates for 836 // fusion (currently only matrix multiplies). 837 ReversePostOrderTraversal<Function *> RPOT(&Func); 838 for (auto *BB : RPOT) 839 for (Instruction &I : *BB) { 840 if (ShapeMap.find(&I) == ShapeMap.end()) 841 continue; 842 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 843 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 844 MatrixInsts.push_back(&I); 845 } 846 847 // Second, try to fuse candidates. 848 SmallPtrSet<Instruction *, 16> FusedInsts; 849 for (CallInst *CI : MaybeFusableInsts) 850 LowerMatrixMultiplyFused(CI, FusedInsts); 851 Changed = !FusedInsts.empty(); 852 853 // Third, lower remaining instructions with shape information. 854 for (Instruction *Inst : MatrixInsts) { 855 if (FusedInsts.count(Inst)) 856 continue; 857 858 IRBuilder<> Builder(Inst); 859 860 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 861 Changed |= VisitCallInst(CInst); 862 863 Value *Op1; 864 Value *Op2; 865 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 866 Changed |= VisitBinaryOperator(BinOp); 867 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 868 Changed |= VisitUnaryOperator(UnOp); 869 if (match(Inst, m_Load(m_Value(Op1)))) 870 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 871 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 872 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 873 } 874 875 if (ORE) { 876 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 877 RemarkGen.emitRemarks(); 878 } 879 880 // Delete the instructions backwards, as it has a reduced likelihood of 881 // having to update as many def-use and use-def chains. 882 for (auto *Inst : reverse(ToRemove)) { 883 if (!Inst->use_empty()) 884 Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); 885 Inst->eraseFromParent(); 886 } 887 888 return Changed; 889 } 890 891 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 892 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 893 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 894 Type *EltPtrType = PointerType::get(EltType, AS); 895 return Builder.CreatePointerCast(BasePtr, EltPtrType); 896 } 897 898 /// Replace intrinsic calls 899 bool VisitCallInst(CallInst *Inst) { 900 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 901 return false; 902 903 switch (Inst->getCalledFunction()->getIntrinsicID()) { 904 case Intrinsic::matrix_multiply: 905 LowerMultiply(Inst); 906 break; 907 case Intrinsic::matrix_transpose: 908 LowerTranspose(Inst); 909 break; 910 case Intrinsic::matrix_column_major_load: 911 LowerColumnMajorLoad(Inst); 912 break; 913 case Intrinsic::matrix_column_major_store: 914 LowerColumnMajorStore(Inst); 915 break; 916 default: 917 return false; 918 } 919 return true; 920 } 921 922 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 923 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 924 /// ConstantInt, reduce the initial alignment based on the byte offset. For 925 /// non-ConstantInt strides, return the common alignment of the initial 926 /// alignment and the element size in bytes. 927 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 928 MaybeAlign A) const { 929 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 930 if (Idx == 0) 931 return InitialAlign; 932 933 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 934 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 935 uint64_t StrideInBytes = 936 ConstStride->getZExtValue() * ElementSizeInBits / 8; 937 return commonAlignment(InitialAlign, Idx * StrideInBytes); 938 } 939 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 940 } 941 942 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 943 /// vectors. 944 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 945 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 946 auto *VType = cast<VectorType>(Ty); 947 Type *EltTy = VType->getElementType(); 948 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 949 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 950 MatrixTy Result; 951 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 952 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, 953 Shape.getStride(), EltTy, Builder); 954 Value *Vector = Builder.CreateAlignedLoad( 955 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 956 IsVolatile, "col.load"); 957 958 Result.addVector(Vector); 959 } 960 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 961 Result.getNumVectors()); 962 } 963 964 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 965 /// starting at \p MatrixPtr[I][J]. 966 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 967 ShapeInfo MatrixShape, Value *I, Value *J, 968 ShapeInfo ResultShape, Type *EltTy, 969 IRBuilder<> &Builder) { 970 971 Value *Offset = Builder.CreateAdd( 972 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 973 974 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 975 Value *EltPtr = 976 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 977 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 978 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 979 ResultShape.NumColumns); 980 Type *TilePtrTy = PointerType::get(TileTy, AS); 981 Value *TilePtr = 982 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 983 984 return loadMatrix(TileTy, TilePtr, Align, 985 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 986 ResultShape, Builder); 987 } 988 989 /// Lower a load instruction with shape information. 990 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 991 bool IsVolatile, ShapeInfo Shape) { 992 IRBuilder<> Builder(Inst); 993 finalizeLowering(Inst, 994 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 995 Shape, Builder), 996 Builder); 997 } 998 999 /// Lowers llvm.matrix.column.major.load. 1000 /// 1001 /// The intrinsic loads a matrix from memory using a stride between columns. 1002 void LowerColumnMajorLoad(CallInst *Inst) { 1003 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1004 "Intrinsic only supports column-major layout!"); 1005 Value *Ptr = Inst->getArgOperand(0); 1006 Value *Stride = Inst->getArgOperand(1); 1007 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1008 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1009 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1010 } 1011 1012 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1013 /// MatrixPtr[I][J]. 1014 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1015 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1016 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1017 Value *Offset = Builder.CreateAdd( 1018 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1019 1020 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1021 Value *EltPtr = 1022 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1023 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1024 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1025 StoreVal.getNumColumns()); 1026 Type *TilePtrTy = PointerType::get(TileTy, AS); 1027 Value *TilePtr = 1028 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1029 1030 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 1031 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1032 } 1033 1034 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1035 /// vectors. 1036 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1037 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1038 IRBuilder<> &Builder) { 1039 auto VType = cast<VectorType>(Ty); 1040 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 1041 for (auto Vec : enumerate(StoreVal.vectors())) { 1042 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), 1043 Stride, StoreVal.getStride(), 1044 VType->getElementType(), Builder); 1045 Builder.CreateAlignedStore(Vec.value(), GEP, 1046 getAlignForIndex(Vec.index(), Stride, 1047 VType->getElementType(), 1048 MAlign), 1049 IsVolatile); 1050 } 1051 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1052 StoreVal.getNumVectors()); 1053 } 1054 1055 /// Lower a store instruction with shape information. 1056 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1057 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1058 IRBuilder<> Builder(Inst); 1059 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1060 finalizeLowering(Inst, 1061 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1062 IsVolatile, Builder), 1063 Builder); 1064 } 1065 1066 /// Lowers llvm.matrix.column.major.store. 1067 /// 1068 /// The intrinsic store a matrix back memory using a stride between columns. 1069 void LowerColumnMajorStore(CallInst *Inst) { 1070 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1071 "Intrinsic only supports column-major layout!"); 1072 Value *Matrix = Inst->getArgOperand(0); 1073 Value *Ptr = Inst->getArgOperand(1); 1074 Value *Stride = Inst->getArgOperand(2); 1075 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1076 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1077 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1078 } 1079 1080 // Set elements I..I+NumElts-1 to Block 1081 Value *insertVector(Value *Col, unsigned I, Value *Block, 1082 IRBuilder<> &Builder) { 1083 1084 // First, bring Block to the same size as Col 1085 unsigned BlockNumElts = 1086 cast<FixedVectorType>(Block->getType())->getNumElements(); 1087 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1088 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1089 1090 Block = Builder.CreateShuffleVector( 1091 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1092 1093 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1094 // 8, 4, 5, 6 1095 SmallVector<int, 16> Mask; 1096 unsigned i; 1097 for (i = 0; i < I; i++) 1098 Mask.push_back(i); 1099 1100 unsigned VecNumElts = 1101 cast<FixedVectorType>(Col->getType())->getNumElements(); 1102 for (; i < I + BlockNumElts; i++) 1103 Mask.push_back(i - I + VecNumElts); 1104 1105 for (; i < VecNumElts; i++) 1106 Mask.push_back(i); 1107 1108 return Builder.CreateShuffleVector(Col, Block, Mask); 1109 } 1110 1111 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1112 IRBuilder<> &Builder, bool AllowContraction, 1113 unsigned &NumComputeOps) { 1114 NumComputeOps += getNumOps(A->getType()); 1115 if (!Sum) 1116 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1117 1118 if (UseFPOp) { 1119 if (AllowContraction) { 1120 // Use fmuladd for floating point operations and let the backend decide 1121 // if that's profitable. 1122 Function *FMulAdd = Intrinsic::getDeclaration( 1123 Func.getParent(), Intrinsic::fmuladd, A->getType()); 1124 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1125 } 1126 NumComputeOps += getNumOps(A->getType()); 1127 Value *Mul = Builder.CreateFMul(A, B); 1128 return Builder.CreateFAdd(Sum, Mul); 1129 } 1130 1131 NumComputeOps += getNumOps(A->getType()); 1132 Value *Mul = Builder.CreateMul(A, B); 1133 return Builder.CreateAdd(Sum, Mul); 1134 } 1135 1136 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1137 /// users with shape information, there's nothing to do: they will use the 1138 /// cached value when they are lowered. For other users, \p Matrix is 1139 /// flattened and the uses are updated to use it. Also marks \p Inst for 1140 /// deletion. 1141 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1142 IRBuilder<> &Builder) { 1143 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1144 1145 ToRemove.push_back(Inst); 1146 Value *Flattened = nullptr; 1147 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1148 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1149 if (!Flattened) 1150 Flattened = Matrix.embedInVector(Builder); 1151 U.set(Flattened); 1152 } 1153 } 1154 } 1155 1156 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1157 /// addition. 1158 /// 1159 /// We can fold a transpose into the operand that is used to extract scalars. 1160 /// This is the first operands with row-major and the second with 1161 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1162 /// operand is transposed. 1163 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1164 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1165 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1166 const unsigned VF = std::max<unsigned>( 1167 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1168 .getFixedSize() / 1169 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1170 1U); 1171 unsigned R = Result.getNumRows(); 1172 unsigned C = Result.getNumColumns(); 1173 unsigned M = A.getNumColumns(); 1174 1175 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1176 assert(A.isColumnMajor() == B.isColumnMajor() && 1177 Result.isColumnMajor() == A.isColumnMajor() && 1178 "operands must agree on matrix layout"); 1179 unsigned NumComputeOps = 0; 1180 1181 Builder.setFastMathFlags(FMF); 1182 1183 if (A.isColumnMajor()) { 1184 // Multiply columns from the first operand with scalars from the second 1185 // operand. Then move along the K axes and accumulate the columns. With 1186 // this the adds can be vectorized without reassociation. 1187 for (unsigned J = 0; J < C; ++J) { 1188 unsigned BlockSize = VF; 1189 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1190 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1191 1192 for (unsigned I = 0; I < R; I += BlockSize) { 1193 // Gradually lower the vectorization factor to cover the remainder. 1194 while (I + BlockSize > R) 1195 BlockSize /= 2; 1196 1197 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1198 : nullptr; 1199 for (unsigned K = 0; K < M; ++K) { 1200 Value *L = A.extractVector(I, K, BlockSize, Builder); 1201 Value *RH = Builder.CreateExtractElement( 1202 B.getColumn(IsScalarMatrixTransposed ? K : J), 1203 IsScalarMatrixTransposed ? J : K); 1204 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1205 Sum = 1206 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1207 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1208 } 1209 Result.setVector(J, 1210 insertVector(Result.getVector(J), I, Sum, Builder)); 1211 } 1212 } 1213 } else { 1214 // Multiply rows from the second operand with scalars from the first 1215 // operand. Then move along the K axes and accumulate the rows. With this 1216 // the adds can be vectorized without reassociation. 1217 for (unsigned I = 0; I < R; ++I) { 1218 unsigned BlockSize = VF; 1219 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1220 for (unsigned J = 0; J < C; J += BlockSize) { 1221 // Gradually lower the vectorization factor to cover the remainder. 1222 while (J + BlockSize > C) 1223 BlockSize /= 2; 1224 1225 Value *Sum = nullptr; 1226 for (unsigned K = 0; K < M; ++K) { 1227 Value *R = B.extractVector(K, J, BlockSize, Builder); 1228 Value *LH = Builder.CreateExtractElement( 1229 A.getVector(IsScalarMatrixTransposed ? K : I), 1230 IsScalarMatrixTransposed ? I : K); 1231 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1232 Sum = 1233 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1234 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1235 } 1236 Result.setVector(I, 1237 insertVector(Result.getVector(I), J, Sum, Builder)); 1238 } 1239 } 1240 } 1241 Result.addNumComputeOps(NumComputeOps); 1242 } 1243 1244 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1245 /// copying it to a new location. This new or otherwise the original location 1246 /// is returned. 1247 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1248 CallInst *MatMul) { 1249 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1250 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1251 1252 // If we can statically determine noalias we're good. 1253 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1254 return Load->getPointerOperand(); 1255 1256 // Create code to check if the memory locations of the Load and Store 1257 // overlap and if they do, copy Load's operand to a new buffer. 1258 1259 // First, create new blocks for 2n part of the check and the copy. 1260 BasicBlock *Check0 = MatMul->getParent(); 1261 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1262 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1263 // as we adjust Check0 and Check1's branches. 1264 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1265 for (BasicBlock *Succ : successors(Check0)) 1266 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1267 1268 BasicBlock *Check1 = 1269 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1270 nullptr, "alias_cont"); 1271 BasicBlock *Copy = 1272 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1273 nullptr, "copy"); 1274 BasicBlock *Fusion = 1275 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1276 nullptr, "no_alias"); 1277 1278 // Check if the loaded memory location begins before the end of the store 1279 // location. If the condition holds, they might overlap, otherwise they are 1280 // guaranteed to not overlap. 1281 IRBuilder<> Builder(MatMul); 1282 Check0->getTerminator()->eraseFromParent(); 1283 Builder.SetInsertPoint(Check0); 1284 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1285 Value *StoreBegin = Builder.CreatePtrToInt( 1286 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1287 Value *StoreEnd = Builder.CreateAdd( 1288 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1289 "store.end", true, true); 1290 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1291 IntPtrTy, "load.begin"); 1292 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1293 Fusion); 1294 1295 // Check if the store begins before the end of the load location. If the 1296 // condition holds, they alias, otherwise they are guaranteed to not 1297 // overlap. 1298 Check1->getTerminator()->eraseFromParent(); 1299 Builder.SetInsertPoint(Check1, Check1->begin()); 1300 Value *LoadEnd = Builder.CreateAdd( 1301 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1302 "load.end", true, true); 1303 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1304 Fusion); 1305 1306 // Copy load operand to new alloca. 1307 Builder.SetInsertPoint(Copy, Copy->begin()); 1308 AllocaInst *NewLd = 1309 Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 1310 Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 1311 Load->getPointerOperand(), Load->getAlign(), 1312 LoadLoc.Size.getValue()); 1313 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1314 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1315 PHI->addIncoming(Load->getPointerOperand(), Check0); 1316 PHI->addIncoming(Load->getPointerOperand(), Check1); 1317 PHI->addIncoming(NewLd, Copy); 1318 1319 // Adjust DT. 1320 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1321 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1322 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1323 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1324 DT->applyUpdates(DTUpdates); 1325 return PHI; 1326 } 1327 1328 bool isFusionProfitable(CallInst *MatMul) { 1329 if (ForceFusion) 1330 return true; 1331 1332 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1333 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1334 1335 const unsigned R = LShape.NumRows; 1336 const unsigned C = RShape.NumColumns; 1337 const unsigned M = LShape.NumColumns; 1338 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1339 1340 const unsigned VF = std::max<unsigned>( 1341 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1342 .getFixedSize() / 1343 EltType->getPrimitiveSizeInBits().getFixedSize(), 1344 1U); 1345 1346 // Cost model for tiling 1347 // 1348 // For tiling to be beneficial, we need reuse either along the R or 1349 // the C axis. We vectorize along the R axis so that means at least 1350 // 3 elements. 1351 // TODO: Also consider cost of copying if operands alias. 1352 if (R <= VF && C == 1) 1353 return false; 1354 // Then we need enough elements to exceed the number of vector 1355 // registers we have. Note that this is an oversimplification since 1356 // fusing also takes some extra loads which may exceed the number of 1357 // reloads necessary. 1358 unsigned Op0Regs = (R + VF - 1) / VF * M; 1359 unsigned Op1Regs = (M + VF - 1) / VF * C; 1360 return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 1361 } 1362 1363 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1364 MatrixTy Res; 1365 auto *ColumType = FixedVectorType::get(EltType, R); 1366 for (unsigned I = 0; I < C; ++I) 1367 Res.addVector(ConstantAggregateZero::get(ColumType)); 1368 return Res; 1369 } 1370 1371 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1372 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1373 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1374 1375 // Create the main tiling loop nest. 1376 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1377 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1378 Instruction *InsertI = cast<Instruction>(MatMul); 1379 BasicBlock *Start = InsertI->getParent(); 1380 BasicBlock *End = 1381 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1382 IRBuilder<> Builder(MatMul); 1383 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1384 1385 Type *TileVecTy = 1386 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1387 MatrixTy TileResult; 1388 // Insert in the inner loop header. 1389 Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); 1390 // Create PHI nodes for the result columns to accumulate across iterations. 1391 SmallVector<PHINode *, 4> ColumnPhis; 1392 for (unsigned I = 0; I < TileSize; I++) { 1393 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1394 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1395 TI.RowLoopHeader->getSingleSuccessor()); 1396 TileResult.addVector(Phi); 1397 ColumnPhis.push_back(Phi); 1398 } 1399 1400 // Insert in the inner loop body, which computes 1401 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1402 Builder.SetInsertPoint(InnerBody->getTerminator()); 1403 // Load tiles of the operands. 1404 MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, 1405 {TileSize, TileSize}, EltType, Builder); 1406 MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, 1407 {TileSize, TileSize}, EltType, Builder); 1408 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1409 getFastMathFlags(MatMul)); 1410 // Store result after the inner loop is done. 1411 Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); 1412 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1413 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1414 TI.CurrentRow, TI.CurrentCol, EltType, Builder); 1415 1416 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1417 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); 1418 1419 // Force unrolling of a few iterations of the inner loop, to make sure there 1420 // is enough work per iteration. 1421 // FIXME: The unroller should make this decision directly instead, but 1422 // currently the cost-model is not up to the task. 1423 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1424 addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), 1425 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1426 } 1427 1428 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1429 StoreInst *Store, 1430 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1431 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1432 "Tiling only supported for column-major matrixes at the moment!"); 1433 if (!isFusionProfitable(MatMul)) 1434 return; 1435 1436 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1437 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1438 1439 const unsigned R = LShape.NumRows; 1440 const unsigned C = RShape.NumColumns; 1441 const unsigned M = LShape.NumColumns; 1442 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1443 1444 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1445 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1446 Value *CPtr = Store->getPointerOperand(); 1447 1448 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1449 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1450 else { 1451 IRBuilder<> Builder(Store); 1452 for (unsigned J = 0; J < C; J += TileSize) 1453 for (unsigned I = 0; I < R; I += TileSize) { 1454 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1455 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1456 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1457 1458 for (unsigned K = 0; K < M; K += TileSize) { 1459 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1460 MatrixTy A = 1461 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1462 LShape, Builder.getInt64(I), Builder.getInt64(K), 1463 {TileR, TileM}, EltType, Builder); 1464 MatrixTy B = 1465 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1466 RShape, Builder.getInt64(K), Builder.getInt64(J), 1467 {TileM, TileC}, EltType, Builder); 1468 emitMatrixMultiply(Res, A, B, Builder, true, false, 1469 getFastMathFlags(MatMul)); 1470 } 1471 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1472 Builder.getInt64(I), Builder.getInt64(J), EltType, 1473 Builder); 1474 } 1475 } 1476 1477 // Mark eliminated instructions as fused and remove them. 1478 FusedInsts.insert(Store); 1479 FusedInsts.insert(MatMul); 1480 Store->eraseFromParent(); 1481 MatMul->eraseFromParent(); 1482 if (LoadOp0->hasNUses(0)) { 1483 FusedInsts.insert(LoadOp0); 1484 LoadOp0->eraseFromParent(); 1485 } 1486 if (LoadOp1->hasNUses(0)) { 1487 FusedInsts.insert(LoadOp1); 1488 LoadOp1->eraseFromParent(); 1489 } 1490 } 1491 1492 /// Try to lower matrix multiply chains by fusing operations. 1493 /// 1494 /// Call finalizeLowering on lowered instructions. Instructions that are 1495 /// completely eliminated by fusion are added to \p FusedInsts. 1496 void LowerMatrixMultiplyFused(CallInst *MatMul, 1497 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1498 if (!FuseMatrix || !DT) 1499 return; 1500 1501 assert(AA && LI && "Analyses should be available"); 1502 1503 Value *A = MatMul->getArgOperand(0); 1504 Value *B = MatMul->getArgOperand(1); 1505 1506 // We can fold the transpose into the operand that is used to fetch scalars. 1507 Value *T; 1508 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1509 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1510 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1511 IRBuilder<> Builder(MatMul); 1512 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1513 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1514 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1515 const unsigned R = LShape.NumRows; 1516 const unsigned M = LShape.NumColumns; 1517 const unsigned C = RShape.NumColumns; 1518 1519 MatrixTy MA; 1520 MatrixTy MB; 1521 1522 Value *Transpose; 1523 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1524 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1525 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1526 Transpose = B; 1527 } else { 1528 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1529 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1530 Transpose = A; 1531 } 1532 1533 // Initialize the output 1534 MatrixTy Result(R, C, EltType); 1535 1536 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1537 getFastMathFlags(MatMul)); 1538 1539 FusedInsts.insert(MatMul); 1540 if (Transpose->hasOneUse()) { 1541 FusedInsts.insert(cast<Instruction>(Transpose)); 1542 ToRemove.push_back(cast<Instruction>(Transpose)); 1543 } 1544 finalizeLowering(MatMul, Result, Builder); 1545 // TODO: add a fake entry for the folded instruction so that this is 1546 // included in the expression in the remark. 1547 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1548 return; 1549 } 1550 1551 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1552 return; 1553 1554 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1555 // since the single store user will be lowered as part of this. 1556 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1557 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1558 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1559 if (LoadOp0 && LoadOp1 && Store) { 1560 // The store address must dominate the MatMul instruction, otherwise 1561 // we create invalid IR. 1562 // FIXME: See if we can hoist the store address computation. 1563 auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1564 if (AddrI && (!DT->dominates(AddrI, MatMul))) 1565 return; 1566 1567 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1568 return; 1569 } 1570 } 1571 1572 /// Lowers llvm.matrix.multiply. 1573 void LowerMultiply(CallInst *MatMul) { 1574 IRBuilder<> Builder(MatMul); 1575 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1576 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1577 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1578 1579 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1580 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1581 assert(Lhs.getElementType() == Rhs.getElementType() && 1582 "Matrix multiply argument element types do not match."); 1583 1584 const unsigned R = LShape.NumRows; 1585 const unsigned C = RShape.NumColumns; 1586 assert(LShape.NumColumns == RShape.NumRows); 1587 1588 // Initialize the output 1589 MatrixTy Result(R, C, EltType); 1590 assert(Lhs.getElementType() == Result.getElementType() && 1591 "Matrix multiply result element type does not match arguments."); 1592 1593 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1594 getFastMathFlags(MatMul)); 1595 finalizeLowering(MatMul, Result, Builder); 1596 } 1597 1598 /// Lowers llvm.matrix.transpose. 1599 void LowerTranspose(CallInst *Inst) { 1600 MatrixTy Result; 1601 IRBuilder<> Builder(Inst); 1602 Value *InputVal = Inst->getArgOperand(0); 1603 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1604 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1605 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1606 1607 const unsigned NewNumVecs = 1608 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1609 const unsigned NewNumElts = 1610 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1611 1612 for (unsigned I = 0; I < NewNumVecs; ++I) { 1613 // Build a single result vector. First initialize it. 1614 Value *ResultVector = UndefValue::get( 1615 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1616 // Go through the old elements and insert it into the resulting vector. 1617 for (auto J : enumerate(InputMatrix.vectors())) { 1618 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1619 // Row and column indices are transposed. 1620 ResultVector = 1621 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1622 } 1623 Result.addVector(ResultVector); 1624 } 1625 1626 // TODO: Improve estimate of operations needed for transposes. Currently we 1627 // just count the insertelement/extractelement instructions, but do not 1628 // account for later simplifications/combines. 1629 finalizeLowering( 1630 Inst, 1631 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1632 .addNumExposedTransposes(1), 1633 Builder); 1634 } 1635 1636 /// Lower load instructions, if shape information is available. 1637 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1638 auto I = ShapeMap.find(Inst); 1639 if (I == ShapeMap.end()) 1640 return false; 1641 1642 LowerLoad(Inst, Ptr, Inst->getAlign(), 1643 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1644 I->second); 1645 return true; 1646 } 1647 1648 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1649 IRBuilder<> &Builder) { 1650 auto I = ShapeMap.find(StoredVal); 1651 if (I == ShapeMap.end()) 1652 return false; 1653 1654 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1655 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1656 I->second); 1657 return true; 1658 } 1659 1660 /// Lower binary operators, if shape information is available. 1661 bool VisitBinaryOperator(BinaryOperator *Inst) { 1662 auto I = ShapeMap.find(Inst); 1663 if (I == ShapeMap.end()) 1664 return false; 1665 1666 Value *Lhs = Inst->getOperand(0); 1667 Value *Rhs = Inst->getOperand(1); 1668 1669 IRBuilder<> Builder(Inst); 1670 ShapeInfo &Shape = I->second; 1671 1672 MatrixTy Result; 1673 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1674 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1675 assert(A.isColumnMajor() == B.isColumnMajor() && 1676 Result.isColumnMajor() == A.isColumnMajor() && 1677 "operands must agree on matrix layout"); 1678 1679 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1680 1681 // Helper to perform binary op on vectors. 1682 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1683 switch (Inst->getOpcode()) { 1684 case Instruction::Add: 1685 return Builder.CreateAdd(LHS, RHS); 1686 case Instruction::Mul: 1687 return Builder.CreateMul(LHS, RHS); 1688 case Instruction::Sub: 1689 return Builder.CreateSub(LHS, RHS); 1690 case Instruction::FAdd: 1691 return Builder.CreateFAdd(LHS, RHS); 1692 case Instruction::FMul: 1693 return Builder.CreateFMul(LHS, RHS); 1694 case Instruction::FSub: 1695 return Builder.CreateFSub(LHS, RHS); 1696 default: 1697 llvm_unreachable("Unsupported binary operator for matrix"); 1698 } 1699 }; 1700 1701 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1702 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1703 1704 finalizeLowering(Inst, 1705 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1706 Result.getNumVectors()), 1707 Builder); 1708 return true; 1709 } 1710 1711 /// Lower unary operators, if shape information is available. 1712 bool VisitUnaryOperator(UnaryOperator *Inst) { 1713 auto I = ShapeMap.find(Inst); 1714 if (I == ShapeMap.end()) 1715 return false; 1716 1717 Value *Op = Inst->getOperand(0); 1718 1719 IRBuilder<> Builder(Inst); 1720 ShapeInfo &Shape = I->second; 1721 1722 MatrixTy Result; 1723 MatrixTy M = getMatrix(Op, Shape, Builder); 1724 1725 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1726 1727 // Helper to perform unary op on vectors. 1728 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1729 switch (Inst->getOpcode()) { 1730 case Instruction::FNeg: 1731 return Builder.CreateFNeg(Op); 1732 default: 1733 llvm_unreachable("Unsupported unary operator for matrix"); 1734 } 1735 }; 1736 1737 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1738 Result.addVector(BuildVectorOp(M.getVector(I))); 1739 1740 finalizeLowering(Inst, 1741 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1742 Result.getNumVectors()), 1743 Builder); 1744 return true; 1745 } 1746 1747 /// Helper to linearize a matrix expression tree into a string. Currently 1748 /// matrix expressions are linarized by starting at an expression leaf and 1749 /// linearizing bottom up. 1750 struct ExprLinearizer { 1751 unsigned LengthToBreak = 100; 1752 std::string Str; 1753 raw_string_ostream Stream; 1754 unsigned LineLength = 0; 1755 const DataLayout &DL; 1756 1757 /// Mapping from instructions to matrixes. It is used to identify 1758 /// matrix instructions. 1759 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1760 1761 /// Mapping from values to the leaves of all expressions that the value is 1762 /// part of. 1763 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1764 1765 /// Set of matrix expressions in the scope of a given DISubprogram. 1766 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1767 1768 /// Leaf node of the expression to linearize. 1769 Value *Leaf; 1770 1771 /// Used to keep track of sub-expressions that get reused while linearizing 1772 /// the expression. Re-used sub-expressions are marked as (reused). 1773 SmallPtrSet<Value *, 8> ReusedExprs; 1774 1775 ExprLinearizer(const DataLayout &DL, 1776 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1777 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1778 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1779 Value *Leaf) 1780 : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1781 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1782 1783 void indent(unsigned N) { 1784 LineLength += N; 1785 for (unsigned i = 0; i < N; i++) 1786 Stream << " "; 1787 } 1788 1789 void lineBreak() { 1790 Stream << "\n"; 1791 LineLength = 0; 1792 } 1793 1794 void maybeIndent(unsigned Indent) { 1795 if (LineLength >= LengthToBreak) 1796 lineBreak(); 1797 1798 if (LineLength == 0) 1799 indent(Indent); 1800 } 1801 1802 void write(StringRef S) { 1803 LineLength += S.size(); 1804 Stream << S; 1805 } 1806 1807 Value *getUnderlyingObjectThroughLoads(Value *V) { 1808 if (Value *Ptr = getPointerOperand(V)) 1809 return getUnderlyingObjectThroughLoads(Ptr); 1810 else if (V->getType()->isPointerTy()) 1811 return getUnderlyingObject(V); 1812 return V; 1813 } 1814 1815 /// Returns true if \p V is a matrix value in the given subprogram. 1816 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1817 1818 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1819 /// \p SS. 1820 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1821 auto M = Inst2Matrix.find(V); 1822 if (M == Inst2Matrix.end()) 1823 SS << "unknown"; 1824 else { 1825 SS << M->second.getNumRows(); 1826 SS << "x"; 1827 SS << M->second.getNumColumns(); 1828 } 1829 } 1830 1831 /// Write the called function name. Handles calls to llvm.matrix.* 1832 /// specially: we write the name, followed by the dimensions of the input 1833 /// matrixes, followed by the scalar type name. 1834 void writeFnName(CallInst *CI) { 1835 if (!CI->getCalledFunction()) 1836 write("<no called fn>"); 1837 else { 1838 StringRef Name = CI->getCalledFunction()->getName(); 1839 if (!Name.startswith("llvm.matrix")) { 1840 write(Name); 1841 return; 1842 } 1843 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1844 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1845 .drop_front(StringRef("llvm.matrix.").size())); 1846 write("."); 1847 std::string Tmp; 1848 raw_string_ostream SS(Tmp); 1849 1850 switch (II->getIntrinsicID()) { 1851 case Intrinsic::matrix_multiply: 1852 prettyPrintMatrixType(II->getOperand(0), SS); 1853 SS << "."; 1854 prettyPrintMatrixType(II->getOperand(1), SS); 1855 SS << "." << *II->getType()->getScalarType(); 1856 break; 1857 case Intrinsic::matrix_transpose: 1858 prettyPrintMatrixType(II->getOperand(0), SS); 1859 SS << "." << *II->getType()->getScalarType(); 1860 break; 1861 case Intrinsic::matrix_column_major_load: 1862 prettyPrintMatrixType(II, SS); 1863 SS << "." << *II->getType()->getScalarType(); 1864 break; 1865 case Intrinsic::matrix_column_major_store: 1866 prettyPrintMatrixType(II->getOperand(0), SS); 1867 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1868 break; 1869 default: 1870 llvm_unreachable("Unhandled case"); 1871 } 1872 SS.flush(); 1873 write(Tmp); 1874 } 1875 } 1876 1877 unsigned getNumShapeArgs(CallInst *CI) const { 1878 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1879 switch (II->getIntrinsicID()) { 1880 case Intrinsic::matrix_multiply: 1881 return 3; 1882 case Intrinsic::matrix_transpose: 1883 return 2; 1884 case Intrinsic::matrix_column_major_load: 1885 case Intrinsic::matrix_column_major_store: 1886 return 3; 1887 default: 1888 return 0; 1889 } 1890 } 1891 return 0; 1892 } 1893 1894 /// Special printing for values: for pointers, we print if they refer to an 1895 /// (function) external address or a stack address, for other values we 1896 /// either print the constant or "scalar"/"matrix" for other values. 1897 void write(Value *V) { 1898 V = getUnderlyingObjectThroughLoads(V); 1899 if (V->getType()->isPointerTy()) { 1900 if (isa<AllocaInst>(V)) { 1901 Stream << "stack addr"; 1902 LineLength += StringRef("stack addr").size(); 1903 } else { 1904 Stream << "addr"; 1905 LineLength += StringRef("addr").size(); 1906 } 1907 if (!V->getName().empty()) { 1908 Stream << " %" << V->getName() << ""; 1909 LineLength += V->getName().size() + 2; 1910 } 1911 return; 1912 } 1913 1914 std::string Tmp; 1915 raw_string_ostream TmpStream(Tmp); 1916 1917 if (auto *CI = dyn_cast<ConstantInt>(V)) 1918 TmpStream << CI->getValue(); 1919 else if (isa<Constant>(V)) 1920 TmpStream << "constant"; 1921 else { 1922 if (isMatrix(V)) 1923 TmpStream << "matrix"; 1924 else 1925 TmpStream << "scalar"; 1926 } 1927 TmpStream.flush(); 1928 Tmp = std::string(StringRef(Tmp).trim()); 1929 LineLength += Tmp.size(); 1930 Stream << Tmp; 1931 } 1932 1933 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1934 /// Expressions that are re-used multiple times are prefixed with (reused) 1935 /// at the re-used root instruction. 1936 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1937 bool ParentShared) { 1938 auto *I = cast<Instruction>(Expr); 1939 maybeIndent(Indent); 1940 SmallVector<Value *, 8> Ops; 1941 1942 // Is Expr shared with other expression leaves? 1943 bool ExprShared = false; 1944 1945 // Deal with shared subtrees. Mark them as shared, if required. 1946 if (!ParentShared) { 1947 auto SI = Shared.find(Expr); 1948 assert(SI != Shared.end() && SI->second.count(Leaf)); 1949 1950 for (Value *S : SI->second) { 1951 if (S == Leaf) 1952 continue; 1953 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1954 write("shared with remark at line " + std::to_string(DL.getLine()) + 1955 " column " + std::to_string(DL.getCol()) + " ("); 1956 } 1957 ExprShared = SI->second.size() > 1; 1958 } 1959 1960 bool Reused = !ReusedExprs.insert(Expr).second; 1961 if (Reused && !ParentReused) 1962 write("(reused) "); 1963 1964 if (auto *CI = dyn_cast<CallInst>(I)) { 1965 writeFnName(CI); 1966 1967 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 1968 } else if (isa<BitCastInst>(Expr)) { 1969 // Special case bitcasts, which are used to materialize matrixes from 1970 // non-matrix ops. 1971 write("matrix"); 1972 return; 1973 } else { 1974 Ops.append(I->value_op_begin(), I->value_op_end()); 1975 write(std::string(I->getOpcodeName())); 1976 } 1977 1978 write(std::string("(")); 1979 1980 unsigned NumOpsToBreak = 1; 1981 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 1982 NumOpsToBreak = 2; 1983 1984 for (Value *Op : Ops) { 1985 if (Ops.size() > NumOpsToBreak) 1986 lineBreak(); 1987 1988 maybeIndent(Indent + 1); 1989 if (isMatrix(Op)) 1990 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1991 else 1992 write(Op); 1993 if (Op != Ops.back()) 1994 write(", "); 1995 } 1996 1997 write(")"); 1998 } 1999 2000 const std::string &getResult() { 2001 Stream.flush(); 2002 return Str; 2003 } 2004 }; 2005 2006 /// Generate remarks for matrix operations in a function. To generate remarks 2007 /// for matrix expressions, the following approach is used: 2008 /// 1. Use the inlined-at debug information to group matrix operations to the 2009 /// DISubprograms they are contained in. 2010 /// 2. Collect leaves of matrix expressions (done in 2011 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2012 // mapping. Leaves are lowered matrix instructions without other matrix 2013 // users (like stores) in the current subprogram. 2014 /// 3. For each leaf, create a remark containing a linearizied version of the 2015 /// matrix expression. The expression is linearized by a recursive 2016 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2017 /// that multiple leaves can share sub-expressions. Shared subexpressions 2018 /// are explicitly marked as shared(). 2019 struct RemarkGenerator { 2020 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2021 OptimizationRemarkEmitter &ORE; 2022 Function &Func; 2023 const DataLayout &DL; 2024 2025 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2026 OptimizationRemarkEmitter &ORE, Function &Func) 2027 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2028 DL(Func.getParent()->getDataLayout()) {} 2029 2030 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2031 /// instructions in Inst2Matrix returning void or without any users in 2032 /// \p ExprsInSubprogram. Currently that should only include stores. 2033 SmallVector<Value *, 4> 2034 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2035 SmallVector<Value *, 4> Leaves; 2036 for (auto *Expr : ExprsInSubprogram) 2037 if (Expr->getType()->isVoidTy() || 2038 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2039 return ExprsInSubprogram.count(U); 2040 })) 2041 Leaves.push_back(Expr); 2042 return Leaves; 2043 } 2044 2045 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2046 /// to all visited expressions in \p Shared. Limit the matrix operations to 2047 /// the ones in \p ExprsInSubprogram. 2048 void collectSharedInfo(Value *Leaf, Value *V, 2049 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2050 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2051 2052 if (!ExprsInSubprogram.count(V)) 2053 return; 2054 2055 auto I = Shared.insert({V, {}}); 2056 I.first->second.insert(Leaf); 2057 2058 for (Value *Op : cast<Instruction>(V)->operand_values()) 2059 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2060 } 2061 2062 /// Calculate the number of exclusive and shared op counts for expression 2063 /// starting at \p V. Expressions used multiple times are counted once. 2064 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2065 std::pair<OpInfoTy, OpInfoTy> 2066 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2067 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2068 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2069 if (!ExprsInSubprogram.count(Root)) 2070 return {}; 2071 2072 // Already counted this expression. Stop. 2073 if (!ReusedExprs.insert(Root).second) 2074 return {}; 2075 2076 OpInfoTy SharedCount; 2077 OpInfoTy Count; 2078 2079 auto I = Shared.find(Root); 2080 auto CM = Inst2Matrix.find(Root); 2081 if (I->second.size() == 1) 2082 Count = CM->second.getOpInfo(); 2083 else 2084 SharedCount = CM->second.getOpInfo(); 2085 2086 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2087 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2088 Count += C.first; 2089 SharedCount += C.second; 2090 } 2091 return {Count, SharedCount}; 2092 } 2093 2094 void emitRemarks() { 2095 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2096 return; 2097 2098 // Map matrix operations to their containting subprograms, by traversing 2099 // the inlinedAt chain. If the function does not have a DISubprogram, we 2100 // only map them to the containing function. 2101 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2102 for (auto &KV : Inst2Matrix) { 2103 if (Func.getSubprogram()) { 2104 auto *I = cast<Instruction>(KV.first); 2105 DILocation *Context = I->getDebugLoc(); 2106 while (Context) { 2107 auto I = 2108 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 2109 I.first->second.push_back(KV.first); 2110 Context = DebugLoc(Context).getInlinedAt(); 2111 } 2112 } else { 2113 auto I = Subprog2Exprs.insert({nullptr, {}}); 2114 I.first->second.push_back(KV.first); 2115 } 2116 } 2117 for (auto &KV : Subprog2Exprs) { 2118 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2119 KV.second.end()); 2120 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2121 2122 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2123 for (Value *Leaf : Leaves) 2124 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2125 2126 // Generate remarks for each leaf. 2127 for (auto *L : Leaves) { 2128 2129 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2130 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2131 while (Context) { 2132 if (getSubprogram(Context->getScope()) == KV.first) { 2133 Loc = Context; 2134 break; 2135 } 2136 Context = DebugLoc(Context).getInlinedAt(); 2137 } 2138 2139 SmallPtrSet<Value *, 8> ReusedExprs; 2140 OpInfoTy Counts, SharedCounts; 2141 std::tie(Counts, SharedCounts) = 2142 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2143 2144 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2145 cast<Instruction>(L)->getParent()); 2146 2147 Rem << "Lowered with "; 2148 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2149 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2150 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2151 << " compute ops, " 2152 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2153 << " exposed transposes"; 2154 2155 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2156 SharedCounts.NumComputeOps > 0) { 2157 Rem << ",\nadditionally " 2158 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2159 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2160 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2161 << " compute ops" 2162 << " are shared with other expressions"; 2163 } 2164 2165 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2166 ORE.emit(Rem); 2167 } 2168 } 2169 } 2170 2171 std::string 2172 linearize(Value *L, 2173 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2174 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2175 const DataLayout &DL) { 2176 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2177 Lin.linearizeExpr(L, 0, false, false); 2178 return Lin.getResult(); 2179 } 2180 }; 2181 }; 2182 } // namespace 2183 2184 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2185 FunctionAnalysisManager &AM) { 2186 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2187 OptimizationRemarkEmitter *ORE = nullptr; 2188 AAResults *AA = nullptr; 2189 DominatorTree *DT = nullptr; 2190 LoopInfo *LI = nullptr; 2191 2192 if (!Minimal) { 2193 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2194 AA = &AM.getResult<AAManager>(F); 2195 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2196 LI = &AM.getResult<LoopAnalysis>(F); 2197 } 2198 2199 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2200 if (LMT.Visit()) { 2201 PreservedAnalyses PA; 2202 if (!Minimal) { 2203 PA.preserve<LoopAnalysis>(); 2204 PA.preserve<DominatorTreeAnalysis>(); 2205 } 2206 return PA; 2207 } 2208 return PreservedAnalyses::all(); 2209 } 2210 2211 namespace { 2212 2213 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2214 public: 2215 static char ID; 2216 2217 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2218 initializeLowerMatrixIntrinsicsLegacyPassPass( 2219 *PassRegistry::getPassRegistry()); 2220 } 2221 2222 bool runOnFunction(Function &F) override { 2223 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2224 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 2225 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 2226 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2227 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2228 LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2229 bool C = LMT.Visit(); 2230 return C; 2231 } 2232 2233 void getAnalysisUsage(AnalysisUsage &AU) const override { 2234 AU.addRequired<TargetTransformInfoWrapperPass>(); 2235 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 2236 AU.addRequired<AAResultsWrapperPass>(); 2237 AU.addRequired<DominatorTreeWrapperPass>(); 2238 AU.addPreserved<DominatorTreeWrapperPass>(); 2239 AU.addRequired<LoopInfoWrapperPass>(); 2240 AU.addPreserved<LoopInfoWrapperPass>(); 2241 } 2242 }; 2243 } // namespace 2244 2245 static const char pass_name[] = "Lower the matrix intrinsics"; 2246 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2247 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2248 false, false) 2249 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 2250 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 2251 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 2252 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2253 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2254 false, false) 2255 2256 Pass *llvm::createLowerMatrixIntrinsicsPass() { 2257 return new LowerMatrixIntrinsicsLegacyPass(); 2258 } 2259 2260 namespace { 2261 2262 /// A lightweight version of the matrix lowering pass that only requires TTI. 2263 /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2264 /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2265 /// example with -O0. 2266 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2267 public: 2268 static char ID; 2269 2270 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2271 initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2272 *PassRegistry::getPassRegistry()); 2273 } 2274 2275 bool runOnFunction(Function &F) override { 2276 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2277 LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2278 bool C = LMT.Visit(); 2279 return C; 2280 } 2281 2282 void getAnalysisUsage(AnalysisUsage &AU) const override { 2283 AU.addRequired<TargetTransformInfoWrapperPass>(); 2284 AU.setPreservesCFG(); 2285 } 2286 }; 2287 } // namespace 2288 2289 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2290 char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2291 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2292 "lower-matrix-intrinsics-minimal", pass_name_minimal, 2293 false, false) 2294 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2295 "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2296 false) 2297 2298 Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2299 return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2300 } 2301