1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Implement multiply & add fusion 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 17 #include "llvm/ADT/GraphTraits.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 21 #include "llvm/Analysis/TargetTransformInfo.h" 22 #include "llvm/Analysis/ValueTracking.h" 23 #include "llvm/Analysis/VectorUtils.h" 24 #include "llvm/IR/CFG.h" 25 #include "llvm/IR/DataLayout.h" 26 #include "llvm/IR/DebugInfoMetadata.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/IntrinsicInst.h" 31 #include "llvm/IR/PatternMatch.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Transforms/Scalar.h" 36 37 using namespace llvm; 38 using namespace PatternMatch; 39 40 #define DEBUG_TYPE "lower-matrix-intrinsics" 41 42 static cl::opt<bool> EnableShapePropagation( 43 "matrix-propagate-shape", cl::init(true), cl::Hidden, 44 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 45 "instructions.")); 46 47 static cl::opt<bool> AllowContractEnabled( 48 "matrix-allow-contract", cl::init(false), cl::Hidden, 49 cl::desc("Allow the use of FMAs if available and profitable. This may " 50 "result in different results, due to less rounding error.")); 51 52 /// Helper function to either return Scope, if it is a subprogram or the 53 /// attached subprogram for a local scope. 54 static DISubprogram *getSubprogram(DIScope *Scope) { 55 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 56 return Subprogram; 57 return cast<DILocalScope>(Scope)->getSubprogram(); 58 } 59 60 namespace { 61 62 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 63 // the start address of column \p Col with type (\p EltType x \p NumRows) 64 // assuming \p Stride elements between start two consecutive columns. 65 // \p Stride must be >= \p NumRows. 66 // 67 // Consider a 4x4 matrix like below 68 // 69 // 0 1 2 3 70 // 0 v_0_0 v_0_1 v_0_2 v_0_3 71 // 1 v_1_0 v_1_1 v_1_2 v_1_3 72 // 2 v_2_0 v_2_1 v_2_2 v_2_3 73 // 3 v_3_0 v_3_1 v_3_2 v_3_3 74 75 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 76 // we need a pointer to the first element of the submatrix as base pointer. 77 // Then we can use computeColumnAddr to compute the addresses for the columns 78 // of the sub-matrix. 79 // 80 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 81 // -> just returns Base 82 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 83 // -> returns Base + (1 * 4) 84 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 85 // -> returns Base + (2 * 4) 86 // 87 // The graphic below illustrates the number of elements in a column (marked 88 // with |) and the number of skipped elements (marked with }). 89 // 90 // v_0_0 v_0_1 {v_0_2 {v_0_3 91 // Base Col 1 Col 2 92 // | | | 93 // v_1_0 |v_1_1 |v_1_2 |v_1_3 94 // v_2_0 |v_2_1 |v_2_2 |v_2_3 95 // v_3_0 {v_3_1 {v_3_2 v_3_3 96 // 97 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 98 unsigned NumRows, Type *EltType, 99 IRBuilder<> &Builder) { 100 101 assert((!isa<ConstantInt>(Stride) || 102 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 103 "Stride must be >= the number of rows."); 104 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 105 106 // Compute the start of the column with index Col as Col * Stride. 107 Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); 108 109 // Get pointer to the start of the selected column. Skip GEP creation, 110 // if we select column 0. 111 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 112 ColumnStart = BasePtr; 113 else 114 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); 115 116 // Cast elementwise column start pointer to a pointer to a column 117 // (EltType x NumRows)*. 118 Type *ColumnType = VectorType::get(EltType, NumRows); 119 Type *ColumnPtrType = PointerType::get(ColumnType, AS); 120 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); 121 } 122 123 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 124 /// 125 /// Currently, the lowering for each matrix intrinsic is done as follows: 126 /// 1. Propagate the shape information from intrinsics to connected 127 /// instructions. 128 /// 2. Lower instructions with shape information. 129 /// 2.1. Get column vectors for each argument. If we already lowered the 130 /// definition of an argument, use the produced column vectors directly. 131 /// If not, split the operand vector containing an embedded matrix into 132 /// a set of column vectors, 133 /// 2.2. Lower the instruction in terms of columnwise operations, which yields 134 /// a set of column vectors containing result matrix. Note that we lower 135 /// all instructions that have shape information. Besides the intrinsics, 136 /// this includes stores for example. 137 /// 2.3. Update uses of the lowered instruction. If we have shape information 138 /// for a user, there is nothing to do, as we will look up the result 139 /// column matrix when lowering the user. For other uses, we embed the 140 /// result matrix in a flat vector and update the use. 141 /// 2.4. Cache the result column matrix for the instruction we lowered 142 /// 3. After we lowered all instructions in a function, remove the now 143 /// obsolete instructions. 144 /// 145 class LowerMatrixIntrinsics { 146 Function &Func; 147 const DataLayout &DL; 148 const TargetTransformInfo &TTI; 149 OptimizationRemarkEmitter &ORE; 150 151 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 152 struct OpInfoTy { 153 /// Number of stores emitted to generate this matrix. 154 unsigned NumStores = 0; 155 /// Number of loads emitted to generate this matrix. 156 unsigned NumLoads = 0; 157 /// Number of compute operations emitted to generate this matrix. 158 unsigned NumComputeOps = 0; 159 160 OpInfoTy &operator+=(const OpInfoTy &RHS) { 161 NumStores += RHS.NumStores; 162 NumLoads += RHS.NumLoads; 163 NumComputeOps += RHS.NumComputeOps; 164 return *this; 165 } 166 }; 167 168 /// Wrapper class representing a matrix as a set of vectors, either in row or 169 /// column major layout. All vectors must have the same vector type. 170 class MatrixTy { 171 SmallVector<Value *, 16> Vectors; 172 173 OpInfoTy OpInfo; 174 175 bool IsColumnMajor = true; 176 177 public: 178 MatrixTy() : Vectors() {} 179 MatrixTy(ArrayRef<Value *> Vectors) 180 : Vectors(Vectors.begin(), Vectors.end()) {} 181 182 Value *getVector(unsigned i) const { return Vectors[i]; } 183 Value *getColumn(unsigned i) const { 184 assert(isColumnMajor() && "only supported for column-major matrixes"); 185 return Vectors[i]; 186 } 187 188 void setColumn(unsigned i, Value *V) { Vectors[i] = V; } 189 190 Type *getElementType() { return getVectorTy()->getElementType(); } 191 192 unsigned getNumColumns() const { 193 if (isColumnMajor()) 194 return Vectors.size(); 195 else { 196 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 197 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 198 } 199 } 200 unsigned getNumRows() const { 201 if (isColumnMajor()) { 202 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 203 return cast<VectorType>(Vectors[0]->getType())->getNumElements(); 204 } else 205 return Vectors.size(); 206 } 207 208 const SmallVectorImpl<Value *> &getColumnVectors() const { return Vectors; } 209 210 SmallVectorImpl<Value *> &getColumnVectors() { return Vectors; } 211 212 void addColumn(Value *V) { Vectors.push_back(V); } 213 214 VectorType *getColumnTy() { 215 assert(isColumnMajor() && "only supported for column-major matrixes"); 216 return getVectorTy(); 217 } 218 219 VectorType *getVectorTy() { 220 return cast<VectorType>(Vectors[0]->getType()); 221 } 222 223 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 224 return make_range(Vectors.begin(), Vectors.end()); 225 } 226 227 /// Embed the columns of the matrix into a flat vector by concatenating 228 /// them. 229 Value *embedInVector(IRBuilder<> &Builder) const { 230 return Vectors.size() == 1 ? Vectors[0] 231 : concatenateVectors(Builder, Vectors); 232 } 233 234 MatrixTy &addNumLoads(unsigned N) { 235 OpInfo.NumLoads += N; 236 return *this; 237 } 238 239 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 240 241 MatrixTy &addNumStores(unsigned N) { 242 OpInfo.NumStores += N; 243 return *this; 244 } 245 246 MatrixTy &addNumComputeOps(unsigned N) { 247 OpInfo.NumComputeOps += N; 248 return *this; 249 } 250 251 unsigned getNumStores() const { return OpInfo.NumStores; } 252 unsigned getNumLoads() const { return OpInfo.NumLoads; } 253 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 254 255 const OpInfoTy &getOpInfo() const { return OpInfo; } 256 257 bool isColumnMajor() const { return IsColumnMajor; } 258 }; 259 260 struct ShapeInfo { 261 unsigned NumRows; 262 unsigned NumColumns; 263 264 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 265 : NumRows(NumRows), NumColumns(NumColumns) {} 266 267 ShapeInfo(Value *NumRows, Value *NumColumns) 268 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), 269 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} 270 271 bool operator==(const ShapeInfo &other) { 272 return NumRows == other.NumRows && NumColumns == other.NumColumns; 273 } 274 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 275 276 /// Returns true if shape-information is defined, meaning both dimensions 277 /// are != 0. 278 operator bool() const { 279 assert(NumRows == 0 || NumColumns != 0); 280 return NumRows != 0; 281 } 282 }; 283 284 /// Maps instructions to their shape information. The shape information 285 /// describes the shape to be used while lowering. This matches the shape of 286 /// the result value of the instruction, with the only exceptions being store 287 /// instructions and the matrix_columnwise_store intrinsics. For those, the 288 /// shape information indicates that those instructions should be lowered 289 /// using shape information as well. 290 DenseMap<Value *, ShapeInfo> ShapeMap; 291 292 /// List of instructions to remove. While lowering, we are not replacing all 293 /// users of a lowered instruction, if shape information is available and 294 /// those need to be removed after we finished lowering. 295 SmallVector<Instruction *, 16> ToRemove; 296 297 /// Map from instructions to their produced column matrix. 298 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 299 300 public: 301 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 302 OptimizationRemarkEmitter &ORE) 303 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {} 304 305 unsigned getNumOps(Type *VT) { 306 assert(isa<VectorType>(VT) && "Expected vector type"); 307 return getNumOps(VT->getScalarType(), 308 cast<VectorType>(VT)->getNumElements()); 309 } 310 311 // 312 /// Return the estimated number of vector ops required for an operation on 313 /// \p VT * N. 314 unsigned getNumOps(Type *ST, unsigned N) { 315 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 316 double(TTI.getRegisterBitWidth(true))); 317 } 318 319 /// Return the set of column vectors that a matrix value is lowered to. 320 /// 321 /// If we lowered \p MatrixVal, just return the cache result column matrix. 322 /// Otherwie split the flat vector \p MatrixVal containing a matrix with 323 /// shape \p SI into column vectors. 324 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 325 IRBuilder<> &Builder) { 326 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 327 assert(VType && "MatrixVal must be a vector type"); 328 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 329 "The vector size must match the number of matrix elements"); 330 331 // Check if we lowered MatrixVal using shape information. In that case, 332 // return the existing column matrix, if it matches the requested shape 333 // information. If there is a mis-match, embed the result in a flat 334 // vector and split it later. 335 auto Found = Inst2ColumnMatrix.find(MatrixVal); 336 if (Found != Inst2ColumnMatrix.end()) { 337 MatrixTy &M = Found->second; 338 // Return the found matrix, if its shape matches the requested shape 339 // information 340 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 341 return M; 342 343 MatrixVal = M.embedInVector(Builder); 344 } 345 346 // Otherwise split MatrixVal. 347 SmallVector<Value *, 16> SplitVecs; 348 Value *Undef = UndefValue::get(VType); 349 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 350 MaskStart += SI.NumRows) { 351 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 352 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 353 SplitVecs.push_back(V); 354 } 355 356 return {SplitVecs}; 357 } 358 359 /// If \p V already has a known shape return false. Otherwise set the shape 360 /// for instructions that support it. 361 bool setShapeInfo(Value *V, ShapeInfo Shape) { 362 assert(Shape && "Shape not set"); 363 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 364 return false; 365 366 auto SIter = ShapeMap.find(V); 367 if (SIter != ShapeMap.end()) { 368 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 369 << SIter->second.NumRows << " " 370 << SIter->second.NumColumns << " for " << *V << "\n"); 371 return false; 372 } 373 374 ShapeMap.insert({V, Shape}); 375 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 376 << " for " << *V << "\n"); 377 return true; 378 } 379 380 bool isUniformShape(Value *V) { 381 Instruction *I = dyn_cast<Instruction>(V); 382 if (!I) 383 return true; 384 385 switch (I->getOpcode()) { 386 case Instruction::FAdd: 387 case Instruction::FSub: 388 case Instruction::FMul: // Scalar multiply. 389 case Instruction::Add: 390 case Instruction::Mul: 391 case Instruction::Sub: 392 return true; 393 default: 394 return false; 395 } 396 } 397 398 /// Returns true if shape information can be used for \p V. The supported 399 /// instructions must match the instructions that can be lowered by this pass. 400 bool supportsShapeInfo(Value *V) { 401 Instruction *Inst = dyn_cast<Instruction>(V); 402 if (!Inst) 403 return false; 404 405 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 406 if (II) 407 switch (II->getIntrinsicID()) { 408 case Intrinsic::matrix_multiply: 409 case Intrinsic::matrix_transpose: 410 case Intrinsic::matrix_columnwise_load: 411 case Intrinsic::matrix_columnwise_store: 412 return true; 413 default: 414 return false; 415 } 416 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 417 } 418 419 /// Propagate the shape information of instructions to their users. 420 /// The work list contains instructions for which we can compute the shape, 421 /// either based on the information provided by matrix intrinsics or known 422 /// shapes of operands. 423 SmallVector<Instruction *, 32> 424 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 425 SmallVector<Instruction *, 32> NewWorkList; 426 // Pop an element for which we guaranteed to have at least one of the 427 // operand shapes. Add the shape for this and then add users to the work 428 // list. 429 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 430 while (!WorkList.empty()) { 431 Instruction *Inst = WorkList.back(); 432 WorkList.pop_back(); 433 434 // New entry, set the value and insert operands 435 bool Propagate = false; 436 437 Value *MatrixA; 438 Value *MatrixB; 439 Value *M; 440 Value *N; 441 Value *K; 442 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 443 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 444 m_Value(N), m_Value(K)))) { 445 Propagate = setShapeInfo(Inst, {M, K}); 446 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 447 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 448 // Flip dimensions. 449 Propagate = setShapeInfo(Inst, {N, M}); 450 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 451 m_Value(MatrixA), m_Value(), m_Value(), 452 m_Value(M), m_Value(N)))) { 453 Propagate = setShapeInfo(Inst, {N, M}); 454 } else if (match(Inst, 455 m_Intrinsic<Intrinsic::matrix_columnwise_load>( 456 m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 457 Propagate = setShapeInfo(Inst, {M, N}); 458 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 459 auto OpShape = ShapeMap.find(MatrixA); 460 if (OpShape != ShapeMap.end()) 461 setShapeInfo(Inst, OpShape->second); 462 continue; 463 } else if (isUniformShape(Inst)) { 464 // Find the first operand that has a known shape and use that. 465 for (auto &Op : Inst->operands()) { 466 auto OpShape = ShapeMap.find(Op.get()); 467 if (OpShape != ShapeMap.end()) { 468 Propagate |= setShapeInfo(Inst, OpShape->second); 469 break; 470 } 471 } 472 } 473 474 if (Propagate) { 475 NewWorkList.push_back(Inst); 476 for (auto *User : Inst->users()) 477 if (ShapeMap.count(User) == 0) 478 WorkList.push_back(cast<Instruction>(User)); 479 } 480 } 481 482 return NewWorkList; 483 } 484 485 /// Propagate the shape to operands of instructions with shape information. 486 /// \p Worklist contains the instruction for which we already know the shape. 487 SmallVector<Instruction *, 32> 488 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 489 SmallVector<Instruction *, 32> NewWorkList; 490 491 auto pushInstruction = [](Value *V, 492 SmallVectorImpl<Instruction *> &WorkList) { 493 Instruction *I = dyn_cast<Instruction>(V); 494 if (I) 495 WorkList.push_back(I); 496 }; 497 // Pop an element with known shape. Traverse the operands, if their shape 498 // derives from the result shape and is unknown, add it and add them to the 499 // worklist. 500 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 501 while (!WorkList.empty()) { 502 Value *V = WorkList.back(); 503 WorkList.pop_back(); 504 505 size_t BeforeProcessingV = WorkList.size(); 506 if (!isa<Instruction>(V)) 507 continue; 508 509 Value *MatrixA; 510 Value *MatrixB; 511 Value *M; 512 Value *N; 513 Value *K; 514 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 515 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 516 m_Value(N), m_Value(K)))) { 517 if (setShapeInfo(MatrixA, {M, N})) 518 pushInstruction(MatrixA, WorkList); 519 520 if (setShapeInfo(MatrixB, {N, K})) 521 pushInstruction(MatrixB, WorkList); 522 523 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 524 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 525 // Flip dimensions. 526 if (setShapeInfo(MatrixA, {M, N})) 527 pushInstruction(MatrixA, WorkList); 528 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 529 m_Value(MatrixA), m_Value(), m_Value(), 530 m_Value(M), m_Value(N)))) { 531 if (setShapeInfo(MatrixA, {M, N})) { 532 pushInstruction(MatrixA, WorkList); 533 } 534 } else if (isa<LoadInst>(V) || 535 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { 536 // Nothing to do, no matrix input. 537 } else if (isa<StoreInst>(V)) { 538 // Nothing to do. We forward-propagated to this so we would just 539 // backward propagate to an instruction with an already known shape. 540 } else if (isUniformShape(V)) { 541 // Propagate to all operands. 542 ShapeInfo Shape = ShapeMap[V]; 543 for (Use &U : cast<Instruction>(V)->operands()) { 544 if (setShapeInfo(U.get(), Shape)) 545 pushInstruction(U.get(), WorkList); 546 } 547 } 548 // After we discovered new shape info for new instructions in the 549 // worklist, we use their users as seeds for the next round of forward 550 // propagation. 551 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 552 for (User *U : WorkList[I]->users()) 553 if (isa<Instruction>(U) && V != U) 554 NewWorkList.push_back(cast<Instruction>(U)); 555 } 556 return NewWorkList; 557 } 558 559 bool Visit() { 560 if (EnableShapePropagation) { 561 SmallVector<Instruction *, 32> WorkList; 562 563 // Initially only the shape of matrix intrinsics is known. 564 // Initialize the work list with ops carrying shape information. 565 for (BasicBlock &BB : Func) 566 for (Instruction &Inst : BB) { 567 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 568 if (!II) 569 continue; 570 571 switch (II->getIntrinsicID()) { 572 case Intrinsic::matrix_multiply: 573 case Intrinsic::matrix_transpose: 574 case Intrinsic::matrix_columnwise_load: 575 case Intrinsic::matrix_columnwise_store: 576 WorkList.push_back(&Inst); 577 break; 578 default: 579 break; 580 } 581 } 582 // Propagate shapes until nothing changes any longer. 583 while (!WorkList.empty()) { 584 WorkList = propagateShapeForward(WorkList); 585 WorkList = propagateShapeBackward(WorkList); 586 } 587 } 588 589 ReversePostOrderTraversal<Function *> RPOT(&Func); 590 bool Changed = false; 591 for (auto *BB : RPOT) { 592 for (Instruction &Inst : make_early_inc_range(*BB)) { 593 IRBuilder<> Builder(&Inst); 594 595 if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 596 Changed |= VisitCallInst(CInst); 597 598 Value *Op1; 599 Value *Op2; 600 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) 601 Changed |= VisitBinaryOperator(BinOp); 602 if (match(&Inst, m_Load(m_Value(Op1)))) 603 Changed |= VisitLoad(&Inst, Op1, Builder); 604 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 605 Changed |= VisitStore(&Inst, Op1, Op2, Builder); 606 } 607 } 608 609 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); 610 RemarkGen.emitRemarks(); 611 612 for (Instruction *Inst : reverse(ToRemove)) 613 Inst->eraseFromParent(); 614 615 return Changed; 616 } 617 618 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 619 IRBuilder<> &Builder) { 620 return Builder.CreateAlignedLoad( 621 ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load"); 622 } 623 624 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 625 Type *EltType, IRBuilder<> &Builder) { 626 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, 627 DL.getABITypeAlign(EltType)); 628 } 629 630 631 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 632 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 633 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 634 Type *EltPtrType = PointerType::get(EltType, AS); 635 return Builder.CreatePointerCast(BasePtr, EltPtrType); 636 } 637 638 /// Replace intrinsic calls 639 bool VisitCallInst(CallInst *Inst) { 640 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 641 return false; 642 643 switch (Inst->getCalledFunction()->getIntrinsicID()) { 644 case Intrinsic::matrix_multiply: 645 LowerMultiply(Inst); 646 break; 647 case Intrinsic::matrix_transpose: 648 LowerTranspose(Inst); 649 break; 650 case Intrinsic::matrix_columnwise_load: 651 LowerColumnwiseLoad(Inst); 652 break; 653 case Intrinsic::matrix_columnwise_store: 654 LowerColumnwiseStore(Inst); 655 break; 656 default: 657 return false; 658 } 659 return true; 660 } 661 662 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 663 /// columns. 664 MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, ShapeInfo Shape, 665 IRBuilder<> &Builder) { 666 auto VType = cast<VectorType>(Ty); 667 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 668 MatrixTy Result; 669 // Distance between start of one column and the start of the next 670 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 671 Value *GEP = 672 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 673 VType->getElementType(), Builder); 674 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 675 Result.addColumn(Column); 676 } 677 return Result.addNumLoads(getNumOps(Result.getColumnTy()) * 678 Result.getNumColumns()); 679 } 680 681 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 682 /// starting at \p MatrixPtr[I][J]. 683 MatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I, 684 unsigned J, ShapeInfo ResultShape, Type *EltTy, 685 IRBuilder<> &Builder) { 686 687 Value *Offset = Builder.CreateAdd( 688 Builder.CreateMul(Builder.getInt32(J), 689 Builder.getInt32(MatrixShape.NumRows)), 690 Builder.getInt32(I)); 691 692 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 693 Value *EltPtr = 694 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 695 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 696 Type *TileTy = 697 VectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns); 698 Type *TilePtrTy = PointerType::get(TileTy, AS); 699 Value *TilePtr = 700 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 701 702 return loadMatrix(TileTy, TilePtr, Builder.getInt32(ResultShape.NumRows), 703 ResultShape, Builder); 704 } 705 706 /// Lower a load instruction with shape information. 707 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, 708 ShapeInfo Shape) { 709 IRBuilder<> Builder(Inst); 710 finalizeLowering(Inst, 711 loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder), 712 Builder); 713 } 714 715 /// Lowers llvm.matrix.columnwise.load. 716 /// 717 /// The intrinsic loads a matrix from memory using a stride between columns. 718 void LowerColumnwiseLoad(CallInst *Inst) { 719 Value *Ptr = Inst->getArgOperand(0); 720 Value *Stride = Inst->getArgOperand(1); 721 LowerLoad(Inst, Ptr, Stride, 722 {Inst->getArgOperand(2), Inst->getArgOperand(3)}); 723 } 724 725 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 726 /// MatrixPtr[I][J]. 727 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 728 ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy, 729 IRBuilder<> &Builder) { 730 Value *Offset = Builder.CreateAdd( 731 Builder.CreateMul(Builder.getInt32(J), 732 Builder.getInt32(MatrixShape.NumRows)), 733 Builder.getInt32(I)); 734 735 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 736 Value *EltPtr = 737 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 738 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 739 Type *TileTy = VectorType::get(EltTy, StoreVal.getNumRows() * 740 StoreVal.getNumColumns()); 741 Type *TilePtrTy = PointerType::get(TileTy, AS); 742 Value *TilePtr = 743 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 744 745 storeMatrix(TileTy, StoreVal, TilePtr, 746 Builder.getInt32(StoreVal.getNumRows()), Builder); 747 } 748 749 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 750 /// columns. 751 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride, 752 IRBuilder<> &Builder) { 753 auto VType = cast<VectorType>(Ty); 754 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 755 for (auto C : enumerate(StoreVal.columns())) { 756 Value *GEP = computeColumnAddr(EltPtr, Builder.getInt32(C.index()), 757 Stride, StoreVal.getNumRows(), 758 VType->getElementType(), Builder); 759 createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 760 } 761 return MatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) * 762 StoreVal.getNumColumns()); 763 } 764 765 /// Lower a store instruction with shape information. 766 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 767 ShapeInfo Shape) { 768 IRBuilder<> Builder(Inst); 769 auto StoreVal = getMatrix(Matrix, Shape, Builder); 770 finalizeLowering( 771 Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder), 772 Builder); 773 } 774 775 /// Lowers llvm.matrix.columnwise.store. 776 /// 777 /// The intrinsic store a matrix back memory using a stride between columns. 778 void LowerColumnwiseStore(CallInst *Inst) { 779 Value *Matrix = Inst->getArgOperand(0); 780 Value *Ptr = Inst->getArgOperand(1); 781 Value *Stride = Inst->getArgOperand(2); 782 LowerStore(Inst, Matrix, Ptr, Stride, 783 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 784 } 785 786 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 787 /// the matrix \p LM represented as a vector of column vectors. 788 Value *extractVector(const MatrixTy &LM, unsigned I, unsigned J, 789 unsigned NumElts, IRBuilder<> &Builder) { 790 Value *Col = LM.getColumn(J); 791 Value *Undef = UndefValue::get(Col->getType()); 792 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 793 return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 794 } 795 796 // Set elements I..I+NumElts-1 to Block 797 Value *insertVector(Value *Col, unsigned I, Value *Block, 798 IRBuilder<> &Builder) { 799 800 // First, bring Block to the same size as Col 801 unsigned BlockNumElts = 802 cast<VectorType>(Block->getType())->getNumElements(); 803 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 804 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 805 806 Value *ExtendMask = 807 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 808 Value *Undef = UndefValue::get(Block->getType()); 809 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 810 811 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 812 // 8, 4, 5, 6 813 SmallVector<Constant *, 16> Mask; 814 unsigned i; 815 for (i = 0; i < I; i++) 816 Mask.push_back(Builder.getInt32(i)); 817 818 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 819 for (; i < I + BlockNumElts; i++) 820 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 821 822 for (; i < VecNumElts; i++) 823 Mask.push_back(Builder.getInt32(i)); 824 825 Value *MaskVal = ConstantVector::get(Mask); 826 827 return Builder.CreateShuffleVector(Col, Block, MaskVal); 828 } 829 830 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 831 IRBuilder<> &Builder, bool AllowContraction, 832 unsigned &NumComputeOps) { 833 NumComputeOps += getNumOps(A->getType()); 834 if (!Sum) 835 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 836 837 if (UseFPOp) { 838 if (AllowContraction) { 839 // Use fmuladd for floating point operations and let the backend decide 840 // if that's profitable. 841 Function *FMulAdd = Intrinsic::getDeclaration( 842 Func.getParent(), Intrinsic::fmuladd, A->getType()); 843 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 844 } 845 NumComputeOps += getNumOps(A->getType()); 846 Value *Mul = Builder.CreateFMul(A, B); 847 return Builder.CreateFAdd(Sum, Mul); 848 } 849 850 NumComputeOps += getNumOps(A->getType()); 851 Value *Mul = Builder.CreateMul(A, B); 852 return Builder.CreateAdd(Sum, Mul); 853 } 854 855 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 856 /// users with shape information, there's nothing to do: the will use the 857 /// cached value when they are lowered. For other users, \p Matrix is 858 /// flattened and the uses are updated to use it. Also marks \p Inst for 859 /// deletion. 860 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 861 IRBuilder<> &Builder) { 862 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 863 864 ToRemove.push_back(Inst); 865 Value *Flattened = nullptr; 866 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 867 Use &U = *I++; 868 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 869 if (!Flattened) 870 Flattened = Matrix.embedInVector(Builder); 871 U.set(Flattened); 872 } 873 } 874 } 875 876 /// Compute Res += A * B for tile-sized matrices with left-associating 877 /// addition. 878 void emitChainedMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 879 const MatrixTy &B, bool AllowContraction, 880 IRBuilder<> &Builder, bool isTiled) { 881 const unsigned VF = std::max<unsigned>( 882 TTI.getRegisterBitWidth(true) / 883 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 884 1U); 885 unsigned R = Result.getNumRows(); 886 unsigned C = Result.getNumColumns(); 887 unsigned M = A.getNumColumns(); 888 889 for (unsigned J = 0; J < C; ++J) { 890 unsigned BlockSize = VF; 891 892 // If Result is zero, we don't need to accumulate in the K==0 iteration. 893 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 894 895 unsigned NumOps = 0; 896 for (unsigned I = 0; I < R; I += BlockSize) { 897 // Gradually lower the vectorization factor to cover the remainder. 898 while (I + BlockSize > R) 899 BlockSize /= 2; 900 901 Value *Sum = 902 isTiled ? extractVector(Result, I, J, BlockSize, Builder) : nullptr; 903 for (unsigned K = 0; K < M; ++K) { 904 Value *L = extractVector(A, I, K, BlockSize, Builder); 905 Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); 906 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 907 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 908 Result.getElementType()->isFloatingPointTy(), 909 Builder, AllowContraction, NumOps); 910 } 911 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 912 } 913 914 Result.addNumComputeOps(NumOps); 915 } 916 } 917 918 /// Lowers llvm.matrix.multiply. 919 void LowerMultiply(CallInst *MatMul) { 920 IRBuilder<> Builder(MatMul); 921 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 922 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 923 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 924 925 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 926 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 927 928 const unsigned R = LShape.NumRows; 929 const unsigned C = RShape.NumColumns; 930 assert(LShape.NumColumns == RShape.NumRows); 931 932 // Initialize the output 933 MatrixTy Result; 934 for (unsigned J = 0; J < C; ++J) 935 Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 936 937 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 938 MatMul->hasAllowContract()); 939 emitChainedMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); 940 finalizeLowering(MatMul, Result, Builder); 941 } 942 943 /// Lowers llvm.matrix.transpose. 944 void LowerTranspose(CallInst *Inst) { 945 MatrixTy Result; 946 IRBuilder<> Builder(Inst); 947 Value *InputVal = Inst->getArgOperand(0); 948 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 949 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 950 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 951 952 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 953 // Build a single column vector for this row. First initialize it. 954 Value *ResultColumn = UndefValue::get( 955 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 956 957 // Go through the elements of this row and insert it into the resulting 958 // column vector. 959 for (auto C : enumerate(InputMatrix.columns())) { 960 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 961 // We insert at index Column since that is the row index after the 962 // transpose. 963 ResultColumn = 964 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 965 } 966 Result.addColumn(ResultColumn); 967 } 968 969 // TODO: Improve estimate of operations needed for transposes. Currently we 970 // just count the insertelement/extractelement instructions, but do not 971 // account for later simplifications/combines. 972 finalizeLowering( 973 Inst, 974 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 975 Builder); 976 } 977 978 /// Lower load instructions, if shape information is available. 979 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { 980 auto I = ShapeMap.find(Inst); 981 if (I == ShapeMap.end()) 982 return false; 983 984 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); 985 return true; 986 } 987 988 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 989 IRBuilder<> &Builder) { 990 auto I = ShapeMap.find(StoredVal); 991 if (I == ShapeMap.end()) 992 return false; 993 994 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); 995 return true; 996 } 997 998 /// Lower binary operators, if shape information is available. 999 bool VisitBinaryOperator(BinaryOperator *Inst) { 1000 auto I = ShapeMap.find(Inst); 1001 if (I == ShapeMap.end()) 1002 return false; 1003 1004 Value *Lhs = Inst->getOperand(0); 1005 Value *Rhs = Inst->getOperand(1); 1006 1007 IRBuilder<> Builder(Inst); 1008 ShapeInfo &Shape = I->second; 1009 1010 MatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); 1011 MatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); 1012 1013 // Add each column and store the result back into the opmapping 1014 MatrixTy Result; 1015 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1016 switch (Inst->getOpcode()) { 1017 case Instruction::Add: 1018 return Builder.CreateAdd(LHS, RHS); 1019 case Instruction::Mul: 1020 return Builder.CreateMul(LHS, RHS); 1021 case Instruction::Sub: 1022 return Builder.CreateSub(LHS, RHS); 1023 case Instruction::FAdd: 1024 return Builder.CreateFAdd(LHS, RHS); 1025 case Instruction::FMul: 1026 return Builder.CreateFMul(LHS, RHS); 1027 case Instruction::FSub: 1028 return Builder.CreateFSub(LHS, RHS); 1029 default: 1030 llvm_unreachable("Unsupported binary operator for matrix"); 1031 } 1032 }; 1033 for (unsigned C = 0; C < Shape.NumColumns; ++C) 1034 Result.addColumn( 1035 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); 1036 1037 finalizeLowering(Inst, 1038 Result.addNumComputeOps(getNumOps(Result.getColumnTy()) * 1039 Result.getNumColumns()), 1040 Builder); 1041 return true; 1042 } 1043 1044 /// Helper to linearize a matrix expression tree into a string. Currently 1045 /// matrix expressions are linarized by starting at an expression leaf and 1046 /// linearizing bottom up. 1047 struct ExprLinearizer { 1048 unsigned LengthToBreak = 100; 1049 std::string Str; 1050 raw_string_ostream Stream; 1051 unsigned LineLength = 0; 1052 const DataLayout &DL; 1053 1054 /// Mapping from instructions to column matrixes. It is used to identify 1055 /// matrix instructions. 1056 const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix; 1057 1058 /// Mapping from values to the leaves of all expressions that the value is 1059 /// part of. 1060 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1061 1062 /// Set of matrix expressions in the scope of a given DISubprogram. 1063 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1064 1065 /// Leaf node of the expression to linearize. 1066 Value *Leaf; 1067 1068 /// Used to keep track of sub-expressions that get reused while linearizing 1069 /// the expression. Re-used sub-expressions are marked as (reused). 1070 SmallPtrSet<Value *, 8> ReusedExprs; 1071 1072 ExprLinearizer(const DataLayout &DL, 1073 const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix, 1074 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1075 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1076 Value *Leaf) 1077 : Str(), Stream(Str), DL(DL), Inst2ColumnMatrix(Inst2ColumnMatrix), 1078 Shared(Shared), ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1079 1080 void indent(unsigned N) { 1081 LineLength += N; 1082 for (unsigned i = 0; i < N; i++) 1083 Stream << " "; 1084 } 1085 1086 void lineBreak() { 1087 Stream << "\n"; 1088 LineLength = 0; 1089 } 1090 1091 void maybeIndent(unsigned Indent) { 1092 if (LineLength >= LengthToBreak) 1093 lineBreak(); 1094 1095 if (LineLength == 0) 1096 indent(Indent); 1097 } 1098 1099 void write(StringRef S) { 1100 LineLength += S.size(); 1101 Stream << S; 1102 } 1103 1104 Value *getUnderlyingObjectThroughLoads(Value *V) { 1105 if (Value *Ptr = getPointerOperand(V)) 1106 return getUnderlyingObjectThroughLoads(Ptr); 1107 else if (V->getType()->isPointerTy()) 1108 return GetUnderlyingObject(V, DL); 1109 return V; 1110 } 1111 1112 /// Returns true if \p V is a matrix value in the given subprogram. 1113 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1114 1115 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1116 /// \p SS. 1117 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1118 auto M = Inst2ColumnMatrix.find(V); 1119 if (M == Inst2ColumnMatrix.end()) 1120 SS << "unknown"; 1121 else { 1122 SS << M->second.getNumRows(); 1123 SS << "x"; 1124 SS << M->second.getNumColumns(); 1125 } 1126 } 1127 1128 /// Write the called function name. Handles calls to llvm.matrix.* 1129 /// specially: we write the name, followed by the dimensions of the input 1130 /// matrixes, followed by the scalar type name. 1131 void writeFnName(CallInst *CI) { 1132 if (!CI->getCalledFunction()) 1133 write("<no called fn>"); 1134 else { 1135 StringRef Name = CI->getCalledFunction()->getName(); 1136 if (!Name.startswith("llvm.matrix")) { 1137 write(Name); 1138 return; 1139 } 1140 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1141 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1142 .drop_front(StringRef("llvm.matrix.").size())); 1143 write("."); 1144 std::string Tmp = ""; 1145 raw_string_ostream SS(Tmp); 1146 1147 switch (II->getIntrinsicID()) { 1148 case Intrinsic::matrix_multiply: 1149 prettyPrintMatrixType(II->getOperand(0), SS); 1150 SS << "."; 1151 prettyPrintMatrixType(II->getOperand(1), SS); 1152 SS << "." << *II->getType()->getScalarType(); 1153 break; 1154 case Intrinsic::matrix_transpose: 1155 prettyPrintMatrixType(II->getOperand(0), SS); 1156 SS << "." << *II->getType()->getScalarType(); 1157 break; 1158 case Intrinsic::matrix_columnwise_load: 1159 prettyPrintMatrixType(II, SS); 1160 SS << "." << *II->getType()->getScalarType(); 1161 break; 1162 case Intrinsic::matrix_columnwise_store: 1163 prettyPrintMatrixType(II->getOperand(0), SS); 1164 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1165 break; 1166 default: 1167 llvm_unreachable("Unhandled case"); 1168 } 1169 SS.flush(); 1170 write(Tmp); 1171 } 1172 } 1173 1174 unsigned getNumShapeArgs(CallInst *CI) const { 1175 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1176 switch (II->getIntrinsicID()) { 1177 case Intrinsic::matrix_multiply: 1178 return 3; 1179 case Intrinsic::matrix_transpose: 1180 case Intrinsic::matrix_columnwise_load: 1181 case Intrinsic::matrix_columnwise_store: 1182 return 2; 1183 default: 1184 return 0; 1185 } 1186 } 1187 return 0; 1188 } 1189 1190 /// Special printing for values: for pointers, we print if they refer to an 1191 /// (function) external address or a stack address, for other values we 1192 /// either print the constant or "scalar"/"matrix" for other values. 1193 void write(Value *V) { 1194 V = getUnderlyingObjectThroughLoads(V); 1195 if (V->getType()->isPointerTy()) { 1196 if (isa<AllocaInst>(V)) { 1197 Stream << "stack addr"; 1198 LineLength += StringRef("stack addr").size(); 1199 } else { 1200 Stream << "addr"; 1201 LineLength += StringRef("addr").size(); 1202 } 1203 if (!V->getName().empty()) { 1204 Stream << " %" << V->getName() << ""; 1205 LineLength += V->getName().size() + 2; 1206 } 1207 return; 1208 } 1209 1210 std::string Tmp; 1211 raw_string_ostream TmpStream(Tmp); 1212 1213 if (auto *CI = dyn_cast<ConstantInt>(V)) 1214 TmpStream << CI->getValue(); 1215 else if (isa<Constant>(V)) 1216 TmpStream << "constant"; 1217 else { 1218 if (isMatrix(V)) 1219 TmpStream << "matrix"; 1220 else 1221 TmpStream << "scalar"; 1222 } 1223 TmpStream.flush(); 1224 Tmp = std::string(StringRef(Tmp).trim()); 1225 LineLength += Tmp.size(); 1226 Stream << Tmp; 1227 } 1228 1229 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1230 /// Expressions that are re-used multiple times are prefixed with (reused) 1231 /// at the re-used root instruction. 1232 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1233 bool ParentShared) { 1234 auto *I = cast<Instruction>(Expr); 1235 maybeIndent(Indent); 1236 SmallVector<Value *, 8> Ops; 1237 1238 // Is Expr shared with other expression leaves? 1239 bool ExprShared = false; 1240 1241 // Deal with shared subtrees. Mark them as shared, if required. 1242 if (!ParentShared) { 1243 auto SI = Shared.find(Expr); 1244 assert(SI != Shared.end() && SI->second.find(Leaf) != SI->second.end()); 1245 1246 for (Value *S : SI->second) { 1247 if (S == Leaf) 1248 continue; 1249 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1250 write("shared with remark at line " + std::to_string(DL.getLine()) + 1251 " column " + std::to_string(DL.getCol()) + " ("); 1252 } 1253 ExprShared = SI->second.size() > 1; 1254 } 1255 1256 bool Reused = !ReusedExprs.insert(Expr).second; 1257 if (Reused && !ParentReused) 1258 write("(reused) "); 1259 1260 if (auto *CI = dyn_cast<CallInst>(I)) { 1261 writeFnName(CI); 1262 1263 Ops.append(CallSite(CI).arg_begin(), 1264 CallSite(CI).arg_end() - getNumShapeArgs(CI)); 1265 } else if (isa<BitCastInst>(Expr)) { 1266 // Special case bitcasts, which are used to materialize matrixes from 1267 // non-matrix ops. 1268 write("matrix"); 1269 return; 1270 } else { 1271 Ops.append(I->value_op_begin(), I->value_op_end()); 1272 write(std::string(I->getOpcodeName())); 1273 } 1274 1275 write(std::string("(")); 1276 1277 unsigned NumOpsToBreak = 1; 1278 if (match(Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) 1279 NumOpsToBreak = 2; 1280 1281 for (Value *Op : Ops) { 1282 if (Ops.size() > NumOpsToBreak) 1283 lineBreak(); 1284 1285 maybeIndent(Indent + 1); 1286 if (isMatrix(Op)) 1287 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1288 else 1289 write(Op); 1290 if (Op != Ops.back()) 1291 write(", "); 1292 } 1293 1294 write(")"); 1295 } 1296 1297 const std::string &getResult() { 1298 Stream.flush(); 1299 return Str; 1300 } 1301 }; 1302 1303 /// Generate remarks for matrix operations in a function. To generate remarks 1304 /// for matrix expressions, the following approach is used: 1305 /// 1. Use the inlined-at debug information to group matrix operations to the 1306 /// DISubprograms they are contained in. 1307 /// 2. Collect leaves of matrix expressions (done in 1308 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1309 // mapping. Leaves are lowered matrix instructions without other matrix 1310 // users (like stores) in the current subprogram. 1311 /// 3. For each leaf, create a remark containing a linearizied version of the 1312 /// matrix expression. The expression is linearized by a recursive 1313 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1314 /// that multiple leaves can share sub-expressions. Shared subexpressions 1315 /// are explicitly marked as shared(). 1316 struct RemarkGenerator { 1317 const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix; 1318 OptimizationRemarkEmitter &ORE; 1319 Function &Func; 1320 const DataLayout &DL; 1321 1322 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2ColumnMatrix, 1323 OptimizationRemarkEmitter &ORE, Function &Func) 1324 : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), Func(Func), 1325 DL(Func.getParent()->getDataLayout()) {} 1326 1327 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1328 /// instructions in Inst2ColumnMatrix returning void or without any users in 1329 /// \p ExprsInSubprogram. Currently that should only include stores. 1330 SmallVector<Value *, 4> 1331 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1332 SmallVector<Value *, 4> Leaves; 1333 for (auto *Expr : ExprsInSubprogram) 1334 if (Expr->getType()->isVoidTy() || 1335 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1336 return ExprsInSubprogram.count(U); 1337 })) 1338 Leaves.push_back(Expr); 1339 return Leaves; 1340 } 1341 1342 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1343 /// to all visited expressions in \p Shared. Limit the matrix operations to 1344 /// the ones in \p ExprsInSubprogram. 1345 void collectSharedInfo(Value *Leaf, Value *V, 1346 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1347 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1348 1349 if (!ExprsInSubprogram.count(V)) 1350 return; 1351 1352 auto I = Shared.insert({V, {}}); 1353 I.first->second.insert(Leaf); 1354 1355 for (Value *Op : cast<Instruction>(V)->operand_values()) 1356 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1357 return; 1358 } 1359 1360 /// Calculate the number of exclusive and shared op counts for expression 1361 /// starting at \p V. Expressions used multiple times are counted once. 1362 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1363 std::pair<OpInfoTy, OpInfoTy> 1364 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1365 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1366 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1367 if (!ExprsInSubprogram.count(Root)) 1368 return {}; 1369 1370 // Already counted this expression. Stop. 1371 if (!ReusedExprs.insert(Root).second) 1372 return {}; 1373 1374 OpInfoTy SharedCount; 1375 OpInfoTy Count; 1376 1377 auto I = Shared.find(Root); 1378 auto CM = Inst2ColumnMatrix.find(Root); 1379 if (I->second.size() == 1) 1380 Count = CM->second.getOpInfo(); 1381 else 1382 SharedCount = CM->second.getOpInfo(); 1383 1384 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1385 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1386 Count += C.first; 1387 SharedCount += C.second; 1388 } 1389 return {Count, SharedCount}; 1390 } 1391 1392 void emitRemarks() { 1393 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1394 return; 1395 1396 // Map matrix operations to their containting subprograms, by traversing 1397 // the inlinedAt chain. If the function does not have a DISubprogram, we 1398 // only map them to the containing function. 1399 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1400 for (auto &KV : Inst2ColumnMatrix) { 1401 if (Func.getSubprogram()) { 1402 auto *I = cast<Instruction>(KV.first); 1403 DILocation *Context = I->getDebugLoc(); 1404 while (Context) { 1405 auto I = 1406 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1407 I.first->second.push_back(KV.first); 1408 Context = DebugLoc(Context).getInlinedAt(); 1409 } 1410 } else { 1411 auto I = Subprog2Exprs.insert({nullptr, {}}); 1412 I.first->second.push_back(KV.first); 1413 } 1414 } 1415 for (auto &KV : Subprog2Exprs) { 1416 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1417 KV.second.end()); 1418 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1419 1420 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1421 for (Value *Leaf : Leaves) 1422 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1423 1424 // Generate remarks for each leaf. 1425 for (auto *L : Leaves) { 1426 1427 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1428 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1429 while (Context) { 1430 if (getSubprogram(Context->getScope()) == KV.first) { 1431 Loc = Context; 1432 break; 1433 } 1434 Context = DebugLoc(Context).getInlinedAt(); 1435 } 1436 1437 SmallPtrSet<Value *, 8> ReusedExprs; 1438 OpInfoTy Counts, SharedCounts; 1439 std::tie(Counts, SharedCounts) = 1440 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1441 1442 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 1443 cast<Instruction>(L)->getParent()); 1444 1445 Rem << "Lowered with "; 1446 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1447 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1448 << ore::NV("NumComputeOps", Counts.NumComputeOps) 1449 << " compute ops"; 1450 1451 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1452 SharedCounts.NumComputeOps > 0) { 1453 Rem << ",\nadditionally " 1454 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1455 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1456 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1457 << " compute ops" 1458 << " are shared with other expressions"; 1459 } 1460 1461 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 1462 ORE.emit(Rem); 1463 } 1464 } 1465 } 1466 1467 std::string 1468 linearize(Value *L, 1469 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1470 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1471 const DataLayout &DL) { 1472 ExprLinearizer Lin(DL, Inst2ColumnMatrix, Shared, ExprsInSubprogram, L); 1473 Lin.linearizeExpr(L, 0, false, false); 1474 return Lin.getResult(); 1475 } 1476 }; 1477 }; 1478 } // namespace 1479 1480 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1481 FunctionAnalysisManager &AM) { 1482 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1483 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1484 LowerMatrixIntrinsics LMT(F, TTI, ORE); 1485 if (LMT.Visit()) { 1486 PreservedAnalyses PA; 1487 PA.preserveSet<CFGAnalyses>(); 1488 return PA; 1489 } 1490 return PreservedAnalyses::all(); 1491 } 1492 1493 namespace { 1494 1495 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1496 public: 1497 static char ID; 1498 1499 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1500 initializeLowerMatrixIntrinsicsLegacyPassPass( 1501 *PassRegistry::getPassRegistry()); 1502 } 1503 1504 bool runOnFunction(Function &F) override { 1505 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1506 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1507 LowerMatrixIntrinsics LMT(F, TTI, ORE); 1508 bool C = LMT.Visit(); 1509 return C; 1510 } 1511 1512 void getAnalysisUsage(AnalysisUsage &AU) const override { 1513 AU.addRequired<TargetTransformInfoWrapperPass>(); 1514 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1515 AU.setPreservesCFG(); 1516 } 1517 }; 1518 } // namespace 1519 1520 static const char pass_name[] = "Lower the matrix intrinsics"; 1521 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1522 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1523 false, false) 1524 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1525 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1526 false, false) 1527 1528 Pass *llvm::createLowerMatrixIntrinsicsPass() { 1529 return new LowerMatrixIntrinsicsLegacyPass(); 1530 } 1531