1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Implement multiply & add fusion 13 // * Add remark, summarizing the available matrix optimization opportunities 14 // (WIP). 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 19 #include "llvm/ADT/GraphTraits.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 23 #include "llvm/Analysis/TargetTransformInfo.h" 24 #include "llvm/Analysis/ValueTracking.h" 25 #include "llvm/Analysis/VectorUtils.h" 26 #include "llvm/IR/CFG.h" 27 #include "llvm/IR/DataLayout.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicInst.h" 32 #include "llvm/IR/PatternMatch.h" 33 #include "llvm/InitializePasses.h" 34 #include "llvm/Pass.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Transforms/Scalar.h" 37 38 using namespace llvm; 39 using namespace PatternMatch; 40 41 #define DEBUG_TYPE "lower-matrix-intrinsics" 42 43 static cl::opt<bool> EnableShapePropagation( 44 "matrix-propagate-shape", cl::init(true), cl::Hidden, 45 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 46 "instructions.")); 47 48 static cl::opt<bool> AllowContractEnabled( 49 "matrix-allow-contract", cl::init(false), cl::Hidden, 50 cl::desc("Allow the use of FMAs if available and profitable. This may " 51 "result in different results, due to less rounding error.")); 52 53 namespace { 54 55 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 56 // the start address of column \p Col with type (\p EltType x \p NumRows) 57 // assuming \p Stride elements between start two consecutive columns. 58 // \p Stride must be >= \p NumRows. 59 // 60 // Consider a 4x4 matrix like below 61 // 62 // 0 1 2 3 63 // 0 v_0_0 v_0_1 v_0_2 v_0_3 64 // 1 v_1_0 v_1_1 v_1_2 v_1_3 65 // 2 v_2_0 v_2_1 v_2_2 v_2_3 66 // 3 v_3_0 v_3_1 v_3_2 v_3_3 67 68 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 69 // we need a pointer to the first element of the submatrix as base pointer. 70 // Then we can use computeColumnAddr to compute the addresses for the columns 71 // of the sub-matrix. 72 // 73 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 74 // -> just returns Base 75 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 76 // -> returns Base + (1 * 4) 77 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 78 // -> returns Base + (2 * 4) 79 // 80 // The graphic below illustrates the number of elements in a column (marked 81 // with |) and the number of skipped elements (marked with }). 82 // 83 // v_0_0 v_0_1 {v_0_2 {v_0_3 84 // Base Col 1 Col 2 85 // | | | 86 // v_1_0 |v_1_1 |v_1_2 |v_1_3 87 // v_2_0 |v_2_1 |v_2_2 |v_2_3 88 // v_3_0 {v_3_1 {v_3_2 v_3_3 89 // 90 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 91 unsigned NumRows, Type *EltType, 92 IRBuilder<> &Builder) { 93 94 assert((!isa<ConstantInt>(Stride) || 95 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 96 "Stride must be >= the number of rows."); 97 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 98 99 // Compute the start of the column with index Col as Col * Stride. 100 Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); 101 102 // Get pointer to the start of the selected column. Skip GEP creation, 103 // if we select column 0. 104 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 105 ColumnStart = BasePtr; 106 else 107 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); 108 109 // Cast elementwise column start pointer to a pointer to a column 110 // (EltType x NumRows)*. 111 Type *ColumnType = VectorType::get(EltType, NumRows); 112 Type *ColumnPtrType = PointerType::get(ColumnType, AS); 113 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); 114 } 115 116 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 117 /// 118 /// Currently, the lowering for each matrix intrinsic is done as follows: 119 /// 1. Propagate the shape information from intrinsics to connected 120 /// instructions. 121 /// 2. Lower instructions with shape information. 122 /// 2.1. Get column vectors for each argument. If we already lowered the 123 /// definition of an argument, use the produced column vectors directly. 124 /// If not, split the operand vector containing an embedded matrix into 125 /// a set of column vectors, 126 /// 2.2. Lower the instruction in terms of columnwise operations, which yields 127 /// a set of column vectors containing result matrix. Note that we lower 128 /// all instructions that have shape information. Besides the intrinsics, 129 /// this includes stores for example. 130 /// 2.3. Update uses of the lowered instruction. If we have shape information 131 /// for a user, there is nothing to do, as we will look up the result 132 /// column matrix when lowering the user. For other uses, we embed the 133 /// result matrix in a flat vector and update the use. 134 /// 2.4. Cache the result column matrix for the instruction we lowered 135 /// 3. After we lowered all instructions in a function, remove the now 136 /// obsolete instructions. 137 /// 138 class LowerMatrixIntrinsics { 139 Function &Func; 140 const DataLayout &DL; 141 const TargetTransformInfo &TTI; 142 OptimizationRemarkEmitter &ORE; 143 144 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 145 struct OpInfoTy { 146 /// Number of stores emitted to generate this matrix. 147 unsigned NumStores = 0; 148 /// Number of loads emitted to generate this matrix. 149 unsigned NumLoads = 0; 150 /// Number of compute operations emitted to generate this matrix. 151 unsigned NumComputeOps = 0; 152 153 OpInfoTy &operator+=(const OpInfoTy &RHS) { 154 NumStores += RHS.NumStores; 155 NumLoads += RHS.NumLoads; 156 NumComputeOps += RHS.NumComputeOps; 157 return *this; 158 } 159 }; 160 161 /// Wrapper class representing a matrix as a set of column vectors. 162 /// All column vectors must have the same vector type. 163 class ColumnMatrixTy { 164 SmallVector<Value *, 16> Columns; 165 166 OpInfoTy OpInfo; 167 168 public: 169 ColumnMatrixTy() : Columns() {} 170 ColumnMatrixTy(ArrayRef<Value *> Cols) 171 : Columns(Cols.begin(), Cols.end()) {} 172 173 Value *getColumn(unsigned i) const { return Columns[i]; } 174 175 void setColumn(unsigned i, Value *V) { Columns[i] = V; } 176 177 size_t getNumColumns() const { return Columns.size(); } 178 size_t getNumRows() const { 179 assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); 180 return cast<VectorType>(Columns[0]->getType())->getNumElements(); 181 } 182 183 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } 184 185 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } 186 187 void addColumn(Value *V) { Columns.push_back(V); } 188 189 VectorType *getColumnTy() { 190 return cast<VectorType>(Columns[0]->getType()); 191 } 192 193 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 194 return make_range(Columns.begin(), Columns.end()); 195 } 196 197 /// Embed the columns of the matrix into a flat vector by concatenating 198 /// them. 199 Value *embedInVector(IRBuilder<> &Builder) const { 200 return Columns.size() == 1 ? Columns[0] 201 : concatenateVectors(Builder, Columns); 202 } 203 204 ColumnMatrixTy &addNumLoads(unsigned N) { 205 OpInfo.NumLoads += N; 206 return *this; 207 } 208 209 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 210 211 ColumnMatrixTy &addNumStores(unsigned N) { 212 OpInfo.NumStores += N; 213 return *this; 214 } 215 216 ColumnMatrixTy &addNumComputeOps(unsigned N) { 217 OpInfo.NumComputeOps += N; 218 return *this; 219 } 220 221 unsigned getNumStores() const { return OpInfo.NumStores; } 222 unsigned getNumLoads() const { return OpInfo.NumLoads; } 223 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 224 225 const OpInfoTy &getOpInfo() const { return OpInfo; } 226 }; 227 228 struct ShapeInfo { 229 unsigned NumRows; 230 unsigned NumColumns; 231 232 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 233 : NumRows(NumRows), NumColumns(NumColumns) {} 234 235 ShapeInfo(Value *NumRows, Value *NumColumns) 236 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), 237 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} 238 239 bool operator==(const ShapeInfo &other) { 240 return NumRows == other.NumRows && NumColumns == other.NumColumns; 241 } 242 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 243 244 /// Returns true if shape-information is defined, meaning both dimensions 245 /// are != 0. 246 operator bool() const { 247 assert(NumRows == 0 || NumColumns != 0); 248 return NumRows != 0; 249 } 250 }; 251 252 /// Maps instructions to their shape information. The shape information 253 /// describes the shape to be used while lowering. This matches the shape of 254 /// the result value of the instruction, with the only exceptions being store 255 /// instructions and the matrix_columnwise_store intrinsics. For those, the 256 /// shape information indicates that those instructions should be lowered 257 /// using shape information as well. 258 DenseMap<Value *, ShapeInfo> ShapeMap; 259 260 /// List of instructions to remove. While lowering, we are not replacing all 261 /// users of a lowered instruction, if shape information is available and 262 /// those need to be removed after we finished lowering. 263 SmallVector<Instruction *, 16> ToRemove; 264 265 /// Map from instructions to their produced column matrix. 266 MapVector<Value *, ColumnMatrixTy> Inst2ColumnMatrix; 267 268 public: 269 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 270 OptimizationRemarkEmitter &ORE) 271 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {} 272 273 unsigned getNumOps(Type *VT) { 274 assert(isa<VectorType>(VT) && "Expected vector type"); 275 return getNumOps(VT->getScalarType(), 276 cast<VectorType>(VT)->getNumElements()); 277 } 278 279 // 280 /// Return the estimated number of vector ops required for an operation on 281 /// \p VT * N. 282 unsigned getNumOps(Type *ST, unsigned N) { 283 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 284 double(TTI.getRegisterBitWidth(true))); 285 } 286 287 /// Return the set of column vectors that a matrix value is lowered to. 288 /// 289 /// If we lowered \p MatrixVal, just return the cache result column matrix. 290 /// Otherwie split the flat vector \p MatrixVal containing a matrix with 291 /// shape \p SI into column vectors. 292 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 293 IRBuilder<> &Builder) { 294 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 295 assert(VType && "MatrixVal must be a vector type"); 296 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 297 "The vector size must match the number of matrix elements"); 298 299 // Check if we lowered MatrixVal using shape information. In that case, 300 // return the existing column matrix, if it matches the requested shape 301 // information. If there is a mis-match, embed the result in a flat 302 // vector and split it later. 303 auto Found = Inst2ColumnMatrix.find(MatrixVal); 304 if (Found != Inst2ColumnMatrix.end()) { 305 ColumnMatrixTy &M = Found->second; 306 // Return the found matrix, if its shape matches the requested shape 307 // information 308 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 309 return M; 310 311 MatrixVal = M.embedInVector(Builder); 312 } 313 314 // Otherwise split MatrixVal. 315 SmallVector<Value *, 16> SplitVecs; 316 Value *Undef = UndefValue::get(VType); 317 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 318 MaskStart += SI.NumRows) { 319 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 320 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 321 SplitVecs.push_back(V); 322 } 323 324 return {SplitVecs}; 325 } 326 327 /// If \p V already has a known shape return false. Otherwise set the shape 328 /// for instructions that support it. 329 bool setShapeInfo(Value *V, ShapeInfo Shape) { 330 assert(Shape && "Shape not set"); 331 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 332 return false; 333 334 auto SIter = ShapeMap.find(V); 335 if (SIter != ShapeMap.end()) { 336 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 337 << SIter->second.NumRows << " " 338 << SIter->second.NumColumns << " for " << *V << "\n"); 339 return false; 340 } 341 342 ShapeMap.insert({V, Shape}); 343 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 344 << " for " << *V << "\n"); 345 return true; 346 } 347 348 bool isUniformShape(Value *V) { 349 Instruction *I = dyn_cast<Instruction>(V); 350 if (!I) 351 return true; 352 353 switch (I->getOpcode()) { 354 case Instruction::FAdd: 355 case Instruction::FSub: 356 case Instruction::FMul: // Scalar multiply. 357 case Instruction::Add: 358 case Instruction::Mul: 359 case Instruction::Sub: 360 return true; 361 default: 362 return false; 363 } 364 } 365 366 /// Returns true if shape information can be used for \p V. The supported 367 /// instructions must match the instructions that can be lowered by this pass. 368 bool supportsShapeInfo(Value *V) { 369 Instruction *Inst = dyn_cast<Instruction>(V); 370 if (!Inst) 371 return false; 372 373 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 374 if (II) 375 switch (II->getIntrinsicID()) { 376 case Intrinsic::matrix_multiply: 377 case Intrinsic::matrix_transpose: 378 case Intrinsic::matrix_columnwise_load: 379 case Intrinsic::matrix_columnwise_store: 380 return true; 381 default: 382 return false; 383 } 384 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 385 } 386 387 /// Propagate the shape information of instructions to their users. 388 /// The work list contains instructions for which we can compute the shape, 389 /// either based on the information provided by matrix intrinsics or known 390 /// shapes of operands. 391 SmallVector<Instruction *, 32> 392 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 393 SmallVector<Instruction *, 32> NewWorkList; 394 // Pop an element for which we guaranteed to have at least one of the 395 // operand shapes. Add the shape for this and then add users to the work 396 // list. 397 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 398 while (!WorkList.empty()) { 399 Instruction *Inst = WorkList.back(); 400 WorkList.pop_back(); 401 402 // New entry, set the value and insert operands 403 bool Propagate = false; 404 405 Value *MatrixA; 406 Value *MatrixB; 407 Value *M; 408 Value *N; 409 Value *K; 410 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 411 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 412 m_Value(N), m_Value(K)))) { 413 Propagate = setShapeInfo(Inst, {M, K}); 414 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 415 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 416 // Flip dimensions. 417 Propagate = setShapeInfo(Inst, {N, M}); 418 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 419 m_Value(MatrixA), m_Value(), m_Value(), 420 m_Value(M), m_Value(N)))) { 421 Propagate = setShapeInfo(Inst, {N, M}); 422 } else if (match(Inst, 423 m_Intrinsic<Intrinsic::matrix_columnwise_load>( 424 m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 425 Propagate = setShapeInfo(Inst, {M, N}); 426 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 427 auto OpShape = ShapeMap.find(MatrixA); 428 if (OpShape != ShapeMap.end()) 429 setShapeInfo(Inst, OpShape->second); 430 continue; 431 } else if (isUniformShape(Inst)) { 432 // Find the first operand that has a known shape and use that. 433 for (auto &Op : Inst->operands()) { 434 auto OpShape = ShapeMap.find(Op.get()); 435 if (OpShape != ShapeMap.end()) { 436 Propagate |= setShapeInfo(Inst, OpShape->second); 437 break; 438 } 439 } 440 } 441 442 if (Propagate) { 443 NewWorkList.push_back(Inst); 444 for (auto *User : Inst->users()) 445 if (ShapeMap.count(User) == 0) 446 WorkList.push_back(cast<Instruction>(User)); 447 } 448 } 449 450 return NewWorkList; 451 } 452 453 /// Propagate the shape to operands of instructions with shape information. 454 /// \p Worklist contains the instruction for which we already know the shape. 455 SmallVector<Instruction *, 32> 456 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 457 SmallVector<Instruction *, 32> NewWorkList; 458 459 auto pushInstruction = [](Value *V, 460 SmallVectorImpl<Instruction *> &WorkList) { 461 Instruction *I = dyn_cast<Instruction>(V); 462 if (I) 463 WorkList.push_back(I); 464 }; 465 // Pop an element with known shape. Traverse the operands, if their shape 466 // derives from the result shape and is unknown, add it and add them to the 467 // worklist. 468 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 469 while (!WorkList.empty()) { 470 Value *V = WorkList.back(); 471 WorkList.pop_back(); 472 473 size_t BeforeProcessingV = WorkList.size(); 474 if (!isa<Instruction>(V)) 475 continue; 476 477 Value *MatrixA; 478 Value *MatrixB; 479 Value *M; 480 Value *N; 481 Value *K; 482 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 483 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 484 m_Value(N), m_Value(K)))) { 485 if (setShapeInfo(MatrixA, {M, N})) 486 pushInstruction(MatrixA, WorkList); 487 488 if (setShapeInfo(MatrixB, {N, K})) 489 pushInstruction(MatrixB, WorkList); 490 491 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 492 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 493 // Flip dimensions. 494 if (setShapeInfo(MatrixA, {M, N})) 495 pushInstruction(MatrixA, WorkList); 496 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 497 m_Value(MatrixA), m_Value(), m_Value(), 498 m_Value(M), m_Value(N)))) { 499 if (setShapeInfo(MatrixA, {M, N})) { 500 pushInstruction(MatrixA, WorkList); 501 } 502 } else if (isa<LoadInst>(V) || 503 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { 504 // Nothing to do, no matrix input. 505 } else if (isa<StoreInst>(V)) { 506 // Nothing to do. We forward-propagated to this so we would just 507 // backward propagate to an instruction with an already known shape. 508 } else if (isUniformShape(V)) { 509 // Propagate to all operands. 510 ShapeInfo Shape = ShapeMap[V]; 511 for (Use &U : cast<Instruction>(V)->operands()) { 512 if (setShapeInfo(U.get(), Shape)) 513 pushInstruction(U.get(), WorkList); 514 } 515 } 516 // After we discovered new shape info for new instructions in the 517 // worklist, we use their users as seeds for the next round of forward 518 // propagation. 519 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 520 for (User *U : WorkList[I]->users()) 521 if (isa<Instruction>(U) && V != U) 522 NewWorkList.push_back(cast<Instruction>(U)); 523 } 524 return NewWorkList; 525 } 526 527 bool Visit() { 528 if (EnableShapePropagation) { 529 SmallVector<Instruction *, 32> WorkList; 530 531 // Initially only the shape of matrix intrinsics is known. 532 // Initialize the work list with ops carrying shape information. 533 for (BasicBlock &BB : Func) 534 for (Instruction &Inst : BB) { 535 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 536 if (!II) 537 continue; 538 539 switch (II->getIntrinsicID()) { 540 case Intrinsic::matrix_multiply: 541 case Intrinsic::matrix_transpose: 542 case Intrinsic::matrix_columnwise_load: 543 case Intrinsic::matrix_columnwise_store: 544 WorkList.push_back(&Inst); 545 break; 546 default: 547 break; 548 } 549 } 550 // Propagate shapes until nothing changes any longer. 551 while (!WorkList.empty()) { 552 WorkList = propagateShapeForward(WorkList); 553 WorkList = propagateShapeBackward(WorkList); 554 } 555 } 556 557 ReversePostOrderTraversal<Function *> RPOT(&Func); 558 bool Changed = false; 559 for (auto *BB : RPOT) { 560 for (Instruction &Inst : make_early_inc_range(*BB)) { 561 IRBuilder<> Builder(&Inst); 562 563 if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 564 Changed |= VisitCallInst(CInst); 565 566 Value *Op1; 567 Value *Op2; 568 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) 569 Changed |= VisitBinaryOperator(BinOp); 570 if (match(&Inst, m_Load(m_Value(Op1)))) 571 Changed |= VisitLoad(&Inst, Op1, Builder); 572 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 573 Changed |= VisitStore(&Inst, Op1, Op2, Builder); 574 } 575 } 576 577 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, DL); 578 RemarkGen.emitRemarks(); 579 580 for (Instruction *Inst : reverse(ToRemove)) 581 Inst->eraseFromParent(); 582 583 return Changed; 584 } 585 586 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 587 IRBuilder<> &Builder) { 588 return Builder.CreateAlignedLoad( 589 ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load"); 590 } 591 592 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 593 Type *EltType, IRBuilder<> &Builder) { 594 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, 595 DL.getABITypeAlign(EltType)); 596 } 597 598 599 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 600 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 601 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 602 Type *EltPtrType = PointerType::get(EltType, AS); 603 return Builder.CreatePointerCast(BasePtr, EltPtrType); 604 } 605 606 /// Replace intrinsic calls 607 bool VisitCallInst(CallInst *Inst) { 608 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 609 return false; 610 611 switch (Inst->getCalledFunction()->getIntrinsicID()) { 612 case Intrinsic::matrix_multiply: 613 LowerMultiply(Inst); 614 break; 615 case Intrinsic::matrix_transpose: 616 LowerTranspose(Inst); 617 break; 618 case Intrinsic::matrix_columnwise_load: 619 LowerColumnwiseLoad(Inst); 620 break; 621 case Intrinsic::matrix_columnwise_store: 622 LowerColumnwiseStore(Inst); 623 break; 624 default: 625 return false; 626 } 627 return true; 628 } 629 630 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, 631 ShapeInfo Shape) { 632 IRBuilder<> Builder(Inst); 633 auto VType = cast<VectorType>(Inst->getType()); 634 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 635 ColumnMatrixTy Result; 636 // Distance between start of one column and the start of the next 637 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 638 Value *GEP = 639 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 640 VType->getElementType(), Builder); 641 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 642 Result.addColumn(Column); 643 } 644 645 finalizeLowering(Inst, 646 Result.addNumLoads(getNumOps(Result.getColumnTy()) * 647 Result.getNumColumns()), 648 Builder); 649 } 650 651 /// Lowers llvm.matrix.columnwise.load. 652 /// 653 /// The intrinsic loads a matrix from memory using a stride between columns. 654 void LowerColumnwiseLoad(CallInst *Inst) { 655 Value *Ptr = Inst->getArgOperand(0); 656 Value *Stride = Inst->getArgOperand(1); 657 LowerLoad(Inst, Ptr, Stride, 658 {Inst->getArgOperand(2), Inst->getArgOperand(3)}); 659 } 660 661 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 662 ShapeInfo Shape) { 663 IRBuilder<> Builder(Inst); 664 auto VType = cast<VectorType>(Matrix->getType()); 665 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 666 auto LM = getMatrix(Matrix, Shape, Builder); 667 for (auto C : enumerate(LM.columns())) { 668 Value *GEP = 669 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, 670 Shape.NumRows, VType->getElementType(), Builder); 671 createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 672 } 673 Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores( 674 getNumOps(LM.getColumnTy()) * LM.getNumColumns()); 675 676 ToRemove.push_back(Inst); 677 } 678 679 /// Lowers llvm.matrix.columnwise.store. 680 /// 681 /// The intrinsic store a matrix back memory using a stride between columns. 682 void LowerColumnwiseStore(CallInst *Inst) { 683 Value *Matrix = Inst->getArgOperand(0); 684 Value *Ptr = Inst->getArgOperand(1); 685 Value *Stride = Inst->getArgOperand(2); 686 LowerStore(Inst, Matrix, Ptr, Stride, 687 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 688 } 689 690 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 691 /// the matrix \p LM represented as a vector of column vectors. 692 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, 693 unsigned NumElts, IRBuilder<> &Builder) { 694 Value *Col = LM.getColumn(J); 695 Value *Undef = UndefValue::get(Col->getType()); 696 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 697 return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 698 } 699 700 // Set elements I..I+NumElts-1 to Block 701 Value *insertVector(Value *Col, unsigned I, Value *Block, 702 IRBuilder<> &Builder) { 703 704 // First, bring Block to the same size as Col 705 unsigned BlockNumElts = 706 cast<VectorType>(Block->getType())->getNumElements(); 707 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 708 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 709 710 Value *ExtendMask = 711 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 712 Value *Undef = UndefValue::get(Block->getType()); 713 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 714 715 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 716 // 8, 4, 5, 6 717 SmallVector<Constant *, 16> Mask; 718 unsigned i; 719 for (i = 0; i < I; i++) 720 Mask.push_back(Builder.getInt32(i)); 721 722 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 723 for (; i < I + BlockNumElts; i++) 724 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 725 726 for (; i < VecNumElts; i++) 727 Mask.push_back(Builder.getInt32(i)); 728 729 Value *MaskVal = ConstantVector::get(Mask); 730 731 return Builder.CreateShuffleVector(Col, Block, MaskVal); 732 } 733 734 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 735 IRBuilder<> &Builder, bool AllowContraction, 736 unsigned &NumComputeOps) { 737 NumComputeOps += getNumOps(A->getType()); 738 if (!Sum) 739 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 740 741 if (UseFPOp) { 742 if (AllowContraction) { 743 // Use fmuladd for floating point operations and let the backend decide 744 // if that's profitable. 745 Function *FMulAdd = Intrinsic::getDeclaration( 746 Func.getParent(), Intrinsic::fmuladd, A->getType()); 747 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 748 } 749 NumComputeOps += getNumOps(A->getType()); 750 Value *Mul = Builder.CreateFMul(A, B); 751 return Builder.CreateFAdd(Sum, Mul); 752 } 753 754 NumComputeOps += getNumOps(A->getType()); 755 Value *Mul = Builder.CreateMul(A, B); 756 return Builder.CreateAdd(Sum, Mul); 757 } 758 759 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 760 /// users with shape information, there's nothing to do: the will use the 761 /// cached value when they are lowered. For other users, \p Matrix is 762 /// flattened and the uses are updated to use it. Also marks \p Inst for 763 /// deletion. 764 void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, 765 IRBuilder<> &Builder) { 766 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 767 768 ToRemove.push_back(Inst); 769 Value *Flattened = nullptr; 770 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 771 Use &U = *I++; 772 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 773 if (!Flattened) 774 Flattened = Matrix.embedInVector(Builder); 775 U.set(Flattened); 776 } 777 } 778 } 779 780 /// Lowers llvm.matrix.multiply. 781 void LowerMultiply(CallInst *MatMul) { 782 IRBuilder<> Builder(MatMul); 783 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 784 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 785 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 786 787 const ColumnMatrixTy &Lhs = 788 getMatrix(MatMul->getArgOperand(0), LShape, Builder); 789 const ColumnMatrixTy &Rhs = 790 getMatrix(MatMul->getArgOperand(1), RShape, Builder); 791 792 const unsigned R = LShape.NumRows; 793 const unsigned M = LShape.NumColumns; 794 const unsigned C = RShape.NumColumns; 795 assert(M == RShape.NumRows); 796 797 // Initialize the output 798 ColumnMatrixTy Result; 799 for (unsigned J = 0; J < C; ++J) 800 Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 801 802 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / 803 EltType->getPrimitiveSizeInBits(), 804 uint64_t(1)); 805 806 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 807 MatMul->hasAllowContract()); 808 unsigned NumComputeOps = 0; 809 // Multiply columns from the first operand with scalars from the second 810 // operand. Then move along the K axes and accumulate the columns. With 811 // this the adds can be vectorized without reassociation. 812 for (unsigned J = 0; J < C; ++J) { 813 unsigned BlockSize = VF; 814 for (unsigned I = 0; I < R; I += BlockSize) { 815 // Gradually lower the vectorization factor to cover the remainder. 816 while (I + BlockSize > R) 817 BlockSize /= 2; 818 819 Value *Sum = nullptr; 820 for (unsigned K = 0; K < M; ++K) { 821 Value *L = extractVector(Lhs, I, K, BlockSize, Builder); 822 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); 823 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 824 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), 825 Builder, AllowContract, NumComputeOps); 826 } 827 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 828 } 829 } 830 Result.addNumComputeOps(NumComputeOps); 831 finalizeLowering(MatMul, Result, Builder); 832 } 833 834 /// Lowers llvm.matrix.transpose. 835 void LowerTranspose(CallInst *Inst) { 836 ColumnMatrixTy Result; 837 IRBuilder<> Builder(Inst); 838 Value *InputVal = Inst->getArgOperand(0); 839 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 840 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 841 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 842 843 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 844 // Build a single column vector for this row. First initialize it. 845 Value *ResultColumn = UndefValue::get( 846 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 847 848 // Go through the elements of this row and insert it into the resulting 849 // column vector. 850 for (auto C : enumerate(InputMatrix.columns())) { 851 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 852 // We insert at index Column since that is the row index after the 853 // transpose. 854 ResultColumn = 855 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 856 } 857 Result.addColumn(ResultColumn); 858 } 859 860 // TODO: Improve estimate of operations needed for transposes. Currently we 861 // just count the insertelement/extractelement instructions, but do not 862 // account for later simplifications/combines. 863 finalizeLowering( 864 Inst, 865 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 866 Builder); 867 } 868 869 /// Lower load instructions, if shape information is available. 870 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { 871 auto I = ShapeMap.find(Inst); 872 if (I == ShapeMap.end()) 873 return false; 874 875 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); 876 return true; 877 } 878 879 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 880 IRBuilder<> &Builder) { 881 auto I = ShapeMap.find(StoredVal); 882 if (I == ShapeMap.end()) 883 return false; 884 885 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); 886 return true; 887 } 888 889 /// Lower binary operators, if shape information is available. 890 bool VisitBinaryOperator(BinaryOperator *Inst) { 891 auto I = ShapeMap.find(Inst); 892 if (I == ShapeMap.end()) 893 return false; 894 895 Value *Lhs = Inst->getOperand(0); 896 Value *Rhs = Inst->getOperand(1); 897 898 IRBuilder<> Builder(Inst); 899 ShapeInfo &Shape = I->second; 900 901 ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); 902 ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); 903 904 // Add each column and store the result back into the opmapping 905 ColumnMatrixTy Result; 906 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { 907 switch (Inst->getOpcode()) { 908 case Instruction::Add: 909 return Builder.CreateAdd(LHS, RHS); 910 case Instruction::Mul: 911 return Builder.CreateMul(LHS, RHS); 912 case Instruction::Sub: 913 return Builder.CreateSub(LHS, RHS); 914 case Instruction::FAdd: 915 return Builder.CreateFAdd(LHS, RHS); 916 case Instruction::FMul: 917 return Builder.CreateFMul(LHS, RHS); 918 case Instruction::FSub: 919 return Builder.CreateFSub(LHS, RHS); 920 default: 921 llvm_unreachable("Unsupported binary operator for matrix"); 922 } 923 }; 924 for (unsigned C = 0; C < Shape.NumColumns; ++C) 925 Result.addColumn( 926 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); 927 928 finalizeLowering(Inst, 929 Result.addNumComputeOps(getNumOps(Result.getColumnTy()) * 930 Result.getNumColumns()), 931 Builder); 932 return true; 933 } 934 935 /// Helper to linearize a matrix expression tree into a string. Currently 936 /// matrix expressions are linarized by starting at an expression leaf and 937 /// linearizing bottom up. 938 struct ExprLinearizer { 939 unsigned LengthToBreak = 100; 940 std::string Str; 941 raw_string_ostream Stream; 942 unsigned LineLength = 0; 943 const DataLayout &DL; 944 945 /// Mapping from instructions to column matrixes. It is used to identify 946 /// matrix instructions. 947 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix; 948 949 /// Mapping from values to the leaves of all expressions that the value is 950 /// part of. 951 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 952 953 /// Leaf node of the expression to linearize. 954 Value *Leaf; 955 956 /// Used to keep track of sub-expressions that get reused while linearizing 957 /// the expression. Re-used sub-expressions are marked as (reused). 958 SmallPtrSet<Value *, 8> ReusedExprs; 959 960 ExprLinearizer(const DataLayout &DL, 961 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix, 962 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 963 Value *Leaf) 964 : Str(), Stream(Str), DL(DL), Inst2ColumnMatrix(Inst2ColumnMatrix), 965 Shared(Shared), Leaf(Leaf) {} 966 967 void indent(unsigned N) { 968 LineLength += N; 969 for (unsigned i = 0; i < N; i++) 970 Stream << " "; 971 } 972 973 void lineBreak() { 974 Stream << "\n"; 975 LineLength = 0; 976 } 977 978 void maybeIndent(unsigned Indent) { 979 if (LineLength >= LengthToBreak) 980 lineBreak(); 981 982 if (LineLength == 0) 983 indent(Indent); 984 } 985 986 void write(StringRef S) { 987 LineLength += S.size(); 988 Stream << S; 989 } 990 991 Value *getUnderlyingObjectThroughLoads(Value *V) { 992 if (Value *Ptr = getPointerOperand(V)) 993 return getUnderlyingObjectThroughLoads(Ptr); 994 else if (V->getType()->isPointerTy()) 995 return GetUnderlyingObject(V, DL); 996 return V; 997 } 998 999 /// Returns true if \p V is a matrix value. 1000 bool isMatrix(Value *V) const { 1001 return Inst2ColumnMatrix.find(V) != Inst2ColumnMatrix.end(); 1002 } 1003 1004 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1005 /// \p SS. 1006 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1007 auto M = Inst2ColumnMatrix.find(V); 1008 if (M == Inst2ColumnMatrix.end()) 1009 SS << "unknown"; 1010 else { 1011 SS << M->second.getNumRows(); 1012 SS << "x"; 1013 SS << M->second.getNumColumns(); 1014 } 1015 } 1016 1017 /// Write the called function name. Handles calls to llvm.matrix.* 1018 /// specially: we write the name, followed by the dimensions of the input 1019 /// matrixes, followed by the scalar type name. 1020 void writeFnName(CallInst *CI) { 1021 if (!CI->getCalledFunction()) 1022 write("<no called fn>"); 1023 else { 1024 StringRef Name = CI->getCalledFunction()->getName(); 1025 if (!Name.startswith("llvm.matrix")) { 1026 write(Name); 1027 return; 1028 } 1029 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1030 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1031 .drop_front(StringRef("llvm.matrix.").size())); 1032 write("."); 1033 std::string Tmp = ""; 1034 raw_string_ostream SS(Tmp); 1035 1036 switch (II->getIntrinsicID()) { 1037 case Intrinsic::matrix_multiply: 1038 prettyPrintMatrixType(II->getOperand(0), SS); 1039 SS << "."; 1040 prettyPrintMatrixType(II->getOperand(1), SS); 1041 SS << "." << *II->getType()->getScalarType(); 1042 break; 1043 case Intrinsic::matrix_transpose: 1044 prettyPrintMatrixType(II->getOperand(0), SS); 1045 SS << "." << *II->getType()->getScalarType(); 1046 break; 1047 case Intrinsic::matrix_columnwise_load: 1048 prettyPrintMatrixType(II, SS); 1049 SS << "." << *II->getType()->getScalarType(); 1050 break; 1051 case Intrinsic::matrix_columnwise_store: 1052 prettyPrintMatrixType(II->getOperand(0), SS); 1053 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1054 break; 1055 default: 1056 llvm_unreachable("Unhandled case"); 1057 } 1058 SS.flush(); 1059 write(Tmp); 1060 } 1061 } 1062 1063 unsigned getNumShapeArgs(CallInst *CI) const { 1064 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1065 switch (II->getIntrinsicID()) { 1066 case Intrinsic::matrix_multiply: 1067 return 3; 1068 case Intrinsic::matrix_transpose: 1069 case Intrinsic::matrix_columnwise_load: 1070 case Intrinsic::matrix_columnwise_store: 1071 return 2; 1072 default: 1073 return 0; 1074 } 1075 } 1076 return 0; 1077 } 1078 1079 /// Special printing for values: for pointers, we print if they refer to an 1080 /// (function) external address or a stack address, for other values we 1081 /// either print the constant or "scalar"/"matrix" for other values. 1082 void write(Value *V) { 1083 V = getUnderlyingObjectThroughLoads(V); 1084 if (V->getType()->isPointerTy()) { 1085 if (isa<AllocaInst>(V)) { 1086 Stream << "stack addr"; 1087 LineLength += StringRef("stack addr").size(); 1088 } else { 1089 Stream << "addr"; 1090 LineLength += StringRef("addr").size(); 1091 } 1092 if (!V->getName().empty()) { 1093 Stream << " %" << V->getName() << ""; 1094 LineLength += V->getName().size() + 2; 1095 } 1096 return; 1097 } 1098 1099 std::string Tmp; 1100 raw_string_ostream TmpStream(Tmp); 1101 1102 if (auto *CI = dyn_cast<ConstantInt>(V)) 1103 TmpStream << CI->getValue(); 1104 else if (isa<Constant>(V)) 1105 TmpStream << "constant"; 1106 else { 1107 if (isMatrix(V)) 1108 TmpStream << "matrix"; 1109 else 1110 TmpStream << "scalar"; 1111 } 1112 TmpStream.flush(); 1113 Tmp = std::string(StringRef(Tmp).trim()); 1114 LineLength += Tmp.size(); 1115 Stream << Tmp; 1116 } 1117 1118 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1119 /// Expressions that are re-used multiple times are prefixed with (reused) 1120 /// at the re-used root instruction. 1121 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1122 bool ParentShared) { 1123 auto *I = cast<Instruction>(Expr); 1124 maybeIndent(Indent); 1125 SmallVector<Value *, 8> Ops; 1126 1127 // Is Expr shared with other expression leaves? 1128 bool ExprShared = false; 1129 1130 // Deal with shared subtrees. Mark them as shared, if required. 1131 if (!ParentShared) { 1132 auto SI = Shared.find(Expr); 1133 assert(SI != Shared.end() && SI->second.find(Leaf) != SI->second.end()); 1134 1135 for (Value *S : SI->second) { 1136 if (S == Leaf) 1137 continue; 1138 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1139 write("shared with remark at line " + std::to_string(DL.getLine()) + 1140 " column " + std::to_string(DL.getCol()) + " ("); 1141 } 1142 ExprShared = SI->second.size() > 1; 1143 } 1144 1145 bool Reused = !ReusedExprs.insert(Expr).second; 1146 if (Reused && !ParentReused) 1147 write("(reused) "); 1148 1149 if (auto *CI = dyn_cast<CallInst>(I)) { 1150 writeFnName(CI); 1151 1152 Ops.append(CallSite(CI).arg_begin(), 1153 CallSite(CI).arg_end() - getNumShapeArgs(CI)); 1154 } else if (isa<BitCastInst>(Expr)) { 1155 // Special case bitcasts, which are used to materialize matrixes from 1156 // non-matrix ops. 1157 write("matrix"); 1158 return; 1159 } else { 1160 Ops.append(I->value_op_begin(), I->value_op_end()); 1161 write(std::string(I->getOpcodeName())); 1162 } 1163 1164 write(std::string("(")); 1165 1166 unsigned NumOpsToBreak = 1; 1167 if (match(Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) 1168 NumOpsToBreak = 2; 1169 1170 for (Value *Op : Ops) { 1171 if (Ops.size() > NumOpsToBreak) 1172 lineBreak(); 1173 1174 maybeIndent(Indent + 1); 1175 if (isMatrix(Op)) 1176 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1177 else 1178 write(Op); 1179 if (Op != Ops.back()) 1180 write(", "); 1181 } 1182 1183 write(")"); 1184 } 1185 1186 const std::string &getResult() { 1187 Stream.flush(); 1188 return Str; 1189 } 1190 }; 1191 1192 /// Generate remarks for matrix operations in a function. To generate remarks 1193 /// for matrix expressions, the following approach is used: 1194 /// 1. Collect leafs of matrix expressions (done in 1195 /// RemarkGenerator::getExpressionLeaves). Leaves are lowered matrix 1196 /// instructions without other matrix users (like stores). 1197 /// 1198 /// 2. For each leaf, create a remark containing a linearizied version of the 1199 /// matrix expression. 1200 /// 1201 /// TODO: 1202 /// * Summarize number of vector instructions generated for each expression. 1203 /// * Propagate matrix remarks up the inlining chain. 1204 struct RemarkGenerator { 1205 const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix; 1206 OptimizationRemarkEmitter &ORE; 1207 const DataLayout &DL; 1208 1209 RemarkGenerator(const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix, 1210 OptimizationRemarkEmitter &ORE, const DataLayout &DL) 1211 : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), DL(DL) {} 1212 1213 /// Return all leafs of matrix expressions. Those are instructions in 1214 /// Inst2ColumnMatrix returing void. Currently that should only include 1215 /// stores. 1216 SmallVector<Value *, 4> getExpressionLeaves() { 1217 SmallVector<Value *, 4> Leaves; 1218 for (auto &KV : Inst2ColumnMatrix) 1219 if (KV.first->getType()->isVoidTy()) 1220 Leaves.push_back(KV.first); 1221 1222 return Leaves; 1223 } 1224 1225 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1226 /// to all visited expressions in \p Shared. 1227 void collectSharedInfo(Value *Leaf, Value *V, 1228 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1229 1230 if (Inst2ColumnMatrix.find(V) == Inst2ColumnMatrix.end()) 1231 return; 1232 1233 auto I = Shared.insert({V, {}}); 1234 I.first->second.insert(Leaf); 1235 1236 for (Value *Op : cast<Instruction>(V)->operand_values()) 1237 collectSharedInfo(Leaf, Op, Shared); 1238 return; 1239 } 1240 1241 /// Calculate the number of exclusive and shared op counts for expression 1242 /// starting at \p V. Expressions used multiple times are counted once. 1243 std::pair<OpInfoTy, OpInfoTy> 1244 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1245 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1246 auto CM = Inst2ColumnMatrix.find(Root); 1247 if (CM == Inst2ColumnMatrix.end()) 1248 return {}; 1249 1250 // Already counted this expression. Stop. 1251 if (!ReusedExprs.insert(Root).second) 1252 return {}; 1253 1254 OpInfoTy SharedCount; 1255 OpInfoTy Count; 1256 1257 auto I = Shared.find(Root); 1258 if (I->second.size() == 1) 1259 Count = CM->second.getOpInfo(); 1260 else 1261 SharedCount = CM->second.getOpInfo(); 1262 1263 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1264 auto C = sumOpInfos(Op, ReusedExprs, Shared); 1265 Count += C.first; 1266 SharedCount += C.second; 1267 } 1268 return {Count, SharedCount}; 1269 } 1270 1271 void emitRemarks() { 1272 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1273 return; 1274 1275 // Find leafs of matrix expressions. 1276 auto Leaves = getExpressionLeaves(); 1277 1278 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1279 1280 for (Value *Leaf : Leaves) 1281 collectSharedInfo(Leaf, Leaf, Shared); 1282 1283 // Generate remarks for each leaf. 1284 for (auto *L : Leaves) { 1285 SmallPtrSet<Value *, 8> ReusedExprs; 1286 OpInfoTy Counts, SharedCounts; 1287 std::tie(Counts, SharedCounts) = sumOpInfos(L, ReusedExprs, Shared); 1288 1289 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", 1290 cast<Instruction>(L)->getDebugLoc(), 1291 cast<Instruction>(L)->getParent()); 1292 1293 Rem << "Lowered with "; 1294 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1295 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1296 << ore::NV("NumComputeOps", Counts.NumComputeOps) << " compute ops"; 1297 1298 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1299 SharedCounts.NumComputeOps > 0) { 1300 Rem << ",\nadditionally " 1301 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1302 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1303 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1304 << " compute ops" 1305 << " are shared with other expressions"; 1306 } 1307 1308 Rem << ("\n" + linearize(L, Shared, DL)); 1309 ORE.emit(Rem); 1310 } 1311 } 1312 1313 std::string 1314 linearize(Value *L, 1315 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1316 const DataLayout &DL) { 1317 ExprLinearizer Lin(DL, Inst2ColumnMatrix, Shared, L); 1318 Lin.linearizeExpr(L, 0, false, false); 1319 return Lin.getResult(); 1320 } 1321 }; 1322 }; 1323 } // namespace 1324 1325 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1326 FunctionAnalysisManager &AM) { 1327 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1328 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1329 LowerMatrixIntrinsics LMT(F, TTI, ORE); 1330 if (LMT.Visit()) { 1331 PreservedAnalyses PA; 1332 PA.preserveSet<CFGAnalyses>(); 1333 return PA; 1334 } 1335 return PreservedAnalyses::all(); 1336 } 1337 1338 namespace { 1339 1340 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1341 public: 1342 static char ID; 1343 1344 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1345 initializeLowerMatrixIntrinsicsLegacyPassPass( 1346 *PassRegistry::getPassRegistry()); 1347 } 1348 1349 bool runOnFunction(Function &F) override { 1350 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1351 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1352 LowerMatrixIntrinsics LMT(F, TTI, ORE); 1353 bool C = LMT.Visit(); 1354 return C; 1355 } 1356 1357 void getAnalysisUsage(AnalysisUsage &AU) const override { 1358 AU.addRequired<TargetTransformInfoWrapperPass>(); 1359 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1360 AU.setPreservesCFG(); 1361 } 1362 }; 1363 } // namespace 1364 1365 static const char pass_name[] = "Lower the matrix intrinsics"; 1366 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1367 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1368 false, false) 1369 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1370 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1371 false, false) 1372 1373 Pass *llvm::createLowerMatrixIntrinsicsPass() { 1374 return new LowerMatrixIntrinsicsLegacyPass(); 1375 } 1376