1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Implement multiply & add fusion 13 // * Add remark, summarizing the available matrix optimization opportunities. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 18 #include "llvm/ADT/GraphTraits.h" 19 #include "llvm/ADT/PostOrderIterator.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/Analysis/TargetTransformInfo.h" 22 #include "llvm/Analysis/VectorUtils.h" 23 #include "llvm/IR/CFG.h" 24 #include "llvm/IR/DataLayout.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/IntrinsicInst.h" 29 #include "llvm/IR/PatternMatch.h" 30 #include "llvm/InitializePasses.h" 31 #include "llvm/Pass.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Transforms/Scalar.h" 34 35 using namespace llvm; 36 using namespace PatternMatch; 37 38 #define DEBUG_TYPE "lower-matrix-intrinsics" 39 40 static cl::opt<bool> EnableShapePropagation( 41 "matrix-propagate-shape", cl::init(true), cl::Hidden, 42 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 43 "instructions.")); 44 45 static cl::opt<bool> AllowContractEnabled( 46 "matrix-allow-contract", cl::init(false), cl::Hidden, 47 cl::desc("Allow the use of FMAs if available and profitable. This may " 48 "result in different results, due to less rounding error.")); 49 50 namespace { 51 52 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 53 // the start address of column \p Col with type (\p EltType x \p NumRows) 54 // assuming \p Stride elements between start two consecutive columns. 55 // \p Stride must be >= \p NumRows. 56 // 57 // Consider a 4x4 matrix like below 58 // 59 // 0 1 2 3 60 // 0 v_0_0 v_0_1 v_0_2 v_0_3 61 // 1 v_1_0 v_1_1 v_1_2 v_1_3 62 // 2 v_2_0 v_2_1 v_2_2 v_2_3 63 // 3 v_3_0 v_3_1 v_3_2 v_3_3 64 65 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 66 // we need a pointer to the first element of the submatrix as base pointer. 67 // Then we can use computeColumnAddr to compute the addresses for the columns 68 // of the sub-matrix. 69 // 70 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 71 // -> just returns Base 72 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 73 // -> returns Base + (1 * 4) 74 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 75 // -> returns Base + (2 * 4) 76 // 77 // The graphic below illustrates the number of elements in a column (marked 78 // with |) and the number of skipped elements (marked with }). 79 // 80 // v_0_0 v_0_1 {v_0_2 {v_0_3 81 // Base Col 1 Col 2 82 // | | | 83 // v_1_0 |v_1_1 |v_1_2 |v_1_3 84 // v_2_0 |v_2_1 |v_2_2 |v_2_3 85 // v_3_0 {v_3_1 {v_3_2 v_3_3 86 // 87 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 88 unsigned NumRows, Type *EltType, 89 IRBuilder<> &Builder) { 90 91 assert((!isa<ConstantInt>(Stride) || 92 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 93 "Stride must be >= the number of rows."); 94 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 95 96 // Compute the start of the column with index Col as Col * Stride. 97 Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); 98 99 // Get pointer to the start of the selected column. Skip GEP creation, 100 // if we select column 0. 101 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 102 ColumnStart = BasePtr; 103 else 104 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); 105 106 // Cast elementwise column start pointer to a pointer to a column 107 // (EltType x NumRows)*. 108 Type *ColumnType = VectorType::get(EltType, NumRows); 109 Type *ColumnPtrType = PointerType::get(ColumnType, AS); 110 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); 111 } 112 113 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 114 /// 115 /// Currently, the lowering for each matrix intrinsic is done as follows: 116 /// 1. Propagate the shape information from intrinsics to connected 117 /// instructions. 118 /// 2. Lower instructions with shape information. 119 /// 2.1. Get column vectors for each argument. If we already lowered the 120 /// definition of an argument, use the produced column vectors directly. 121 /// If not, split the operand vector containing an embedded matrix into 122 /// a set of column vectors, 123 /// 2.2. Lower the instruction in terms of columnwise operations, which yields 124 /// a set of column vectors containing result matrix. Note that we lower 125 /// all instructions that have shape information. Besides the intrinsics, 126 /// this includes stores for example. 127 /// 2.3. Update uses of the lowered instruction. If we have shape information 128 /// for a user, there is nothing to do, as we will look up the result 129 /// column matrix when lowering the user. For other uses, we embed the 130 /// result matrix in a flat vector and update the use. 131 /// 2.4. Cache the result column matrix for the instruction we lowered 132 /// 3. After we lowered all instructions in a function, remove the now 133 /// obsolete instructions. 134 /// 135 class LowerMatrixIntrinsics { 136 Function &Func; 137 const DataLayout &DL; 138 const TargetTransformInfo &TTI; 139 140 /// Wrapper class representing a matrix as a set of column vectors. 141 /// All column vectors must have the same vector type. 142 class ColumnMatrixTy { 143 SmallVector<Value *, 16> Columns; 144 145 public: 146 ColumnMatrixTy() : Columns() {} 147 ColumnMatrixTy(ArrayRef<Value *> Cols) 148 : Columns(Cols.begin(), Cols.end()) {} 149 150 Value *getColumn(unsigned i) const { return Columns[i]; } 151 152 void setColumn(unsigned i, Value *V) { Columns[i] = V; } 153 154 size_t getNumColumns() const { return Columns.size(); } 155 size_t getNumRows() const { 156 assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); 157 return cast<VectorType>(Columns[0]->getType())->getNumElements(); 158 } 159 160 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } 161 162 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } 163 164 void addColumn(Value *V) { Columns.push_back(V); } 165 166 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 167 return make_range(Columns.begin(), Columns.end()); 168 } 169 170 /// Embed the columns of the matrix into a flat vector by concatenating 171 /// them. 172 Value *embedInVector(IRBuilder<> &Builder) const { 173 return Columns.size() == 1 ? Columns[0] 174 : concatenateVectors(Builder, Columns); 175 } 176 }; 177 178 struct ShapeInfo { 179 unsigned NumRows; 180 unsigned NumColumns; 181 182 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 183 : NumRows(NumRows), NumColumns(NumColumns) {} 184 185 ShapeInfo(Value *NumRows, Value *NumColumns) 186 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), 187 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} 188 189 bool operator==(const ShapeInfo &other) { 190 return NumRows == other.NumRows && NumColumns == other.NumColumns; 191 } 192 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 193 194 /// Returns true if shape-information is defined, meaning both dimensions 195 /// are != 0. 196 operator bool() const { 197 assert(NumRows == 0 || NumColumns != 0); 198 return NumRows != 0; 199 } 200 }; 201 202 /// Maps instructions to their shape information. The shape information 203 /// describes the shape to be used while lowering. This matches the shape of 204 /// the result value of the instruction, with the only exceptions being store 205 /// instructions and the matrix_columnwise_store intrinsics. For those, the 206 /// shape information indicates that those instructions should be lowered 207 /// using shape information as well. 208 DenseMap<Value *, ShapeInfo> ShapeMap; 209 210 /// List of instructions to remove. While lowering, we are not replacing all 211 /// users of a lowered instruction, if shape information is available and 212 /// those need to be removed after we finished lowering. 213 SmallVector<Instruction *, 16> ToRemove; 214 215 /// Map from instructions to their produced column matrix. 216 DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix; 217 218 public: 219 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) 220 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} 221 222 /// Return the set of column vectors that a matrix value is lowered to. 223 /// 224 /// If we lowered \p MatrixVal, just return the cache result column matrix. 225 /// Otherwie split the flat vector \p MatrixVal containing a matrix with 226 /// shape \p SI into column vectors. 227 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 228 IRBuilder<> Builder) { 229 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 230 assert(VType && "MatrixVal must be a vector type"); 231 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 232 "The vector size must match the number of matrix elements"); 233 234 // Check if we lowered MatrixVal using shape information. In that case, 235 // return the existing column matrix, if it matches the requested shape 236 // information. If there is a mis-match, embed the result in a flat 237 // vector and split it later. 238 auto Found = Inst2ColumnMatrix.find(MatrixVal); 239 if (Found != Inst2ColumnMatrix.end()) { 240 ColumnMatrixTy &M = Found->second; 241 // Return the found matrix, if its shape matches the requested shape 242 // information 243 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 244 return M; 245 246 MatrixVal = M.embedInVector(Builder); 247 } 248 249 // Otherwise split MatrixVal. 250 SmallVector<Value *, 16> SplitVecs; 251 Value *Undef = UndefValue::get(VType); 252 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 253 MaskStart += SI.NumRows) { 254 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 255 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 256 SplitVecs.push_back(V); 257 } 258 259 return {SplitVecs}; 260 } 261 262 /// If \p V already has a known shape return false. Otherwise set the shape 263 /// for instructions that support it. 264 bool setShapeInfo(Value *V, ShapeInfo Shape) { 265 assert(Shape && "Shape not set"); 266 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 267 return false; 268 269 auto SIter = ShapeMap.find(V); 270 if (SIter != ShapeMap.end()) { 271 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 272 << SIter->second.NumRows << " " 273 << SIter->second.NumColumns << " for " << *V << "\n"); 274 return false; 275 } 276 277 ShapeMap.insert({V, Shape}); 278 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 279 << " for " << *V << "\n"); 280 return true; 281 } 282 283 bool isUniformShape(Value *V) { 284 Instruction *I = dyn_cast<Instruction>(V); 285 if (!I) 286 return true; 287 288 switch (I->getOpcode()) { 289 case Instruction::FAdd: 290 case Instruction::FSub: 291 case Instruction::FMul: // Scalar multiply. 292 case Instruction::Add: 293 case Instruction::Mul: 294 case Instruction::Sub: 295 return true; 296 default: 297 return false; 298 } 299 } 300 301 /// Returns true if shape information can be used for \p V. The supported 302 /// instructions must match the instructions that can be lowered by this pass. 303 bool supportsShapeInfo(Value *V) { 304 Instruction *Inst = dyn_cast<Instruction>(V); 305 if (!Inst) 306 return false; 307 308 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 309 if (II) 310 switch (II->getIntrinsicID()) { 311 case Intrinsic::matrix_multiply: 312 case Intrinsic::matrix_transpose: 313 case Intrinsic::matrix_columnwise_load: 314 case Intrinsic::matrix_columnwise_store: 315 return true; 316 default: 317 return false; 318 } 319 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 320 } 321 322 /// Propagate the shape information of instructions to their users. 323 /// The work list contains instructions for which we can compute the shape, 324 /// either based on the information provided by matrix intrinsics or known 325 /// shapes of operands. 326 SmallVector<Instruction *, 32> 327 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 328 SmallVector<Instruction *, 32> NewWorkList; 329 // Pop an element for which we guaranteed to have at least one of the 330 // operand shapes. Add the shape for this and then add users to the work 331 // list. 332 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 333 while (!WorkList.empty()) { 334 Instruction *Inst = WorkList.back(); 335 WorkList.pop_back(); 336 337 // New entry, set the value and insert operands 338 bool Propagate = false; 339 340 Value *MatrixA; 341 Value *MatrixB; 342 Value *M; 343 Value *N; 344 Value *K; 345 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 346 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 347 m_Value(N), m_Value(K)))) { 348 Propagate = setShapeInfo(Inst, {M, K}); 349 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 350 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 351 // Flip dimensions. 352 Propagate = setShapeInfo(Inst, {N, M}); 353 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 354 m_Value(MatrixA), m_Value(), m_Value(), 355 m_Value(M), m_Value(N)))) { 356 Propagate = setShapeInfo(Inst, {N, M}); 357 } else if (match(Inst, 358 m_Intrinsic<Intrinsic::matrix_columnwise_load>( 359 m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 360 Propagate = setShapeInfo(Inst, {M, N}); 361 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 362 auto OpShape = ShapeMap.find(MatrixA); 363 if (OpShape != ShapeMap.end()) 364 setShapeInfo(Inst, OpShape->second); 365 continue; 366 } else if (isUniformShape(Inst)) { 367 // Find the first operand that has a known shape and use that. 368 for (auto &Op : Inst->operands()) { 369 auto OpShape = ShapeMap.find(Op.get()); 370 if (OpShape != ShapeMap.end()) { 371 Propagate |= setShapeInfo(Inst, OpShape->second); 372 break; 373 } 374 } 375 } 376 377 if (Propagate) { 378 NewWorkList.push_back(Inst); 379 for (auto *User : Inst->users()) 380 if (ShapeMap.count(User) == 0) 381 WorkList.push_back(cast<Instruction>(User)); 382 } 383 } 384 385 return NewWorkList; 386 } 387 388 /// Propagate the shape to operands of instructions with shape information. 389 /// \p Worklist contains the instruction for which we already know the shape. 390 SmallVector<Instruction *, 32> 391 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 392 SmallVector<Instruction *, 32> NewWorkList; 393 394 auto pushInstruction = [](Value *V, 395 SmallVectorImpl<Instruction *> &WorkList) { 396 Instruction *I = dyn_cast<Instruction>(V); 397 if (I) 398 WorkList.push_back(I); 399 }; 400 // Pop an element with known shape. Traverse the operands, if their shape 401 // derives from the result shape and is unknown, add it and add them to the 402 // worklist. 403 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 404 while (!WorkList.empty()) { 405 Value *V = WorkList.back(); 406 WorkList.pop_back(); 407 408 size_t BeforeProcessingV = WorkList.size(); 409 if (!isa<Instruction>(V)) 410 continue; 411 412 Value *MatrixA; 413 Value *MatrixB; 414 Value *M; 415 Value *N; 416 Value *K; 417 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 418 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 419 m_Value(N), m_Value(K)))) { 420 if (setShapeInfo(MatrixA, {M, N})) 421 pushInstruction(MatrixA, WorkList); 422 423 if (setShapeInfo(MatrixB, {N, K})) 424 pushInstruction(MatrixB, WorkList); 425 426 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 427 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 428 // Flip dimensions. 429 if (setShapeInfo(MatrixA, {M, N})) 430 pushInstruction(MatrixA, WorkList); 431 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 432 m_Value(MatrixA), m_Value(), m_Value(), 433 m_Value(M), m_Value(N)))) { 434 if (setShapeInfo(MatrixA, {M, N})) { 435 pushInstruction(MatrixA, WorkList); 436 } 437 } else if (isa<LoadInst>(V) || 438 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { 439 // Nothing to do, no matrix input. 440 } else if (isa<StoreInst>(V)) { 441 // Nothing to do. We forward-propagated to this so we would just 442 // backward propagate to an instruction with an already known shape. 443 } else if (isUniformShape(V)) { 444 // Propagate to all operands. 445 ShapeInfo Shape = ShapeMap[V]; 446 for (Use &U : cast<Instruction>(V)->operands()) { 447 if (setShapeInfo(U.get(), Shape)) 448 pushInstruction(U.get(), WorkList); 449 } 450 } 451 // After we discovered new shape info for new instructions in the 452 // worklist, we use their users as seeds for the next round of forward 453 // propagation. 454 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 455 for (User *U : WorkList[I]->users()) 456 if (isa<Instruction>(U) && V != U) 457 NewWorkList.push_back(cast<Instruction>(U)); 458 } 459 return NewWorkList; 460 } 461 462 bool Visit() { 463 if (EnableShapePropagation) { 464 SmallVector<Instruction *, 32> WorkList; 465 466 // Initially only the shape of matrix intrinsics is known. 467 // Initialize the work list with ops carrying shape information. 468 for (BasicBlock &BB : Func) 469 for (Instruction &Inst : BB) { 470 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 471 if (!II) 472 continue; 473 474 switch (II->getIntrinsicID()) { 475 case Intrinsic::matrix_multiply: 476 case Intrinsic::matrix_transpose: 477 case Intrinsic::matrix_columnwise_load: 478 case Intrinsic::matrix_columnwise_store: 479 WorkList.push_back(&Inst); 480 break; 481 default: 482 break; 483 } 484 } 485 // Propagate shapes until nothing changes any longer. 486 while (!WorkList.empty()) { 487 WorkList = propagateShapeForward(WorkList); 488 WorkList = propagateShapeBackward(WorkList); 489 } 490 } 491 492 ReversePostOrderTraversal<Function *> RPOT(&Func); 493 bool Changed = false; 494 for (auto *BB : RPOT) { 495 for (Instruction &Inst : make_early_inc_range(*BB)) { 496 IRBuilder<> Builder(&Inst); 497 498 if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 499 Changed |= VisitCallInst(CInst); 500 501 Value *Op1; 502 Value *Op2; 503 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) 504 Changed |= VisitBinaryOperator(BinOp); 505 if (match(&Inst, m_Load(m_Value(Op1)))) 506 Changed |= VisitLoad(&Inst, Op1, Builder); 507 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 508 Changed |= VisitStore(&Inst, Op1, Op2, Builder); 509 } 510 } 511 512 for (Instruction *Inst : reverse(ToRemove)) 513 Inst->eraseFromParent(); 514 515 return Changed; 516 } 517 518 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 519 IRBuilder<> Builder) { 520 unsigned Align = DL.getABITypeAlignment(EltType); 521 return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load"); 522 } 523 524 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 525 Type *EltType, IRBuilder<> Builder) { 526 unsigned Align = DL.getABITypeAlignment(EltType); 527 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); 528 } 529 530 531 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 532 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 533 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 534 Type *EltPtrType = PointerType::get(EltType, AS); 535 return Builder.CreatePointerCast(BasePtr, EltPtrType); 536 } 537 538 /// Replace intrinsic calls 539 bool VisitCallInst(CallInst *Inst) { 540 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 541 return false; 542 543 switch (Inst->getCalledFunction()->getIntrinsicID()) { 544 case Intrinsic::matrix_multiply: 545 LowerMultiply(Inst); 546 break; 547 case Intrinsic::matrix_transpose: 548 LowerTranspose(Inst); 549 break; 550 case Intrinsic::matrix_columnwise_load: 551 LowerColumnwiseLoad(Inst); 552 break; 553 case Intrinsic::matrix_columnwise_store: 554 LowerColumnwiseStore(Inst); 555 break; 556 default: 557 return false; 558 } 559 return true; 560 } 561 562 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, 563 ShapeInfo Shape) { 564 IRBuilder<> Builder(Inst); 565 auto VType = cast<VectorType>(Inst->getType()); 566 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 567 ColumnMatrixTy Result; 568 // Distance between start of one column and the start of the next 569 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 570 Value *GEP = 571 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 572 VType->getElementType(), Builder); 573 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 574 Result.addColumn(Column); 575 } 576 577 finalizeLowering(Inst, Result, Builder); 578 } 579 580 /// Lowers llvm.matrix.columnwise.load. 581 /// 582 /// The intrinsic loads a matrix from memory using a stride between columns. 583 void LowerColumnwiseLoad(CallInst *Inst) { 584 Value *Ptr = Inst->getArgOperand(0); 585 Value *Stride = Inst->getArgOperand(1); 586 LowerLoad(Inst, Ptr, Stride, 587 {Inst->getArgOperand(2), Inst->getArgOperand(3)}); 588 } 589 590 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 591 ShapeInfo Shape) { 592 IRBuilder<> Builder(Inst); 593 auto VType = cast<VectorType>(Matrix->getType()); 594 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 595 auto LM = getMatrix(Matrix, Shape, Builder); 596 for (auto C : enumerate(LM.columns())) { 597 Value *GEP = 598 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, 599 Shape.NumRows, VType->getElementType(), Builder); 600 createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 601 } 602 603 ToRemove.push_back(Inst); 604 } 605 606 /// Lowers llvm.matrix.columnwise.store. 607 /// 608 /// The intrinsic store a matrix back memory using a stride between columns. 609 void LowerColumnwiseStore(CallInst *Inst) { 610 Value *Matrix = Inst->getArgOperand(0); 611 Value *Ptr = Inst->getArgOperand(1); 612 Value *Stride = Inst->getArgOperand(2); 613 LowerStore(Inst, Matrix, Ptr, Stride, 614 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 615 } 616 617 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 618 /// the matrix \p LM represented as a vector of column vectors. 619 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, 620 unsigned NumElts, IRBuilder<> Builder) { 621 Value *Col = LM.getColumn(J); 622 Value *Undef = UndefValue::get(Col->getType()); 623 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 624 return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 625 } 626 627 // Set elements I..I+NumElts-1 to Block 628 Value *insertVector(Value *Col, unsigned I, Value *Block, 629 IRBuilder<> Builder) { 630 631 // First, bring Block to the same size as Col 632 unsigned BlockNumElts = 633 cast<VectorType>(Block->getType())->getNumElements(); 634 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 635 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 636 637 Value *ExtendMask = 638 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 639 Value *Undef = UndefValue::get(Block->getType()); 640 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 641 642 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 643 // 8, 4, 5, 6 644 SmallVector<Constant *, 16> Mask; 645 unsigned i; 646 for (i = 0; i < I; i++) 647 Mask.push_back(Builder.getInt32(i)); 648 649 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 650 for (; i < I + BlockNumElts; i++) 651 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 652 653 for (; i < VecNumElts; i++) 654 Mask.push_back(Builder.getInt32(i)); 655 656 Value *MaskVal = ConstantVector::get(Mask); 657 658 return Builder.CreateShuffleVector(Col, Block, MaskVal); 659 } 660 661 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 662 IRBuilder<> &Builder, bool AllowContraction) { 663 664 if (!Sum) 665 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 666 667 if (UseFPOp) { 668 if (AllowContraction) { 669 // Use fmuladd for floating point operations and let the backend decide 670 // if that's profitable. 671 Value *FMulAdd = Intrinsic::getDeclaration( 672 Func.getParent(), Intrinsic::fmuladd, A->getType()); 673 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 674 } 675 Value *Mul = Builder.CreateFMul(A, B); 676 return Builder.CreateFAdd(Sum, Mul); 677 } 678 679 Value *Mul = Builder.CreateMul(A, B); 680 return Builder.CreateAdd(Sum, Mul); 681 } 682 683 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 684 /// users with shape information, there's nothing to do: the will use the 685 /// cached value when they are lowered. For other users, \p Matrix is 686 /// flattened and the uses are updated to use it. Also marks \p Inst for 687 /// deletion. 688 void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, 689 IRBuilder<> &Builder) { 690 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 691 692 ToRemove.push_back(Inst); 693 Value *Flattened = nullptr; 694 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 695 Use &U = *I++; 696 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 697 if (!Flattened) 698 Flattened = Matrix.embedInVector(Builder); 699 U.set(Flattened); 700 } 701 } 702 } 703 704 /// Lowers llvm.matrix.multiply. 705 void LowerMultiply(CallInst *MatMul) { 706 IRBuilder<> Builder(MatMul); 707 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 708 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 709 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 710 711 const ColumnMatrixTy &Lhs = 712 getMatrix(MatMul->getArgOperand(0), LShape, Builder); 713 const ColumnMatrixTy &Rhs = 714 getMatrix(MatMul->getArgOperand(1), RShape, Builder); 715 716 const unsigned R = LShape.NumRows; 717 const unsigned M = LShape.NumColumns; 718 const unsigned C = RShape.NumColumns; 719 assert(M == RShape.NumRows); 720 721 // Initialize the output 722 ColumnMatrixTy Result; 723 for (unsigned J = 0; J < C; ++J) 724 Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 725 726 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / 727 EltType->getPrimitiveSizeInBits(), 728 uint64_t(1)); 729 730 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 731 MatMul->hasAllowContract()); 732 // Multiply columns from the first operand with scalars from the second 733 // operand. Then move along the K axes and accumulate the columns. With 734 // this the adds can be vectorized without reassociation. 735 for (unsigned J = 0; J < C; ++J) { 736 unsigned BlockSize = VF; 737 for (unsigned I = 0; I < R; I += BlockSize) { 738 // Gradually lower the vectorization factor to cover the remainder. 739 while (I + BlockSize > R) 740 BlockSize /= 2; 741 742 Value *Sum = nullptr; 743 for (unsigned K = 0; K < M; ++K) { 744 Value *L = extractVector(Lhs, I, K, BlockSize, Builder); 745 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); 746 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 747 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), 748 Builder, AllowContract); 749 } 750 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 751 } 752 } 753 finalizeLowering(MatMul, Result, Builder); 754 } 755 756 /// Lowers llvm.matrix.transpose. 757 void LowerTranspose(CallInst *Inst) { 758 ColumnMatrixTy Result; 759 IRBuilder<> Builder(Inst); 760 Value *InputVal = Inst->getArgOperand(0); 761 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 762 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 763 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 764 765 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 766 // Build a single column vector for this row. First initialize it. 767 Value *ResultColumn = UndefValue::get( 768 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 769 770 // Go through the elements of this row and insert it into the resulting 771 // column vector. 772 for (auto C : enumerate(InputMatrix.columns())) { 773 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 774 // We insert at index Column since that is the row index after the 775 // transpose. 776 ResultColumn = 777 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 778 } 779 Result.addColumn(ResultColumn); 780 } 781 782 finalizeLowering(Inst, Result, Builder); 783 } 784 785 /// Lower load instructions, if shape information is available. 786 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { 787 auto I = ShapeMap.find(Inst); 788 if (I == ShapeMap.end()) 789 return false; 790 791 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); 792 return true; 793 } 794 795 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 796 IRBuilder<> &Builder) { 797 auto I = ShapeMap.find(StoredVal); 798 if (I == ShapeMap.end()) 799 return false; 800 801 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); 802 return true; 803 } 804 805 /// Lower binary operators, if shape information is available. 806 bool VisitBinaryOperator(BinaryOperator *Inst) { 807 auto I = ShapeMap.find(Inst); 808 if (I == ShapeMap.end()) 809 return false; 810 811 Value *Lhs = Inst->getOperand(0); 812 Value *Rhs = Inst->getOperand(1); 813 814 IRBuilder<> Builder(Inst); 815 ShapeInfo &Shape = I->second; 816 817 ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); 818 ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); 819 820 // Add each column and store the result back into the opmapping 821 ColumnMatrixTy Result; 822 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { 823 switch (Inst->getOpcode()) { 824 case Instruction::Add: 825 return Builder.CreateAdd(LHS, RHS); 826 case Instruction::Mul: 827 return Builder.CreateMul(LHS, RHS); 828 case Instruction::Sub: 829 return Builder.CreateSub(LHS, RHS); 830 case Instruction::FAdd: 831 return Builder.CreateFAdd(LHS, RHS); 832 case Instruction::FMul: 833 return Builder.CreateFMul(LHS, RHS); 834 case Instruction::FSub: 835 return Builder.CreateFSub(LHS, RHS); 836 default: 837 llvm_unreachable("Unsupported binary operator for matrix"); 838 } 839 }; 840 for (unsigned C = 0; C < Shape.NumColumns; ++C) 841 Result.addColumn( 842 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); 843 844 finalizeLowering(Inst, Result, Builder); 845 return true; 846 } 847 }; 848 } // namespace 849 850 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 851 FunctionAnalysisManager &AM) { 852 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 853 LowerMatrixIntrinsics LMT(F, TTI); 854 if (LMT.Visit()) { 855 PreservedAnalyses PA; 856 PA.preserveSet<CFGAnalyses>(); 857 return PA; 858 } 859 return PreservedAnalyses::all(); 860 } 861 862 namespace { 863 864 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 865 public: 866 static char ID; 867 868 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 869 initializeLowerMatrixIntrinsicsLegacyPassPass( 870 *PassRegistry::getPassRegistry()); 871 } 872 873 bool runOnFunction(Function &F) override { 874 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 875 LowerMatrixIntrinsics LMT(F, *TTI); 876 bool C = LMT.Visit(); 877 return C; 878 } 879 880 void getAnalysisUsage(AnalysisUsage &AU) const override { 881 AU.addRequired<TargetTransformInfoWrapperPass>(); 882 AU.setPreservesCFG(); 883 } 884 }; 885 } // namespace 886 887 static const char pass_name[] = "Lower the matrix intrinsics"; 888 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 889 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 890 false, false) 891 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 892 false, false) 893 894 Pass *llvm::createLowerMatrixIntrinsicsPass() { 895 return new LowerMatrixIntrinsicsLegacyPass(); 896 } 897