1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Implement multiply & add fusion 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 17 #include "llvm/ADT/GraphTraits.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 21 #include "llvm/Analysis/TargetTransformInfo.h" 22 #include "llvm/Analysis/ValueTracking.h" 23 #include "llvm/Analysis/VectorUtils.h" 24 #include "llvm/IR/CFG.h" 25 #include "llvm/IR/DataLayout.h" 26 #include "llvm/IR/DebugInfoMetadata.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/IntrinsicInst.h" 31 #include "llvm/IR/PatternMatch.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Transforms/Scalar.h" 36 37 using namespace llvm; 38 using namespace PatternMatch; 39 40 #define DEBUG_TYPE "lower-matrix-intrinsics" 41 42 static cl::opt<bool> EnableShapePropagation( 43 "matrix-propagate-shape", cl::init(true), cl::Hidden, 44 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 45 "instructions.")); 46 47 static cl::opt<bool> AllowContractEnabled( 48 "matrix-allow-contract", cl::init(false), cl::Hidden, 49 cl::desc("Allow the use of FMAs if available and profitable. This may " 50 "result in different results, due to less rounding error.")); 51 52 /// Helper function to either return Scope, if it is a subprogram or the 53 /// attached subprogram for a local scope. 54 static DISubprogram *getSubprogram(DIScope *Scope) { 55 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 56 return Subprogram; 57 return cast<DILocalScope>(Scope)->getSubprogram(); 58 } 59 60 namespace { 61 62 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 63 // the start address of column \p Col with type (\p EltType x \p NumRows) 64 // assuming \p Stride elements between start two consecutive columns. 65 // \p Stride must be >= \p NumRows. 66 // 67 // Consider a 4x4 matrix like below 68 // 69 // 0 1 2 3 70 // 0 v_0_0 v_0_1 v_0_2 v_0_3 71 // 1 v_1_0 v_1_1 v_1_2 v_1_3 72 // 2 v_2_0 v_2_1 v_2_2 v_2_3 73 // 3 v_3_0 v_3_1 v_3_2 v_3_3 74 75 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 76 // we need a pointer to the first element of the submatrix as base pointer. 77 // Then we can use computeColumnAddr to compute the addresses for the columns 78 // of the sub-matrix. 79 // 80 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 81 // -> just returns Base 82 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 83 // -> returns Base + (1 * 4) 84 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 85 // -> returns Base + (2 * 4) 86 // 87 // The graphic below illustrates the number of elements in a column (marked 88 // with |) and the number of skipped elements (marked with }). 89 // 90 // v_0_0 v_0_1 {v_0_2 {v_0_3 91 // Base Col 1 Col 2 92 // | | | 93 // v_1_0 |v_1_1 |v_1_2 |v_1_3 94 // v_2_0 |v_2_1 |v_2_2 |v_2_3 95 // v_3_0 {v_3_1 {v_3_2 v_3_3 96 // 97 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 98 unsigned NumRows, Type *EltType, 99 IRBuilder<> &Builder) { 100 101 assert((!isa<ConstantInt>(Stride) || 102 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 103 "Stride must be >= the number of rows."); 104 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 105 106 // Compute the start of the column with index Col as Col * Stride. 107 Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); 108 109 // Get pointer to the start of the selected column. Skip GEP creation, 110 // if we select column 0. 111 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 112 ColumnStart = BasePtr; 113 else 114 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); 115 116 // Cast elementwise column start pointer to a pointer to a column 117 // (EltType x NumRows)*. 118 Type *ColumnType = VectorType::get(EltType, NumRows); 119 Type *ColumnPtrType = PointerType::get(ColumnType, AS); 120 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); 121 } 122 123 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 124 /// 125 /// Currently, the lowering for each matrix intrinsic is done as follows: 126 /// 1. Propagate the shape information from intrinsics to connected 127 /// instructions. 128 /// 2. Lower instructions with shape information. 129 /// 2.1. Get column vectors for each argument. If we already lowered the 130 /// definition of an argument, use the produced column vectors directly. 131 /// If not, split the operand vector containing an embedded matrix into 132 /// a set of column vectors, 133 /// 2.2. Lower the instruction in terms of columnwise operations, which yields 134 /// a set of column vectors containing result matrix. Note that we lower 135 /// all instructions that have shape information. Besides the intrinsics, 136 /// this includes stores for example. 137 /// 2.3. Update uses of the lowered instruction. If we have shape information 138 /// for a user, there is nothing to do, as we will look up the result 139 /// column matrix when lowering the user. For other uses, we embed the 140 /// result matrix in a flat vector and update the use. 141 /// 2.4. Cache the result column matrix for the instruction we lowered 142 /// 3. After we lowered all instructions in a function, remove the now 143 /// obsolete instructions. 144 /// 145 class LowerMatrixIntrinsics { 146 Function &Func; 147 const DataLayout &DL; 148 const TargetTransformInfo &TTI; 149 OptimizationRemarkEmitter &ORE; 150 151 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 152 struct OpInfoTy { 153 /// Number of stores emitted to generate this matrix. 154 unsigned NumStores = 0; 155 /// Number of loads emitted to generate this matrix. 156 unsigned NumLoads = 0; 157 /// Number of compute operations emitted to generate this matrix. 158 unsigned NumComputeOps = 0; 159 160 OpInfoTy &operator+=(const OpInfoTy &RHS) { 161 NumStores += RHS.NumStores; 162 NumLoads += RHS.NumLoads; 163 NumComputeOps += RHS.NumComputeOps; 164 return *this; 165 } 166 }; 167 168 /// Wrapper class representing a matrix as a set of column vectors. 169 /// All column vectors must have the same vector type. 170 class ColumnMatrixTy { 171 SmallVector<Value *, 16> Columns; 172 173 OpInfoTy OpInfo; 174 175 public: 176 ColumnMatrixTy() : Columns() {} 177 ColumnMatrixTy(ArrayRef<Value *> Cols) 178 : Columns(Cols.begin(), Cols.end()) {} 179 180 Value *getColumn(unsigned i) const { return Columns[i]; } 181 182 void setColumn(unsigned i, Value *V) { Columns[i] = V; } 183 184 size_t getNumColumns() const { return Columns.size(); } 185 size_t getNumRows() const { 186 assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); 187 return cast<VectorType>(Columns[0]->getType())->getNumElements(); 188 } 189 190 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } 191 192 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } 193 194 void addColumn(Value *V) { Columns.push_back(V); } 195 196 VectorType *getColumnTy() { 197 return cast<VectorType>(Columns[0]->getType()); 198 } 199 200 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 201 return make_range(Columns.begin(), Columns.end()); 202 } 203 204 /// Embed the columns of the matrix into a flat vector by concatenating 205 /// them. 206 Value *embedInVector(IRBuilder<> &Builder) const { 207 return Columns.size() == 1 ? Columns[0] 208 : concatenateVectors(Builder, Columns); 209 } 210 211 ColumnMatrixTy &addNumLoads(unsigned N) { 212 OpInfo.NumLoads += N; 213 return *this; 214 } 215 216 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 217 218 ColumnMatrixTy &addNumStores(unsigned N) { 219 OpInfo.NumStores += N; 220 return *this; 221 } 222 223 ColumnMatrixTy &addNumComputeOps(unsigned N) { 224 OpInfo.NumComputeOps += N; 225 return *this; 226 } 227 228 unsigned getNumStores() const { return OpInfo.NumStores; } 229 unsigned getNumLoads() const { return OpInfo.NumLoads; } 230 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 231 232 const OpInfoTy &getOpInfo() const { return OpInfo; } 233 }; 234 235 struct ShapeInfo { 236 unsigned NumRows; 237 unsigned NumColumns; 238 239 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 240 : NumRows(NumRows), NumColumns(NumColumns) {} 241 242 ShapeInfo(Value *NumRows, Value *NumColumns) 243 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), 244 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} 245 246 bool operator==(const ShapeInfo &other) { 247 return NumRows == other.NumRows && NumColumns == other.NumColumns; 248 } 249 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 250 251 /// Returns true if shape-information is defined, meaning both dimensions 252 /// are != 0. 253 operator bool() const { 254 assert(NumRows == 0 || NumColumns != 0); 255 return NumRows != 0; 256 } 257 }; 258 259 /// Maps instructions to their shape information. The shape information 260 /// describes the shape to be used while lowering. This matches the shape of 261 /// the result value of the instruction, with the only exceptions being store 262 /// instructions and the matrix_columnwise_store intrinsics. For those, the 263 /// shape information indicates that those instructions should be lowered 264 /// using shape information as well. 265 DenseMap<Value *, ShapeInfo> ShapeMap; 266 267 /// List of instructions to remove. While lowering, we are not replacing all 268 /// users of a lowered instruction, if shape information is available and 269 /// those need to be removed after we finished lowering. 270 SmallVector<Instruction *, 16> ToRemove; 271 272 /// Map from instructions to their produced column matrix. 273 MapVector<Value *, ColumnMatrixTy> Inst2ColumnMatrix; 274 275 public: 276 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 277 OptimizationRemarkEmitter &ORE) 278 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {} 279 280 unsigned getNumOps(Type *VT) { 281 assert(isa<VectorType>(VT) && "Expected vector type"); 282 return getNumOps(VT->getScalarType(), 283 cast<VectorType>(VT)->getNumElements()); 284 } 285 286 // 287 /// Return the estimated number of vector ops required for an operation on 288 /// \p VT * N. 289 unsigned getNumOps(Type *ST, unsigned N) { 290 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 291 double(TTI.getRegisterBitWidth(true))); 292 } 293 294 /// Return the set of column vectors that a matrix value is lowered to. 295 /// 296 /// If we lowered \p MatrixVal, just return the cache result column matrix. 297 /// Otherwie split the flat vector \p MatrixVal containing a matrix with 298 /// shape \p SI into column vectors. 299 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 300 IRBuilder<> &Builder) { 301 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 302 assert(VType && "MatrixVal must be a vector type"); 303 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 304 "The vector size must match the number of matrix elements"); 305 306 // Check if we lowered MatrixVal using shape information. In that case, 307 // return the existing column matrix, if it matches the requested shape 308 // information. If there is a mis-match, embed the result in a flat 309 // vector and split it later. 310 auto Found = Inst2ColumnMatrix.find(MatrixVal); 311 if (Found != Inst2ColumnMatrix.end()) { 312 ColumnMatrixTy &M = Found->second; 313 // Return the found matrix, if its shape matches the requested shape 314 // information 315 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 316 return M; 317 318 MatrixVal = M.embedInVector(Builder); 319 } 320 321 // Otherwise split MatrixVal. 322 SmallVector<Value *, 16> SplitVecs; 323 Value *Undef = UndefValue::get(VType); 324 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 325 MaskStart += SI.NumRows) { 326 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 327 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 328 SplitVecs.push_back(V); 329 } 330 331 return {SplitVecs}; 332 } 333 334 /// If \p V already has a known shape return false. Otherwise set the shape 335 /// for instructions that support it. 336 bool setShapeInfo(Value *V, ShapeInfo Shape) { 337 assert(Shape && "Shape not set"); 338 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 339 return false; 340 341 auto SIter = ShapeMap.find(V); 342 if (SIter != ShapeMap.end()) { 343 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 344 << SIter->second.NumRows << " " 345 << SIter->second.NumColumns << " for " << *V << "\n"); 346 return false; 347 } 348 349 ShapeMap.insert({V, Shape}); 350 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 351 << " for " << *V << "\n"); 352 return true; 353 } 354 355 bool isUniformShape(Value *V) { 356 Instruction *I = dyn_cast<Instruction>(V); 357 if (!I) 358 return true; 359 360 switch (I->getOpcode()) { 361 case Instruction::FAdd: 362 case Instruction::FSub: 363 case Instruction::FMul: // Scalar multiply. 364 case Instruction::Add: 365 case Instruction::Mul: 366 case Instruction::Sub: 367 return true; 368 default: 369 return false; 370 } 371 } 372 373 /// Returns true if shape information can be used for \p V. The supported 374 /// instructions must match the instructions that can be lowered by this pass. 375 bool supportsShapeInfo(Value *V) { 376 Instruction *Inst = dyn_cast<Instruction>(V); 377 if (!Inst) 378 return false; 379 380 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 381 if (II) 382 switch (II->getIntrinsicID()) { 383 case Intrinsic::matrix_multiply: 384 case Intrinsic::matrix_transpose: 385 case Intrinsic::matrix_columnwise_load: 386 case Intrinsic::matrix_columnwise_store: 387 return true; 388 default: 389 return false; 390 } 391 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 392 } 393 394 /// Propagate the shape information of instructions to their users. 395 /// The work list contains instructions for which we can compute the shape, 396 /// either based on the information provided by matrix intrinsics or known 397 /// shapes of operands. 398 SmallVector<Instruction *, 32> 399 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 400 SmallVector<Instruction *, 32> NewWorkList; 401 // Pop an element for which we guaranteed to have at least one of the 402 // operand shapes. Add the shape for this and then add users to the work 403 // list. 404 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 405 while (!WorkList.empty()) { 406 Instruction *Inst = WorkList.back(); 407 WorkList.pop_back(); 408 409 // New entry, set the value and insert operands 410 bool Propagate = false; 411 412 Value *MatrixA; 413 Value *MatrixB; 414 Value *M; 415 Value *N; 416 Value *K; 417 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 418 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 419 m_Value(N), m_Value(K)))) { 420 Propagate = setShapeInfo(Inst, {M, K}); 421 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 422 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 423 // Flip dimensions. 424 Propagate = setShapeInfo(Inst, {N, M}); 425 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 426 m_Value(MatrixA), m_Value(), m_Value(), 427 m_Value(M), m_Value(N)))) { 428 Propagate = setShapeInfo(Inst, {N, M}); 429 } else if (match(Inst, 430 m_Intrinsic<Intrinsic::matrix_columnwise_load>( 431 m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 432 Propagate = setShapeInfo(Inst, {M, N}); 433 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 434 auto OpShape = ShapeMap.find(MatrixA); 435 if (OpShape != ShapeMap.end()) 436 setShapeInfo(Inst, OpShape->second); 437 continue; 438 } else if (isUniformShape(Inst)) { 439 // Find the first operand that has a known shape and use that. 440 for (auto &Op : Inst->operands()) { 441 auto OpShape = ShapeMap.find(Op.get()); 442 if (OpShape != ShapeMap.end()) { 443 Propagate |= setShapeInfo(Inst, OpShape->second); 444 break; 445 } 446 } 447 } 448 449 if (Propagate) { 450 NewWorkList.push_back(Inst); 451 for (auto *User : Inst->users()) 452 if (ShapeMap.count(User) == 0) 453 WorkList.push_back(cast<Instruction>(User)); 454 } 455 } 456 457 return NewWorkList; 458 } 459 460 /// Propagate the shape to operands of instructions with shape information. 461 /// \p Worklist contains the instruction for which we already know the shape. 462 SmallVector<Instruction *, 32> 463 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 464 SmallVector<Instruction *, 32> NewWorkList; 465 466 auto pushInstruction = [](Value *V, 467 SmallVectorImpl<Instruction *> &WorkList) { 468 Instruction *I = dyn_cast<Instruction>(V); 469 if (I) 470 WorkList.push_back(I); 471 }; 472 // Pop an element with known shape. Traverse the operands, if their shape 473 // derives from the result shape and is unknown, add it and add them to the 474 // worklist. 475 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 476 while (!WorkList.empty()) { 477 Value *V = WorkList.back(); 478 WorkList.pop_back(); 479 480 size_t BeforeProcessingV = WorkList.size(); 481 if (!isa<Instruction>(V)) 482 continue; 483 484 Value *MatrixA; 485 Value *MatrixB; 486 Value *M; 487 Value *N; 488 Value *K; 489 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 490 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 491 m_Value(N), m_Value(K)))) { 492 if (setShapeInfo(MatrixA, {M, N})) 493 pushInstruction(MatrixA, WorkList); 494 495 if (setShapeInfo(MatrixB, {N, K})) 496 pushInstruction(MatrixB, WorkList); 497 498 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 499 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 500 // Flip dimensions. 501 if (setShapeInfo(MatrixA, {M, N})) 502 pushInstruction(MatrixA, WorkList); 503 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 504 m_Value(MatrixA), m_Value(), m_Value(), 505 m_Value(M), m_Value(N)))) { 506 if (setShapeInfo(MatrixA, {M, N})) { 507 pushInstruction(MatrixA, WorkList); 508 } 509 } else if (isa<LoadInst>(V) || 510 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { 511 // Nothing to do, no matrix input. 512 } else if (isa<StoreInst>(V)) { 513 // Nothing to do. We forward-propagated to this so we would just 514 // backward propagate to an instruction with an already known shape. 515 } else if (isUniformShape(V)) { 516 // Propagate to all operands. 517 ShapeInfo Shape = ShapeMap[V]; 518 for (Use &U : cast<Instruction>(V)->operands()) { 519 if (setShapeInfo(U.get(), Shape)) 520 pushInstruction(U.get(), WorkList); 521 } 522 } 523 // After we discovered new shape info for new instructions in the 524 // worklist, we use their users as seeds for the next round of forward 525 // propagation. 526 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 527 for (User *U : WorkList[I]->users()) 528 if (isa<Instruction>(U) && V != U) 529 NewWorkList.push_back(cast<Instruction>(U)); 530 } 531 return NewWorkList; 532 } 533 534 bool Visit() { 535 if (EnableShapePropagation) { 536 SmallVector<Instruction *, 32> WorkList; 537 538 // Initially only the shape of matrix intrinsics is known. 539 // Initialize the work list with ops carrying shape information. 540 for (BasicBlock &BB : Func) 541 for (Instruction &Inst : BB) { 542 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 543 if (!II) 544 continue; 545 546 switch (II->getIntrinsicID()) { 547 case Intrinsic::matrix_multiply: 548 case Intrinsic::matrix_transpose: 549 case Intrinsic::matrix_columnwise_load: 550 case Intrinsic::matrix_columnwise_store: 551 WorkList.push_back(&Inst); 552 break; 553 default: 554 break; 555 } 556 } 557 // Propagate shapes until nothing changes any longer. 558 while (!WorkList.empty()) { 559 WorkList = propagateShapeForward(WorkList); 560 WorkList = propagateShapeBackward(WorkList); 561 } 562 } 563 564 ReversePostOrderTraversal<Function *> RPOT(&Func); 565 bool Changed = false; 566 for (auto *BB : RPOT) { 567 for (Instruction &Inst : make_early_inc_range(*BB)) { 568 IRBuilder<> Builder(&Inst); 569 570 if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 571 Changed |= VisitCallInst(CInst); 572 573 Value *Op1; 574 Value *Op2; 575 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) 576 Changed |= VisitBinaryOperator(BinOp); 577 if (match(&Inst, m_Load(m_Value(Op1)))) 578 Changed |= VisitLoad(&Inst, Op1, Builder); 579 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 580 Changed |= VisitStore(&Inst, Op1, Op2, Builder); 581 } 582 } 583 584 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); 585 RemarkGen.emitRemarks(); 586 587 for (Instruction *Inst : reverse(ToRemove)) 588 Inst->eraseFromParent(); 589 590 return Changed; 591 } 592 593 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 594 IRBuilder<> &Builder) { 595 return Builder.CreateAlignedLoad( 596 ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load"); 597 } 598 599 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 600 Type *EltType, IRBuilder<> &Builder) { 601 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, 602 DL.getABITypeAlign(EltType)); 603 } 604 605 606 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 607 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 608 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 609 Type *EltPtrType = PointerType::get(EltType, AS); 610 return Builder.CreatePointerCast(BasePtr, EltPtrType); 611 } 612 613 /// Replace intrinsic calls 614 bool VisitCallInst(CallInst *Inst) { 615 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 616 return false; 617 618 switch (Inst->getCalledFunction()->getIntrinsicID()) { 619 case Intrinsic::matrix_multiply: 620 LowerMultiply(Inst); 621 break; 622 case Intrinsic::matrix_transpose: 623 LowerTranspose(Inst); 624 break; 625 case Intrinsic::matrix_columnwise_load: 626 LowerColumnwiseLoad(Inst); 627 break; 628 case Intrinsic::matrix_columnwise_store: 629 LowerColumnwiseStore(Inst); 630 break; 631 default: 632 return false; 633 } 634 return true; 635 } 636 637 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, 638 ShapeInfo Shape) { 639 IRBuilder<> Builder(Inst); 640 auto VType = cast<VectorType>(Inst->getType()); 641 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 642 ColumnMatrixTy Result; 643 // Distance between start of one column and the start of the next 644 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 645 Value *GEP = 646 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 647 VType->getElementType(), Builder); 648 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 649 Result.addColumn(Column); 650 } 651 652 finalizeLowering(Inst, 653 Result.addNumLoads(getNumOps(Result.getColumnTy()) * 654 Result.getNumColumns()), 655 Builder); 656 } 657 658 /// Lowers llvm.matrix.columnwise.load. 659 /// 660 /// The intrinsic loads a matrix from memory using a stride between columns. 661 void LowerColumnwiseLoad(CallInst *Inst) { 662 Value *Ptr = Inst->getArgOperand(0); 663 Value *Stride = Inst->getArgOperand(1); 664 LowerLoad(Inst, Ptr, Stride, 665 {Inst->getArgOperand(2), Inst->getArgOperand(3)}); 666 } 667 668 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 669 ShapeInfo Shape) { 670 IRBuilder<> Builder(Inst); 671 auto VType = cast<VectorType>(Matrix->getType()); 672 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 673 auto LM = getMatrix(Matrix, Shape, Builder); 674 for (auto C : enumerate(LM.columns())) { 675 Value *GEP = 676 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, 677 Shape.NumRows, VType->getElementType(), Builder); 678 createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 679 } 680 Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores( 681 getNumOps(LM.getColumnTy()) * LM.getNumColumns()); 682 683 ToRemove.push_back(Inst); 684 } 685 686 /// Lowers llvm.matrix.columnwise.store. 687 /// 688 /// The intrinsic store a matrix back memory using a stride between columns. 689 void LowerColumnwiseStore(CallInst *Inst) { 690 Value *Matrix = Inst->getArgOperand(0); 691 Value *Ptr = Inst->getArgOperand(1); 692 Value *Stride = Inst->getArgOperand(2); 693 LowerStore(Inst, Matrix, Ptr, Stride, 694 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 695 } 696 697 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 698 /// the matrix \p LM represented as a vector of column vectors. 699 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, 700 unsigned NumElts, IRBuilder<> &Builder) { 701 Value *Col = LM.getColumn(J); 702 Value *Undef = UndefValue::get(Col->getType()); 703 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 704 return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 705 } 706 707 // Set elements I..I+NumElts-1 to Block 708 Value *insertVector(Value *Col, unsigned I, Value *Block, 709 IRBuilder<> &Builder) { 710 711 // First, bring Block to the same size as Col 712 unsigned BlockNumElts = 713 cast<VectorType>(Block->getType())->getNumElements(); 714 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 715 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 716 717 Value *ExtendMask = 718 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 719 Value *Undef = UndefValue::get(Block->getType()); 720 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 721 722 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 723 // 8, 4, 5, 6 724 SmallVector<Constant *, 16> Mask; 725 unsigned i; 726 for (i = 0; i < I; i++) 727 Mask.push_back(Builder.getInt32(i)); 728 729 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 730 for (; i < I + BlockNumElts; i++) 731 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 732 733 for (; i < VecNumElts; i++) 734 Mask.push_back(Builder.getInt32(i)); 735 736 Value *MaskVal = ConstantVector::get(Mask); 737 738 return Builder.CreateShuffleVector(Col, Block, MaskVal); 739 } 740 741 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 742 IRBuilder<> &Builder, bool AllowContraction, 743 unsigned &NumComputeOps) { 744 NumComputeOps += getNumOps(A->getType()); 745 if (!Sum) 746 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 747 748 if (UseFPOp) { 749 if (AllowContraction) { 750 // Use fmuladd for floating point operations and let the backend decide 751 // if that's profitable. 752 Function *FMulAdd = Intrinsic::getDeclaration( 753 Func.getParent(), Intrinsic::fmuladd, A->getType()); 754 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 755 } 756 NumComputeOps += getNumOps(A->getType()); 757 Value *Mul = Builder.CreateFMul(A, B); 758 return Builder.CreateFAdd(Sum, Mul); 759 } 760 761 NumComputeOps += getNumOps(A->getType()); 762 Value *Mul = Builder.CreateMul(A, B); 763 return Builder.CreateAdd(Sum, Mul); 764 } 765 766 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 767 /// users with shape information, there's nothing to do: the will use the 768 /// cached value when they are lowered. For other users, \p Matrix is 769 /// flattened and the uses are updated to use it. Also marks \p Inst for 770 /// deletion. 771 void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, 772 IRBuilder<> &Builder) { 773 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 774 775 ToRemove.push_back(Inst); 776 Value *Flattened = nullptr; 777 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 778 Use &U = *I++; 779 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 780 if (!Flattened) 781 Flattened = Matrix.embedInVector(Builder); 782 U.set(Flattened); 783 } 784 } 785 } 786 787 /// Lowers llvm.matrix.multiply. 788 void LowerMultiply(CallInst *MatMul) { 789 IRBuilder<> Builder(MatMul); 790 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 791 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 792 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 793 794 const ColumnMatrixTy &Lhs = 795 getMatrix(MatMul->getArgOperand(0), LShape, Builder); 796 const ColumnMatrixTy &Rhs = 797 getMatrix(MatMul->getArgOperand(1), RShape, Builder); 798 799 const unsigned R = LShape.NumRows; 800 const unsigned M = LShape.NumColumns; 801 const unsigned C = RShape.NumColumns; 802 assert(M == RShape.NumRows); 803 804 // Initialize the output 805 ColumnMatrixTy Result; 806 for (unsigned J = 0; J < C; ++J) 807 Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 808 809 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / 810 EltType->getPrimitiveSizeInBits(), 811 uint64_t(1)); 812 813 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 814 MatMul->hasAllowContract()); 815 unsigned NumComputeOps = 0; 816 // Multiply columns from the first operand with scalars from the second 817 // operand. Then move along the K axes and accumulate the columns. With 818 // this the adds can be vectorized without reassociation. 819 for (unsigned J = 0; J < C; ++J) { 820 unsigned BlockSize = VF; 821 for (unsigned I = 0; I < R; I += BlockSize) { 822 // Gradually lower the vectorization factor to cover the remainder. 823 while (I + BlockSize > R) 824 BlockSize /= 2; 825 826 Value *Sum = nullptr; 827 for (unsigned K = 0; K < M; ++K) { 828 Value *L = extractVector(Lhs, I, K, BlockSize, Builder); 829 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); 830 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 831 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), 832 Builder, AllowContract, NumComputeOps); 833 } 834 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 835 } 836 } 837 Result.addNumComputeOps(NumComputeOps); 838 finalizeLowering(MatMul, Result, Builder); 839 } 840 841 /// Lowers llvm.matrix.transpose. 842 void LowerTranspose(CallInst *Inst) { 843 ColumnMatrixTy Result; 844 IRBuilder<> Builder(Inst); 845 Value *InputVal = Inst->getArgOperand(0); 846 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 847 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 848 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 849 850 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 851 // Build a single column vector for this row. First initialize it. 852 Value *ResultColumn = UndefValue::get( 853 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 854 855 // Go through the elements of this row and insert it into the resulting 856 // column vector. 857 for (auto C : enumerate(InputMatrix.columns())) { 858 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 859 // We insert at index Column since that is the row index after the 860 // transpose. 861 ResultColumn = 862 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 863 } 864 Result.addColumn(ResultColumn); 865 } 866 867 // TODO: Improve estimate of operations needed for transposes. Currently we 868 // just count the insertelement/extractelement instructions, but do not 869 // account for later simplifications/combines. 870 finalizeLowering( 871 Inst, 872 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 873 Builder); 874 } 875 876 /// Lower load instructions, if shape information is available. 877 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { 878 auto I = ShapeMap.find(Inst); 879 if (I == ShapeMap.end()) 880 return false; 881 882 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); 883 return true; 884 } 885 886 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 887 IRBuilder<> &Builder) { 888 auto I = ShapeMap.find(StoredVal); 889 if (I == ShapeMap.end()) 890 return false; 891 892 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); 893 return true; 894 } 895 896 /// Lower binary operators, if shape information is available. 897 bool VisitBinaryOperator(BinaryOperator *Inst) { 898 auto I = ShapeMap.find(Inst); 899 if (I == ShapeMap.end()) 900 return false; 901 902 Value *Lhs = Inst->getOperand(0); 903 Value *Rhs = Inst->getOperand(1); 904 905 IRBuilder<> Builder(Inst); 906 ShapeInfo &Shape = I->second; 907 908 ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); 909 ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); 910 911 // Add each column and store the result back into the opmapping 912 ColumnMatrixTy Result; 913 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { 914 switch (Inst->getOpcode()) { 915 case Instruction::Add: 916 return Builder.CreateAdd(LHS, RHS); 917 case Instruction::Mul: 918 return Builder.CreateMul(LHS, RHS); 919 case Instruction::Sub: 920 return Builder.CreateSub(LHS, RHS); 921 case Instruction::FAdd: 922 return Builder.CreateFAdd(LHS, RHS); 923 case Instruction::FMul: 924 return Builder.CreateFMul(LHS, RHS); 925 case Instruction::FSub: 926 return Builder.CreateFSub(LHS, RHS); 927 default: 928 llvm_unreachable("Unsupported binary operator for matrix"); 929 } 930 }; 931 for (unsigned C = 0; C < Shape.NumColumns; ++C) 932 Result.addColumn( 933 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); 934 935 finalizeLowering(Inst, 936 Result.addNumComputeOps(getNumOps(Result.getColumnTy()) * 937 Result.getNumColumns()), 938 Builder); 939 return true; 940 } 941 942 /// Helper to linearize a matrix expression tree into a string. Currently 943 /// matrix expressions are linarized by starting at an expression leaf and 944 /// linearizing bottom up. 945 struct ExprLinearizer { 946 unsigned LengthToBreak = 100; 947 std::string Str; 948 raw_string_ostream Stream; 949 unsigned LineLength = 0; 950 const DataLayout &DL; 951 952 /// Mapping from instructions to column matrixes. It is used to identify 953 /// matrix instructions. 954 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix; 955 956 /// Mapping from values to the leaves of all expressions that the value is 957 /// part of. 958 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 959 960 /// Set of matrix expressions in the scope of a given DISubprogram. 961 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 962 963 /// Leaf node of the expression to linearize. 964 Value *Leaf; 965 966 /// Used to keep track of sub-expressions that get reused while linearizing 967 /// the expression. Re-used sub-expressions are marked as (reused). 968 SmallPtrSet<Value *, 8> ReusedExprs; 969 970 ExprLinearizer(const DataLayout &DL, 971 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix, 972 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 973 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 974 Value *Leaf) 975 : Str(), Stream(Str), DL(DL), Inst2ColumnMatrix(Inst2ColumnMatrix), 976 Shared(Shared), ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 977 978 void indent(unsigned N) { 979 LineLength += N; 980 for (unsigned i = 0; i < N; i++) 981 Stream << " "; 982 } 983 984 void lineBreak() { 985 Stream << "\n"; 986 LineLength = 0; 987 } 988 989 void maybeIndent(unsigned Indent) { 990 if (LineLength >= LengthToBreak) 991 lineBreak(); 992 993 if (LineLength == 0) 994 indent(Indent); 995 } 996 997 void write(StringRef S) { 998 LineLength += S.size(); 999 Stream << S; 1000 } 1001 1002 Value *getUnderlyingObjectThroughLoads(Value *V) { 1003 if (Value *Ptr = getPointerOperand(V)) 1004 return getUnderlyingObjectThroughLoads(Ptr); 1005 else if (V->getType()->isPointerTy()) 1006 return GetUnderlyingObject(V, DL); 1007 return V; 1008 } 1009 1010 /// Returns true if \p V is a matrix value in the given subprogram. 1011 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1012 1013 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1014 /// \p SS. 1015 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1016 auto M = Inst2ColumnMatrix.find(V); 1017 if (M == Inst2ColumnMatrix.end()) 1018 SS << "unknown"; 1019 else { 1020 SS << M->second.getNumRows(); 1021 SS << "x"; 1022 SS << M->second.getNumColumns(); 1023 } 1024 } 1025 1026 /// Write the called function name. Handles calls to llvm.matrix.* 1027 /// specially: we write the name, followed by the dimensions of the input 1028 /// matrixes, followed by the scalar type name. 1029 void writeFnName(CallInst *CI) { 1030 if (!CI->getCalledFunction()) 1031 write("<no called fn>"); 1032 else { 1033 StringRef Name = CI->getCalledFunction()->getName(); 1034 if (!Name.startswith("llvm.matrix")) { 1035 write(Name); 1036 return; 1037 } 1038 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1039 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1040 .drop_front(StringRef("llvm.matrix.").size())); 1041 write("."); 1042 std::string Tmp = ""; 1043 raw_string_ostream SS(Tmp); 1044 1045 switch (II->getIntrinsicID()) { 1046 case Intrinsic::matrix_multiply: 1047 prettyPrintMatrixType(II->getOperand(0), SS); 1048 SS << "."; 1049 prettyPrintMatrixType(II->getOperand(1), SS); 1050 SS << "." << *II->getType()->getScalarType(); 1051 break; 1052 case Intrinsic::matrix_transpose: 1053 prettyPrintMatrixType(II->getOperand(0), SS); 1054 SS << "." << *II->getType()->getScalarType(); 1055 break; 1056 case Intrinsic::matrix_columnwise_load: 1057 prettyPrintMatrixType(II, SS); 1058 SS << "." << *II->getType()->getScalarType(); 1059 break; 1060 case Intrinsic::matrix_columnwise_store: 1061 prettyPrintMatrixType(II->getOperand(0), SS); 1062 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1063 break; 1064 default: 1065 llvm_unreachable("Unhandled case"); 1066 } 1067 SS.flush(); 1068 write(Tmp); 1069 } 1070 } 1071 1072 unsigned getNumShapeArgs(CallInst *CI) const { 1073 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1074 switch (II->getIntrinsicID()) { 1075 case Intrinsic::matrix_multiply: 1076 return 3; 1077 case Intrinsic::matrix_transpose: 1078 case Intrinsic::matrix_columnwise_load: 1079 case Intrinsic::matrix_columnwise_store: 1080 return 2; 1081 default: 1082 return 0; 1083 } 1084 } 1085 return 0; 1086 } 1087 1088 /// Special printing for values: for pointers, we print if they refer to an 1089 /// (function) external address or a stack address, for other values we 1090 /// either print the constant or "scalar"/"matrix" for other values. 1091 void write(Value *V) { 1092 V = getUnderlyingObjectThroughLoads(V); 1093 if (V->getType()->isPointerTy()) { 1094 if (isa<AllocaInst>(V)) { 1095 Stream << "stack addr"; 1096 LineLength += StringRef("stack addr").size(); 1097 } else { 1098 Stream << "addr"; 1099 LineLength += StringRef("addr").size(); 1100 } 1101 if (!V->getName().empty()) { 1102 Stream << " %" << V->getName() << ""; 1103 LineLength += V->getName().size() + 2; 1104 } 1105 return; 1106 } 1107 1108 std::string Tmp; 1109 raw_string_ostream TmpStream(Tmp); 1110 1111 if (auto *CI = dyn_cast<ConstantInt>(V)) 1112 TmpStream << CI->getValue(); 1113 else if (isa<Constant>(V)) 1114 TmpStream << "constant"; 1115 else { 1116 if (isMatrix(V)) 1117 TmpStream << "matrix"; 1118 else 1119 TmpStream << "scalar"; 1120 } 1121 TmpStream.flush(); 1122 Tmp = std::string(StringRef(Tmp).trim()); 1123 LineLength += Tmp.size(); 1124 Stream << Tmp; 1125 } 1126 1127 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1128 /// Expressions that are re-used multiple times are prefixed with (reused) 1129 /// at the re-used root instruction. 1130 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1131 bool ParentShared) { 1132 auto *I = cast<Instruction>(Expr); 1133 maybeIndent(Indent); 1134 SmallVector<Value *, 8> Ops; 1135 1136 // Is Expr shared with other expression leaves? 1137 bool ExprShared = false; 1138 1139 // Deal with shared subtrees. Mark them as shared, if required. 1140 if (!ParentShared) { 1141 auto SI = Shared.find(Expr); 1142 assert(SI != Shared.end() && SI->second.find(Leaf) != SI->second.end()); 1143 1144 for (Value *S : SI->second) { 1145 if (S == Leaf) 1146 continue; 1147 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1148 write("shared with remark at line " + std::to_string(DL.getLine()) + 1149 " column " + std::to_string(DL.getCol()) + " ("); 1150 } 1151 ExprShared = SI->second.size() > 1; 1152 } 1153 1154 bool Reused = !ReusedExprs.insert(Expr).second; 1155 if (Reused && !ParentReused) 1156 write("(reused) "); 1157 1158 if (auto *CI = dyn_cast<CallInst>(I)) { 1159 writeFnName(CI); 1160 1161 Ops.append(CallSite(CI).arg_begin(), 1162 CallSite(CI).arg_end() - getNumShapeArgs(CI)); 1163 } else if (isa<BitCastInst>(Expr)) { 1164 // Special case bitcasts, which are used to materialize matrixes from 1165 // non-matrix ops. 1166 write("matrix"); 1167 return; 1168 } else { 1169 Ops.append(I->value_op_begin(), I->value_op_end()); 1170 write(std::string(I->getOpcodeName())); 1171 } 1172 1173 write(std::string("(")); 1174 1175 unsigned NumOpsToBreak = 1; 1176 if (match(Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) 1177 NumOpsToBreak = 2; 1178 1179 for (Value *Op : Ops) { 1180 if (Ops.size() > NumOpsToBreak) 1181 lineBreak(); 1182 1183 maybeIndent(Indent + 1); 1184 if (isMatrix(Op)) 1185 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1186 else 1187 write(Op); 1188 if (Op != Ops.back()) 1189 write(", "); 1190 } 1191 1192 write(")"); 1193 } 1194 1195 const std::string &getResult() { 1196 Stream.flush(); 1197 return Str; 1198 } 1199 }; 1200 1201 /// Generate remarks for matrix operations in a function. To generate remarks 1202 /// for matrix expressions, the following approach is used: 1203 /// 1. Use the inlined-at debug information to group matrix operations to the 1204 /// DISubprograms they are contained in. 1205 /// 2. Collect leaves of matrix expressions (done in 1206 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1207 // mapping. Leaves are lowered matrix instructions without other matrix 1208 // users (like stores) in the current subprogram. 1209 /// 3. For each leaf, create a remark containing a linearizied version of the 1210 /// matrix expression. The expression is linearized by a recursive 1211 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1212 /// that multiple leaves can share sub-expressions. Shared subexpressions 1213 /// are explicitly marked as shared(). 1214 struct RemarkGenerator { 1215 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix; 1216 OptimizationRemarkEmitter &ORE; 1217 Function &Func; 1218 const DataLayout &DL; 1219 1220 RemarkGenerator(const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix, 1221 OptimizationRemarkEmitter &ORE, Function &Func) 1222 : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), Func(Func), 1223 DL(Func.getParent()->getDataLayout()) {} 1224 1225 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1226 /// instructions in Inst2ColumnMatrix returning void or without any users in 1227 /// \p ExprsInSubprogram. Currently that should only include stores. 1228 SmallVector<Value *, 4> 1229 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1230 SmallVector<Value *, 4> Leaves; 1231 for (auto *Expr : ExprsInSubprogram) 1232 if (Expr->getType()->isVoidTy() || 1233 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1234 return ExprsInSubprogram.count(U); 1235 })) 1236 Leaves.push_back(Expr); 1237 return Leaves; 1238 } 1239 1240 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1241 /// to all visited expressions in \p Shared. Limit the matrix operations to 1242 /// the ones in \p ExprsInSubprogram. 1243 void collectSharedInfo(Value *Leaf, Value *V, 1244 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1245 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1246 1247 if (!ExprsInSubprogram.count(V)) 1248 return; 1249 1250 auto I = Shared.insert({V, {}}); 1251 I.first->second.insert(Leaf); 1252 1253 for (Value *Op : cast<Instruction>(V)->operand_values()) 1254 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1255 return; 1256 } 1257 1258 /// Calculate the number of exclusive and shared op counts for expression 1259 /// starting at \p V. Expressions used multiple times are counted once. 1260 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1261 std::pair<OpInfoTy, OpInfoTy> 1262 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1263 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1264 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1265 if (!ExprsInSubprogram.count(Root)) 1266 return {}; 1267 1268 // Already counted this expression. Stop. 1269 if (!ReusedExprs.insert(Root).second) 1270 return {}; 1271 1272 OpInfoTy SharedCount; 1273 OpInfoTy Count; 1274 1275 auto I = Shared.find(Root); 1276 auto CM = Inst2ColumnMatrix.find(Root); 1277 if (I->second.size() == 1) 1278 Count = CM->second.getOpInfo(); 1279 else 1280 SharedCount = CM->second.getOpInfo(); 1281 1282 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1283 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1284 Count += C.first; 1285 SharedCount += C.second; 1286 } 1287 return {Count, SharedCount}; 1288 } 1289 1290 void emitRemarks() { 1291 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1292 return; 1293 1294 // Map matrix operations to their containting subprograms, by traversing 1295 // the inlinedAt chain. If the function does not have a DISubprogram, we 1296 // only map them to the containing function. 1297 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1298 for (auto &KV : Inst2ColumnMatrix) { 1299 if (Func.getSubprogram()) { 1300 auto *I = cast<Instruction>(KV.first); 1301 DILocation *Context = I->getDebugLoc(); 1302 while (Context) { 1303 auto I = 1304 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1305 I.first->second.push_back(KV.first); 1306 Context = DebugLoc(Context).getInlinedAt(); 1307 } 1308 } else { 1309 auto I = Subprog2Exprs.insert({nullptr, {}}); 1310 I.first->second.push_back(KV.first); 1311 } 1312 } 1313 for (auto &KV : Subprog2Exprs) { 1314 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1315 KV.second.end()); 1316 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1317 1318 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1319 for (Value *Leaf : Leaves) 1320 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1321 1322 // Generate remarks for each leaf. 1323 for (auto *L : Leaves) { 1324 1325 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1326 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1327 while (Context) { 1328 if (getSubprogram(Context->getScope()) == KV.first) { 1329 Loc = Context; 1330 break; 1331 } 1332 Context = DebugLoc(Context).getInlinedAt(); 1333 } 1334 1335 SmallPtrSet<Value *, 8> ReusedExprs; 1336 OpInfoTy Counts, SharedCounts; 1337 std::tie(Counts, SharedCounts) = 1338 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1339 1340 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 1341 cast<Instruction>(L)->getParent()); 1342 1343 Rem << "Lowered with "; 1344 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1345 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1346 << ore::NV("NumComputeOps", Counts.NumComputeOps) 1347 << " compute ops"; 1348 1349 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1350 SharedCounts.NumComputeOps > 0) { 1351 Rem << ",\nadditionally " 1352 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1353 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1354 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1355 << " compute ops" 1356 << " are shared with other expressions"; 1357 } 1358 1359 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 1360 ORE.emit(Rem); 1361 } 1362 } 1363 } 1364 1365 std::string 1366 linearize(Value *L, 1367 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1368 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1369 const DataLayout &DL) { 1370 ExprLinearizer Lin(DL, Inst2ColumnMatrix, Shared, ExprsInSubprogram, L); 1371 Lin.linearizeExpr(L, 0, false, false); 1372 return Lin.getResult(); 1373 } 1374 }; 1375 }; 1376 } // namespace 1377 1378 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1379 FunctionAnalysisManager &AM) { 1380 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1381 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1382 LowerMatrixIntrinsics LMT(F, TTI, ORE); 1383 if (LMT.Visit()) { 1384 PreservedAnalyses PA; 1385 PA.preserveSet<CFGAnalyses>(); 1386 return PA; 1387 } 1388 return PreservedAnalyses::all(); 1389 } 1390 1391 namespace { 1392 1393 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1394 public: 1395 static char ID; 1396 1397 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1398 initializeLowerMatrixIntrinsicsLegacyPassPass( 1399 *PassRegistry::getPassRegistry()); 1400 } 1401 1402 bool runOnFunction(Function &F) override { 1403 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1404 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1405 LowerMatrixIntrinsics LMT(F, TTI, ORE); 1406 bool C = LMT.Visit(); 1407 return C; 1408 } 1409 1410 void getAnalysisUsage(AnalysisUsage &AU) const override { 1411 AU.addRequired<TargetTransformInfoWrapperPass>(); 1412 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1413 AU.setPreservesCFG(); 1414 } 1415 }; 1416 } // namespace 1417 1418 static const char pass_name[] = "Lower the matrix intrinsics"; 1419 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1420 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1421 false, false) 1422 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1423 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1424 false, false) 1425 1426 Pass *llvm::createLowerMatrixIntrinsicsPass() { 1427 return new LowerMatrixIntrinsicsLegacyPass(); 1428 } 1429