1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/PatternMatch.h" 38 #include "llvm/InitializePasses.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Alignment.h" 41 #include "llvm/Support/Debug.h" 42 #include "llvm/Transforms/Scalar.h" 43 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 44 45 using namespace llvm; 46 using namespace PatternMatch; 47 48 #define DEBUG_TYPE "lower-matrix-intrinsics" 49 50 static cl::opt<bool> EnableShapePropagation( 51 "matrix-propagate-shape", cl::init(true), cl::Hidden, 52 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 53 "instructions.")); 54 55 static cl::opt<bool> 56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 57 cl::desc("Enable/disable fusing matrix instructions.")); 58 // TODO: Allow and use non-square tiles. 59 static cl::opt<unsigned> TileSize( 60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 61 cl::desc( 62 "Tile size for matrix instruction fusion using square-shaped tiles.")); 63 static cl::opt<bool> ForceFusion( 64 "force-fuse-matrix", cl::init(false), cl::Hidden, 65 cl::desc("Force matrix instruction fusion even if not profitable.")); 66 static cl::opt<bool> AllowContractEnabled( 67 "matrix-allow-contract", cl::init(false), cl::Hidden, 68 cl::desc("Allow the use of FMAs if available and profitable. This may " 69 "result in different results, due to less rounding error.")); 70 71 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 72 73 static cl::opt<MatrixLayoutTy> MatrixLayout( 74 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 75 cl::desc("Sets the default matrix layout"), 76 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 77 "Use column-major layout"), 78 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 79 "Use row-major layout"))); 80 81 /// Helper function to either return Scope, if it is a subprogram or the 82 /// attached subprogram for a local scope. 83 static DISubprogram *getSubprogram(DIScope *Scope) { 84 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 85 return Subprogram; 86 return cast<DILocalScope>(Scope)->getSubprogram(); 87 } 88 89 namespace { 90 91 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 92 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 93 // assuming \p Stride elements between start two consecutive vectors. 94 // \p Stride must be >= \p NumElements. 95 // For column-major matrixes, the function computes the address of a column 96 // vectors and \p NumElements must be set to the number of elements in a column 97 // (= number of rows of the matrix). For row-major matrixes, the function 98 // computes the address of a row vector and \p NumElements must be set to the 99 // number of elements in a column (= number of columns of the matrix). 100 // 101 // Consider a 4x4 matrix in column-mjaor layout like below 102 // 103 // 0 1 2 3 104 // 0 v_0_0 v_0_1 v_0_2 v_0_3 105 // 1 v_1_0 v_1_1 v_1_2 v_1_3 106 // 2 v_2_0 v_2_1 v_2_2 v_2_3 107 // 3 v_3_0 v_3_1 v_3_2 v_3_3 108 109 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 110 // we need a pointer to the first element of the submatrix as base pointer. 111 // Then we can use computeVectorAddr to compute the addresses for the columns 112 // of the sub-matrix. 113 // 114 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 115 // -> just returns Base 116 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 117 // -> returns Base + (1 * 4) 118 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 119 // -> returns Base + (2 * 4) 120 // 121 // The graphic below illustrates the number of elements in a column (marked 122 // with |) and the number of skipped elements (marked with }). 123 // 124 // v_0_0 v_0_1 {v_0_2 {v_0_3 125 // Base Col 1 Col 2 126 // | | | 127 // v_1_0 |v_1_1 |v_1_2 |v_1_3 128 // v_2_0 |v_2_1 |v_2_2 |v_2_3 129 // v_3_0 {v_3_1 {v_3_2 v_3_3 130 // 131 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 132 unsigned NumElements, Type *EltType, 133 IRBuilder<> &Builder) { 134 135 assert((!isa<ConstantInt>(Stride) || 136 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 137 "Stride must be >= the number of elements in the result vector."); 138 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 139 140 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 141 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 142 143 // Get pointer to the start of the selected vector. Skip GEP creation, 144 // if we select vector 0. 145 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 146 VecStart = BasePtr; 147 else 148 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 149 150 // Cast elementwise vector start pointer to a pointer to a vector 151 // (EltType x NumElements)*. 152 auto *VecType = FixedVectorType::get(EltType, NumElements); 153 Type *VecPtrType = PointerType::get(VecType, AS); 154 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 155 } 156 157 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 158 /// 159 /// Currently, the lowering for each matrix intrinsic is done as follows: 160 /// 1. Propagate the shape information from intrinsics to connected 161 /// instructions. 162 /// 2. Lower instructions with shape information (assuming column-major layout). 163 /// The lowering works similarly using row-major layout. 164 /// 2.1. Get column vectors for each argument. If we already lowered the 165 /// definition of an argument, use the produced column vectors directly. 166 /// If not, split the operand vector containing an embedded matrix into 167 /// a set of column vectors, 168 /// 2.2. Lower the instruction in terms of column major operations, which 169 /// yields a set of column vectors containing result matrix. Note that we 170 /// lower all instructions that have shape information. Besides the 171 /// intrinsics, this includes stores for example. 172 /// 2.3. Update uses of the lowered instruction. If we have shape information 173 /// for a user, there is nothing to do, as we will look up the result 174 /// column matrix when lowering the user. For other uses, we embed the 175 /// result matrix in a flat vector and update the use. 176 /// 2.4. Cache the result column matrix for the instruction we lowered 177 /// 3. After we lowered all instructions in a function, remove the now 178 /// obsolete instructions. 179 /// 180 class LowerMatrixIntrinsics { 181 Function &Func; 182 const DataLayout &DL; 183 const TargetTransformInfo &TTI; 184 AliasAnalysis &AA; 185 DominatorTree &DT; 186 LoopInfo &LI; 187 OptimizationRemarkEmitter &ORE; 188 189 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 190 struct OpInfoTy { 191 /// Number of stores emitted to generate this matrix. 192 unsigned NumStores = 0; 193 /// Number of loads emitted to generate this matrix. 194 unsigned NumLoads = 0; 195 /// Number of compute operations emitted to generate this matrix. 196 unsigned NumComputeOps = 0; 197 198 OpInfoTy &operator+=(const OpInfoTy &RHS) { 199 NumStores += RHS.NumStores; 200 NumLoads += RHS.NumLoads; 201 NumComputeOps += RHS.NumComputeOps; 202 return *this; 203 } 204 }; 205 206 /// Wrapper class representing a matrix as a set of vectors, either in row or 207 /// column major layout. All vectors must have the same vector type. 208 class MatrixTy { 209 SmallVector<Value *, 16> Vectors; 210 211 OpInfoTy OpInfo; 212 213 bool IsColumnMajor = true; 214 215 public: 216 MatrixTy() 217 : Vectors(), 218 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 219 MatrixTy(ArrayRef<Value *> Vectors) 220 : Vectors(Vectors.begin(), Vectors.end()), 221 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 222 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 223 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 224 225 unsigned D = isColumnMajor() ? NumColumns : NumRows; 226 for (unsigned J = 0; J < D; ++J) 227 addVector(UndefValue::get(FixedVectorType::get( 228 EltTy, isColumnMajor() ? NumRows : NumColumns))); 229 } 230 231 Value *getVector(unsigned i) const { return Vectors[i]; } 232 Value *getColumn(unsigned i) const { 233 assert(isColumnMajor() && "only supported for column-major matrixes"); 234 return Vectors[i]; 235 } 236 Value *getRow(unsigned i) const { 237 assert(!isColumnMajor() && "only supported for row-major matrixes"); 238 return Vectors[i]; 239 } 240 241 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 242 243 Type *getElementType() { return getVectorTy()->getElementType(); } 244 245 unsigned getNumVectors() const { 246 if (isColumnMajor()) 247 return getNumColumns(); 248 return getNumRows(); 249 } 250 251 unsigned getNumColumns() const { 252 if (isColumnMajor()) 253 return Vectors.size(); 254 else { 255 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 256 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 257 } 258 } 259 unsigned getNumRows() const { 260 if (isColumnMajor()) { 261 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 262 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 263 } else 264 return Vectors.size(); 265 } 266 267 void addVector(Value *V) { Vectors.push_back(V); } 268 VectorType *getColumnTy() { 269 assert(isColumnMajor() && "only supported for column-major matrixes"); 270 return getVectorTy(); 271 } 272 273 VectorType *getVectorTy() { 274 return cast<VectorType>(Vectors[0]->getType()); 275 } 276 277 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 278 assert(isColumnMajor() && 279 "columns() only supported for column-major matrixes"); 280 return make_range(Vectors.begin(), Vectors.end()); 281 } 282 283 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 284 return make_range(Vectors.begin(), Vectors.end()); 285 } 286 287 /// Embed the vectors of the matrix into a flat vector by concatenating 288 /// them. 289 Value *embedInVector(IRBuilder<> &Builder) const { 290 return Vectors.size() == 1 ? Vectors[0] 291 : concatenateVectors(Builder, Vectors); 292 } 293 294 MatrixTy &addNumLoads(unsigned N) { 295 OpInfo.NumLoads += N; 296 return *this; 297 } 298 299 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 300 301 MatrixTy &addNumStores(unsigned N) { 302 OpInfo.NumStores += N; 303 return *this; 304 } 305 306 MatrixTy &addNumComputeOps(unsigned N) { 307 OpInfo.NumComputeOps += N; 308 return *this; 309 } 310 311 unsigned getNumStores() const { return OpInfo.NumStores; } 312 unsigned getNumLoads() const { return OpInfo.NumLoads; } 313 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 314 315 const OpInfoTy &getOpInfo() const { return OpInfo; } 316 317 bool isColumnMajor() const { return IsColumnMajor; } 318 319 unsigned getStride() const { 320 if (isColumnMajor()) 321 return getNumRows(); 322 return getNumColumns(); 323 } 324 325 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 326 /// matrix is column-major, the result vector is extracted from a column 327 /// vector, otherwise from a row vector. 328 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 329 IRBuilder<> &Builder) const { 330 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 331 Value *Undef = UndefValue::get(Vec->getType()); 332 return Builder.CreateShuffleVector( 333 Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 334 "block"); 335 } 336 }; 337 338 struct ShapeInfo { 339 unsigned NumRows; 340 unsigned NumColumns; 341 342 bool IsColumnMajor; 343 344 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 345 : NumRows(NumRows), NumColumns(NumColumns), 346 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 347 348 ShapeInfo(Value *NumRows, Value *NumColumns) 349 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 350 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 351 352 bool operator==(const ShapeInfo &other) { 353 return NumRows == other.NumRows && NumColumns == other.NumColumns; 354 } 355 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 356 357 /// Returns true if shape-information is defined, meaning both dimensions 358 /// are != 0. 359 operator bool() const { 360 assert(NumRows == 0 || NumColumns != 0); 361 return NumRows != 0; 362 } 363 364 unsigned getStride() const { 365 if (IsColumnMajor) 366 return NumRows; 367 return NumColumns; 368 } 369 370 unsigned getNumVectors() const { 371 if (IsColumnMajor) 372 return NumColumns; 373 return NumRows; 374 } 375 }; 376 377 /// Maps instructions to their shape information. The shape information 378 /// describes the shape to be used while lowering. This matches the shape of 379 /// the result value of the instruction, with the only exceptions being store 380 /// instructions and the matrix_column_major_store intrinsics. For those, the 381 /// shape information indicates that those instructions should be lowered 382 /// using shape information as well. 383 DenseMap<Value *, ShapeInfo> ShapeMap; 384 385 /// List of instructions to remove. While lowering, we are not replacing all 386 /// users of a lowered instruction, if shape information is available and 387 /// those need to be removed after we finished lowering. 388 SmallVector<Instruction *, 16> ToRemove; 389 390 /// Map from instructions to their produced column matrix. 391 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 392 393 public: 394 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 395 AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, 396 OptimizationRemarkEmitter &ORE) 397 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 398 LI(LI), ORE(ORE) {} 399 400 unsigned getNumOps(Type *VT) { 401 assert(isa<VectorType>(VT) && "Expected vector type"); 402 return getNumOps(VT->getScalarType(), 403 cast<VectorType>(VT)->getNumElements()); 404 } 405 406 // 407 /// Return the estimated number of vector ops required for an operation on 408 /// \p VT * N. 409 unsigned getNumOps(Type *ST, unsigned N) { 410 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 411 double(TTI.getRegisterBitWidth(true))); 412 } 413 414 /// Return the set of vectors that a matrix value is lowered to. 415 /// 416 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 417 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 418 /// into vectors. 419 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 420 IRBuilder<> &Builder) { 421 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 422 assert(VType && "MatrixVal must be a vector type"); 423 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 424 "The vector size must match the number of matrix elements"); 425 426 // Check if we lowered MatrixVal using shape information. In that case, 427 // return the existing matrix, if it matches the requested shape 428 // information. If there is a mis-match, embed the result in a flat 429 // vector and split it later. 430 auto Found = Inst2ColumnMatrix.find(MatrixVal); 431 if (Found != Inst2ColumnMatrix.end()) { 432 MatrixTy &M = Found->second; 433 // Return the found matrix, if its shape matches the requested shape 434 // information 435 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 436 return M; 437 438 MatrixVal = M.embedInVector(Builder); 439 } 440 441 // Otherwise split MatrixVal. 442 SmallVector<Value *, 16> SplitVecs; 443 Value *Undef = UndefValue::get(VType); 444 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 445 MaskStart += SI.getStride()) { 446 Value *V = Builder.CreateShuffleVector( 447 MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), 448 "split"); 449 SplitVecs.push_back(V); 450 } 451 452 return {SplitVecs}; 453 } 454 455 /// If \p V already has a known shape return false. Otherwise set the shape 456 /// for instructions that support it. 457 bool setShapeInfo(Value *V, ShapeInfo Shape) { 458 assert(Shape && "Shape not set"); 459 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 460 return false; 461 462 auto SIter = ShapeMap.find(V); 463 if (SIter != ShapeMap.end()) { 464 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 465 << SIter->second.NumRows << " " 466 << SIter->second.NumColumns << " for " << *V << "\n"); 467 return false; 468 } 469 470 ShapeMap.insert({V, Shape}); 471 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 472 << " for " << *V << "\n"); 473 return true; 474 } 475 476 bool isUniformShape(Value *V) { 477 Instruction *I = dyn_cast<Instruction>(V); 478 if (!I) 479 return true; 480 481 switch (I->getOpcode()) { 482 case Instruction::FAdd: 483 case Instruction::FSub: 484 case Instruction::FMul: // Scalar multiply. 485 case Instruction::Add: 486 case Instruction::Mul: 487 case Instruction::Sub: 488 return true; 489 default: 490 return false; 491 } 492 } 493 494 /// Returns true if shape information can be used for \p V. The supported 495 /// instructions must match the instructions that can be lowered by this pass. 496 bool supportsShapeInfo(Value *V) { 497 Instruction *Inst = dyn_cast<Instruction>(V); 498 if (!Inst) 499 return false; 500 501 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 502 if (II) 503 switch (II->getIntrinsicID()) { 504 case Intrinsic::matrix_multiply: 505 case Intrinsic::matrix_transpose: 506 case Intrinsic::matrix_column_major_load: 507 case Intrinsic::matrix_column_major_store: 508 return true; 509 default: 510 return false; 511 } 512 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 513 } 514 515 /// Propagate the shape information of instructions to their users. 516 /// The work list contains instructions for which we can compute the shape, 517 /// either based on the information provided by matrix intrinsics or known 518 /// shapes of operands. 519 SmallVector<Instruction *, 32> 520 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 521 SmallVector<Instruction *, 32> NewWorkList; 522 // Pop an element for which we guaranteed to have at least one of the 523 // operand shapes. Add the shape for this and then add users to the work 524 // list. 525 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 526 while (!WorkList.empty()) { 527 Instruction *Inst = WorkList.back(); 528 WorkList.pop_back(); 529 530 // New entry, set the value and insert operands 531 bool Propagate = false; 532 533 Value *MatrixA; 534 Value *MatrixB; 535 Value *M; 536 Value *N; 537 Value *K; 538 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 539 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 540 m_Value(N), m_Value(K)))) { 541 Propagate = setShapeInfo(Inst, {M, K}); 542 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 543 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 544 // Flip dimensions. 545 Propagate = setShapeInfo(Inst, {N, M}); 546 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 547 m_Value(MatrixA), m_Value(), m_Value(), 548 m_Value(), m_Value(M), m_Value(N)))) { 549 Propagate = setShapeInfo(Inst, {N, M}); 550 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 551 m_Value(), m_Value(), m_Value(), m_Value(M), 552 m_Value(N)))) { 553 Propagate = setShapeInfo(Inst, {M, N}); 554 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 555 auto OpShape = ShapeMap.find(MatrixA); 556 if (OpShape != ShapeMap.end()) 557 setShapeInfo(Inst, OpShape->second); 558 continue; 559 } else if (isUniformShape(Inst)) { 560 // Find the first operand that has a known shape and use that. 561 for (auto &Op : Inst->operands()) { 562 auto OpShape = ShapeMap.find(Op.get()); 563 if (OpShape != ShapeMap.end()) { 564 Propagate |= setShapeInfo(Inst, OpShape->second); 565 break; 566 } 567 } 568 } 569 570 if (Propagate) { 571 NewWorkList.push_back(Inst); 572 for (auto *User : Inst->users()) 573 if (ShapeMap.count(User) == 0) 574 WorkList.push_back(cast<Instruction>(User)); 575 } 576 } 577 578 return NewWorkList; 579 } 580 581 /// Propagate the shape to operands of instructions with shape information. 582 /// \p Worklist contains the instruction for which we already know the shape. 583 SmallVector<Instruction *, 32> 584 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 585 SmallVector<Instruction *, 32> NewWorkList; 586 587 auto pushInstruction = [](Value *V, 588 SmallVectorImpl<Instruction *> &WorkList) { 589 Instruction *I = dyn_cast<Instruction>(V); 590 if (I) 591 WorkList.push_back(I); 592 }; 593 // Pop an element with known shape. Traverse the operands, if their shape 594 // derives from the result shape and is unknown, add it and add them to the 595 // worklist. 596 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 597 while (!WorkList.empty()) { 598 Value *V = WorkList.back(); 599 WorkList.pop_back(); 600 601 size_t BeforeProcessingV = WorkList.size(); 602 if (!isa<Instruction>(V)) 603 continue; 604 605 Value *MatrixA; 606 Value *MatrixB; 607 Value *M; 608 Value *N; 609 Value *K; 610 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 611 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 612 m_Value(N), m_Value(K)))) { 613 if (setShapeInfo(MatrixA, {M, N})) 614 pushInstruction(MatrixA, WorkList); 615 616 if (setShapeInfo(MatrixB, {N, K})) 617 pushInstruction(MatrixB, WorkList); 618 619 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 620 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 621 // Flip dimensions. 622 if (setShapeInfo(MatrixA, {M, N})) 623 pushInstruction(MatrixA, WorkList); 624 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 625 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 626 m_Value(M), m_Value(N)))) { 627 if (setShapeInfo(MatrixA, {M, N})) { 628 pushInstruction(MatrixA, WorkList); 629 } 630 } else if (isa<LoadInst>(V) || 631 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 632 // Nothing to do, no matrix input. 633 } else if (isa<StoreInst>(V)) { 634 // Nothing to do. We forward-propagated to this so we would just 635 // backward propagate to an instruction with an already known shape. 636 } else if (isUniformShape(V)) { 637 // Propagate to all operands. 638 ShapeInfo Shape = ShapeMap[V]; 639 for (Use &U : cast<Instruction>(V)->operands()) { 640 if (setShapeInfo(U.get(), Shape)) 641 pushInstruction(U.get(), WorkList); 642 } 643 } 644 // After we discovered new shape info for new instructions in the 645 // worklist, we use their users as seeds for the next round of forward 646 // propagation. 647 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 648 for (User *U : WorkList[I]->users()) 649 if (isa<Instruction>(U) && V != U) 650 NewWorkList.push_back(cast<Instruction>(U)); 651 } 652 return NewWorkList; 653 } 654 655 bool Visit() { 656 if (EnableShapePropagation) { 657 SmallVector<Instruction *, 32> WorkList; 658 659 // Initially only the shape of matrix intrinsics is known. 660 // Initialize the work list with ops carrying shape information. 661 for (BasicBlock &BB : Func) 662 for (Instruction &Inst : BB) { 663 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 664 if (!II) 665 continue; 666 667 switch (II->getIntrinsicID()) { 668 case Intrinsic::matrix_multiply: 669 case Intrinsic::matrix_transpose: 670 case Intrinsic::matrix_column_major_load: 671 case Intrinsic::matrix_column_major_store: 672 WorkList.push_back(&Inst); 673 break; 674 default: 675 break; 676 } 677 } 678 // Propagate shapes until nothing changes any longer. 679 while (!WorkList.empty()) { 680 WorkList = propagateShapeForward(WorkList); 681 WorkList = propagateShapeBackward(WorkList); 682 } 683 } 684 685 bool Changed = false; 686 SmallVector<CallInst *, 16> MaybeFusableInsts; 687 SmallVector<Instruction *, 16> MatrixInsts; 688 689 // First, collect all instructions with shape information and candidates for 690 // fusion (currently only matrix multiplies). 691 ReversePostOrderTraversal<Function *> RPOT(&Func); 692 for (auto *BB : RPOT) 693 for (Instruction &I : *BB) { 694 if (ShapeMap.find(&I) == ShapeMap.end()) 695 continue; 696 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 697 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 698 MatrixInsts.push_back(&I); 699 } 700 701 // Second, try to fuse candidates. 702 SmallPtrSet<Instruction *, 16> FusedInsts; 703 for (CallInst *CI : MaybeFusableInsts) 704 LowerMatrixMultiplyFused(CI, FusedInsts); 705 Changed = !FusedInsts.empty(); 706 707 // Third, lower remaining instructions with shape information. 708 for (Instruction *Inst : MatrixInsts) { 709 if (FusedInsts.count(Inst)) 710 continue; 711 712 IRBuilder<> Builder(Inst); 713 714 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 715 Changed |= VisitCallInst(CInst); 716 717 Value *Op1; 718 Value *Op2; 719 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 720 Changed |= VisitBinaryOperator(BinOp); 721 if (match(Inst, m_Load(m_Value(Op1)))) 722 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 723 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 724 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 725 } 726 727 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); 728 RemarkGen.emitRemarks(); 729 730 for (Instruction *Inst : reverse(ToRemove)) 731 Inst->eraseFromParent(); 732 733 return Changed; 734 } 735 736 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 737 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 738 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 739 Type *EltPtrType = PointerType::get(EltType, AS); 740 return Builder.CreatePointerCast(BasePtr, EltPtrType); 741 } 742 743 /// Replace intrinsic calls 744 bool VisitCallInst(CallInst *Inst) { 745 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 746 return false; 747 748 switch (Inst->getCalledFunction()->getIntrinsicID()) { 749 case Intrinsic::matrix_multiply: 750 LowerMultiply(Inst); 751 break; 752 case Intrinsic::matrix_transpose: 753 LowerTranspose(Inst); 754 break; 755 case Intrinsic::matrix_column_major_load: 756 LowerColumnMajorLoad(Inst); 757 break; 758 case Intrinsic::matrix_column_major_store: 759 LowerColumnMajorStore(Inst); 760 break; 761 default: 762 return false; 763 } 764 return true; 765 } 766 767 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 768 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 769 /// ConstantInt, reduce the initial alignment based on the byte offset. For 770 /// non-ConstantInt strides, return the common alignment of the initial 771 /// alignment and the element size in bytes. 772 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 773 MaybeAlign A) const { 774 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 775 if (Idx == 0) 776 return InitialAlign; 777 778 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 779 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 780 uint64_t StrideInBytes = 781 ConstStride->getZExtValue() * ElementSizeInBits / 8; 782 return commonAlignment(InitialAlign, Idx * StrideInBytes); 783 } 784 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 785 } 786 787 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 788 /// vectors. 789 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 790 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 791 auto VType = cast<VectorType>(Ty); 792 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 793 MatrixTy Result; 794 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 795 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, 796 Shape.getStride(), VType->getElementType(), 797 Builder); 798 Value *Vector = Builder.CreateAlignedLoad( 799 GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), 800 IsVolatile, "col.load"); 801 802 Result.addVector(Vector); 803 } 804 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 805 Result.getNumVectors()); 806 } 807 808 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 809 /// starting at \p MatrixPtr[I][J]. 810 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 811 ShapeInfo MatrixShape, Value *I, Value *J, 812 ShapeInfo ResultShape, Type *EltTy, 813 IRBuilder<> &Builder) { 814 815 Value *Offset = Builder.CreateAdd( 816 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 817 818 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 819 Value *EltPtr = 820 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 821 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 822 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 823 ResultShape.NumColumns); 824 Type *TilePtrTy = PointerType::get(TileTy, AS); 825 Value *TilePtr = 826 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 827 828 return loadMatrix(TileTy, TilePtr, Align, 829 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 830 ResultShape, Builder); 831 } 832 833 /// Lower a load instruction with shape information. 834 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 835 bool IsVolatile, ShapeInfo Shape) { 836 IRBuilder<> Builder(Inst); 837 finalizeLowering(Inst, 838 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 839 Shape, Builder), 840 Builder); 841 } 842 843 /// Lowers llvm.matrix.column.major.load. 844 /// 845 /// The intrinsic loads a matrix from memory using a stride between columns. 846 void LowerColumnMajorLoad(CallInst *Inst) { 847 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 848 "Intrinsic only supports column-major layout!"); 849 Value *Ptr = Inst->getArgOperand(0); 850 Value *Stride = Inst->getArgOperand(1); 851 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 852 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 853 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 854 } 855 856 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 857 /// MatrixPtr[I][J]. 858 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 859 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 860 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 861 Value *Offset = Builder.CreateAdd( 862 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 863 864 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 865 Value *EltPtr = 866 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 867 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 868 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 869 StoreVal.getNumColumns()); 870 Type *TilePtrTy = PointerType::get(TileTy, AS); 871 Value *TilePtr = 872 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 873 874 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 875 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 876 } 877 878 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 879 /// vectors. 880 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 881 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 882 IRBuilder<> &Builder) { 883 auto VType = cast<VectorType>(Ty); 884 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 885 for (auto Vec : enumerate(StoreVal.vectors())) { 886 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), 887 Stride, StoreVal.getStride(), 888 VType->getElementType(), Builder); 889 Builder.CreateAlignedStore(Vec.value(), GEP, 890 getAlignForIndex(Vec.index(), Stride, 891 VType->getElementType(), 892 MAlign), 893 IsVolatile); 894 } 895 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 896 StoreVal.getNumVectors()); 897 } 898 899 /// Lower a store instruction with shape information. 900 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 901 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 902 IRBuilder<> Builder(Inst); 903 auto StoreVal = getMatrix(Matrix, Shape, Builder); 904 finalizeLowering(Inst, 905 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 906 IsVolatile, Builder), 907 Builder); 908 } 909 910 /// Lowers llvm.matrix.column.major.store. 911 /// 912 /// The intrinsic store a matrix back memory using a stride between columns. 913 void LowerColumnMajorStore(CallInst *Inst) { 914 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 915 "Intrinsic only supports column-major layout!"); 916 Value *Matrix = Inst->getArgOperand(0); 917 Value *Ptr = Inst->getArgOperand(1); 918 Value *Stride = Inst->getArgOperand(2); 919 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 920 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 921 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 922 } 923 924 // Set elements I..I+NumElts-1 to Block 925 Value *insertVector(Value *Col, unsigned I, Value *Block, 926 IRBuilder<> &Builder) { 927 928 // First, bring Block to the same size as Col 929 unsigned BlockNumElts = 930 cast<VectorType>(Block->getType())->getNumElements(); 931 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 932 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 933 934 Value *Undef = UndefValue::get(Block->getType()); 935 Block = Builder.CreateShuffleVector( 936 Block, Undef, 937 createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 938 939 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 940 // 8, 4, 5, 6 941 SmallVector<int, 16> Mask; 942 unsigned i; 943 for (i = 0; i < I; i++) 944 Mask.push_back(i); 945 946 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 947 for (; i < I + BlockNumElts; i++) 948 Mask.push_back(i - I + VecNumElts); 949 950 for (; i < VecNumElts; i++) 951 Mask.push_back(i); 952 953 return Builder.CreateShuffleVector(Col, Block, Mask); 954 } 955 956 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 957 IRBuilder<> &Builder, bool AllowContraction, 958 unsigned &NumComputeOps) { 959 NumComputeOps += getNumOps(A->getType()); 960 if (!Sum) 961 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 962 963 if (UseFPOp) { 964 if (AllowContraction) { 965 // Use fmuladd for floating point operations and let the backend decide 966 // if that's profitable. 967 Function *FMulAdd = Intrinsic::getDeclaration( 968 Func.getParent(), Intrinsic::fmuladd, A->getType()); 969 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 970 } 971 NumComputeOps += getNumOps(A->getType()); 972 Value *Mul = Builder.CreateFMul(A, B); 973 return Builder.CreateFAdd(Sum, Mul); 974 } 975 976 NumComputeOps += getNumOps(A->getType()); 977 Value *Mul = Builder.CreateMul(A, B); 978 return Builder.CreateAdd(Sum, Mul); 979 } 980 981 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 982 /// users with shape information, there's nothing to do: the will use the 983 /// cached value when they are lowered. For other users, \p Matrix is 984 /// flattened and the uses are updated to use it. Also marks \p Inst for 985 /// deletion. 986 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 987 IRBuilder<> &Builder) { 988 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 989 990 ToRemove.push_back(Inst); 991 Value *Flattened = nullptr; 992 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 993 Use &U = *I++; 994 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 995 if (!Flattened) 996 Flattened = Matrix.embedInVector(Builder); 997 U.set(Flattened); 998 } 999 } 1000 } 1001 1002 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1003 /// addition. 1004 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1005 const MatrixTy &B, bool AllowContraction, 1006 IRBuilder<> &Builder, bool isTiled) { 1007 const unsigned VF = std::max<unsigned>( 1008 TTI.getRegisterBitWidth(true) / 1009 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1010 1U); 1011 unsigned R = Result.getNumRows(); 1012 unsigned C = Result.getNumColumns(); 1013 unsigned M = A.getNumColumns(); 1014 1015 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1016 assert(A.isColumnMajor() == B.isColumnMajor() && 1017 Result.isColumnMajor() == A.isColumnMajor() && 1018 "operands must agree on matrix layout"); 1019 unsigned NumComputeOps = 0; 1020 if (A.isColumnMajor()) { 1021 // Multiply columns from the first operand with scalars from the second 1022 // operand. Then move along the K axes and accumulate the columns. With 1023 // this the adds can be vectorized without reassociation. 1024 for (unsigned J = 0; J < C; ++J) { 1025 unsigned BlockSize = VF; 1026 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1027 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1028 1029 for (unsigned I = 0; I < R; I += BlockSize) { 1030 // Gradually lower the vectorization factor to cover the remainder. 1031 while (I + BlockSize > R) 1032 BlockSize /= 2; 1033 1034 Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) 1035 : nullptr; 1036 for (unsigned K = 0; K < M; ++K) { 1037 Value *L = A.extractVector(I, K, BlockSize, Builder); 1038 Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); 1039 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1040 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1041 Result.getElementType()->isFloatingPointTy(), 1042 Builder, AllowContraction, NumComputeOps); 1043 } 1044 Result.setVector(J, 1045 insertVector(Result.getVector(J), I, Sum, Builder)); 1046 } 1047 } 1048 } else { 1049 // Multiply rows from the second operand with scalars from the first 1050 // operand. Then move along the K axes and accumulate the rows. With this 1051 // the adds can be vectorized without reassociation. 1052 for (unsigned I = 0; I < R; ++I) { 1053 unsigned BlockSize = VF; 1054 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1055 for (unsigned J = 0; J < C; J += BlockSize) { 1056 // Gradually lower the vectorization factor to cover the remainder. 1057 while (J + BlockSize > C) 1058 BlockSize /= 2; 1059 1060 Value *Sum = nullptr; 1061 for (unsigned K = 0; K < M; ++K) { 1062 Value *R = B.extractVector(K, J, BlockSize, Builder); 1063 Value *LH = Builder.CreateExtractElement(A.getVector(I), K); 1064 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1065 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1066 IsFP, Builder, AllowContraction, NumComputeOps); 1067 } 1068 Result.setVector(I, 1069 insertVector(Result.getVector(I), J, Sum, Builder)); 1070 } 1071 } 1072 } 1073 Result.addNumComputeOps(NumComputeOps); 1074 } 1075 1076 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1077 /// copying it to a new location. This new or otherwise the original location 1078 /// is returned. 1079 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1080 CallInst *MatMul) { 1081 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1082 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1083 1084 AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); 1085 1086 // If we can statically determine noalias we're good. 1087 if (!LdAliased) 1088 return Load->getPointerOperand(); 1089 1090 // Create code to check if the memory locations of the Load and Store 1091 // overlap and if they do, copy Load's operand to a new buffer. 1092 1093 // First, create new blocks for 2n part of the check and the copy. 1094 BasicBlock *Check0 = MatMul->getParent(); 1095 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1096 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1097 // as we adjust Check0 and Check1's branches. 1098 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1099 for (BasicBlock *Succ : successors(Check0)) 1100 DTUpdates.push_back({DT.Delete, Check0, Succ}); 1101 1102 BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1103 nullptr, "alias_cont"); 1104 BasicBlock *Copy = 1105 SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); 1106 BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1107 nullptr, "no_alias"); 1108 1109 // Check if the loaded memory location begins before the end of the store 1110 // location. If the condition holds, they might overlap, otherwise they are 1111 // guaranteed to not overlap. 1112 IRBuilder<> Builder(MatMul); 1113 Check0->getTerminator()->eraseFromParent(); 1114 Builder.SetInsertPoint(Check0); 1115 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1116 Value *StoreBegin = Builder.CreatePtrToInt( 1117 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1118 Value *StoreEnd = Builder.CreateAdd( 1119 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1120 "store.end", true, true); 1121 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1122 IntPtrTy, "load.begin"); 1123 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1124 Fusion); 1125 1126 // Check if the store begins before the end of the load location. If the 1127 // condition holds, they alias, otherwise they are guaranteed to not 1128 // overlap. 1129 Check1->getTerminator()->eraseFromParent(); 1130 Builder.SetInsertPoint(Check1, Check1->begin()); 1131 Value *LoadEnd = Builder.CreateAdd( 1132 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1133 "load.end", true, true); 1134 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1135 Fusion); 1136 1137 // Copy load operand to new alloca. 1138 Builder.SetInsertPoint(Copy, Copy->begin()); 1139 AllocaInst *NewLd = 1140 Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 1141 Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 1142 Load->getPointerOperand(), Load->getAlign(), 1143 LoadLoc.Size.getValue()); 1144 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1145 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1146 PHI->addIncoming(Load->getPointerOperand(), Check0); 1147 PHI->addIncoming(Load->getPointerOperand(), Check1); 1148 PHI->addIncoming(NewLd, Copy); 1149 1150 // Adjust DT. 1151 DTUpdates.push_back({DT.Insert, Check0, Check1}); 1152 DTUpdates.push_back({DT.Insert, Check0, Fusion}); 1153 DTUpdates.push_back({DT.Insert, Check1, Copy}); 1154 DTUpdates.push_back({DT.Insert, Check1, Fusion}); 1155 DT.applyUpdates(DTUpdates); 1156 return PHI; 1157 } 1158 1159 bool isFusionProfitable(CallInst *MatMul) { 1160 if (ForceFusion) 1161 return true; 1162 1163 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1164 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1165 1166 const unsigned R = LShape.NumRows; 1167 const unsigned C = RShape.NumColumns; 1168 const unsigned M = LShape.NumColumns; 1169 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1170 1171 const unsigned VF = 1172 std::max<unsigned>(TTI.getRegisterBitWidth(true) / 1173 EltType->getPrimitiveSizeInBits().getFixedSize(), 1174 1U); 1175 1176 // Cost model for tiling 1177 // 1178 // For tiling to be beneficial, we need reuse either along the R or 1179 // the C axis. We vectorize along the R axis so that means at least 1180 // 3 elements. 1181 // TODO: Also consider cost of copying if operands alias. 1182 if (R <= VF && C == 1) 1183 return false; 1184 // Then we need enough elements to exceed the number of vector 1185 // registers we have. Note that this is an oversimplification since 1186 // fusing also takes some extra loads which may exceed the number of 1187 // reloads necessary. 1188 unsigned Op0Regs = (R + VF - 1) / VF * M; 1189 unsigned Op1Regs = (M + VF - 1) / VF * C; 1190 return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 1191 } 1192 1193 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1194 MatrixTy Res; 1195 auto *ColumType = FixedVectorType::get(EltType, R); 1196 for (unsigned I = 0; I < C; ++I) 1197 Res.addVector(ConstantAggregateZero::get(ColumType)); 1198 return Res; 1199 } 1200 1201 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1202 StoreInst *Store, 1203 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1204 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1205 "Tiling only supported for column-major matrixes at the moment!"); 1206 if (!isFusionProfitable(MatMul)) 1207 return; 1208 1209 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1210 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1211 1212 const unsigned R = LShape.NumRows; 1213 const unsigned C = RShape.NumColumns; 1214 const unsigned M = LShape.NumColumns; 1215 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1216 1217 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1218 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1219 Value *CPtr = Store->getPointerOperand(); 1220 1221 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1222 MatMul->hasAllowContract()); 1223 IRBuilder<> Builder(Store); 1224 for (unsigned J = 0; J < C; J += TileSize) 1225 for (unsigned I = 0; I < R; I += TileSize) { 1226 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1227 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1228 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1229 1230 for (unsigned K = 0; K < M; K += TileSize) { 1231 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1232 MatrixTy A = 1233 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1234 LShape, Builder.getInt64(I), Builder.getInt64(K), 1235 {TileR, TileM}, EltType, Builder); 1236 MatrixTy B = 1237 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1238 RShape, Builder.getInt64(K), Builder.getInt64(J), 1239 {TileM, TileC}, EltType, Builder); 1240 emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); 1241 } 1242 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1243 Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); 1244 } 1245 1246 // Mark eliminated instructions as fused and remove them. 1247 FusedInsts.insert(Store); 1248 FusedInsts.insert(MatMul); 1249 Store->eraseFromParent(); 1250 MatMul->eraseFromParent(); 1251 if (LoadOp0->hasNUses(0)) { 1252 FusedInsts.insert(LoadOp0); 1253 LoadOp0->eraseFromParent(); 1254 } 1255 if (LoadOp1->hasNUses(0)) { 1256 FusedInsts.insert(LoadOp1); 1257 LoadOp1->eraseFromParent(); 1258 } 1259 } 1260 1261 /// Try to lower matrix multiply chains by fusing operations. 1262 /// 1263 /// Currently we only lower {ld, ld} -> matmul -> st chains. 1264 // 1265 /// No need to return a MatrixTy object for the result of the operation, since 1266 /// the single store user will be lowered as part of this. Instructions that 1267 /// are completely eliminated by fusion are added to \p FusedInsts. 1268 void LowerMatrixMultiplyFused(CallInst *MatMul, 1269 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1270 if (!FuseMatrix || !MatMul->hasOneUse() || 1271 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1272 return; 1273 1274 auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); 1275 auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); 1276 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1277 if (LoadOp0 && LoadOp1 && Store) { 1278 // The store address must dominate the MatMul instruction, otherwise 1279 // we create invalid IR. 1280 // FIXME: See if we can hoist the store address computation. 1281 auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1282 if (AddrI && (!DT.dominates(AddrI, MatMul))) 1283 return; 1284 1285 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1286 return; 1287 } 1288 } 1289 1290 /// Lowers llvm.matrix.multiply. 1291 void LowerMultiply(CallInst *MatMul) { 1292 IRBuilder<> Builder(MatMul); 1293 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1294 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1295 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1296 1297 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1298 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1299 1300 const unsigned R = LShape.NumRows; 1301 const unsigned C = RShape.NumColumns; 1302 assert(LShape.NumColumns == RShape.NumRows); 1303 1304 // Initialize the output 1305 MatrixTy Result(R, C, EltType); 1306 1307 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1308 MatMul->hasAllowContract()); 1309 emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); 1310 finalizeLowering(MatMul, Result, Builder); 1311 } 1312 1313 /// Lowers llvm.matrix.transpose. 1314 void LowerTranspose(CallInst *Inst) { 1315 MatrixTy Result; 1316 IRBuilder<> Builder(Inst); 1317 Value *InputVal = Inst->getArgOperand(0); 1318 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1319 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1320 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1321 1322 const unsigned NewNumVecs = 1323 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1324 const unsigned NewNumElts = 1325 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1326 1327 for (unsigned I = 0; I < NewNumVecs; ++I) { 1328 // Build a single result vector. First initialize it. 1329 Value *ResultVector = UndefValue::get( 1330 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1331 // Go through the old elements and insert it into the resulting vector. 1332 for (auto J : enumerate(InputMatrix.vectors())) { 1333 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1334 // Row and column indices are transposed. 1335 ResultVector = 1336 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1337 } 1338 Result.addVector(ResultVector); 1339 } 1340 1341 // TODO: Improve estimate of operations needed for transposes. Currently we 1342 // just count the insertelement/extractelement instructions, but do not 1343 // account for later simplifications/combines. 1344 finalizeLowering( 1345 Inst, 1346 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 1347 Builder); 1348 } 1349 1350 /// Lower load instructions, if shape information is available. 1351 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1352 auto I = ShapeMap.find(Inst); 1353 if (I == ShapeMap.end()) 1354 return false; 1355 1356 LowerLoad(Inst, Ptr, Inst->getAlign(), 1357 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1358 I->second); 1359 return true; 1360 } 1361 1362 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1363 IRBuilder<> &Builder) { 1364 auto I = ShapeMap.find(StoredVal); 1365 if (I == ShapeMap.end()) 1366 return false; 1367 1368 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1369 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1370 I->second); 1371 return true; 1372 } 1373 1374 /// Lower binary operators, if shape information is available. 1375 bool VisitBinaryOperator(BinaryOperator *Inst) { 1376 auto I = ShapeMap.find(Inst); 1377 if (I == ShapeMap.end()) 1378 return false; 1379 1380 Value *Lhs = Inst->getOperand(0); 1381 Value *Rhs = Inst->getOperand(1); 1382 1383 IRBuilder<> Builder(Inst); 1384 ShapeInfo &Shape = I->second; 1385 1386 MatrixTy Result; 1387 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1388 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1389 assert(A.isColumnMajor() == B.isColumnMajor() && 1390 Result.isColumnMajor() == A.isColumnMajor() && 1391 "operands must agree on matrix layout"); 1392 1393 // Helper to perform binary op on vectors. 1394 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1395 switch (Inst->getOpcode()) { 1396 case Instruction::Add: 1397 return Builder.CreateAdd(LHS, RHS); 1398 case Instruction::Mul: 1399 return Builder.CreateMul(LHS, RHS); 1400 case Instruction::Sub: 1401 return Builder.CreateSub(LHS, RHS); 1402 case Instruction::FAdd: 1403 return Builder.CreateFAdd(LHS, RHS); 1404 case Instruction::FMul: 1405 return Builder.CreateFMul(LHS, RHS); 1406 case Instruction::FSub: 1407 return Builder.CreateFSub(LHS, RHS); 1408 default: 1409 llvm_unreachable("Unsupported binary operator for matrix"); 1410 } 1411 }; 1412 1413 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1414 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1415 1416 finalizeLowering(Inst, 1417 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1418 Result.getNumVectors()), 1419 Builder); 1420 return true; 1421 } 1422 1423 /// Helper to linearize a matrix expression tree into a string. Currently 1424 /// matrix expressions are linarized by starting at an expression leaf and 1425 /// linearizing bottom up. 1426 struct ExprLinearizer { 1427 unsigned LengthToBreak = 100; 1428 std::string Str; 1429 raw_string_ostream Stream; 1430 unsigned LineLength = 0; 1431 const DataLayout &DL; 1432 1433 /// Mapping from instructions to matrixes. It is used to identify 1434 /// matrix instructions. 1435 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1436 1437 /// Mapping from values to the leaves of all expressions that the value is 1438 /// part of. 1439 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1440 1441 /// Set of matrix expressions in the scope of a given DISubprogram. 1442 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1443 1444 /// Leaf node of the expression to linearize. 1445 Value *Leaf; 1446 1447 /// Used to keep track of sub-expressions that get reused while linearizing 1448 /// the expression. Re-used sub-expressions are marked as (reused). 1449 SmallPtrSet<Value *, 8> ReusedExprs; 1450 1451 ExprLinearizer(const DataLayout &DL, 1452 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1453 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1454 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1455 Value *Leaf) 1456 : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1457 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1458 1459 void indent(unsigned N) { 1460 LineLength += N; 1461 for (unsigned i = 0; i < N; i++) 1462 Stream << " "; 1463 } 1464 1465 void lineBreak() { 1466 Stream << "\n"; 1467 LineLength = 0; 1468 } 1469 1470 void maybeIndent(unsigned Indent) { 1471 if (LineLength >= LengthToBreak) 1472 lineBreak(); 1473 1474 if (LineLength == 0) 1475 indent(Indent); 1476 } 1477 1478 void write(StringRef S) { 1479 LineLength += S.size(); 1480 Stream << S; 1481 } 1482 1483 Value *getUnderlyingObjectThroughLoads(Value *V) { 1484 if (Value *Ptr = getPointerOperand(V)) 1485 return getUnderlyingObjectThroughLoads(Ptr); 1486 else if (V->getType()->isPointerTy()) 1487 return GetUnderlyingObject(V, DL); 1488 return V; 1489 } 1490 1491 /// Returns true if \p V is a matrix value in the given subprogram. 1492 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1493 1494 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1495 /// \p SS. 1496 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1497 auto M = Inst2Matrix.find(V); 1498 if (M == Inst2Matrix.end()) 1499 SS << "unknown"; 1500 else { 1501 SS << M->second.getNumRows(); 1502 SS << "x"; 1503 SS << M->second.getNumColumns(); 1504 } 1505 } 1506 1507 /// Write the called function name. Handles calls to llvm.matrix.* 1508 /// specially: we write the name, followed by the dimensions of the input 1509 /// matrixes, followed by the scalar type name. 1510 void writeFnName(CallInst *CI) { 1511 if (!CI->getCalledFunction()) 1512 write("<no called fn>"); 1513 else { 1514 StringRef Name = CI->getCalledFunction()->getName(); 1515 if (!Name.startswith("llvm.matrix")) { 1516 write(Name); 1517 return; 1518 } 1519 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1520 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1521 .drop_front(StringRef("llvm.matrix.").size())); 1522 write("."); 1523 std::string Tmp = ""; 1524 raw_string_ostream SS(Tmp); 1525 1526 switch (II->getIntrinsicID()) { 1527 case Intrinsic::matrix_multiply: 1528 prettyPrintMatrixType(II->getOperand(0), SS); 1529 SS << "."; 1530 prettyPrintMatrixType(II->getOperand(1), SS); 1531 SS << "." << *II->getType()->getScalarType(); 1532 break; 1533 case Intrinsic::matrix_transpose: 1534 prettyPrintMatrixType(II->getOperand(0), SS); 1535 SS << "." << *II->getType()->getScalarType(); 1536 break; 1537 case Intrinsic::matrix_column_major_load: 1538 prettyPrintMatrixType(II, SS); 1539 SS << "." << *II->getType()->getScalarType(); 1540 break; 1541 case Intrinsic::matrix_column_major_store: 1542 prettyPrintMatrixType(II->getOperand(0), SS); 1543 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1544 break; 1545 default: 1546 llvm_unreachable("Unhandled case"); 1547 } 1548 SS.flush(); 1549 write(Tmp); 1550 } 1551 } 1552 1553 unsigned getNumShapeArgs(CallInst *CI) const { 1554 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1555 switch (II->getIntrinsicID()) { 1556 case Intrinsic::matrix_multiply: 1557 return 3; 1558 case Intrinsic::matrix_transpose: 1559 return 2; 1560 case Intrinsic::matrix_column_major_load: 1561 case Intrinsic::matrix_column_major_store: 1562 return 3; 1563 default: 1564 return 0; 1565 } 1566 } 1567 return 0; 1568 } 1569 1570 /// Special printing for values: for pointers, we print if they refer to an 1571 /// (function) external address or a stack address, for other values we 1572 /// either print the constant or "scalar"/"matrix" for other values. 1573 void write(Value *V) { 1574 V = getUnderlyingObjectThroughLoads(V); 1575 if (V->getType()->isPointerTy()) { 1576 if (isa<AllocaInst>(V)) { 1577 Stream << "stack addr"; 1578 LineLength += StringRef("stack addr").size(); 1579 } else { 1580 Stream << "addr"; 1581 LineLength += StringRef("addr").size(); 1582 } 1583 if (!V->getName().empty()) { 1584 Stream << " %" << V->getName() << ""; 1585 LineLength += V->getName().size() + 2; 1586 } 1587 return; 1588 } 1589 1590 std::string Tmp; 1591 raw_string_ostream TmpStream(Tmp); 1592 1593 if (auto *CI = dyn_cast<ConstantInt>(V)) 1594 TmpStream << CI->getValue(); 1595 else if (isa<Constant>(V)) 1596 TmpStream << "constant"; 1597 else { 1598 if (isMatrix(V)) 1599 TmpStream << "matrix"; 1600 else 1601 TmpStream << "scalar"; 1602 } 1603 TmpStream.flush(); 1604 Tmp = std::string(StringRef(Tmp).trim()); 1605 LineLength += Tmp.size(); 1606 Stream << Tmp; 1607 } 1608 1609 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1610 /// Expressions that are re-used multiple times are prefixed with (reused) 1611 /// at the re-used root instruction. 1612 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1613 bool ParentShared) { 1614 auto *I = cast<Instruction>(Expr); 1615 maybeIndent(Indent); 1616 SmallVector<Value *, 8> Ops; 1617 1618 // Is Expr shared with other expression leaves? 1619 bool ExprShared = false; 1620 1621 // Deal with shared subtrees. Mark them as shared, if required. 1622 if (!ParentShared) { 1623 auto SI = Shared.find(Expr); 1624 assert(SI != Shared.end() && SI->second.count(Leaf)); 1625 1626 for (Value *S : SI->second) { 1627 if (S == Leaf) 1628 continue; 1629 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1630 write("shared with remark at line " + std::to_string(DL.getLine()) + 1631 " column " + std::to_string(DL.getCol()) + " ("); 1632 } 1633 ExprShared = SI->second.size() > 1; 1634 } 1635 1636 bool Reused = !ReusedExprs.insert(Expr).second; 1637 if (Reused && !ParentReused) 1638 write("(reused) "); 1639 1640 if (auto *CI = dyn_cast<CallInst>(I)) { 1641 writeFnName(CI); 1642 1643 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 1644 } else if (isa<BitCastInst>(Expr)) { 1645 // Special case bitcasts, which are used to materialize matrixes from 1646 // non-matrix ops. 1647 write("matrix"); 1648 return; 1649 } else { 1650 Ops.append(I->value_op_begin(), I->value_op_end()); 1651 write(std::string(I->getOpcodeName())); 1652 } 1653 1654 write(std::string("(")); 1655 1656 unsigned NumOpsToBreak = 1; 1657 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 1658 NumOpsToBreak = 2; 1659 1660 for (Value *Op : Ops) { 1661 if (Ops.size() > NumOpsToBreak) 1662 lineBreak(); 1663 1664 maybeIndent(Indent + 1); 1665 if (isMatrix(Op)) 1666 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1667 else 1668 write(Op); 1669 if (Op != Ops.back()) 1670 write(", "); 1671 } 1672 1673 write(")"); 1674 } 1675 1676 const std::string &getResult() { 1677 Stream.flush(); 1678 return Str; 1679 } 1680 }; 1681 1682 /// Generate remarks for matrix operations in a function. To generate remarks 1683 /// for matrix expressions, the following approach is used: 1684 /// 1. Use the inlined-at debug information to group matrix operations to the 1685 /// DISubprograms they are contained in. 1686 /// 2. Collect leaves of matrix expressions (done in 1687 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1688 // mapping. Leaves are lowered matrix instructions without other matrix 1689 // users (like stores) in the current subprogram. 1690 /// 3. For each leaf, create a remark containing a linearizied version of the 1691 /// matrix expression. The expression is linearized by a recursive 1692 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1693 /// that multiple leaves can share sub-expressions. Shared subexpressions 1694 /// are explicitly marked as shared(). 1695 struct RemarkGenerator { 1696 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1697 OptimizationRemarkEmitter &ORE; 1698 Function &Func; 1699 const DataLayout &DL; 1700 1701 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 1702 OptimizationRemarkEmitter &ORE, Function &Func) 1703 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 1704 DL(Func.getParent()->getDataLayout()) {} 1705 1706 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1707 /// instructions in Inst2Matrix returning void or without any users in 1708 /// \p ExprsInSubprogram. Currently that should only include stores. 1709 SmallVector<Value *, 4> 1710 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1711 SmallVector<Value *, 4> Leaves; 1712 for (auto *Expr : ExprsInSubprogram) 1713 if (Expr->getType()->isVoidTy() || 1714 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1715 return ExprsInSubprogram.count(U); 1716 })) 1717 Leaves.push_back(Expr); 1718 return Leaves; 1719 } 1720 1721 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1722 /// to all visited expressions in \p Shared. Limit the matrix operations to 1723 /// the ones in \p ExprsInSubprogram. 1724 void collectSharedInfo(Value *Leaf, Value *V, 1725 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1726 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1727 1728 if (!ExprsInSubprogram.count(V)) 1729 return; 1730 1731 auto I = Shared.insert({V, {}}); 1732 I.first->second.insert(Leaf); 1733 1734 for (Value *Op : cast<Instruction>(V)->operand_values()) 1735 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1736 return; 1737 } 1738 1739 /// Calculate the number of exclusive and shared op counts for expression 1740 /// starting at \p V. Expressions used multiple times are counted once. 1741 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1742 std::pair<OpInfoTy, OpInfoTy> 1743 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1744 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1745 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1746 if (!ExprsInSubprogram.count(Root)) 1747 return {}; 1748 1749 // Already counted this expression. Stop. 1750 if (!ReusedExprs.insert(Root).second) 1751 return {}; 1752 1753 OpInfoTy SharedCount; 1754 OpInfoTy Count; 1755 1756 auto I = Shared.find(Root); 1757 auto CM = Inst2Matrix.find(Root); 1758 if (I->second.size() == 1) 1759 Count = CM->second.getOpInfo(); 1760 else 1761 SharedCount = CM->second.getOpInfo(); 1762 1763 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1764 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1765 Count += C.first; 1766 SharedCount += C.second; 1767 } 1768 return {Count, SharedCount}; 1769 } 1770 1771 void emitRemarks() { 1772 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1773 return; 1774 1775 // Map matrix operations to their containting subprograms, by traversing 1776 // the inlinedAt chain. If the function does not have a DISubprogram, we 1777 // only map them to the containing function. 1778 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1779 for (auto &KV : Inst2Matrix) { 1780 if (Func.getSubprogram()) { 1781 auto *I = cast<Instruction>(KV.first); 1782 DILocation *Context = I->getDebugLoc(); 1783 while (Context) { 1784 auto I = 1785 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1786 I.first->second.push_back(KV.first); 1787 Context = DebugLoc(Context).getInlinedAt(); 1788 } 1789 } else { 1790 auto I = Subprog2Exprs.insert({nullptr, {}}); 1791 I.first->second.push_back(KV.first); 1792 } 1793 } 1794 for (auto &KV : Subprog2Exprs) { 1795 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1796 KV.second.end()); 1797 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1798 1799 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1800 for (Value *Leaf : Leaves) 1801 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1802 1803 // Generate remarks for each leaf. 1804 for (auto *L : Leaves) { 1805 1806 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1807 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1808 while (Context) { 1809 if (getSubprogram(Context->getScope()) == KV.first) { 1810 Loc = Context; 1811 break; 1812 } 1813 Context = DebugLoc(Context).getInlinedAt(); 1814 } 1815 1816 SmallPtrSet<Value *, 8> ReusedExprs; 1817 OpInfoTy Counts, SharedCounts; 1818 std::tie(Counts, SharedCounts) = 1819 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1820 1821 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 1822 cast<Instruction>(L)->getParent()); 1823 1824 Rem << "Lowered with "; 1825 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1826 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1827 << ore::NV("NumComputeOps", Counts.NumComputeOps) 1828 << " compute ops"; 1829 1830 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1831 SharedCounts.NumComputeOps > 0) { 1832 Rem << ",\nadditionally " 1833 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1834 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1835 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1836 << " compute ops" 1837 << " are shared with other expressions"; 1838 } 1839 1840 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 1841 ORE.emit(Rem); 1842 } 1843 } 1844 } 1845 1846 std::string 1847 linearize(Value *L, 1848 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1849 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1850 const DataLayout &DL) { 1851 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 1852 Lin.linearizeExpr(L, 0, false, false); 1853 return Lin.getResult(); 1854 } 1855 }; 1856 }; 1857 } // namespace 1858 1859 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1860 FunctionAnalysisManager &AM) { 1861 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1862 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1863 auto &AA = AM.getResult<AAManager>(F); 1864 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 1865 auto &LI = AM.getResult<LoopAnalysis>(F); 1866 1867 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1868 if (LMT.Visit()) { 1869 PreservedAnalyses PA; 1870 PA.preserveSet<CFGAnalyses>(); 1871 return PA; 1872 } 1873 return PreservedAnalyses::all(); 1874 } 1875 1876 namespace { 1877 1878 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1879 public: 1880 static char ID; 1881 1882 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1883 initializeLowerMatrixIntrinsicsLegacyPassPass( 1884 *PassRegistry::getPassRegistry()); 1885 } 1886 1887 bool runOnFunction(Function &F) override { 1888 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1889 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1890 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 1891 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1892 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1893 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1894 bool C = LMT.Visit(); 1895 return C; 1896 } 1897 1898 void getAnalysisUsage(AnalysisUsage &AU) const override { 1899 AU.addRequired<TargetTransformInfoWrapperPass>(); 1900 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1901 AU.addRequired<AAResultsWrapperPass>(); 1902 AU.addRequired<DominatorTreeWrapperPass>(); 1903 AU.addPreserved<DominatorTreeWrapperPass>(); 1904 AU.addRequired<LoopInfoWrapperPass>(); 1905 AU.addPreserved<LoopInfoWrapperPass>(); 1906 } 1907 }; 1908 } // namespace 1909 1910 static const char pass_name[] = "Lower the matrix intrinsics"; 1911 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1912 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1913 false, false) 1914 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1915 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 1916 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 1917 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 1918 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1919 false, false) 1920 1921 Pass *llvm::createLowerMatrixIntrinsicsPass() { 1922 return new LowerMatrixIntrinsicsLegacyPass(); 1923 } 1924