1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/PatternMatch.h" 38 #include "llvm/InitializePasses.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Debug.h" 41 #include "llvm/Transforms/Scalar.h" 42 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 43 44 using namespace llvm; 45 using namespace PatternMatch; 46 47 #define DEBUG_TYPE "lower-matrix-intrinsics" 48 49 static cl::opt<bool> EnableShapePropagation( 50 "matrix-propagate-shape", cl::init(true), cl::Hidden, 51 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 52 "instructions.")); 53 54 static cl::opt<bool> 55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 56 cl::desc("Enable/disable fusing matrix instructions.")); 57 // TODO: Allow and use non-square tiles. 58 static cl::opt<unsigned> TileSize( 59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 60 cl::desc( 61 "Tile size for matrix instruction fusion using square-shaped tiles.")); 62 static cl::opt<bool> ForceFusion( 63 "force-fuse-matrix", cl::init(false), cl::Hidden, 64 cl::desc("Force matrix instruction fusion even if not profitable.")); 65 static cl::opt<bool> AllowContractEnabled( 66 "matrix-allow-contract", cl::init(false), cl::Hidden, 67 cl::desc("Allow the use of FMAs if available and profitable. This may " 68 "result in different results, due to less rounding error.")); 69 70 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 71 72 static cl::opt<MatrixLayoutTy> MatrixLayout( 73 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 74 cl::desc("Sets the default matrix layout"), 75 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 76 "Use column-major layout"), 77 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 78 "Use row-major layout"))); 79 80 /// Helper function to either return Scope, if it is a subprogram or the 81 /// attached subprogram for a local scope. 82 static DISubprogram *getSubprogram(DIScope *Scope) { 83 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 84 return Subprogram; 85 return cast<DILocalScope>(Scope)->getSubprogram(); 86 } 87 88 namespace { 89 90 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 91 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 92 // assuming \p Stride elements between start two consecutive vectors. 93 // \p Stride must be >= \p NumElements. 94 // For column-major matrixes, the function computes the address of a column 95 // vectors and \p NumElements must be set to the number of elements in a column 96 // (= number of rows of the matrix). For row-major matrixes, the function 97 // computes the address of a row vector and \p NumElements must be set to the 98 // number of elements in a column (= number of columns of the matrix). 99 // 100 // Consider a 4x4 matrix in column-mjaor layout like below 101 // 102 // 0 1 2 3 103 // 0 v_0_0 v_0_1 v_0_2 v_0_3 104 // 1 v_1_0 v_1_1 v_1_2 v_1_3 105 // 2 v_2_0 v_2_1 v_2_2 v_2_3 106 // 3 v_3_0 v_3_1 v_3_2 v_3_3 107 108 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 109 // we need a pointer to the first element of the submatrix as base pointer. 110 // Then we can use computeVectorAddr to compute the addresses for the columns 111 // of the sub-matrix. 112 // 113 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 114 // -> just returns Base 115 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 116 // -> returns Base + (1 * 4) 117 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 118 // -> returns Base + (2 * 4) 119 // 120 // The graphic below illustrates the number of elements in a column (marked 121 // with |) and the number of skipped elements (marked with }). 122 // 123 // v_0_0 v_0_1 {v_0_2 {v_0_3 124 // Base Col 1 Col 2 125 // | | | 126 // v_1_0 |v_1_1 |v_1_2 |v_1_3 127 // v_2_0 |v_2_1 |v_2_2 |v_2_3 128 // v_3_0 {v_3_1 {v_3_2 v_3_3 129 // 130 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 131 unsigned NumElements, Type *EltType, 132 IRBuilder<> &Builder) { 133 134 assert((!isa<ConstantInt>(Stride) || 135 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 136 "Stride must be >= the number of elements in the result vector."); 137 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 138 139 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 140 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 141 142 // Get pointer to the start of the selected vector. Skip GEP creation, 143 // if we select vector 0. 144 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 145 VecStart = BasePtr; 146 else 147 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 148 149 // Cast elementwise vector start pointer to a pointer to a vector 150 // (EltType x NumElements)*. 151 Type *VecType = VectorType::get(EltType, NumElements); 152 Type *VecPtrType = PointerType::get(VecType, AS); 153 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 154 } 155 156 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 157 /// 158 /// Currently, the lowering for each matrix intrinsic is done as follows: 159 /// 1. Propagate the shape information from intrinsics to connected 160 /// instructions. 161 /// 2. Lower instructions with shape information (assuming column-major layout). 162 /// The lowering works similarly using row-major layout. 163 /// 2.1. Get column vectors for each argument. If we already lowered the 164 /// definition of an argument, use the produced column vectors directly. 165 /// If not, split the operand vector containing an embedded matrix into 166 /// a set of column vectors, 167 /// 2.2. Lower the instruction in terms of columnwise operations, which yields 168 /// a set of column vectors containing result matrix. Note that we lower 169 /// all instructions that have shape information. Besides the intrinsics, 170 /// this includes stores for example. 171 /// 2.3. Update uses of the lowered instruction. If we have shape information 172 /// for a user, there is nothing to do, as we will look up the result 173 /// column matrix when lowering the user. For other uses, we embed the 174 /// result matrix in a flat vector and update the use. 175 /// 2.4. Cache the result column matrix for the instruction we lowered 176 /// 3. After we lowered all instructions in a function, remove the now 177 /// obsolete instructions. 178 /// 179 class LowerMatrixIntrinsics { 180 Function &Func; 181 const DataLayout &DL; 182 const TargetTransformInfo &TTI; 183 AliasAnalysis &AA; 184 DominatorTree &DT; 185 LoopInfo &LI; 186 OptimizationRemarkEmitter &ORE; 187 188 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 189 struct OpInfoTy { 190 /// Number of stores emitted to generate this matrix. 191 unsigned NumStores = 0; 192 /// Number of loads emitted to generate this matrix. 193 unsigned NumLoads = 0; 194 /// Number of compute operations emitted to generate this matrix. 195 unsigned NumComputeOps = 0; 196 197 OpInfoTy &operator+=(const OpInfoTy &RHS) { 198 NumStores += RHS.NumStores; 199 NumLoads += RHS.NumLoads; 200 NumComputeOps += RHS.NumComputeOps; 201 return *this; 202 } 203 }; 204 205 /// Wrapper class representing a matrix as a set of vectors, either in row or 206 /// column major layout. All vectors must have the same vector type. 207 class MatrixTy { 208 SmallVector<Value *, 16> Vectors; 209 210 OpInfoTy OpInfo; 211 212 bool IsColumnMajor = true; 213 214 public: 215 MatrixTy() 216 : Vectors(), 217 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 218 MatrixTy(ArrayRef<Value *> Vectors) 219 : Vectors(Vectors.begin(), Vectors.end()), 220 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 221 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 222 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 223 224 unsigned D = isColumnMajor() ? NumColumns : NumRows; 225 for (unsigned J = 0; J < D; ++J) 226 addVector(UndefValue::get( 227 VectorType::get(EltTy, isColumnMajor() ? NumRows : NumColumns))); 228 } 229 230 Value *getVector(unsigned i) const { return Vectors[i]; } 231 Value *getColumn(unsigned i) const { 232 assert(isColumnMajor() && "only supported for column-major matrixes"); 233 return Vectors[i]; 234 } 235 Value *getRow(unsigned i) const { 236 assert(!isColumnMajor() && "only supported for row-major matrixes"); 237 return Vectors[i]; 238 } 239 240 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 241 242 Type *getElementType() { return getVectorTy()->getElementType(); } 243 244 unsigned getNumVectors() const { 245 if (isColumnMajor()) 246 return getNumColumns(); 247 return getNumRows(); 248 } 249 250 unsigned getNumColumns() const { 251 if (isColumnMajor()) 252 return Vectors.size(); 253 else { 254 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 255 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 256 } 257 } 258 unsigned getNumRows() const { 259 if (isColumnMajor()) { 260 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 261 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 262 } else 263 return Vectors.size(); 264 } 265 266 void addVector(Value *V) { Vectors.push_back(V); } 267 VectorType *getColumnTy() { 268 assert(isColumnMajor() && "only supported for column-major matrixes"); 269 return getVectorTy(); 270 } 271 272 VectorType *getVectorTy() { 273 return cast<VectorType>(Vectors[0]->getType()); 274 } 275 276 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 277 assert(isColumnMajor() && 278 "columns() only supported for column-major matrixes"); 279 return make_range(Vectors.begin(), Vectors.end()); 280 } 281 282 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 283 return make_range(Vectors.begin(), Vectors.end()); 284 } 285 286 /// Embed the vectors of the matrix into a flat vector by concatenating 287 /// them. 288 Value *embedInVector(IRBuilder<> &Builder) const { 289 return Vectors.size() == 1 ? Vectors[0] 290 : concatenateVectors(Builder, Vectors); 291 } 292 293 MatrixTy &addNumLoads(unsigned N) { 294 OpInfo.NumLoads += N; 295 return *this; 296 } 297 298 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 299 300 MatrixTy &addNumStores(unsigned N) { 301 OpInfo.NumStores += N; 302 return *this; 303 } 304 305 MatrixTy &addNumComputeOps(unsigned N) { 306 OpInfo.NumComputeOps += N; 307 return *this; 308 } 309 310 unsigned getNumStores() const { return OpInfo.NumStores; } 311 unsigned getNumLoads() const { return OpInfo.NumLoads; } 312 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 313 314 const OpInfoTy &getOpInfo() const { return OpInfo; } 315 316 bool isColumnMajor() const { return IsColumnMajor; } 317 318 unsigned getStride() const { 319 if (isColumnMajor()) 320 return getNumRows(); 321 return getNumColumns(); 322 } 323 324 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 325 /// matrix is column-major, the result vector is extracted from a column 326 /// vector, otherwise from a row vector. 327 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 328 IRBuilder<> &Builder) const { 329 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 330 Value *Undef = UndefValue::get(Vec->getType()); 331 Constant *Mask = 332 createSequentialMask(Builder, isColumnMajor() ? I : J, NumElts, 0); 333 return Builder.CreateShuffleVector(Vec, Undef, Mask, "block"); 334 } 335 }; 336 337 struct ShapeInfo { 338 unsigned NumRows; 339 unsigned NumColumns; 340 341 bool IsColumnMajor; 342 343 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 344 : NumRows(NumRows), NumColumns(NumColumns), 345 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 346 347 ShapeInfo(Value *NumRows, Value *NumColumns) 348 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 349 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 350 351 bool operator==(const ShapeInfo &other) { 352 return NumRows == other.NumRows && NumColumns == other.NumColumns; 353 } 354 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 355 356 /// Returns true if shape-information is defined, meaning both dimensions 357 /// are != 0. 358 operator bool() const { 359 assert(NumRows == 0 || NumColumns != 0); 360 return NumRows != 0; 361 } 362 363 unsigned getStride() const { 364 if (IsColumnMajor) 365 return NumRows; 366 return NumColumns; 367 } 368 369 unsigned getNumVectors() const { 370 if (IsColumnMajor) 371 return NumColumns; 372 return NumRows; 373 } 374 }; 375 376 /// Maps instructions to their shape information. The shape information 377 /// describes the shape to be used while lowering. This matches the shape of 378 /// the result value of the instruction, with the only exceptions being store 379 /// instructions and the matrix_columnwise_store intrinsics. For those, the 380 /// shape information indicates that those instructions should be lowered 381 /// using shape information as well. 382 DenseMap<Value *, ShapeInfo> ShapeMap; 383 384 /// List of instructions to remove. While lowering, we are not replacing all 385 /// users of a lowered instruction, if shape information is available and 386 /// those need to be removed after we finished lowering. 387 SmallVector<Instruction *, 16> ToRemove; 388 389 /// Map from instructions to their produced column matrix. 390 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 391 392 public: 393 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 394 AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, 395 OptimizationRemarkEmitter &ORE) 396 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 397 LI(LI), ORE(ORE) {} 398 399 unsigned getNumOps(Type *VT) { 400 assert(isa<VectorType>(VT) && "Expected vector type"); 401 return getNumOps(VT->getScalarType(), 402 cast<VectorType>(VT)->getNumElements()); 403 } 404 405 // 406 /// Return the estimated number of vector ops required for an operation on 407 /// \p VT * N. 408 unsigned getNumOps(Type *ST, unsigned N) { 409 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 410 double(TTI.getRegisterBitWidth(true))); 411 } 412 413 /// Return the set of vectors that a matrix value is lowered to. 414 /// 415 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 416 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 417 /// into vectors. 418 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 419 IRBuilder<> &Builder) { 420 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 421 assert(VType && "MatrixVal must be a vector type"); 422 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 423 "The vector size must match the number of matrix elements"); 424 425 // Check if we lowered MatrixVal using shape information. In that case, 426 // return the existing matrix, if it matches the requested shape 427 // information. If there is a mis-match, embed the result in a flat 428 // vector and split it later. 429 auto Found = Inst2ColumnMatrix.find(MatrixVal); 430 if (Found != Inst2ColumnMatrix.end()) { 431 MatrixTy &M = Found->second; 432 // Return the found matrix, if its shape matches the requested shape 433 // information 434 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 435 return M; 436 437 MatrixVal = M.embedInVector(Builder); 438 } 439 440 // Otherwise split MatrixVal. 441 SmallVector<Value *, 16> SplitVecs; 442 Value *Undef = UndefValue::get(VType); 443 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 444 MaskStart += SI.getStride()) { 445 Constant *Mask = 446 createSequentialMask(Builder, MaskStart, SI.getStride(), 0); 447 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 448 SplitVecs.push_back(V); 449 } 450 451 return {SplitVecs}; 452 } 453 454 /// If \p V already has a known shape return false. Otherwise set the shape 455 /// for instructions that support it. 456 bool setShapeInfo(Value *V, ShapeInfo Shape) { 457 assert(Shape && "Shape not set"); 458 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 459 return false; 460 461 auto SIter = ShapeMap.find(V); 462 if (SIter != ShapeMap.end()) { 463 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 464 << SIter->second.NumRows << " " 465 << SIter->second.NumColumns << " for " << *V << "\n"); 466 return false; 467 } 468 469 ShapeMap.insert({V, Shape}); 470 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 471 << " for " << *V << "\n"); 472 return true; 473 } 474 475 bool isUniformShape(Value *V) { 476 Instruction *I = dyn_cast<Instruction>(V); 477 if (!I) 478 return true; 479 480 switch (I->getOpcode()) { 481 case Instruction::FAdd: 482 case Instruction::FSub: 483 case Instruction::FMul: // Scalar multiply. 484 case Instruction::Add: 485 case Instruction::Mul: 486 case Instruction::Sub: 487 return true; 488 default: 489 return false; 490 } 491 } 492 493 /// Returns true if shape information can be used for \p V. The supported 494 /// instructions must match the instructions that can be lowered by this pass. 495 bool supportsShapeInfo(Value *V) { 496 Instruction *Inst = dyn_cast<Instruction>(V); 497 if (!Inst) 498 return false; 499 500 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 501 if (II) 502 switch (II->getIntrinsicID()) { 503 case Intrinsic::matrix_multiply: 504 case Intrinsic::matrix_transpose: 505 case Intrinsic::matrix_columnwise_load: 506 case Intrinsic::matrix_columnwise_store: 507 return true; 508 default: 509 return false; 510 } 511 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 512 } 513 514 /// Propagate the shape information of instructions to their users. 515 /// The work list contains instructions for which we can compute the shape, 516 /// either based on the information provided by matrix intrinsics or known 517 /// shapes of operands. 518 SmallVector<Instruction *, 32> 519 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 520 SmallVector<Instruction *, 32> NewWorkList; 521 // Pop an element for which we guaranteed to have at least one of the 522 // operand shapes. Add the shape for this and then add users to the work 523 // list. 524 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 525 while (!WorkList.empty()) { 526 Instruction *Inst = WorkList.back(); 527 WorkList.pop_back(); 528 529 // New entry, set the value and insert operands 530 bool Propagate = false; 531 532 Value *MatrixA; 533 Value *MatrixB; 534 Value *M; 535 Value *N; 536 Value *K; 537 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 538 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 539 m_Value(N), m_Value(K)))) { 540 Propagate = setShapeInfo(Inst, {M, K}); 541 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 542 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 543 // Flip dimensions. 544 Propagate = setShapeInfo(Inst, {N, M}); 545 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 546 m_Value(MatrixA), m_Value(), m_Value(), 547 m_Value(M), m_Value(N)))) { 548 Propagate = setShapeInfo(Inst, {N, M}); 549 } else if (match(Inst, 550 m_Intrinsic<Intrinsic::matrix_columnwise_load>( 551 m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 552 Propagate = setShapeInfo(Inst, {M, N}); 553 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 554 auto OpShape = ShapeMap.find(MatrixA); 555 if (OpShape != ShapeMap.end()) 556 setShapeInfo(Inst, OpShape->second); 557 continue; 558 } else if (isUniformShape(Inst)) { 559 // Find the first operand that has a known shape and use that. 560 for (auto &Op : Inst->operands()) { 561 auto OpShape = ShapeMap.find(Op.get()); 562 if (OpShape != ShapeMap.end()) { 563 Propagate |= setShapeInfo(Inst, OpShape->second); 564 break; 565 } 566 } 567 } 568 569 if (Propagate) { 570 NewWorkList.push_back(Inst); 571 for (auto *User : Inst->users()) 572 if (ShapeMap.count(User) == 0) 573 WorkList.push_back(cast<Instruction>(User)); 574 } 575 } 576 577 return NewWorkList; 578 } 579 580 /// Propagate the shape to operands of instructions with shape information. 581 /// \p Worklist contains the instruction for which we already know the shape. 582 SmallVector<Instruction *, 32> 583 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 584 SmallVector<Instruction *, 32> NewWorkList; 585 586 auto pushInstruction = [](Value *V, 587 SmallVectorImpl<Instruction *> &WorkList) { 588 Instruction *I = dyn_cast<Instruction>(V); 589 if (I) 590 WorkList.push_back(I); 591 }; 592 // Pop an element with known shape. Traverse the operands, if their shape 593 // derives from the result shape and is unknown, add it and add them to the 594 // worklist. 595 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 596 while (!WorkList.empty()) { 597 Value *V = WorkList.back(); 598 WorkList.pop_back(); 599 600 size_t BeforeProcessingV = WorkList.size(); 601 if (!isa<Instruction>(V)) 602 continue; 603 604 Value *MatrixA; 605 Value *MatrixB; 606 Value *M; 607 Value *N; 608 Value *K; 609 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 610 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 611 m_Value(N), m_Value(K)))) { 612 if (setShapeInfo(MatrixA, {M, N})) 613 pushInstruction(MatrixA, WorkList); 614 615 if (setShapeInfo(MatrixB, {N, K})) 616 pushInstruction(MatrixB, WorkList); 617 618 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 619 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 620 // Flip dimensions. 621 if (setShapeInfo(MatrixA, {M, N})) 622 pushInstruction(MatrixA, WorkList); 623 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 624 m_Value(MatrixA), m_Value(), m_Value(), 625 m_Value(M), m_Value(N)))) { 626 if (setShapeInfo(MatrixA, {M, N})) { 627 pushInstruction(MatrixA, WorkList); 628 } 629 } else if (isa<LoadInst>(V) || 630 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { 631 // Nothing to do, no matrix input. 632 } else if (isa<StoreInst>(V)) { 633 // Nothing to do. We forward-propagated to this so we would just 634 // backward propagate to an instruction with an already known shape. 635 } else if (isUniformShape(V)) { 636 // Propagate to all operands. 637 ShapeInfo Shape = ShapeMap[V]; 638 for (Use &U : cast<Instruction>(V)->operands()) { 639 if (setShapeInfo(U.get(), Shape)) 640 pushInstruction(U.get(), WorkList); 641 } 642 } 643 // After we discovered new shape info for new instructions in the 644 // worklist, we use their users as seeds for the next round of forward 645 // propagation. 646 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 647 for (User *U : WorkList[I]->users()) 648 if (isa<Instruction>(U) && V != U) 649 NewWorkList.push_back(cast<Instruction>(U)); 650 } 651 return NewWorkList; 652 } 653 654 bool Visit() { 655 if (EnableShapePropagation) { 656 SmallVector<Instruction *, 32> WorkList; 657 658 // Initially only the shape of matrix intrinsics is known. 659 // Initialize the work list with ops carrying shape information. 660 for (BasicBlock &BB : Func) 661 for (Instruction &Inst : BB) { 662 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 663 if (!II) 664 continue; 665 666 switch (II->getIntrinsicID()) { 667 case Intrinsic::matrix_multiply: 668 case Intrinsic::matrix_transpose: 669 case Intrinsic::matrix_columnwise_load: 670 case Intrinsic::matrix_columnwise_store: 671 WorkList.push_back(&Inst); 672 break; 673 default: 674 break; 675 } 676 } 677 // Propagate shapes until nothing changes any longer. 678 while (!WorkList.empty()) { 679 WorkList = propagateShapeForward(WorkList); 680 WorkList = propagateShapeBackward(WorkList); 681 } 682 } 683 684 bool Changed = false; 685 SmallVector<CallInst *, 16> MaybeFusableInsts; 686 SmallVector<Instruction *, 16> MatrixInsts; 687 688 // First, collect all instructions with shape information and candidates for 689 // fusion (currently only matrix multiplies). 690 ReversePostOrderTraversal<Function *> RPOT(&Func); 691 for (auto *BB : RPOT) 692 for (Instruction &I : *BB) { 693 if (ShapeMap.find(&I) == ShapeMap.end()) 694 continue; 695 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 696 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 697 MatrixInsts.push_back(&I); 698 } 699 700 // Second, try to fuse candidates. 701 SmallPtrSet<Instruction *, 16> FusedInsts; 702 for (CallInst *CI : MaybeFusableInsts) 703 LowerMatrixMultiplyFused(CI, FusedInsts); 704 Changed = !FusedInsts.empty(); 705 706 // Third, lower remaining instructions with shape information. 707 for (Instruction *Inst : MatrixInsts) { 708 if (FusedInsts.find(Inst) != FusedInsts.end()) 709 continue; 710 711 IRBuilder<> Builder(Inst); 712 713 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 714 Changed |= VisitCallInst(CInst); 715 716 Value *Op1; 717 Value *Op2; 718 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 719 Changed |= VisitBinaryOperator(BinOp); 720 if (match(Inst, m_Load(m_Value(Op1)))) 721 Changed |= VisitLoad(Inst, Op1, Builder); 722 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 723 Changed |= VisitStore(Inst, Op1, Op2, Builder); 724 } 725 726 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); 727 RemarkGen.emitRemarks(); 728 729 for (Instruction *Inst : reverse(ToRemove)) 730 Inst->eraseFromParent(); 731 732 return Changed; 733 } 734 735 LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType, 736 IRBuilder<> &Builder) { 737 return Builder.CreateAlignedLoad( 738 ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load"); 739 } 740 741 StoreInst *createVectorStore(Value *ColumnValue, Value *ColumnPtr, 742 Type *EltType, IRBuilder<> &Builder) { 743 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, 744 DL.getABITypeAlign(EltType)); 745 } 746 747 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 748 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 749 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 750 Type *EltPtrType = PointerType::get(EltType, AS); 751 return Builder.CreatePointerCast(BasePtr, EltPtrType); 752 } 753 754 /// Replace intrinsic calls 755 bool VisitCallInst(CallInst *Inst) { 756 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 757 return false; 758 759 switch (Inst->getCalledFunction()->getIntrinsicID()) { 760 case Intrinsic::matrix_multiply: 761 LowerMultiply(Inst); 762 break; 763 case Intrinsic::matrix_transpose: 764 LowerTranspose(Inst); 765 break; 766 case Intrinsic::matrix_columnwise_load: 767 LowerColumnwiseLoad(Inst); 768 break; 769 case Intrinsic::matrix_columnwise_store: 770 LowerColumnwiseStore(Inst); 771 break; 772 default: 773 return false; 774 } 775 return true; 776 } 777 778 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 779 /// vectors. 780 MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, ShapeInfo Shape, 781 IRBuilder<> &Builder) { 782 auto VType = cast<VectorType>(Ty); 783 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 784 MatrixTy Result; 785 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 786 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt32(I), Stride, 787 Shape.getStride(), VType->getElementType(), 788 Builder); 789 Value *Vector = createVectorLoad(GEP, VType->getElementType(), Builder); 790 Result.addVector(Vector); 791 } 792 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 793 Result.getNumVectors()); 794 } 795 796 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 797 /// starting at \p MatrixPtr[I][J]. 798 MatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, Value *I, 799 Value *J, ShapeInfo ResultShape, Type *EltTy, 800 IRBuilder<> &Builder) { 801 802 Value *Offset = Builder.CreateAdd( 803 Builder.CreateMul(J, Builder.getInt32(MatrixShape.getStride())), I); 804 805 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 806 Value *EltPtr = 807 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 808 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 809 Type *TileTy = 810 VectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns); 811 Type *TilePtrTy = PointerType::get(TileTy, AS); 812 Value *TilePtr = 813 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 814 815 return loadMatrix(TileTy, TilePtr, 816 Builder.getInt32(MatrixShape.getStride()), ResultShape, 817 Builder); 818 } 819 820 /// Lower a load instruction with shape information. 821 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, 822 ShapeInfo Shape) { 823 IRBuilder<> Builder(Inst); 824 finalizeLowering(Inst, 825 loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder), 826 Builder); 827 } 828 829 /// Lowers llvm.matrix.columnwise.load. 830 /// 831 /// The intrinsic loads a matrix from memory using a stride between columns. 832 void LowerColumnwiseLoad(CallInst *Inst) { 833 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 834 "Intrinsic only supports column-major layout!"); 835 Value *Ptr = Inst->getArgOperand(0); 836 Value *Stride = Inst->getArgOperand(1); 837 LowerLoad(Inst, Ptr, Stride, 838 {Inst->getArgOperand(2), Inst->getArgOperand(3)}); 839 } 840 841 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 842 /// MatrixPtr[I][J]. 843 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 844 ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy, 845 IRBuilder<> &Builder) { 846 Value *Offset = Builder.CreateAdd( 847 Builder.CreateMul(J, Builder.getInt32(MatrixShape.getStride())), I); 848 849 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 850 Value *EltPtr = 851 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 852 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 853 Type *TileTy = VectorType::get(EltTy, StoreVal.getNumRows() * 854 StoreVal.getNumColumns()); 855 Type *TilePtrTy = PointerType::get(TileTy, AS); 856 Value *TilePtr = 857 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 858 859 storeMatrix(TileTy, StoreVal, TilePtr, 860 Builder.getInt32(MatrixShape.getStride()), Builder); 861 } 862 863 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 864 /// vectors. 865 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride, 866 IRBuilder<> &Builder) { 867 auto VType = cast<VectorType>(Ty); 868 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 869 for (auto Vec : enumerate(StoreVal.vectors())) { 870 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt32(Vec.index()), 871 Stride, StoreVal.getStride(), 872 VType->getElementType(), Builder); 873 createVectorStore(Vec.value(), GEP, VType->getElementType(), Builder); 874 } 875 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 876 StoreVal.getNumVectors()); 877 } 878 879 /// Lower a store instruction with shape information. 880 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 881 ShapeInfo Shape) { 882 IRBuilder<> Builder(Inst); 883 auto StoreVal = getMatrix(Matrix, Shape, Builder); 884 finalizeLowering( 885 Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder), 886 Builder); 887 } 888 889 /// Lowers llvm.matrix.columnwise.store. 890 /// 891 /// The intrinsic store a matrix back memory using a stride between columns. 892 void LowerColumnwiseStore(CallInst *Inst) { 893 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 894 "Intrinsic only supports column-major layout!"); 895 Value *Matrix = Inst->getArgOperand(0); 896 Value *Ptr = Inst->getArgOperand(1); 897 Value *Stride = Inst->getArgOperand(2); 898 LowerStore(Inst, Matrix, Ptr, Stride, 899 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 900 } 901 902 // Set elements I..I+NumElts-1 to Block 903 Value *insertVector(Value *Col, unsigned I, Value *Block, 904 IRBuilder<> &Builder) { 905 906 // First, bring Block to the same size as Col 907 unsigned BlockNumElts = 908 cast<VectorType>(Block->getType())->getNumElements(); 909 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 910 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 911 912 Value *ExtendMask = 913 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 914 Value *Undef = UndefValue::get(Block->getType()); 915 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 916 917 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 918 // 8, 4, 5, 6 919 SmallVector<Constant *, 16> Mask; 920 unsigned i; 921 for (i = 0; i < I; i++) 922 Mask.push_back(Builder.getInt32(i)); 923 924 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 925 for (; i < I + BlockNumElts; i++) 926 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 927 928 for (; i < VecNumElts; i++) 929 Mask.push_back(Builder.getInt32(i)); 930 931 Value *MaskVal = ConstantVector::get(Mask); 932 933 return Builder.CreateShuffleVector(Col, Block, MaskVal); 934 } 935 936 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 937 IRBuilder<> &Builder, bool AllowContraction, 938 unsigned &NumComputeOps) { 939 NumComputeOps += getNumOps(A->getType()); 940 if (!Sum) 941 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 942 943 if (UseFPOp) { 944 if (AllowContraction) { 945 // Use fmuladd for floating point operations and let the backend decide 946 // if that's profitable. 947 Function *FMulAdd = Intrinsic::getDeclaration( 948 Func.getParent(), Intrinsic::fmuladd, A->getType()); 949 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 950 } 951 NumComputeOps += getNumOps(A->getType()); 952 Value *Mul = Builder.CreateFMul(A, B); 953 return Builder.CreateFAdd(Sum, Mul); 954 } 955 956 NumComputeOps += getNumOps(A->getType()); 957 Value *Mul = Builder.CreateMul(A, B); 958 return Builder.CreateAdd(Sum, Mul); 959 } 960 961 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 962 /// users with shape information, there's nothing to do: the will use the 963 /// cached value when they are lowered. For other users, \p Matrix is 964 /// flattened and the uses are updated to use it. Also marks \p Inst for 965 /// deletion. 966 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 967 IRBuilder<> &Builder) { 968 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 969 970 ToRemove.push_back(Inst); 971 Value *Flattened = nullptr; 972 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 973 Use &U = *I++; 974 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 975 if (!Flattened) 976 Flattened = Matrix.embedInVector(Builder); 977 U.set(Flattened); 978 } 979 } 980 } 981 982 /// Compute \p Result += \p A * \p B for input matrices with left-associating 983 /// addition. 984 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 985 const MatrixTy &B, bool AllowContraction, 986 IRBuilder<> &Builder, bool isTiled) { 987 const unsigned VF = std::max<unsigned>( 988 TTI.getRegisterBitWidth(true) / 989 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 990 1U); 991 unsigned R = Result.getNumRows(); 992 unsigned C = Result.getNumColumns(); 993 unsigned M = A.getNumColumns(); 994 995 bool IsFP = Result.getElementType()->isFloatingPointTy(); 996 assert(A.isColumnMajor() == B.isColumnMajor() && 997 Result.isColumnMajor() == A.isColumnMajor() && 998 "operands must agree on matrix layout"); 999 unsigned NumComputeOps = 0; 1000 if (A.isColumnMajor()) { 1001 // Multiply columns from the first operand with scalars from the second 1002 // operand. Then move along the K axes and accumulate the columns. With 1003 // this the adds can be vectorized without reassociation. 1004 for (unsigned J = 0; J < C; ++J) { 1005 unsigned BlockSize = VF; 1006 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1007 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1008 1009 for (unsigned I = 0; I < R; I += BlockSize) { 1010 // Gradually lower the vectorization factor to cover the remainder. 1011 while (I + BlockSize > R) 1012 BlockSize /= 2; 1013 1014 Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) 1015 : nullptr; 1016 for (unsigned K = 0; K < M; ++K) { 1017 Value *L = A.extractVector(I, K, BlockSize, Builder); 1018 Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); 1019 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1020 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1021 Result.getElementType()->isFloatingPointTy(), 1022 Builder, AllowContraction, NumComputeOps); 1023 } 1024 Result.setVector(J, 1025 insertVector(Result.getVector(J), I, Sum, Builder)); 1026 } 1027 } 1028 } else { 1029 // Multiply rows from the second operand with scalars from the first 1030 // operand. Then move along the K axes and accumulate the rows. With this 1031 // the adds can be vectorized without reassociation. 1032 for (unsigned I = 0; I < R; ++I) { 1033 unsigned BlockSize = VF; 1034 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1035 for (unsigned J = 0; J < C; J += BlockSize) { 1036 // Gradually lower the vectorization factor to cover the remainder. 1037 while (J + BlockSize > C) 1038 BlockSize /= 2; 1039 1040 Value *Sum = nullptr; 1041 for (unsigned K = 0; K < M; ++K) { 1042 Value *R = B.extractVector(K, J, BlockSize, Builder); 1043 Value *LH = Builder.CreateExtractElement(A.getVector(I), K); 1044 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1045 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1046 IsFP, Builder, AllowContraction, NumComputeOps); 1047 } 1048 Result.setVector(I, 1049 insertVector(Result.getVector(I), J, Sum, Builder)); 1050 } 1051 } 1052 } 1053 Result.addNumComputeOps(NumComputeOps); 1054 } 1055 1056 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1057 /// copying it to a new location. This new or otherwise the original location 1058 /// is returned. 1059 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1060 CallInst *MatMul) { 1061 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1062 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1063 1064 AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); 1065 1066 // If we can statically determine noalias we're good. 1067 if (!LdAliased) 1068 return Load->getPointerOperand(); 1069 1070 // Create code to check if the memory locations of the Load and Store 1071 // overlap and if they do, copy Load's operand to a new buffer. 1072 1073 // First, create new blocks for 2n part of the check and the copy. 1074 BasicBlock *Check0 = MatMul->getParent(); 1075 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1076 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1077 // as we adjust Check0 and Check1's branches. 1078 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1079 for (BasicBlock *Succ : successors(Check0)) 1080 DTUpdates.push_back({DT.Delete, Check0, Succ}); 1081 1082 BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1083 nullptr, "alias_cont"); 1084 BasicBlock *Copy = 1085 SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); 1086 BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1087 nullptr, "no_alias"); 1088 1089 // Check if the loaded memory location begins before the end of the store 1090 // location. If the condition holds, they might overlap, otherwise they are 1091 // guaranteed to not overlap. 1092 IRBuilder<> Builder(MatMul); 1093 Check0->getTerminator()->eraseFromParent(); 1094 Builder.SetInsertPoint(Check0); 1095 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1096 Value *StoreBegin = Builder.CreatePtrToInt( 1097 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1098 Value *StoreEnd = Builder.CreateAdd( 1099 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1100 "store.end", true, true); 1101 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1102 IntPtrTy, "load.begin"); 1103 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1104 Fusion); 1105 1106 // Check if the store begins before the end of the load location. If the 1107 // condition holds, they alias, otherwise they are guaranteed to not 1108 // overlap. 1109 Check1->getTerminator()->eraseFromParent(); 1110 Builder.SetInsertPoint(Check1, Check1->begin()); 1111 Value *LoadEnd = Builder.CreateAdd( 1112 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1113 "load.end", true, true); 1114 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1115 Fusion); 1116 1117 // Copy load operand to new alloca. 1118 Builder.SetInsertPoint(Copy, Copy->begin()); 1119 AllocaInst *NewLd = 1120 Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 1121 Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 1122 Load->getPointerOperand(), Load->getAlign(), 1123 LoadLoc.Size.getValue()); 1124 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1125 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1126 PHI->addIncoming(Load->getPointerOperand(), Check0); 1127 PHI->addIncoming(Load->getPointerOperand(), Check1); 1128 PHI->addIncoming(NewLd, Copy); 1129 1130 // Adjust DT. 1131 DTUpdates.push_back({DT.Insert, Check0, Check1}); 1132 DTUpdates.push_back({DT.Insert, Check0, Fusion}); 1133 DTUpdates.push_back({DT.Insert, Check1, Copy}); 1134 DTUpdates.push_back({DT.Insert, Check1, Fusion}); 1135 DT.applyUpdates(DTUpdates); 1136 return PHI; 1137 } 1138 1139 bool isFusionProfitable(CallInst *MatMul) { 1140 if (ForceFusion) 1141 return true; 1142 1143 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1144 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1145 1146 const unsigned R = LShape.NumRows; 1147 const unsigned C = RShape.NumColumns; 1148 const unsigned M = LShape.NumColumns; 1149 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1150 1151 const unsigned VF = 1152 std::max<unsigned>(TTI.getRegisterBitWidth(true) / 1153 EltType->getPrimitiveSizeInBits().getFixedSize(), 1154 1U); 1155 1156 // Cost model for tiling 1157 // 1158 // For tiling to be beneficial, we need reuse either along the R or 1159 // the C axis. We vectorize along the R axis so that means at least 1160 // 3 elements. 1161 // TODO: Also consider cost of copying if operands alias. 1162 if (R <= VF && C == 1) 1163 return false; 1164 // Then we need enough elements to exceed the number of vector 1165 // registers we have. Note that this is an oversimplification since 1166 // fusing also takes some extra loads which may exceed the number of 1167 // reloads necessary. 1168 unsigned Op0Regs = (R + VF - 1) / VF * M; 1169 unsigned Op1Regs = (M + VF - 1) / VF * C; 1170 return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 1171 } 1172 1173 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1174 MatrixTy Res; 1175 Type *ColumType = VectorType::get(EltType, R); 1176 for (unsigned I = 0; I < C; ++I) 1177 Res.addVector(ConstantAggregateZero::get(ColumType)); 1178 return Res; 1179 } 1180 1181 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1182 StoreInst *Store, 1183 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1184 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1185 "Tiling only supported for column-major matrixes at the moment!"); 1186 if (!isFusionProfitable(MatMul)) 1187 return; 1188 1189 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1190 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1191 1192 const unsigned R = LShape.NumRows; 1193 const unsigned C = RShape.NumColumns; 1194 const unsigned M = LShape.NumColumns; 1195 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1196 1197 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1198 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1199 Value *CPtr = Store->getPointerOperand(); 1200 1201 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1202 MatMul->hasAllowContract()); 1203 IRBuilder<> Builder(Store); 1204 for (unsigned J = 0; J < C; J += TileSize) 1205 for (unsigned I = 0; I < R; I += TileSize) { 1206 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1207 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1208 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1209 1210 for (unsigned K = 0; K < M; K += TileSize) { 1211 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1212 MatrixTy A = 1213 loadMatrix(APtr, LShape, Builder.getInt32(I), Builder.getInt32(K), 1214 {TileR, TileM}, EltType, Builder); 1215 MatrixTy B = 1216 loadMatrix(BPtr, RShape, Builder.getInt32(K), Builder.getInt32(J), 1217 {TileM, TileC}, EltType, Builder); 1218 emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); 1219 } 1220 storeMatrix(Res, CPtr, {R, M}, Builder.getInt32(I), Builder.getInt32(J), 1221 EltType, Builder); 1222 } 1223 1224 // Mark eliminated instructions as fused and remove them. 1225 FusedInsts.insert(Store); 1226 FusedInsts.insert(MatMul); 1227 Store->eraseFromParent(); 1228 MatMul->eraseFromParent(); 1229 if (LoadOp0->hasNUses(0)) { 1230 FusedInsts.insert(LoadOp0); 1231 LoadOp0->eraseFromParent(); 1232 } 1233 if (LoadOp1->hasNUses(0)) { 1234 FusedInsts.insert(LoadOp1); 1235 LoadOp1->eraseFromParent(); 1236 } 1237 } 1238 1239 /// Try to lower matrix multiply chains by fusing operations. 1240 /// 1241 /// Currently we only lower {ld, ld} -> matmul -> st chains. 1242 // 1243 /// No need to return a MatrixTy object for the result of the operation, since 1244 /// the single store user will be lowered as part of this. Instructions that 1245 /// are completely eliminated by fusion are added to \p FusedInsts. 1246 void LowerMatrixMultiplyFused(CallInst *MatMul, 1247 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1248 if (!FuseMatrix || !MatMul->hasOneUse() || 1249 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1250 return; 1251 1252 auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); 1253 auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); 1254 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1255 if (LoadOp0 && LoadOp1 && Store) { 1256 // The store address must dominate the MatMul instruction, otherwise 1257 // we create invalid IR. 1258 // FIXME: See if we can hoist the store address computation. 1259 auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1260 if (AddrI && (!DT.dominates(AddrI, MatMul))) 1261 return; 1262 1263 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1264 return; 1265 } 1266 } 1267 1268 /// Lowers llvm.matrix.multiply. 1269 void LowerMultiply(CallInst *MatMul) { 1270 IRBuilder<> Builder(MatMul); 1271 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1272 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1273 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1274 1275 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1276 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1277 1278 const unsigned R = LShape.NumRows; 1279 const unsigned C = RShape.NumColumns; 1280 assert(LShape.NumColumns == RShape.NumRows); 1281 1282 // Initialize the output 1283 MatrixTy Result(R, C, EltType); 1284 1285 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1286 MatMul->hasAllowContract()); 1287 emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); 1288 finalizeLowering(MatMul, Result, Builder); 1289 } 1290 1291 /// Lowers llvm.matrix.transpose. 1292 void LowerTranspose(CallInst *Inst) { 1293 MatrixTy Result; 1294 IRBuilder<> Builder(Inst); 1295 Value *InputVal = Inst->getArgOperand(0); 1296 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1297 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1298 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1299 assert(InputMatrix.isColumnMajor() && 1300 "Row-major code-gen not supported yet!"); 1301 1302 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 1303 // Build a single column vector for this row. First initialize it. 1304 Value *ResultColumn = UndefValue::get( 1305 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 1306 1307 // Go through the elements of this row and insert it into the resulting 1308 // column vector. 1309 for (auto C : enumerate(InputMatrix.columns())) { 1310 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 1311 // We insert at index Column since that is the row index after the 1312 // transpose. 1313 ResultColumn = 1314 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 1315 } 1316 Result.addVector(ResultColumn); 1317 } 1318 1319 // TODO: Improve estimate of operations needed for transposes. Currently we 1320 // just count the insertelement/extractelement instructions, but do not 1321 // account for later simplifications/combines. 1322 finalizeLowering( 1323 Inst, 1324 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 1325 Builder); 1326 } 1327 1328 /// Lower load instructions, if shape information is available. 1329 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { 1330 auto I = ShapeMap.find(Inst); 1331 if (I == ShapeMap.end()) 1332 return false; 1333 1334 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.getStride()), I->second); 1335 return true; 1336 } 1337 1338 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 1339 IRBuilder<> &Builder) { 1340 auto I = ShapeMap.find(StoredVal); 1341 if (I == ShapeMap.end()) 1342 return false; 1343 1344 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.getStride()), 1345 I->second); 1346 return true; 1347 } 1348 1349 /// Lower binary operators, if shape information is available. 1350 bool VisitBinaryOperator(BinaryOperator *Inst) { 1351 auto I = ShapeMap.find(Inst); 1352 if (I == ShapeMap.end()) 1353 return false; 1354 1355 Value *Lhs = Inst->getOperand(0); 1356 Value *Rhs = Inst->getOperand(1); 1357 1358 IRBuilder<> Builder(Inst); 1359 ShapeInfo &Shape = I->second; 1360 1361 MatrixTy Result; 1362 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1363 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1364 assert(A.isColumnMajor() == B.isColumnMajor() && 1365 Result.isColumnMajor() == A.isColumnMajor() && 1366 "operands must agree on matrix layout"); 1367 1368 // Helper to perform binary op on vectors. 1369 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1370 switch (Inst->getOpcode()) { 1371 case Instruction::Add: 1372 return Builder.CreateAdd(LHS, RHS); 1373 case Instruction::Mul: 1374 return Builder.CreateMul(LHS, RHS); 1375 case Instruction::Sub: 1376 return Builder.CreateSub(LHS, RHS); 1377 case Instruction::FAdd: 1378 return Builder.CreateFAdd(LHS, RHS); 1379 case Instruction::FMul: 1380 return Builder.CreateFMul(LHS, RHS); 1381 case Instruction::FSub: 1382 return Builder.CreateFSub(LHS, RHS); 1383 default: 1384 llvm_unreachable("Unsupported binary operator for matrix"); 1385 } 1386 }; 1387 1388 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1389 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1390 1391 finalizeLowering(Inst, 1392 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1393 Result.getNumVectors()), 1394 Builder); 1395 return true; 1396 } 1397 1398 /// Helper to linearize a matrix expression tree into a string. Currently 1399 /// matrix expressions are linarized by starting at an expression leaf and 1400 /// linearizing bottom up. 1401 struct ExprLinearizer { 1402 unsigned LengthToBreak = 100; 1403 std::string Str; 1404 raw_string_ostream Stream; 1405 unsigned LineLength = 0; 1406 const DataLayout &DL; 1407 1408 /// Mapping from instructions to matrixes. It is used to identify 1409 /// matrix instructions. 1410 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1411 1412 /// Mapping from values to the leaves of all expressions that the value is 1413 /// part of. 1414 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1415 1416 /// Set of matrix expressions in the scope of a given DISubprogram. 1417 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1418 1419 /// Leaf node of the expression to linearize. 1420 Value *Leaf; 1421 1422 /// Used to keep track of sub-expressions that get reused while linearizing 1423 /// the expression. Re-used sub-expressions are marked as (reused). 1424 SmallPtrSet<Value *, 8> ReusedExprs; 1425 1426 ExprLinearizer(const DataLayout &DL, 1427 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1428 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1429 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1430 Value *Leaf) 1431 : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1432 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1433 1434 void indent(unsigned N) { 1435 LineLength += N; 1436 for (unsigned i = 0; i < N; i++) 1437 Stream << " "; 1438 } 1439 1440 void lineBreak() { 1441 Stream << "\n"; 1442 LineLength = 0; 1443 } 1444 1445 void maybeIndent(unsigned Indent) { 1446 if (LineLength >= LengthToBreak) 1447 lineBreak(); 1448 1449 if (LineLength == 0) 1450 indent(Indent); 1451 } 1452 1453 void write(StringRef S) { 1454 LineLength += S.size(); 1455 Stream << S; 1456 } 1457 1458 Value *getUnderlyingObjectThroughLoads(Value *V) { 1459 if (Value *Ptr = getPointerOperand(V)) 1460 return getUnderlyingObjectThroughLoads(Ptr); 1461 else if (V->getType()->isPointerTy()) 1462 return GetUnderlyingObject(V, DL); 1463 return V; 1464 } 1465 1466 /// Returns true if \p V is a matrix value in the given subprogram. 1467 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1468 1469 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1470 /// \p SS. 1471 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1472 auto M = Inst2Matrix.find(V); 1473 if (M == Inst2Matrix.end()) 1474 SS << "unknown"; 1475 else { 1476 SS << M->second.getNumRows(); 1477 SS << "x"; 1478 SS << M->second.getNumColumns(); 1479 } 1480 } 1481 1482 /// Write the called function name. Handles calls to llvm.matrix.* 1483 /// specially: we write the name, followed by the dimensions of the input 1484 /// matrixes, followed by the scalar type name. 1485 void writeFnName(CallInst *CI) { 1486 if (!CI->getCalledFunction()) 1487 write("<no called fn>"); 1488 else { 1489 StringRef Name = CI->getCalledFunction()->getName(); 1490 if (!Name.startswith("llvm.matrix")) { 1491 write(Name); 1492 return; 1493 } 1494 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1495 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1496 .drop_front(StringRef("llvm.matrix.").size())); 1497 write("."); 1498 std::string Tmp = ""; 1499 raw_string_ostream SS(Tmp); 1500 1501 switch (II->getIntrinsicID()) { 1502 case Intrinsic::matrix_multiply: 1503 prettyPrintMatrixType(II->getOperand(0), SS); 1504 SS << "."; 1505 prettyPrintMatrixType(II->getOperand(1), SS); 1506 SS << "." << *II->getType()->getScalarType(); 1507 break; 1508 case Intrinsic::matrix_transpose: 1509 prettyPrintMatrixType(II->getOperand(0), SS); 1510 SS << "." << *II->getType()->getScalarType(); 1511 break; 1512 case Intrinsic::matrix_columnwise_load: 1513 prettyPrintMatrixType(II, SS); 1514 SS << "." << *II->getType()->getScalarType(); 1515 break; 1516 case Intrinsic::matrix_columnwise_store: 1517 prettyPrintMatrixType(II->getOperand(0), SS); 1518 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1519 break; 1520 default: 1521 llvm_unreachable("Unhandled case"); 1522 } 1523 SS.flush(); 1524 write(Tmp); 1525 } 1526 } 1527 1528 unsigned getNumShapeArgs(CallInst *CI) const { 1529 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1530 switch (II->getIntrinsicID()) { 1531 case Intrinsic::matrix_multiply: 1532 return 3; 1533 case Intrinsic::matrix_transpose: 1534 case Intrinsic::matrix_columnwise_load: 1535 case Intrinsic::matrix_columnwise_store: 1536 return 2; 1537 default: 1538 return 0; 1539 } 1540 } 1541 return 0; 1542 } 1543 1544 /// Special printing for values: for pointers, we print if they refer to an 1545 /// (function) external address or a stack address, for other values we 1546 /// either print the constant or "scalar"/"matrix" for other values. 1547 void write(Value *V) { 1548 V = getUnderlyingObjectThroughLoads(V); 1549 if (V->getType()->isPointerTy()) { 1550 if (isa<AllocaInst>(V)) { 1551 Stream << "stack addr"; 1552 LineLength += StringRef("stack addr").size(); 1553 } else { 1554 Stream << "addr"; 1555 LineLength += StringRef("addr").size(); 1556 } 1557 if (!V->getName().empty()) { 1558 Stream << " %" << V->getName() << ""; 1559 LineLength += V->getName().size() + 2; 1560 } 1561 return; 1562 } 1563 1564 std::string Tmp; 1565 raw_string_ostream TmpStream(Tmp); 1566 1567 if (auto *CI = dyn_cast<ConstantInt>(V)) 1568 TmpStream << CI->getValue(); 1569 else if (isa<Constant>(V)) 1570 TmpStream << "constant"; 1571 else { 1572 if (isMatrix(V)) 1573 TmpStream << "matrix"; 1574 else 1575 TmpStream << "scalar"; 1576 } 1577 TmpStream.flush(); 1578 Tmp = std::string(StringRef(Tmp).trim()); 1579 LineLength += Tmp.size(); 1580 Stream << Tmp; 1581 } 1582 1583 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1584 /// Expressions that are re-used multiple times are prefixed with (reused) 1585 /// at the re-used root instruction. 1586 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1587 bool ParentShared) { 1588 auto *I = cast<Instruction>(Expr); 1589 maybeIndent(Indent); 1590 SmallVector<Value *, 8> Ops; 1591 1592 // Is Expr shared with other expression leaves? 1593 bool ExprShared = false; 1594 1595 // Deal with shared subtrees. Mark them as shared, if required. 1596 if (!ParentShared) { 1597 auto SI = Shared.find(Expr); 1598 assert(SI != Shared.end() && SI->second.find(Leaf) != SI->second.end()); 1599 1600 for (Value *S : SI->second) { 1601 if (S == Leaf) 1602 continue; 1603 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1604 write("shared with remark at line " + std::to_string(DL.getLine()) + 1605 " column " + std::to_string(DL.getCol()) + " ("); 1606 } 1607 ExprShared = SI->second.size() > 1; 1608 } 1609 1610 bool Reused = !ReusedExprs.insert(Expr).second; 1611 if (Reused && !ParentReused) 1612 write("(reused) "); 1613 1614 if (auto *CI = dyn_cast<CallInst>(I)) { 1615 writeFnName(CI); 1616 1617 Ops.append(CallSite(CI).arg_begin(), 1618 CallSite(CI).arg_end() - getNumShapeArgs(CI)); 1619 } else if (isa<BitCastInst>(Expr)) { 1620 // Special case bitcasts, which are used to materialize matrixes from 1621 // non-matrix ops. 1622 write("matrix"); 1623 return; 1624 } else { 1625 Ops.append(I->value_op_begin(), I->value_op_end()); 1626 write(std::string(I->getOpcodeName())); 1627 } 1628 1629 write(std::string("(")); 1630 1631 unsigned NumOpsToBreak = 1; 1632 if (match(Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) 1633 NumOpsToBreak = 2; 1634 1635 for (Value *Op : Ops) { 1636 if (Ops.size() > NumOpsToBreak) 1637 lineBreak(); 1638 1639 maybeIndent(Indent + 1); 1640 if (isMatrix(Op)) 1641 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1642 else 1643 write(Op); 1644 if (Op != Ops.back()) 1645 write(", "); 1646 } 1647 1648 write(")"); 1649 } 1650 1651 const std::string &getResult() { 1652 Stream.flush(); 1653 return Str; 1654 } 1655 }; 1656 1657 /// Generate remarks for matrix operations in a function. To generate remarks 1658 /// for matrix expressions, the following approach is used: 1659 /// 1. Use the inlined-at debug information to group matrix operations to the 1660 /// DISubprograms they are contained in. 1661 /// 2. Collect leaves of matrix expressions (done in 1662 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1663 // mapping. Leaves are lowered matrix instructions without other matrix 1664 // users (like stores) in the current subprogram. 1665 /// 3. For each leaf, create a remark containing a linearizied version of the 1666 /// matrix expression. The expression is linearized by a recursive 1667 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1668 /// that multiple leaves can share sub-expressions. Shared subexpressions 1669 /// are explicitly marked as shared(). 1670 struct RemarkGenerator { 1671 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1672 OptimizationRemarkEmitter &ORE; 1673 Function &Func; 1674 const DataLayout &DL; 1675 1676 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 1677 OptimizationRemarkEmitter &ORE, Function &Func) 1678 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 1679 DL(Func.getParent()->getDataLayout()) {} 1680 1681 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1682 /// instructions in Inst2Matrix returning void or without any users in 1683 /// \p ExprsInSubprogram. Currently that should only include stores. 1684 SmallVector<Value *, 4> 1685 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1686 SmallVector<Value *, 4> Leaves; 1687 for (auto *Expr : ExprsInSubprogram) 1688 if (Expr->getType()->isVoidTy() || 1689 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1690 return ExprsInSubprogram.count(U); 1691 })) 1692 Leaves.push_back(Expr); 1693 return Leaves; 1694 } 1695 1696 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1697 /// to all visited expressions in \p Shared. Limit the matrix operations to 1698 /// the ones in \p ExprsInSubprogram. 1699 void collectSharedInfo(Value *Leaf, Value *V, 1700 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1701 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1702 1703 if (!ExprsInSubprogram.count(V)) 1704 return; 1705 1706 auto I = Shared.insert({V, {}}); 1707 I.first->second.insert(Leaf); 1708 1709 for (Value *Op : cast<Instruction>(V)->operand_values()) 1710 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1711 return; 1712 } 1713 1714 /// Calculate the number of exclusive and shared op counts for expression 1715 /// starting at \p V. Expressions used multiple times are counted once. 1716 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1717 std::pair<OpInfoTy, OpInfoTy> 1718 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1719 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1720 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1721 if (!ExprsInSubprogram.count(Root)) 1722 return {}; 1723 1724 // Already counted this expression. Stop. 1725 if (!ReusedExprs.insert(Root).second) 1726 return {}; 1727 1728 OpInfoTy SharedCount; 1729 OpInfoTy Count; 1730 1731 auto I = Shared.find(Root); 1732 auto CM = Inst2Matrix.find(Root); 1733 if (I->second.size() == 1) 1734 Count = CM->second.getOpInfo(); 1735 else 1736 SharedCount = CM->second.getOpInfo(); 1737 1738 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1739 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1740 Count += C.first; 1741 SharedCount += C.second; 1742 } 1743 return {Count, SharedCount}; 1744 } 1745 1746 void emitRemarks() { 1747 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1748 return; 1749 1750 // Map matrix operations to their containting subprograms, by traversing 1751 // the inlinedAt chain. If the function does not have a DISubprogram, we 1752 // only map them to the containing function. 1753 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1754 for (auto &KV : Inst2Matrix) { 1755 if (Func.getSubprogram()) { 1756 auto *I = cast<Instruction>(KV.first); 1757 DILocation *Context = I->getDebugLoc(); 1758 while (Context) { 1759 auto I = 1760 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1761 I.first->second.push_back(KV.first); 1762 Context = DebugLoc(Context).getInlinedAt(); 1763 } 1764 } else { 1765 auto I = Subprog2Exprs.insert({nullptr, {}}); 1766 I.first->second.push_back(KV.first); 1767 } 1768 } 1769 for (auto &KV : Subprog2Exprs) { 1770 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1771 KV.second.end()); 1772 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1773 1774 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1775 for (Value *Leaf : Leaves) 1776 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1777 1778 // Generate remarks for each leaf. 1779 for (auto *L : Leaves) { 1780 1781 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1782 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1783 while (Context) { 1784 if (getSubprogram(Context->getScope()) == KV.first) { 1785 Loc = Context; 1786 break; 1787 } 1788 Context = DebugLoc(Context).getInlinedAt(); 1789 } 1790 1791 SmallPtrSet<Value *, 8> ReusedExprs; 1792 OpInfoTy Counts, SharedCounts; 1793 std::tie(Counts, SharedCounts) = 1794 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1795 1796 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 1797 cast<Instruction>(L)->getParent()); 1798 1799 Rem << "Lowered with "; 1800 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1801 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1802 << ore::NV("NumComputeOps", Counts.NumComputeOps) 1803 << " compute ops"; 1804 1805 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1806 SharedCounts.NumComputeOps > 0) { 1807 Rem << ",\nadditionally " 1808 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1809 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1810 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1811 << " compute ops" 1812 << " are shared with other expressions"; 1813 } 1814 1815 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 1816 ORE.emit(Rem); 1817 } 1818 } 1819 } 1820 1821 std::string 1822 linearize(Value *L, 1823 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1824 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1825 const DataLayout &DL) { 1826 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 1827 Lin.linearizeExpr(L, 0, false, false); 1828 return Lin.getResult(); 1829 } 1830 }; 1831 }; 1832 } // namespace 1833 1834 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1835 FunctionAnalysisManager &AM) { 1836 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1837 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1838 auto &AA = AM.getResult<AAManager>(F); 1839 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 1840 auto &LI = AM.getResult<LoopAnalysis>(F); 1841 1842 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1843 if (LMT.Visit()) { 1844 PreservedAnalyses PA; 1845 PA.preserveSet<CFGAnalyses>(); 1846 return PA; 1847 } 1848 return PreservedAnalyses::all(); 1849 } 1850 1851 namespace { 1852 1853 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1854 public: 1855 static char ID; 1856 1857 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1858 initializeLowerMatrixIntrinsicsLegacyPassPass( 1859 *PassRegistry::getPassRegistry()); 1860 } 1861 1862 bool runOnFunction(Function &F) override { 1863 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1864 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1865 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 1866 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1867 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1868 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1869 bool C = LMT.Visit(); 1870 return C; 1871 } 1872 1873 void getAnalysisUsage(AnalysisUsage &AU) const override { 1874 AU.addRequired<TargetTransformInfoWrapperPass>(); 1875 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1876 AU.addRequired<AAResultsWrapperPass>(); 1877 AU.addRequired<DominatorTreeWrapperPass>(); 1878 AU.addPreserved<DominatorTreeWrapperPass>(); 1879 AU.addRequired<LoopInfoWrapperPass>(); 1880 AU.addPreserved<LoopInfoWrapperPass>(); 1881 } 1882 }; 1883 } // namespace 1884 1885 static const char pass_name[] = "Lower the matrix intrinsics"; 1886 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1887 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1888 false, false) 1889 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1890 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 1891 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 1892 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 1893 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1894 false, false) 1895 1896 Pass *llvm::createLowerMatrixIntrinsicsPass() { 1897 return new LowerMatrixIntrinsicsLegacyPass(); 1898 } 1899