1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/PatternMatch.h" 38 #include "llvm/InitializePasses.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Alignment.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/Debug.h" 43 #include "llvm/Transforms/Scalar.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 46 using namespace llvm; 47 using namespace PatternMatch; 48 49 #define DEBUG_TYPE "lower-matrix-intrinsics" 50 51 static cl::opt<bool> EnableShapePropagation( 52 "matrix-propagate-shape", cl::init(true), cl::Hidden, 53 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 54 "instructions.")); 55 56 static cl::opt<bool> 57 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 58 cl::desc("Enable/disable fusing matrix instructions.")); 59 // TODO: Allow and use non-square tiles. 60 static cl::opt<unsigned> TileSize( 61 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 62 cl::desc( 63 "Tile size for matrix instruction fusion using square-shaped tiles.")); 64 static cl::opt<bool> ForceFusion( 65 "force-fuse-matrix", cl::init(false), cl::Hidden, 66 cl::desc("Force matrix instruction fusion even if not profitable.")); 67 static cl::opt<bool> AllowContractEnabled( 68 "matrix-allow-contract", cl::init(false), cl::Hidden, 69 cl::desc("Allow the use of FMAs if available and profitable. This may " 70 "result in different results, due to less rounding error.")); 71 72 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 73 74 static cl::opt<MatrixLayoutTy> MatrixLayout( 75 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 76 cl::desc("Sets the default matrix layout"), 77 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 78 "Use column-major layout"), 79 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 80 "Use row-major layout"))); 81 82 /// Helper function to either return Scope, if it is a subprogram or the 83 /// attached subprogram for a local scope. 84 static DISubprogram *getSubprogram(DIScope *Scope) { 85 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 86 return Subprogram; 87 return cast<DILocalScope>(Scope)->getSubprogram(); 88 } 89 90 namespace { 91 92 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 93 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 94 // assuming \p Stride elements between start two consecutive vectors. 95 // \p Stride must be >= \p NumElements. 96 // For column-major matrixes, the function computes the address of a column 97 // vectors and \p NumElements must be set to the number of elements in a column 98 // (= number of rows of the matrix). For row-major matrixes, the function 99 // computes the address of a row vector and \p NumElements must be set to the 100 // number of elements in a column (= number of columns of the matrix). 101 // 102 // Consider a 4x4 matrix in column-mjaor layout like below 103 // 104 // 0 1 2 3 105 // 0 v_0_0 v_0_1 v_0_2 v_0_3 106 // 1 v_1_0 v_1_1 v_1_2 v_1_3 107 // 2 v_2_0 v_2_1 v_2_2 v_2_3 108 // 3 v_3_0 v_3_1 v_3_2 v_3_3 109 110 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 111 // we need a pointer to the first element of the submatrix as base pointer. 112 // Then we can use computeVectorAddr to compute the addresses for the columns 113 // of the sub-matrix. 114 // 115 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 116 // -> just returns Base 117 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 118 // -> returns Base + (1 * 4) 119 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 120 // -> returns Base + (2 * 4) 121 // 122 // The graphic below illustrates the number of elements in a column (marked 123 // with |) and the number of skipped elements (marked with }). 124 // 125 // v_0_0 v_0_1 {v_0_2 {v_0_3 126 // Base Col 1 Col 2 127 // | | | 128 // v_1_0 |v_1_1 |v_1_2 |v_1_3 129 // v_2_0 |v_2_1 |v_2_2 |v_2_3 130 // v_3_0 {v_3_1 {v_3_2 v_3_3 131 // 132 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 133 unsigned NumElements, Type *EltType, 134 IRBuilder<> &Builder) { 135 136 assert((!isa<ConstantInt>(Stride) || 137 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 138 "Stride must be >= the number of elements in the result vector."); 139 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 140 141 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 142 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 143 144 // Get pointer to the start of the selected vector. Skip GEP creation, 145 // if we select vector 0. 146 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 147 VecStart = BasePtr; 148 else 149 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 150 151 // Cast elementwise vector start pointer to a pointer to a vector 152 // (EltType x NumElements)*. 153 auto *VecType = FixedVectorType::get(EltType, NumElements); 154 Type *VecPtrType = PointerType::get(VecType, AS); 155 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 156 } 157 158 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 159 /// 160 /// Currently, the lowering for each matrix intrinsic is done as follows: 161 /// 1. Propagate the shape information from intrinsics to connected 162 /// instructions. 163 /// 2. Lower instructions with shape information (assuming column-major layout). 164 /// The lowering works similarly using row-major layout. 165 /// 2.1. Get column vectors for each argument. If we already lowered the 166 /// definition of an argument, use the produced column vectors directly. 167 /// If not, split the operand vector containing an embedded matrix into 168 /// a set of column vectors, 169 /// 2.2. Lower the instruction in terms of column major operations, which 170 /// yields a set of column vectors containing result matrix. Note that we 171 /// lower all instructions that have shape information. Besides the 172 /// intrinsics, this includes stores for example. 173 /// 2.3. Update uses of the lowered instruction. If we have shape information 174 /// for a user, there is nothing to do, as we will look up the result 175 /// column matrix when lowering the user. For other uses, we embed the 176 /// result matrix in a flat vector and update the use. 177 /// 2.4. Cache the result column matrix for the instruction we lowered 178 /// 3. After we lowered all instructions in a function, remove the now 179 /// obsolete instructions. 180 /// 181 class LowerMatrixIntrinsics { 182 Function &Func; 183 const DataLayout &DL; 184 const TargetTransformInfo &TTI; 185 AliasAnalysis &AA; 186 DominatorTree &DT; 187 LoopInfo &LI; 188 OptimizationRemarkEmitter &ORE; 189 190 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 191 struct OpInfoTy { 192 /// Number of stores emitted to generate this matrix. 193 unsigned NumStores = 0; 194 /// Number of loads emitted to generate this matrix. 195 unsigned NumLoads = 0; 196 /// Number of compute operations emitted to generate this matrix. 197 unsigned NumComputeOps = 0; 198 199 OpInfoTy &operator+=(const OpInfoTy &RHS) { 200 NumStores += RHS.NumStores; 201 NumLoads += RHS.NumLoads; 202 NumComputeOps += RHS.NumComputeOps; 203 return *this; 204 } 205 }; 206 207 /// Wrapper class representing a matrix as a set of vectors, either in row or 208 /// column major layout. All vectors must have the same vector type. 209 class MatrixTy { 210 SmallVector<Value *, 16> Vectors; 211 212 OpInfoTy OpInfo; 213 214 bool IsColumnMajor = true; 215 216 public: 217 MatrixTy() 218 : Vectors(), 219 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 220 MatrixTy(ArrayRef<Value *> Vectors) 221 : Vectors(Vectors.begin(), Vectors.end()), 222 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 223 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 224 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 225 226 unsigned D = isColumnMajor() ? NumColumns : NumRows; 227 for (unsigned J = 0; J < D; ++J) 228 addVector(UndefValue::get(FixedVectorType::get( 229 EltTy, isColumnMajor() ? NumRows : NumColumns))); 230 } 231 232 Value *getVector(unsigned i) const { return Vectors[i]; } 233 Value *getColumn(unsigned i) const { 234 assert(isColumnMajor() && "only supported for column-major matrixes"); 235 return Vectors[i]; 236 } 237 Value *getRow(unsigned i) const { 238 assert(!isColumnMajor() && "only supported for row-major matrixes"); 239 return Vectors[i]; 240 } 241 242 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 243 244 Type *getElementType() { return getVectorTy()->getElementType(); } 245 246 unsigned getNumVectors() const { 247 if (isColumnMajor()) 248 return getNumColumns(); 249 return getNumRows(); 250 } 251 252 unsigned getNumColumns() const { 253 if (isColumnMajor()) 254 return Vectors.size(); 255 else { 256 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 257 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 258 } 259 } 260 unsigned getNumRows() const { 261 if (isColumnMajor()) { 262 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 263 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 264 } else 265 return Vectors.size(); 266 } 267 268 void addVector(Value *V) { Vectors.push_back(V); } 269 VectorType *getColumnTy() { 270 assert(isColumnMajor() && "only supported for column-major matrixes"); 271 return getVectorTy(); 272 } 273 274 VectorType *getVectorTy() { 275 return cast<VectorType>(Vectors[0]->getType()); 276 } 277 278 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 279 assert(isColumnMajor() && 280 "columns() only supported for column-major matrixes"); 281 return make_range(Vectors.begin(), Vectors.end()); 282 } 283 284 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 285 return make_range(Vectors.begin(), Vectors.end()); 286 } 287 288 /// Embed the vectors of the matrix into a flat vector by concatenating 289 /// them. 290 Value *embedInVector(IRBuilder<> &Builder) const { 291 return Vectors.size() == 1 ? Vectors[0] 292 : concatenateVectors(Builder, Vectors); 293 } 294 295 MatrixTy &addNumLoads(unsigned N) { 296 OpInfo.NumLoads += N; 297 return *this; 298 } 299 300 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 301 302 MatrixTy &addNumStores(unsigned N) { 303 OpInfo.NumStores += N; 304 return *this; 305 } 306 307 MatrixTy &addNumComputeOps(unsigned N) { 308 OpInfo.NumComputeOps += N; 309 return *this; 310 } 311 312 unsigned getNumStores() const { return OpInfo.NumStores; } 313 unsigned getNumLoads() const { return OpInfo.NumLoads; } 314 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 315 316 const OpInfoTy &getOpInfo() const { return OpInfo; } 317 318 bool isColumnMajor() const { return IsColumnMajor; } 319 320 unsigned getStride() const { 321 if (isColumnMajor()) 322 return getNumRows(); 323 return getNumColumns(); 324 } 325 326 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 327 /// matrix is column-major, the result vector is extracted from a column 328 /// vector, otherwise from a row vector. 329 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 330 IRBuilder<> &Builder) const { 331 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 332 Value *Undef = UndefValue::get(Vec->getType()); 333 return Builder.CreateShuffleVector( 334 Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 335 "block"); 336 } 337 }; 338 339 struct ShapeInfo { 340 unsigned NumRows; 341 unsigned NumColumns; 342 343 bool IsColumnMajor; 344 345 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 346 : NumRows(NumRows), NumColumns(NumColumns), 347 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 348 349 ShapeInfo(Value *NumRows, Value *NumColumns) 350 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 351 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 352 353 bool operator==(const ShapeInfo &other) { 354 return NumRows == other.NumRows && NumColumns == other.NumColumns; 355 } 356 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 357 358 /// Returns true if shape-information is defined, meaning both dimensions 359 /// are != 0. 360 operator bool() const { 361 assert(NumRows == 0 || NumColumns != 0); 362 return NumRows != 0; 363 } 364 365 unsigned getStride() const { 366 if (IsColumnMajor) 367 return NumRows; 368 return NumColumns; 369 } 370 371 unsigned getNumVectors() const { 372 if (IsColumnMajor) 373 return NumColumns; 374 return NumRows; 375 } 376 }; 377 378 /// Maps instructions to their shape information. The shape information 379 /// describes the shape to be used while lowering. This matches the shape of 380 /// the result value of the instruction, with the only exceptions being store 381 /// instructions and the matrix_column_major_store intrinsics. For those, the 382 /// shape information indicates that those instructions should be lowered 383 /// using shape information as well. 384 DenseMap<Value *, ShapeInfo> ShapeMap; 385 386 /// List of instructions to remove. While lowering, we are not replacing all 387 /// users of a lowered instruction, if shape information is available and 388 /// those need to be removed after we finished lowering. 389 SmallVector<Instruction *, 16> ToRemove; 390 391 /// Map from instructions to their produced column matrix. 392 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 393 394 public: 395 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 396 AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, 397 OptimizationRemarkEmitter &ORE) 398 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 399 LI(LI), ORE(ORE) {} 400 401 unsigned getNumOps(Type *VT) { 402 assert(isa<VectorType>(VT) && "Expected vector type"); 403 return getNumOps(VT->getScalarType(), 404 cast<VectorType>(VT)->getNumElements()); 405 } 406 407 // 408 /// Return the estimated number of vector ops required for an operation on 409 /// \p VT * N. 410 unsigned getNumOps(Type *ST, unsigned N) { 411 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 412 double(TTI.getRegisterBitWidth(true))); 413 } 414 415 /// Return the set of vectors that a matrix value is lowered to. 416 /// 417 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 418 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 419 /// into vectors. 420 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 421 IRBuilder<> &Builder) { 422 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 423 assert(VType && "MatrixVal must be a vector type"); 424 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 425 "The vector size must match the number of matrix elements"); 426 427 // Check if we lowered MatrixVal using shape information. In that case, 428 // return the existing matrix, if it matches the requested shape 429 // information. If there is a mis-match, embed the result in a flat 430 // vector and split it later. 431 auto Found = Inst2ColumnMatrix.find(MatrixVal); 432 if (Found != Inst2ColumnMatrix.end()) { 433 MatrixTy &M = Found->second; 434 // Return the found matrix, if its shape matches the requested shape 435 // information 436 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 437 return M; 438 439 MatrixVal = M.embedInVector(Builder); 440 } 441 442 // Otherwise split MatrixVal. 443 SmallVector<Value *, 16> SplitVecs; 444 Value *Undef = UndefValue::get(VType); 445 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 446 MaskStart += SI.getStride()) { 447 Value *V = Builder.CreateShuffleVector( 448 MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), 449 "split"); 450 SplitVecs.push_back(V); 451 } 452 453 return {SplitVecs}; 454 } 455 456 /// If \p V already has a known shape return false. Otherwise set the shape 457 /// for instructions that support it. 458 bool setShapeInfo(Value *V, ShapeInfo Shape) { 459 assert(Shape && "Shape not set"); 460 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 461 return false; 462 463 auto SIter = ShapeMap.find(V); 464 if (SIter != ShapeMap.end()) { 465 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 466 << SIter->second.NumRows << " " 467 << SIter->second.NumColumns << " for " << *V << "\n"); 468 return false; 469 } 470 471 ShapeMap.insert({V, Shape}); 472 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 473 << " for " << *V << "\n"); 474 return true; 475 } 476 477 bool isUniformShape(Value *V) { 478 Instruction *I = dyn_cast<Instruction>(V); 479 if (!I) 480 return true; 481 482 switch (I->getOpcode()) { 483 case Instruction::FAdd: 484 case Instruction::FSub: 485 case Instruction::FMul: // Scalar multiply. 486 case Instruction::Add: 487 case Instruction::Mul: 488 case Instruction::Sub: 489 return true; 490 default: 491 return false; 492 } 493 } 494 495 /// Returns true if shape information can be used for \p V. The supported 496 /// instructions must match the instructions that can be lowered by this pass. 497 bool supportsShapeInfo(Value *V) { 498 Instruction *Inst = dyn_cast<Instruction>(V); 499 if (!Inst) 500 return false; 501 502 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 503 if (II) 504 switch (II->getIntrinsicID()) { 505 case Intrinsic::matrix_multiply: 506 case Intrinsic::matrix_transpose: 507 case Intrinsic::matrix_column_major_load: 508 case Intrinsic::matrix_column_major_store: 509 return true; 510 default: 511 return false; 512 } 513 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 514 } 515 516 /// Propagate the shape information of instructions to their users. 517 /// The work list contains instructions for which we can compute the shape, 518 /// either based on the information provided by matrix intrinsics or known 519 /// shapes of operands. 520 SmallVector<Instruction *, 32> 521 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 522 SmallVector<Instruction *, 32> NewWorkList; 523 // Pop an element for which we guaranteed to have at least one of the 524 // operand shapes. Add the shape for this and then add users to the work 525 // list. 526 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 527 while (!WorkList.empty()) { 528 Instruction *Inst = WorkList.back(); 529 WorkList.pop_back(); 530 531 // New entry, set the value and insert operands 532 bool Propagate = false; 533 534 Value *MatrixA; 535 Value *MatrixB; 536 Value *M; 537 Value *N; 538 Value *K; 539 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 540 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 541 m_Value(N), m_Value(K)))) { 542 Propagate = setShapeInfo(Inst, {M, K}); 543 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 544 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 545 // Flip dimensions. 546 Propagate = setShapeInfo(Inst, {N, M}); 547 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 548 m_Value(MatrixA), m_Value(), m_Value(), 549 m_Value(), m_Value(M), m_Value(N)))) { 550 Propagate = setShapeInfo(Inst, {N, M}); 551 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 552 m_Value(), m_Value(), m_Value(), m_Value(M), 553 m_Value(N)))) { 554 Propagate = setShapeInfo(Inst, {M, N}); 555 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 556 auto OpShape = ShapeMap.find(MatrixA); 557 if (OpShape != ShapeMap.end()) 558 setShapeInfo(Inst, OpShape->second); 559 continue; 560 } else if (isUniformShape(Inst)) { 561 // Find the first operand that has a known shape and use that. 562 for (auto &Op : Inst->operands()) { 563 auto OpShape = ShapeMap.find(Op.get()); 564 if (OpShape != ShapeMap.end()) { 565 Propagate |= setShapeInfo(Inst, OpShape->second); 566 break; 567 } 568 } 569 } 570 571 if (Propagate) { 572 NewWorkList.push_back(Inst); 573 for (auto *User : Inst->users()) 574 if (ShapeMap.count(User) == 0) 575 WorkList.push_back(cast<Instruction>(User)); 576 } 577 } 578 579 return NewWorkList; 580 } 581 582 /// Propagate the shape to operands of instructions with shape information. 583 /// \p Worklist contains the instruction for which we already know the shape. 584 SmallVector<Instruction *, 32> 585 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 586 SmallVector<Instruction *, 32> NewWorkList; 587 588 auto pushInstruction = [](Value *V, 589 SmallVectorImpl<Instruction *> &WorkList) { 590 Instruction *I = dyn_cast<Instruction>(V); 591 if (I) 592 WorkList.push_back(I); 593 }; 594 // Pop an element with known shape. Traverse the operands, if their shape 595 // derives from the result shape and is unknown, add it and add them to the 596 // worklist. 597 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 598 while (!WorkList.empty()) { 599 Value *V = WorkList.back(); 600 WorkList.pop_back(); 601 602 size_t BeforeProcessingV = WorkList.size(); 603 if (!isa<Instruction>(V)) 604 continue; 605 606 Value *MatrixA; 607 Value *MatrixB; 608 Value *M; 609 Value *N; 610 Value *K; 611 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 612 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 613 m_Value(N), m_Value(K)))) { 614 if (setShapeInfo(MatrixA, {M, N})) 615 pushInstruction(MatrixA, WorkList); 616 617 if (setShapeInfo(MatrixB, {N, K})) 618 pushInstruction(MatrixB, WorkList); 619 620 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 621 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 622 // Flip dimensions. 623 if (setShapeInfo(MatrixA, {M, N})) 624 pushInstruction(MatrixA, WorkList); 625 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 626 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 627 m_Value(M), m_Value(N)))) { 628 if (setShapeInfo(MatrixA, {M, N})) { 629 pushInstruction(MatrixA, WorkList); 630 } 631 } else if (isa<LoadInst>(V) || 632 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 633 // Nothing to do, no matrix input. 634 } else if (isa<StoreInst>(V)) { 635 // Nothing to do. We forward-propagated to this so we would just 636 // backward propagate to an instruction with an already known shape. 637 } else if (isUniformShape(V)) { 638 // Propagate to all operands. 639 ShapeInfo Shape = ShapeMap[V]; 640 for (Use &U : cast<Instruction>(V)->operands()) { 641 if (setShapeInfo(U.get(), Shape)) 642 pushInstruction(U.get(), WorkList); 643 } 644 } 645 // After we discovered new shape info for new instructions in the 646 // worklist, we use their users as seeds for the next round of forward 647 // propagation. 648 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 649 for (User *U : WorkList[I]->users()) 650 if (isa<Instruction>(U) && V != U) 651 NewWorkList.push_back(cast<Instruction>(U)); 652 } 653 return NewWorkList; 654 } 655 656 bool Visit() { 657 if (EnableShapePropagation) { 658 SmallVector<Instruction *, 32> WorkList; 659 660 // Initially only the shape of matrix intrinsics is known. 661 // Initialize the work list with ops carrying shape information. 662 for (BasicBlock &BB : Func) 663 for (Instruction &Inst : BB) { 664 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 665 if (!II) 666 continue; 667 668 switch (II->getIntrinsicID()) { 669 case Intrinsic::matrix_multiply: 670 case Intrinsic::matrix_transpose: 671 case Intrinsic::matrix_column_major_load: 672 case Intrinsic::matrix_column_major_store: 673 WorkList.push_back(&Inst); 674 break; 675 default: 676 break; 677 } 678 } 679 // Propagate shapes until nothing changes any longer. 680 while (!WorkList.empty()) { 681 WorkList = propagateShapeForward(WorkList); 682 WorkList = propagateShapeBackward(WorkList); 683 } 684 } 685 686 bool Changed = false; 687 SmallVector<CallInst *, 16> MaybeFusableInsts; 688 SmallVector<Instruction *, 16> MatrixInsts; 689 690 // First, collect all instructions with shape information and candidates for 691 // fusion (currently only matrix multiplies). 692 ReversePostOrderTraversal<Function *> RPOT(&Func); 693 for (auto *BB : RPOT) 694 for (Instruction &I : *BB) { 695 if (ShapeMap.find(&I) == ShapeMap.end()) 696 continue; 697 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 698 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 699 MatrixInsts.push_back(&I); 700 } 701 702 // Second, try to fuse candidates. 703 SmallPtrSet<Instruction *, 16> FusedInsts; 704 for (CallInst *CI : MaybeFusableInsts) 705 LowerMatrixMultiplyFused(CI, FusedInsts); 706 Changed = !FusedInsts.empty(); 707 708 // Third, lower remaining instructions with shape information. 709 for (Instruction *Inst : MatrixInsts) { 710 if (FusedInsts.count(Inst)) 711 continue; 712 713 IRBuilder<> Builder(Inst); 714 715 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 716 Changed |= VisitCallInst(CInst); 717 718 Value *Op1; 719 Value *Op2; 720 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 721 Changed |= VisitBinaryOperator(BinOp); 722 if (match(Inst, m_Load(m_Value(Op1)))) 723 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 724 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 725 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 726 } 727 728 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); 729 RemarkGen.emitRemarks(); 730 731 for (Instruction *Inst : reverse(ToRemove)) 732 Inst->eraseFromParent(); 733 734 return Changed; 735 } 736 737 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 738 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 739 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 740 Type *EltPtrType = PointerType::get(EltType, AS); 741 return Builder.CreatePointerCast(BasePtr, EltPtrType); 742 } 743 744 /// Replace intrinsic calls 745 bool VisitCallInst(CallInst *Inst) { 746 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 747 return false; 748 749 switch (Inst->getCalledFunction()->getIntrinsicID()) { 750 case Intrinsic::matrix_multiply: 751 LowerMultiply(Inst); 752 break; 753 case Intrinsic::matrix_transpose: 754 LowerTranspose(Inst); 755 break; 756 case Intrinsic::matrix_column_major_load: 757 LowerColumnMajorLoad(Inst); 758 break; 759 case Intrinsic::matrix_column_major_store: 760 LowerColumnMajorStore(Inst); 761 break; 762 default: 763 return false; 764 } 765 return true; 766 } 767 768 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 769 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 770 /// ConstantInt, reduce the initial alignment based on the byte offset. For 771 /// non-ConstantInt strides, return the common alignment of the initial 772 /// alignment and the element size in bytes. 773 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 774 MaybeAlign A) const { 775 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 776 if (Idx == 0) 777 return InitialAlign; 778 779 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 780 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 781 uint64_t StrideInBytes = 782 ConstStride->getZExtValue() * ElementSizeInBits / 8; 783 return commonAlignment(InitialAlign, Idx * StrideInBytes); 784 } 785 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 786 } 787 788 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 789 /// vectors. 790 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 791 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 792 auto VType = cast<VectorType>(Ty); 793 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 794 MatrixTy Result; 795 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 796 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, 797 Shape.getStride(), VType->getElementType(), 798 Builder); 799 Value *Vector = Builder.CreateAlignedLoad( 800 GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), 801 IsVolatile, "col.load"); 802 803 Result.addVector(Vector); 804 } 805 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 806 Result.getNumVectors()); 807 } 808 809 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 810 /// starting at \p MatrixPtr[I][J]. 811 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 812 ShapeInfo MatrixShape, Value *I, Value *J, 813 ShapeInfo ResultShape, Type *EltTy, 814 IRBuilder<> &Builder) { 815 816 Value *Offset = Builder.CreateAdd( 817 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 818 819 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 820 Value *EltPtr = 821 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 822 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 823 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 824 ResultShape.NumColumns); 825 Type *TilePtrTy = PointerType::get(TileTy, AS); 826 Value *TilePtr = 827 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 828 829 return loadMatrix(TileTy, TilePtr, Align, 830 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 831 ResultShape, Builder); 832 } 833 834 /// Lower a load instruction with shape information. 835 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 836 bool IsVolatile, ShapeInfo Shape) { 837 IRBuilder<> Builder(Inst); 838 finalizeLowering(Inst, 839 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 840 Shape, Builder), 841 Builder); 842 } 843 844 /// Lowers llvm.matrix.column.major.load. 845 /// 846 /// The intrinsic loads a matrix from memory using a stride between columns. 847 void LowerColumnMajorLoad(CallInst *Inst) { 848 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 849 "Intrinsic only supports column-major layout!"); 850 Value *Ptr = Inst->getArgOperand(0); 851 Value *Stride = Inst->getArgOperand(1); 852 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 853 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 854 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 855 } 856 857 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 858 /// MatrixPtr[I][J]. 859 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 860 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 861 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 862 Value *Offset = Builder.CreateAdd( 863 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 864 865 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 866 Value *EltPtr = 867 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 868 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 869 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 870 StoreVal.getNumColumns()); 871 Type *TilePtrTy = PointerType::get(TileTy, AS); 872 Value *TilePtr = 873 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 874 875 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 876 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 877 } 878 879 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 880 /// vectors. 881 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 882 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 883 IRBuilder<> &Builder) { 884 auto VType = cast<VectorType>(Ty); 885 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 886 for (auto Vec : enumerate(StoreVal.vectors())) { 887 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), 888 Stride, StoreVal.getStride(), 889 VType->getElementType(), Builder); 890 Builder.CreateAlignedStore(Vec.value(), GEP, 891 getAlignForIndex(Vec.index(), Stride, 892 VType->getElementType(), 893 MAlign), 894 IsVolatile); 895 } 896 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 897 StoreVal.getNumVectors()); 898 } 899 900 /// Lower a store instruction with shape information. 901 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 902 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 903 IRBuilder<> Builder(Inst); 904 auto StoreVal = getMatrix(Matrix, Shape, Builder); 905 finalizeLowering(Inst, 906 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 907 IsVolatile, Builder), 908 Builder); 909 } 910 911 /// Lowers llvm.matrix.column.major.store. 912 /// 913 /// The intrinsic store a matrix back memory using a stride between columns. 914 void LowerColumnMajorStore(CallInst *Inst) { 915 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 916 "Intrinsic only supports column-major layout!"); 917 Value *Matrix = Inst->getArgOperand(0); 918 Value *Ptr = Inst->getArgOperand(1); 919 Value *Stride = Inst->getArgOperand(2); 920 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 921 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 922 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 923 } 924 925 // Set elements I..I+NumElts-1 to Block 926 Value *insertVector(Value *Col, unsigned I, Value *Block, 927 IRBuilder<> &Builder) { 928 929 // First, bring Block to the same size as Col 930 unsigned BlockNumElts = 931 cast<VectorType>(Block->getType())->getNumElements(); 932 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 933 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 934 935 Value *Undef = UndefValue::get(Block->getType()); 936 Block = Builder.CreateShuffleVector( 937 Block, Undef, 938 createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 939 940 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 941 // 8, 4, 5, 6 942 SmallVector<int, 16> Mask; 943 unsigned i; 944 for (i = 0; i < I; i++) 945 Mask.push_back(i); 946 947 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 948 for (; i < I + BlockNumElts; i++) 949 Mask.push_back(i - I + VecNumElts); 950 951 for (; i < VecNumElts; i++) 952 Mask.push_back(i); 953 954 return Builder.CreateShuffleVector(Col, Block, Mask); 955 } 956 957 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 958 IRBuilder<> &Builder, bool AllowContraction, 959 unsigned &NumComputeOps) { 960 NumComputeOps += getNumOps(A->getType()); 961 if (!Sum) 962 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 963 964 if (UseFPOp) { 965 if (AllowContraction) { 966 // Use fmuladd for floating point operations and let the backend decide 967 // if that's profitable. 968 Function *FMulAdd = Intrinsic::getDeclaration( 969 Func.getParent(), Intrinsic::fmuladd, A->getType()); 970 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 971 } 972 NumComputeOps += getNumOps(A->getType()); 973 Value *Mul = Builder.CreateFMul(A, B); 974 return Builder.CreateFAdd(Sum, Mul); 975 } 976 977 NumComputeOps += getNumOps(A->getType()); 978 Value *Mul = Builder.CreateMul(A, B); 979 return Builder.CreateAdd(Sum, Mul); 980 } 981 982 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 983 /// users with shape information, there's nothing to do: the will use the 984 /// cached value when they are lowered. For other users, \p Matrix is 985 /// flattened and the uses are updated to use it. Also marks \p Inst for 986 /// deletion. 987 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 988 IRBuilder<> &Builder) { 989 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 990 991 ToRemove.push_back(Inst); 992 Value *Flattened = nullptr; 993 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 994 Use &U = *I++; 995 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 996 if (!Flattened) 997 Flattened = Matrix.embedInVector(Builder); 998 U.set(Flattened); 999 } 1000 } 1001 } 1002 1003 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1004 /// addition. 1005 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1006 const MatrixTy &B, bool AllowContraction, 1007 IRBuilder<> &Builder, bool isTiled) { 1008 const unsigned VF = std::max<unsigned>( 1009 TTI.getRegisterBitWidth(true) / 1010 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1011 1U); 1012 unsigned R = Result.getNumRows(); 1013 unsigned C = Result.getNumColumns(); 1014 unsigned M = A.getNumColumns(); 1015 1016 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1017 assert(A.isColumnMajor() == B.isColumnMajor() && 1018 Result.isColumnMajor() == A.isColumnMajor() && 1019 "operands must agree on matrix layout"); 1020 unsigned NumComputeOps = 0; 1021 if (A.isColumnMajor()) { 1022 // Multiply columns from the first operand with scalars from the second 1023 // operand. Then move along the K axes and accumulate the columns. With 1024 // this the adds can be vectorized without reassociation. 1025 for (unsigned J = 0; J < C; ++J) { 1026 unsigned BlockSize = VF; 1027 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1028 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1029 1030 for (unsigned I = 0; I < R; I += BlockSize) { 1031 // Gradually lower the vectorization factor to cover the remainder. 1032 while (I + BlockSize > R) 1033 BlockSize /= 2; 1034 1035 Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) 1036 : nullptr; 1037 for (unsigned K = 0; K < M; ++K) { 1038 Value *L = A.extractVector(I, K, BlockSize, Builder); 1039 Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); 1040 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1041 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1042 Result.getElementType()->isFloatingPointTy(), 1043 Builder, AllowContraction, NumComputeOps); 1044 } 1045 Result.setVector(J, 1046 insertVector(Result.getVector(J), I, Sum, Builder)); 1047 } 1048 } 1049 } else { 1050 // Multiply rows from the second operand with scalars from the first 1051 // operand. Then move along the K axes and accumulate the rows. With this 1052 // the adds can be vectorized without reassociation. 1053 for (unsigned I = 0; I < R; ++I) { 1054 unsigned BlockSize = VF; 1055 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1056 for (unsigned J = 0; J < C; J += BlockSize) { 1057 // Gradually lower the vectorization factor to cover the remainder. 1058 while (J + BlockSize > C) 1059 BlockSize /= 2; 1060 1061 Value *Sum = nullptr; 1062 for (unsigned K = 0; K < M; ++K) { 1063 Value *R = B.extractVector(K, J, BlockSize, Builder); 1064 Value *LH = Builder.CreateExtractElement(A.getVector(I), K); 1065 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1066 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1067 IsFP, Builder, AllowContraction, NumComputeOps); 1068 } 1069 Result.setVector(I, 1070 insertVector(Result.getVector(I), J, Sum, Builder)); 1071 } 1072 } 1073 } 1074 Result.addNumComputeOps(NumComputeOps); 1075 } 1076 1077 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1078 /// copying it to a new location. This new or otherwise the original location 1079 /// is returned. 1080 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1081 CallInst *MatMul) { 1082 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1083 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1084 1085 AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); 1086 1087 // If we can statically determine noalias we're good. 1088 if (!LdAliased) 1089 return Load->getPointerOperand(); 1090 1091 // Create code to check if the memory locations of the Load and Store 1092 // overlap and if they do, copy Load's operand to a new buffer. 1093 1094 // First, create new blocks for 2n part of the check and the copy. 1095 BasicBlock *Check0 = MatMul->getParent(); 1096 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1097 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1098 // as we adjust Check0 and Check1's branches. 1099 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1100 for (BasicBlock *Succ : successors(Check0)) 1101 DTUpdates.push_back({DT.Delete, Check0, Succ}); 1102 1103 BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1104 nullptr, "alias_cont"); 1105 BasicBlock *Copy = 1106 SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); 1107 BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1108 nullptr, "no_alias"); 1109 1110 // Check if the loaded memory location begins before the end of the store 1111 // location. If the condition holds, they might overlap, otherwise they are 1112 // guaranteed to not overlap. 1113 IRBuilder<> Builder(MatMul); 1114 Check0->getTerminator()->eraseFromParent(); 1115 Builder.SetInsertPoint(Check0); 1116 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1117 Value *StoreBegin = Builder.CreatePtrToInt( 1118 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1119 Value *StoreEnd = Builder.CreateAdd( 1120 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1121 "store.end", true, true); 1122 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1123 IntPtrTy, "load.begin"); 1124 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1125 Fusion); 1126 1127 // Check if the store begins before the end of the load location. If the 1128 // condition holds, they alias, otherwise they are guaranteed to not 1129 // overlap. 1130 Check1->getTerminator()->eraseFromParent(); 1131 Builder.SetInsertPoint(Check1, Check1->begin()); 1132 Value *LoadEnd = Builder.CreateAdd( 1133 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1134 "load.end", true, true); 1135 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1136 Fusion); 1137 1138 // Copy load operand to new alloca. 1139 Builder.SetInsertPoint(Copy, Copy->begin()); 1140 AllocaInst *NewLd = 1141 Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 1142 Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 1143 Load->getPointerOperand(), Load->getAlign(), 1144 LoadLoc.Size.getValue()); 1145 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1146 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1147 PHI->addIncoming(Load->getPointerOperand(), Check0); 1148 PHI->addIncoming(Load->getPointerOperand(), Check1); 1149 PHI->addIncoming(NewLd, Copy); 1150 1151 // Adjust DT. 1152 DTUpdates.push_back({DT.Insert, Check0, Check1}); 1153 DTUpdates.push_back({DT.Insert, Check0, Fusion}); 1154 DTUpdates.push_back({DT.Insert, Check1, Copy}); 1155 DTUpdates.push_back({DT.Insert, Check1, Fusion}); 1156 DT.applyUpdates(DTUpdates); 1157 return PHI; 1158 } 1159 1160 bool isFusionProfitable(CallInst *MatMul) { 1161 if (ForceFusion) 1162 return true; 1163 1164 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1165 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1166 1167 const unsigned R = LShape.NumRows; 1168 const unsigned C = RShape.NumColumns; 1169 const unsigned M = LShape.NumColumns; 1170 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1171 1172 const unsigned VF = 1173 std::max<unsigned>(TTI.getRegisterBitWidth(true) / 1174 EltType->getPrimitiveSizeInBits().getFixedSize(), 1175 1U); 1176 1177 // Cost model for tiling 1178 // 1179 // For tiling to be beneficial, we need reuse either along the R or 1180 // the C axis. We vectorize along the R axis so that means at least 1181 // 3 elements. 1182 // TODO: Also consider cost of copying if operands alias. 1183 if (R <= VF && C == 1) 1184 return false; 1185 // Then we need enough elements to exceed the number of vector 1186 // registers we have. Note that this is an oversimplification since 1187 // fusing also takes some extra loads which may exceed the number of 1188 // reloads necessary. 1189 unsigned Op0Regs = (R + VF - 1) / VF * M; 1190 unsigned Op1Regs = (M + VF - 1) / VF * C; 1191 return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 1192 } 1193 1194 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1195 MatrixTy Res; 1196 auto *ColumType = FixedVectorType::get(EltType, R); 1197 for (unsigned I = 0; I < C; ++I) 1198 Res.addVector(ConstantAggregateZero::get(ColumType)); 1199 return Res; 1200 } 1201 1202 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1203 StoreInst *Store, 1204 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1205 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1206 "Tiling only supported for column-major matrixes at the moment!"); 1207 if (!isFusionProfitable(MatMul)) 1208 return; 1209 1210 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1211 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1212 1213 const unsigned R = LShape.NumRows; 1214 const unsigned C = RShape.NumColumns; 1215 const unsigned M = LShape.NumColumns; 1216 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1217 1218 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1219 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1220 Value *CPtr = Store->getPointerOperand(); 1221 1222 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1223 MatMul->hasAllowContract()); 1224 IRBuilder<> Builder(Store); 1225 for (unsigned J = 0; J < C; J += TileSize) 1226 for (unsigned I = 0; I < R; I += TileSize) { 1227 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1228 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1229 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1230 1231 for (unsigned K = 0; K < M; K += TileSize) { 1232 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1233 MatrixTy A = 1234 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1235 LShape, Builder.getInt64(I), Builder.getInt64(K), 1236 {TileR, TileM}, EltType, Builder); 1237 MatrixTy B = 1238 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1239 RShape, Builder.getInt64(K), Builder.getInt64(J), 1240 {TileM, TileC}, EltType, Builder); 1241 emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); 1242 } 1243 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1244 Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); 1245 } 1246 1247 // Mark eliminated instructions as fused and remove them. 1248 FusedInsts.insert(Store); 1249 FusedInsts.insert(MatMul); 1250 Store->eraseFromParent(); 1251 MatMul->eraseFromParent(); 1252 if (LoadOp0->hasNUses(0)) { 1253 FusedInsts.insert(LoadOp0); 1254 LoadOp0->eraseFromParent(); 1255 } 1256 if (LoadOp1->hasNUses(0)) { 1257 FusedInsts.insert(LoadOp1); 1258 LoadOp1->eraseFromParent(); 1259 } 1260 } 1261 1262 /// Try to lower matrix multiply chains by fusing operations. 1263 /// 1264 /// Currently we only lower {ld, ld} -> matmul -> st chains. 1265 // 1266 /// No need to return a MatrixTy object for the result of the operation, since 1267 /// the single store user will be lowered as part of this. Instructions that 1268 /// are completely eliminated by fusion are added to \p FusedInsts. 1269 void LowerMatrixMultiplyFused(CallInst *MatMul, 1270 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1271 if (!FuseMatrix || !MatMul->hasOneUse() || 1272 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1273 return; 1274 1275 auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); 1276 auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); 1277 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1278 if (LoadOp0 && LoadOp1 && Store) { 1279 // The store address must dominate the MatMul instruction, otherwise 1280 // we create invalid IR. 1281 // FIXME: See if we can hoist the store address computation. 1282 auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1283 if (AddrI && (!DT.dominates(AddrI, MatMul))) 1284 return; 1285 1286 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1287 return; 1288 } 1289 } 1290 1291 /// Lowers llvm.matrix.multiply. 1292 void LowerMultiply(CallInst *MatMul) { 1293 IRBuilder<> Builder(MatMul); 1294 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1295 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1296 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1297 1298 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1299 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1300 1301 const unsigned R = LShape.NumRows; 1302 const unsigned C = RShape.NumColumns; 1303 assert(LShape.NumColumns == RShape.NumRows); 1304 1305 // Initialize the output 1306 MatrixTy Result(R, C, EltType); 1307 1308 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1309 MatMul->hasAllowContract()); 1310 emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); 1311 finalizeLowering(MatMul, Result, Builder); 1312 } 1313 1314 /// Lowers llvm.matrix.transpose. 1315 void LowerTranspose(CallInst *Inst) { 1316 MatrixTy Result; 1317 IRBuilder<> Builder(Inst); 1318 Value *InputVal = Inst->getArgOperand(0); 1319 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1320 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1321 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1322 1323 const unsigned NewNumVecs = 1324 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1325 const unsigned NewNumElts = 1326 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1327 1328 for (unsigned I = 0; I < NewNumVecs; ++I) { 1329 // Build a single result vector. First initialize it. 1330 Value *ResultVector = UndefValue::get( 1331 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1332 // Go through the old elements and insert it into the resulting vector. 1333 for (auto J : enumerate(InputMatrix.vectors())) { 1334 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1335 // Row and column indices are transposed. 1336 ResultVector = 1337 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1338 } 1339 Result.addVector(ResultVector); 1340 } 1341 1342 // TODO: Improve estimate of operations needed for transposes. Currently we 1343 // just count the insertelement/extractelement instructions, but do not 1344 // account for later simplifications/combines. 1345 finalizeLowering( 1346 Inst, 1347 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 1348 Builder); 1349 } 1350 1351 /// Lower load instructions, if shape information is available. 1352 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1353 auto I = ShapeMap.find(Inst); 1354 if (I == ShapeMap.end()) 1355 return false; 1356 1357 LowerLoad(Inst, Ptr, Inst->getAlign(), 1358 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1359 I->second); 1360 return true; 1361 } 1362 1363 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1364 IRBuilder<> &Builder) { 1365 auto I = ShapeMap.find(StoredVal); 1366 if (I == ShapeMap.end()) 1367 return false; 1368 1369 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1370 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1371 I->second); 1372 return true; 1373 } 1374 1375 /// Lower binary operators, if shape information is available. 1376 bool VisitBinaryOperator(BinaryOperator *Inst) { 1377 auto I = ShapeMap.find(Inst); 1378 if (I == ShapeMap.end()) 1379 return false; 1380 1381 Value *Lhs = Inst->getOperand(0); 1382 Value *Rhs = Inst->getOperand(1); 1383 1384 IRBuilder<> Builder(Inst); 1385 ShapeInfo &Shape = I->second; 1386 1387 MatrixTy Result; 1388 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1389 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1390 assert(A.isColumnMajor() == B.isColumnMajor() && 1391 Result.isColumnMajor() == A.isColumnMajor() && 1392 "operands must agree on matrix layout"); 1393 1394 // Helper to perform binary op on vectors. 1395 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1396 switch (Inst->getOpcode()) { 1397 case Instruction::Add: 1398 return Builder.CreateAdd(LHS, RHS); 1399 case Instruction::Mul: 1400 return Builder.CreateMul(LHS, RHS); 1401 case Instruction::Sub: 1402 return Builder.CreateSub(LHS, RHS); 1403 case Instruction::FAdd: 1404 return Builder.CreateFAdd(LHS, RHS); 1405 case Instruction::FMul: 1406 return Builder.CreateFMul(LHS, RHS); 1407 case Instruction::FSub: 1408 return Builder.CreateFSub(LHS, RHS); 1409 default: 1410 llvm_unreachable("Unsupported binary operator for matrix"); 1411 } 1412 }; 1413 1414 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1415 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1416 1417 finalizeLowering(Inst, 1418 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1419 Result.getNumVectors()), 1420 Builder); 1421 return true; 1422 } 1423 1424 /// Helper to linearize a matrix expression tree into a string. Currently 1425 /// matrix expressions are linarized by starting at an expression leaf and 1426 /// linearizing bottom up. 1427 struct ExprLinearizer { 1428 unsigned LengthToBreak = 100; 1429 std::string Str; 1430 raw_string_ostream Stream; 1431 unsigned LineLength = 0; 1432 const DataLayout &DL; 1433 1434 /// Mapping from instructions to matrixes. It is used to identify 1435 /// matrix instructions. 1436 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1437 1438 /// Mapping from values to the leaves of all expressions that the value is 1439 /// part of. 1440 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1441 1442 /// Set of matrix expressions in the scope of a given DISubprogram. 1443 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1444 1445 /// Leaf node of the expression to linearize. 1446 Value *Leaf; 1447 1448 /// Used to keep track of sub-expressions that get reused while linearizing 1449 /// the expression. Re-used sub-expressions are marked as (reused). 1450 SmallPtrSet<Value *, 8> ReusedExprs; 1451 1452 ExprLinearizer(const DataLayout &DL, 1453 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1454 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1455 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1456 Value *Leaf) 1457 : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1458 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1459 1460 void indent(unsigned N) { 1461 LineLength += N; 1462 for (unsigned i = 0; i < N; i++) 1463 Stream << " "; 1464 } 1465 1466 void lineBreak() { 1467 Stream << "\n"; 1468 LineLength = 0; 1469 } 1470 1471 void maybeIndent(unsigned Indent) { 1472 if (LineLength >= LengthToBreak) 1473 lineBreak(); 1474 1475 if (LineLength == 0) 1476 indent(Indent); 1477 } 1478 1479 void write(StringRef S) { 1480 LineLength += S.size(); 1481 Stream << S; 1482 } 1483 1484 Value *getUnderlyingObjectThroughLoads(Value *V) { 1485 if (Value *Ptr = getPointerOperand(V)) 1486 return getUnderlyingObjectThroughLoads(Ptr); 1487 else if (V->getType()->isPointerTy()) 1488 return GetUnderlyingObject(V, DL); 1489 return V; 1490 } 1491 1492 /// Returns true if \p V is a matrix value in the given subprogram. 1493 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1494 1495 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1496 /// \p SS. 1497 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1498 auto M = Inst2Matrix.find(V); 1499 if (M == Inst2Matrix.end()) 1500 SS << "unknown"; 1501 else { 1502 SS << M->second.getNumRows(); 1503 SS << "x"; 1504 SS << M->second.getNumColumns(); 1505 } 1506 } 1507 1508 /// Write the called function name. Handles calls to llvm.matrix.* 1509 /// specially: we write the name, followed by the dimensions of the input 1510 /// matrixes, followed by the scalar type name. 1511 void writeFnName(CallInst *CI) { 1512 if (!CI->getCalledFunction()) 1513 write("<no called fn>"); 1514 else { 1515 StringRef Name = CI->getCalledFunction()->getName(); 1516 if (!Name.startswith("llvm.matrix")) { 1517 write(Name); 1518 return; 1519 } 1520 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1521 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1522 .drop_front(StringRef("llvm.matrix.").size())); 1523 write("."); 1524 std::string Tmp = ""; 1525 raw_string_ostream SS(Tmp); 1526 1527 switch (II->getIntrinsicID()) { 1528 case Intrinsic::matrix_multiply: 1529 prettyPrintMatrixType(II->getOperand(0), SS); 1530 SS << "."; 1531 prettyPrintMatrixType(II->getOperand(1), SS); 1532 SS << "." << *II->getType()->getScalarType(); 1533 break; 1534 case Intrinsic::matrix_transpose: 1535 prettyPrintMatrixType(II->getOperand(0), SS); 1536 SS << "." << *II->getType()->getScalarType(); 1537 break; 1538 case Intrinsic::matrix_column_major_load: 1539 prettyPrintMatrixType(II, SS); 1540 SS << "." << *II->getType()->getScalarType(); 1541 break; 1542 case Intrinsic::matrix_column_major_store: 1543 prettyPrintMatrixType(II->getOperand(0), SS); 1544 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1545 break; 1546 default: 1547 llvm_unreachable("Unhandled case"); 1548 } 1549 SS.flush(); 1550 write(Tmp); 1551 } 1552 } 1553 1554 unsigned getNumShapeArgs(CallInst *CI) const { 1555 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1556 switch (II->getIntrinsicID()) { 1557 case Intrinsic::matrix_multiply: 1558 return 3; 1559 case Intrinsic::matrix_transpose: 1560 return 2; 1561 case Intrinsic::matrix_column_major_load: 1562 case Intrinsic::matrix_column_major_store: 1563 return 3; 1564 default: 1565 return 0; 1566 } 1567 } 1568 return 0; 1569 } 1570 1571 /// Special printing for values: for pointers, we print if they refer to an 1572 /// (function) external address or a stack address, for other values we 1573 /// either print the constant or "scalar"/"matrix" for other values. 1574 void write(Value *V) { 1575 V = getUnderlyingObjectThroughLoads(V); 1576 if (V->getType()->isPointerTy()) { 1577 if (isa<AllocaInst>(V)) { 1578 Stream << "stack addr"; 1579 LineLength += StringRef("stack addr").size(); 1580 } else { 1581 Stream << "addr"; 1582 LineLength += StringRef("addr").size(); 1583 } 1584 if (!V->getName().empty()) { 1585 Stream << " %" << V->getName() << ""; 1586 LineLength += V->getName().size() + 2; 1587 } 1588 return; 1589 } 1590 1591 std::string Tmp; 1592 raw_string_ostream TmpStream(Tmp); 1593 1594 if (auto *CI = dyn_cast<ConstantInt>(V)) 1595 TmpStream << CI->getValue(); 1596 else if (isa<Constant>(V)) 1597 TmpStream << "constant"; 1598 else { 1599 if (isMatrix(V)) 1600 TmpStream << "matrix"; 1601 else 1602 TmpStream << "scalar"; 1603 } 1604 TmpStream.flush(); 1605 Tmp = std::string(StringRef(Tmp).trim()); 1606 LineLength += Tmp.size(); 1607 Stream << Tmp; 1608 } 1609 1610 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1611 /// Expressions that are re-used multiple times are prefixed with (reused) 1612 /// at the re-used root instruction. 1613 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1614 bool ParentShared) { 1615 auto *I = cast<Instruction>(Expr); 1616 maybeIndent(Indent); 1617 SmallVector<Value *, 8> Ops; 1618 1619 // Is Expr shared with other expression leaves? 1620 bool ExprShared = false; 1621 1622 // Deal with shared subtrees. Mark them as shared, if required. 1623 if (!ParentShared) { 1624 auto SI = Shared.find(Expr); 1625 assert(SI != Shared.end() && SI->second.count(Leaf)); 1626 1627 for (Value *S : SI->second) { 1628 if (S == Leaf) 1629 continue; 1630 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1631 write("shared with remark at line " + std::to_string(DL.getLine()) + 1632 " column " + std::to_string(DL.getCol()) + " ("); 1633 } 1634 ExprShared = SI->second.size() > 1; 1635 } 1636 1637 bool Reused = !ReusedExprs.insert(Expr).second; 1638 if (Reused && !ParentReused) 1639 write("(reused) "); 1640 1641 if (auto *CI = dyn_cast<CallInst>(I)) { 1642 writeFnName(CI); 1643 1644 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 1645 } else if (isa<BitCastInst>(Expr)) { 1646 // Special case bitcasts, which are used to materialize matrixes from 1647 // non-matrix ops. 1648 write("matrix"); 1649 return; 1650 } else { 1651 Ops.append(I->value_op_begin(), I->value_op_end()); 1652 write(std::string(I->getOpcodeName())); 1653 } 1654 1655 write(std::string("(")); 1656 1657 unsigned NumOpsToBreak = 1; 1658 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 1659 NumOpsToBreak = 2; 1660 1661 for (Value *Op : Ops) { 1662 if (Ops.size() > NumOpsToBreak) 1663 lineBreak(); 1664 1665 maybeIndent(Indent + 1); 1666 if (isMatrix(Op)) 1667 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1668 else 1669 write(Op); 1670 if (Op != Ops.back()) 1671 write(", "); 1672 } 1673 1674 write(")"); 1675 } 1676 1677 const std::string &getResult() { 1678 Stream.flush(); 1679 return Str; 1680 } 1681 }; 1682 1683 /// Generate remarks for matrix operations in a function. To generate remarks 1684 /// for matrix expressions, the following approach is used: 1685 /// 1. Use the inlined-at debug information to group matrix operations to the 1686 /// DISubprograms they are contained in. 1687 /// 2. Collect leaves of matrix expressions (done in 1688 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1689 // mapping. Leaves are lowered matrix instructions without other matrix 1690 // users (like stores) in the current subprogram. 1691 /// 3. For each leaf, create a remark containing a linearizied version of the 1692 /// matrix expression. The expression is linearized by a recursive 1693 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1694 /// that multiple leaves can share sub-expressions. Shared subexpressions 1695 /// are explicitly marked as shared(). 1696 struct RemarkGenerator { 1697 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1698 OptimizationRemarkEmitter &ORE; 1699 Function &Func; 1700 const DataLayout &DL; 1701 1702 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 1703 OptimizationRemarkEmitter &ORE, Function &Func) 1704 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 1705 DL(Func.getParent()->getDataLayout()) {} 1706 1707 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1708 /// instructions in Inst2Matrix returning void or without any users in 1709 /// \p ExprsInSubprogram. Currently that should only include stores. 1710 SmallVector<Value *, 4> 1711 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1712 SmallVector<Value *, 4> Leaves; 1713 for (auto *Expr : ExprsInSubprogram) 1714 if (Expr->getType()->isVoidTy() || 1715 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1716 return ExprsInSubprogram.count(U); 1717 })) 1718 Leaves.push_back(Expr); 1719 return Leaves; 1720 } 1721 1722 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1723 /// to all visited expressions in \p Shared. Limit the matrix operations to 1724 /// the ones in \p ExprsInSubprogram. 1725 void collectSharedInfo(Value *Leaf, Value *V, 1726 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1727 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1728 1729 if (!ExprsInSubprogram.count(V)) 1730 return; 1731 1732 auto I = Shared.insert({V, {}}); 1733 I.first->second.insert(Leaf); 1734 1735 for (Value *Op : cast<Instruction>(V)->operand_values()) 1736 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1737 return; 1738 } 1739 1740 /// Calculate the number of exclusive and shared op counts for expression 1741 /// starting at \p V. Expressions used multiple times are counted once. 1742 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1743 std::pair<OpInfoTy, OpInfoTy> 1744 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1745 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1746 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1747 if (!ExprsInSubprogram.count(Root)) 1748 return {}; 1749 1750 // Already counted this expression. Stop. 1751 if (!ReusedExprs.insert(Root).second) 1752 return {}; 1753 1754 OpInfoTy SharedCount; 1755 OpInfoTy Count; 1756 1757 auto I = Shared.find(Root); 1758 auto CM = Inst2Matrix.find(Root); 1759 if (I->second.size() == 1) 1760 Count = CM->second.getOpInfo(); 1761 else 1762 SharedCount = CM->second.getOpInfo(); 1763 1764 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1765 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1766 Count += C.first; 1767 SharedCount += C.second; 1768 } 1769 return {Count, SharedCount}; 1770 } 1771 1772 void emitRemarks() { 1773 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1774 return; 1775 1776 // Map matrix operations to their containting subprograms, by traversing 1777 // the inlinedAt chain. If the function does not have a DISubprogram, we 1778 // only map them to the containing function. 1779 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1780 for (auto &KV : Inst2Matrix) { 1781 if (Func.getSubprogram()) { 1782 auto *I = cast<Instruction>(KV.first); 1783 DILocation *Context = I->getDebugLoc(); 1784 while (Context) { 1785 auto I = 1786 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1787 I.first->second.push_back(KV.first); 1788 Context = DebugLoc(Context).getInlinedAt(); 1789 } 1790 } else { 1791 auto I = Subprog2Exprs.insert({nullptr, {}}); 1792 I.first->second.push_back(KV.first); 1793 } 1794 } 1795 for (auto &KV : Subprog2Exprs) { 1796 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1797 KV.second.end()); 1798 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1799 1800 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1801 for (Value *Leaf : Leaves) 1802 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1803 1804 // Generate remarks for each leaf. 1805 for (auto *L : Leaves) { 1806 1807 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1808 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1809 while (Context) { 1810 if (getSubprogram(Context->getScope()) == KV.first) { 1811 Loc = Context; 1812 break; 1813 } 1814 Context = DebugLoc(Context).getInlinedAt(); 1815 } 1816 1817 SmallPtrSet<Value *, 8> ReusedExprs; 1818 OpInfoTy Counts, SharedCounts; 1819 std::tie(Counts, SharedCounts) = 1820 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1821 1822 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 1823 cast<Instruction>(L)->getParent()); 1824 1825 Rem << "Lowered with "; 1826 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1827 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1828 << ore::NV("NumComputeOps", Counts.NumComputeOps) 1829 << " compute ops"; 1830 1831 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1832 SharedCounts.NumComputeOps > 0) { 1833 Rem << ",\nadditionally " 1834 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1835 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1836 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1837 << " compute ops" 1838 << " are shared with other expressions"; 1839 } 1840 1841 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 1842 ORE.emit(Rem); 1843 } 1844 } 1845 } 1846 1847 std::string 1848 linearize(Value *L, 1849 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1850 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1851 const DataLayout &DL) { 1852 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 1853 Lin.linearizeExpr(L, 0, false, false); 1854 return Lin.getResult(); 1855 } 1856 }; 1857 }; 1858 } // namespace 1859 1860 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1861 FunctionAnalysisManager &AM) { 1862 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1863 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1864 auto &AA = AM.getResult<AAManager>(F); 1865 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 1866 auto &LI = AM.getResult<LoopAnalysis>(F); 1867 1868 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1869 if (LMT.Visit()) { 1870 PreservedAnalyses PA; 1871 PA.preserveSet<CFGAnalyses>(); 1872 return PA; 1873 } 1874 return PreservedAnalyses::all(); 1875 } 1876 1877 namespace { 1878 1879 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1880 public: 1881 static char ID; 1882 1883 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1884 initializeLowerMatrixIntrinsicsLegacyPassPass( 1885 *PassRegistry::getPassRegistry()); 1886 } 1887 1888 bool runOnFunction(Function &F) override { 1889 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1890 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1891 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 1892 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1893 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1894 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1895 bool C = LMT.Visit(); 1896 return C; 1897 } 1898 1899 void getAnalysisUsage(AnalysisUsage &AU) const override { 1900 AU.addRequired<TargetTransformInfoWrapperPass>(); 1901 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1902 AU.addRequired<AAResultsWrapperPass>(); 1903 AU.addRequired<DominatorTreeWrapperPass>(); 1904 AU.addPreserved<DominatorTreeWrapperPass>(); 1905 AU.addRequired<LoopInfoWrapperPass>(); 1906 AU.addPreserved<LoopInfoWrapperPass>(); 1907 } 1908 }; 1909 } // namespace 1910 1911 static const char pass_name[] = "Lower the matrix intrinsics"; 1912 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1913 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1914 false, false) 1915 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1916 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 1917 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 1918 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 1919 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1920 false, false) 1921 1922 Pass *llvm::createLowerMatrixIntrinsicsPass() { 1923 return new LowerMatrixIntrinsicsLegacyPass(); 1924 } 1925