1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Implement multiply & add fusion 13 // * Implement shape propagation 14 // * Implement optimizations to reduce or eliminateshufflevector uses by using 15 // shape information. 16 // * Add remark, summarizing the available matrix optimization opportunities. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/Analysis/VectorUtils.h" 26 #include "llvm/IR/CFG.h" 27 #include "llvm/IR/DataLayout.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicInst.h" 32 #include "llvm/IR/PatternMatch.h" 33 #include "llvm/InitializePasses.h" 34 #include "llvm/Pass.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Transforms/Scalar.h" 37 38 using namespace llvm; 39 using namespace PatternMatch; 40 41 #define DEBUG_TYPE "lower-matrix-intrinsics" 42 43 static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape", 44 cl::init(true)); 45 46 static cl::opt<bool> AllowContractEnabled( 47 "matrix-allow-contract", cl::init(false), cl::Hidden, 48 cl::desc("Allow the use of FMAs if available and profitable. This may " 49 "result in different results, due to less rounding error.")); 50 51 namespace { 52 53 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 54 // the start address of column \p Col with type (\p EltType x \p NumRows) 55 // assuming \p Stride elements between start two consecutive columns. 56 // \p Stride must be >= \p NumRows. 57 // 58 // Consider a 4x4 matrix like below 59 // 60 // 0 1 2 3 61 // 0 v_0_0 v_0_1 v_0_2 v_0_3 62 // 1 v_1_0 v_1_1 v_1_2 v_1_3 63 // 2 v_2_0 v_2_1 v_2_2 v_2_3 64 // 3 v_3_0 v_3_1 v_3_2 v_3_3 65 66 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 67 // we need a pointer to the first element of the submatrix as base pointer. 68 // Then we can use computeColumnAddr to compute the addresses for the columns 69 // of the sub-matrix. 70 // 71 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 72 // -> just returns Base 73 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 74 // -> returns Base + (1 * 4) 75 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 76 // -> returns Base + (2 * 4) 77 // 78 // The graphic below illustrates the number of elements in a column (marked 79 // with |) and the number of skipped elements (marked with }). 80 // 81 // v_0_0 v_0_1 {v_0_2 {v_0_3 82 // Base Col 1 Col 2 83 // | | | 84 // v_1_0 |v_1_1 |v_1_2 |v_1_3 85 // v_2_0 |v_2_1 |v_2_2 |v_2_3 86 // v_3_0 {v_3_1 {v_3_2 v_3_3 87 // 88 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 89 unsigned NumRows, Type *EltType, 90 IRBuilder<> &Builder) { 91 92 assert((!isa<ConstantInt>(Stride) || 93 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 94 "Stride must be >= the number of rows."); 95 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 96 97 // Compute the start of the column with index Col as Col * Stride. 98 Value *ColumnStart = Builder.CreateMul(Col, Stride); 99 100 // Get pointer to the start of the selected column. Skip GEP creation, 101 // if we select column 0. 102 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 103 ColumnStart = BasePtr; 104 else 105 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart); 106 107 // Cast elementwise column start pointer to a pointer to a column 108 // (EltType x NumRows)*. 109 Type *ColumnType = VectorType::get(EltType, NumRows); 110 Type *ColumnPtrType = PointerType::get(ColumnType, AS); 111 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType); 112 } 113 114 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 115 /// 116 /// Currently, the lowering for each matrix intrinsic is done as follows: 117 /// 1. Propagate the shape information from intrinsics to connected 118 /// instructions. 119 /// 2. Lower instructions with shape information. 120 /// 2.1. Get column vectors for each argument. If we already lowered the 121 /// definition of an argument, use the produced column vectors directly. 122 /// If not, split the operand vector containing an embedded matrix into 123 /// a set of column vectors, 124 /// 2.2. Lower the instruction in terms of columnwise operations, which yields 125 /// a set of column vectors containing result matrix. Note that we lower 126 /// all instructions that have shape information. Besides the intrinsics, 127 /// this includes stores for example. 128 /// 2.3. Update uses of the lowered instruction. If we have shape information 129 /// for a user, there is nothing to do, as we will look up the result 130 /// column matrix when lowering the user. For other uses, we embed the 131 /// result matrix in a flat vector and update the use. 132 /// 2.4. Cache the result column matrix for the instruction we lowered 133 /// 3. After we lowered all instructions in a function, remove the now 134 /// obsolete instructions. 135 /// 136 class LowerMatrixIntrinsics { 137 Function &Func; 138 const DataLayout &DL; 139 const TargetTransformInfo &TTI; 140 141 /// Wrapper class representing a matrix as a set of column vectors. 142 /// All column vectors must have the same vector type. 143 class ColumnMatrixTy { 144 SmallVector<Value *, 16> Columns; 145 146 public: 147 ColumnMatrixTy() : Columns() {} 148 ColumnMatrixTy(ArrayRef<Value *> Cols) 149 : Columns(Cols.begin(), Cols.end()) {} 150 151 Value *getColumn(unsigned i) const { return Columns[i]; } 152 153 void setColumn(unsigned i, Value *V) { Columns[i] = V; } 154 155 size_t getNumColumns() const { return Columns.size(); } 156 size_t getNumRows() const { 157 assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); 158 return cast<VectorType>(Columns[0]->getType())->getNumElements(); 159 } 160 161 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } 162 163 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } 164 165 void addColumn(Value *V) { Columns.push_back(V); } 166 167 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 168 return make_range(Columns.begin(), Columns.end()); 169 } 170 171 /// Embed the columns of the matrix into a flat vector by concatenating 172 /// them. 173 Value *embedInVector(IRBuilder<> &Builder) const { 174 return Columns.size() == 1 ? Columns[0] 175 : concatenateVectors(Builder, Columns); 176 } 177 }; 178 179 struct ShapeInfo { 180 unsigned NumRows; 181 unsigned NumColumns; 182 183 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 184 : NumRows(NumRows), NumColumns(NumColumns) {} 185 186 ShapeInfo(Value *NumRows, Value *NumColumns) 187 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), 188 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} 189 190 bool operator==(const ShapeInfo &other) { 191 return NumRows == other.NumRows && NumColumns == other.NumColumns; 192 } 193 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 194 195 /// Returns true if shape-information is defined, meaning both dimensions 196 /// are != 0. 197 operator bool() const { 198 assert(NumRows == 0 || NumColumns != 0); 199 return NumRows != 0; 200 } 201 }; 202 203 /// Maps instructions to their shape information. The shape information 204 /// describes the shape to be used while lowering. This matches the shape of 205 /// the result value of the instruction, with the only exceptions being store 206 /// instructions and the matrix_columnwise_store intrinsics. For those, the 207 /// shape information indicates that those instructions should be lowered 208 /// using shape information as well. 209 DenseMap<Value *, ShapeInfo> ShapeMap; 210 211 /// List of instructions to remove. While lowering, we are not replacing all 212 /// users of a lowered instruction, if shape information is available and 213 /// those need to be removed after we finished lowering. 214 SmallVector<Instruction *, 16> ToRemove; 215 216 /// Map from instructions to their produced column matrix. 217 DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix; 218 219 public: 220 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) 221 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} 222 223 /// Return the set of column vectors that a matrix value is lowered to. 224 /// 225 /// If we lowered \p MatrixVal, just return the cache result column matrix. 226 /// Otherwie split the flat vector \p MatrixVal containing a matrix with 227 /// shape \p SI into column vectors. 228 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 229 IRBuilder<> Builder) { 230 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 231 assert(VType && "MatrixVal must be a vector type"); 232 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 233 "The vector size must match the number of matrix elements"); 234 235 // Check if we lowered MatrixVal using shape information. In that case, 236 // return the existing column matrix, if it matches the requested shape 237 // information. If there is a mis-match, embed the result in a flat 238 // vector and split it later. 239 auto Found = Inst2ColumnMatrix.find(MatrixVal); 240 if (Found != Inst2ColumnMatrix.end()) { 241 ColumnMatrixTy &M = Found->second; 242 // Return the found matrix, if its shape matches the requested shape 243 // information 244 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 245 return M; 246 247 MatrixVal = M.embedInVector(Builder); 248 } 249 250 // Otherwise split MatrixVal. 251 SmallVector<Value *, 16> SplitVecs; 252 Value *Undef = UndefValue::get(VType); 253 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 254 MaskStart += SI.NumRows) { 255 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 256 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 257 SplitVecs.push_back(V); 258 } 259 260 return {SplitVecs}; 261 } 262 263 /// If \p V already has a known shape return false. Otherwise set the shape 264 /// for instructions that support it. 265 bool setShapeInfo(Value *V, ShapeInfo Shape) { 266 assert(Shape && "Shape not set"); 267 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 268 return false; 269 270 auto SIter = ShapeMap.find(V); 271 if (SIter != ShapeMap.end()) { 272 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 273 << SIter->second.NumRows << " " 274 << SIter->second.NumColumns << " for " << *V << "\n"); 275 return false; 276 } 277 278 ShapeMap.insert({V, Shape}); 279 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 280 << " for " << *V << "\n"); 281 return true; 282 } 283 284 bool isUniformShape(Value *V) { 285 Instruction *I = dyn_cast<Instruction>(V); 286 if (!I) 287 return true; 288 289 switch (I->getOpcode()) { 290 case Instruction::FAdd: 291 case Instruction::FSub: 292 case Instruction::FMul: // Scalar multiply. 293 case Instruction::Add: 294 case Instruction::Mul: 295 case Instruction::Sub: 296 return true; 297 default: 298 return false; 299 } 300 } 301 302 /// Returns true if shape information can be used for \p V. The supported 303 /// instructions must match the instructions that can be lowered by this pass. 304 bool supportsShapeInfo(Value *V) { 305 Instruction *Inst = dyn_cast<Instruction>(V); 306 if (!Inst) 307 return false; 308 309 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 310 if (II) 311 switch (II->getIntrinsicID()) { 312 case Intrinsic::matrix_multiply: 313 case Intrinsic::matrix_transpose: 314 case Intrinsic::matrix_columnwise_load: 315 case Intrinsic::matrix_columnwise_store: 316 return true; 317 default: 318 return false; 319 } 320 return isUniformShape(V) || isa<StoreInst>(V); 321 } 322 323 /// Propagate the shape information of instructions to their users. 324 void propagateShapeForward() { 325 // The work list contains instructions for which we can compute the shape, 326 // either based on the information provided by matrix intrinsics or known 327 // shapes of operands. 328 SmallVector<Instruction *, 8> WorkList; 329 330 // Initialize the work list with ops carrying shape information. Initially 331 // only the shape of matrix intrinsics is known. 332 for (BasicBlock &BB : Func) 333 for (Instruction &Inst : BB) { 334 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 335 if (!II) 336 continue; 337 338 switch (II->getIntrinsicID()) { 339 case Intrinsic::matrix_multiply: 340 case Intrinsic::matrix_transpose: 341 case Intrinsic::matrix_columnwise_load: 342 case Intrinsic::matrix_columnwise_store: 343 WorkList.push_back(&Inst); 344 break; 345 default: 346 break; 347 } 348 } 349 350 // Pop an element for which we guaranteed to have at least one of the 351 // operand shapes. Add the shape for this and then add users to the work 352 // list. 353 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 354 while (!WorkList.empty()) { 355 Instruction *Inst = WorkList.back(); 356 WorkList.pop_back(); 357 358 // New entry, set the value and insert operands 359 bool Propagate = false; 360 361 Value *MatrixA; 362 Value *MatrixB; 363 Value *M; 364 Value *N; 365 Value *K; 366 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 367 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 368 m_Value(N), m_Value(K)))) { 369 Propagate = setShapeInfo(Inst, {M, K}); 370 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 371 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 372 // Flip dimensions. 373 Propagate = setShapeInfo(Inst, {N, M}); 374 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 375 m_Value(MatrixA), m_Value(), m_Value(), 376 m_Value(M), m_Value(N)))) { 377 Propagate = setShapeInfo(Inst, {N, M}); 378 } else if (match(Inst, 379 m_Intrinsic<Intrinsic::matrix_columnwise_load>( 380 m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 381 Propagate = setShapeInfo(Inst, {M, N}); 382 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 383 auto OpShape = ShapeMap.find(MatrixA); 384 if (OpShape != ShapeMap.end()) 385 setShapeInfo(Inst, OpShape->second); 386 continue; 387 } else if (isUniformShape(Inst)) { 388 // Find the first operand that has a known shape and use that. 389 for (auto &Op : Inst->operands()) { 390 auto OpShape = ShapeMap.find(Op.get()); 391 if (OpShape != ShapeMap.end()) { 392 Propagate |= setShapeInfo(Inst, OpShape->second); 393 break; 394 } 395 } 396 } 397 398 if (Propagate) 399 for (auto *User : Inst->users()) 400 if (ShapeMap.count(User) == 0) 401 WorkList.push_back(cast<Instruction>(User)); 402 } 403 } 404 405 bool Visit() { 406 if (EnableShapePropagation) 407 propagateShapeForward(); 408 409 ReversePostOrderTraversal<Function *> RPOT(&Func); 410 bool Changed = false; 411 for (auto *BB : RPOT) { 412 for (Instruction &Inst : make_early_inc_range(*BB)) { 413 IRBuilder<> Builder(&Inst); 414 415 if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 416 Changed |= VisitCallInst(CInst); 417 418 Value *Op1; 419 Value *Op2; 420 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) 421 Changed |= VisitBinaryOperator(BinOp); 422 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 423 Changed |= VisitStore(&Inst, Op1, Op2, Builder); 424 } 425 } 426 427 for (Instruction *Inst : reverse(ToRemove)) 428 Inst->eraseFromParent(); 429 430 return Changed; 431 } 432 433 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 434 IRBuilder<> Builder) { 435 unsigned Align = DL.getABITypeAlignment(EltType); 436 return Builder.CreateAlignedLoad(ColumnPtr, Align); 437 } 438 439 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 440 Type *EltType, IRBuilder<> Builder) { 441 unsigned Align = DL.getABITypeAlignment(EltType); 442 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); 443 } 444 445 446 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 447 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 448 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 449 Type *EltPtrType = PointerType::get(EltType, AS); 450 return Builder.CreatePointerCast(BasePtr, EltPtrType); 451 } 452 453 /// Replace intrinsic calls 454 bool VisitCallInst(CallInst *Inst) { 455 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 456 return false; 457 458 switch (Inst->getCalledFunction()->getIntrinsicID()) { 459 case Intrinsic::matrix_multiply: 460 LowerMultiply(Inst); 461 break; 462 case Intrinsic::matrix_transpose: 463 LowerTranspose(Inst); 464 break; 465 case Intrinsic::matrix_columnwise_load: 466 LowerColumnwiseLoad(Inst); 467 break; 468 case Intrinsic::matrix_columnwise_store: 469 LowerColumnwiseStore(Inst); 470 break; 471 default: 472 return false; 473 } 474 return true; 475 } 476 477 /// Lowers llvm.matrix.columnwise.load. 478 /// 479 /// The intrinsic loads a matrix from memory using a stride between columns. 480 void LowerColumnwiseLoad(CallInst *Inst) { 481 IRBuilder<> Builder(Inst); 482 Value *Ptr = Inst->getArgOperand(0); 483 Value *Stride = Inst->getArgOperand(1); 484 auto VType = cast<VectorType>(Inst->getType()); 485 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 486 ShapeInfo Shape(Inst->getArgOperand(2), Inst->getArgOperand(3)); 487 488 ColumnMatrixTy Result; 489 // Distance between start of one column and the start of the next 490 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 491 Value *GEP = 492 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 493 VType->getElementType(), Builder); 494 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 495 Result.addColumn(Column); 496 } 497 498 finalizeLowering(Inst, Result, Builder); 499 } 500 501 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 502 ShapeInfo Shape) { 503 IRBuilder<> Builder(Inst); 504 auto VType = cast<VectorType>(Matrix->getType()); 505 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 506 auto LM = getMatrix(Matrix, Shape, Builder); 507 for (auto C : enumerate(LM.columns())) { 508 Value *GEP = 509 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, 510 Shape.NumRows, VType->getElementType(), Builder); 511 createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 512 } 513 514 ToRemove.push_back(Inst); 515 } 516 517 /// Lowers llvm.matrix.columnwise.store. 518 /// 519 /// The intrinsic store a matrix back memory using a stride between columns. 520 void LowerColumnwiseStore(CallInst *Inst) { 521 Value *Matrix = Inst->getArgOperand(0); 522 Value *Ptr = Inst->getArgOperand(1); 523 Value *Stride = Inst->getArgOperand(2); 524 LowerStore(Inst, Matrix, Ptr, Stride, 525 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 526 } 527 528 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 529 /// the matrix \p LM represented as a vector of column vectors. 530 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, 531 unsigned NumElts, IRBuilder<> Builder) { 532 Value *Col = LM.getColumn(J); 533 Value *Undef = UndefValue::get(Col->getType()); 534 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 535 return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 536 } 537 538 // Set elements I..I+NumElts-1 to Block 539 Value *insertVector(Value *Col, unsigned I, Value *Block, 540 IRBuilder<> Builder) { 541 542 // First, bring Block to the same size as Col 543 unsigned BlockNumElts = 544 cast<VectorType>(Block->getType())->getNumElements(); 545 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 546 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 547 548 Value *ExtendMask = 549 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 550 Value *Undef = UndefValue::get(Block->getType()); 551 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 552 553 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 554 // 8, 4, 5, 6 555 SmallVector<Constant *, 16> Mask; 556 unsigned i; 557 for (i = 0; i < I; i++) 558 Mask.push_back(Builder.getInt32(i)); 559 560 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 561 for (; i < I + BlockNumElts; i++) 562 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 563 564 for (; i < VecNumElts; i++) 565 Mask.push_back(Builder.getInt32(i)); 566 567 Value *MaskVal = ConstantVector::get(Mask); 568 569 return Builder.CreateShuffleVector(Col, Block, MaskVal); 570 } 571 572 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 573 IRBuilder<> &Builder, bool AllowContraction) { 574 575 if (!Sum) 576 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 577 578 if (UseFPOp) { 579 if (AllowContraction) { 580 // Use fmuladd for floating point operations and let the backend decide 581 // if that's profitable. 582 Value *FMulAdd = Intrinsic::getDeclaration( 583 Func.getParent(), Intrinsic::fmuladd, A->getType()); 584 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 585 } 586 Value *Mul = Builder.CreateFMul(A, B); 587 return Builder.CreateFAdd(Sum, Mul); 588 } 589 590 Value *Mul = Builder.CreateMul(A, B); 591 return Builder.CreateAdd(Sum, Mul); 592 } 593 594 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 595 /// users with shape information, there's nothing to do: the will use the 596 /// cached value when they are lowered. For other users, \p Matrix is 597 /// flattened and the uses are updated to use it. Also marks \p Inst for 598 /// deletion. 599 void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, 600 IRBuilder<> &Builder) { 601 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 602 603 ToRemove.push_back(Inst); 604 Value *Flattened = nullptr; 605 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 606 Use &U = *I++; 607 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 608 if (!Flattened) 609 Flattened = Matrix.embedInVector(Builder); 610 U.set(Flattened); 611 } 612 } 613 } 614 615 /// Lowers llvm.matrix.multiply. 616 void LowerMultiply(CallInst *MatMul) { 617 IRBuilder<> Builder(MatMul); 618 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 619 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 620 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 621 622 const ColumnMatrixTy &Lhs = 623 getMatrix(MatMul->getArgOperand(0), LShape, Builder); 624 const ColumnMatrixTy &Rhs = 625 getMatrix(MatMul->getArgOperand(1), RShape, Builder); 626 627 const unsigned R = LShape.NumRows; 628 const unsigned M = LShape.NumColumns; 629 const unsigned C = RShape.NumColumns; 630 assert(M == RShape.NumRows); 631 632 // Initialize the output 633 ColumnMatrixTy Result; 634 for (unsigned J = 0; J < C; ++J) 635 Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 636 637 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / 638 EltType->getPrimitiveSizeInBits(), 639 uint64_t(1)); 640 641 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 642 MatMul->hasAllowContract()); 643 // Multiply columns from the first operand with scalars from the second 644 // operand. Then move along the K axes and accumulate the columns. With 645 // this the adds can be vectorized without reassociation. 646 for (unsigned J = 0; J < C; ++J) { 647 unsigned BlockSize = VF; 648 for (unsigned I = 0; I < R; I += BlockSize) { 649 // Gradually lower the vectorization factor to cover the remainder. 650 while (I + BlockSize > R) 651 BlockSize /= 2; 652 653 Value *Sum = nullptr; 654 for (unsigned K = 0; K < M; ++K) { 655 Value *L = extractVector(Lhs, I, K, BlockSize, Builder); 656 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); 657 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 658 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), 659 Builder, AllowContract); 660 } 661 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 662 } 663 } 664 finalizeLowering(MatMul, Result, Builder); 665 } 666 667 /// Lowers llvm.matrix.transpose. 668 void LowerTranspose(CallInst *Inst) { 669 ColumnMatrixTy Result; 670 IRBuilder<> Builder(Inst); 671 Value *InputVal = Inst->getArgOperand(0); 672 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 673 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 674 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 675 676 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 677 // Build a single column vector for this row. First initialize it. 678 Value *ResultColumn = UndefValue::get( 679 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 680 681 // Go through the elements of this row and insert it into the resulting 682 // column vector. 683 for (auto C : enumerate(InputMatrix.columns())) { 684 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 685 // We insert at index Column since that is the row index after the 686 // transpose. 687 ResultColumn = 688 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 689 } 690 Result.addColumn(ResultColumn); 691 } 692 693 finalizeLowering(Inst, Result, Builder); 694 } 695 696 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 697 IRBuilder<> &Builder) { 698 auto I = ShapeMap.find(StoredVal); 699 if (I == ShapeMap.end()) 700 return false; 701 702 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); 703 return true; 704 } 705 706 /// Lower binary operators, if shape information is available. 707 bool VisitBinaryOperator(BinaryOperator *Inst) { 708 auto I = ShapeMap.find(Inst); 709 if (I == ShapeMap.end()) 710 return false; 711 712 Value *Lhs = Inst->getOperand(0); 713 Value *Rhs = Inst->getOperand(1); 714 715 IRBuilder<> Builder(Inst); 716 ShapeInfo &Shape = I->second; 717 718 ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); 719 ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); 720 721 // Add each column and store the result back into the opmapping 722 ColumnMatrixTy Result; 723 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { 724 switch (Inst->getOpcode()) { 725 case Instruction::Add: 726 return Builder.CreateAdd(LHS, RHS); 727 case Instruction::Mul: 728 return Builder.CreateMul(LHS, RHS); 729 case Instruction::Sub: 730 return Builder.CreateSub(LHS, RHS); 731 case Instruction::FAdd: 732 return Builder.CreateFAdd(LHS, RHS); 733 case Instruction::FMul: 734 return Builder.CreateFMul(LHS, RHS); 735 case Instruction::FSub: 736 return Builder.CreateFSub(LHS, RHS); 737 default: 738 llvm_unreachable("Unsupported binary operator for matrix"); 739 } 740 }; 741 for (unsigned C = 0; C < Shape.NumColumns; ++C) 742 Result.addColumn( 743 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); 744 745 finalizeLowering(Inst, Result, Builder); 746 return true; 747 } 748 }; 749 } // namespace 750 751 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 752 FunctionAnalysisManager &AM) { 753 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 754 LowerMatrixIntrinsics LMT(F, TTI); 755 if (LMT.Visit()) { 756 PreservedAnalyses PA; 757 PA.preserveSet<CFGAnalyses>(); 758 return PA; 759 } 760 return PreservedAnalyses::all(); 761 } 762 763 namespace { 764 765 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 766 public: 767 static char ID; 768 769 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 770 initializeLowerMatrixIntrinsicsLegacyPassPass( 771 *PassRegistry::getPassRegistry()); 772 } 773 774 bool runOnFunction(Function &F) override { 775 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 776 LowerMatrixIntrinsics LMT(F, *TTI); 777 bool C = LMT.Visit(); 778 return C; 779 } 780 781 void getAnalysisUsage(AnalysisUsage &AU) const override { 782 AU.addRequired<TargetTransformInfoWrapperPass>(); 783 AU.setPreservesCFG(); 784 } 785 }; 786 } // namespace 787 788 static const char pass_name[] = "Lower the matrix intrinsics"; 789 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 790 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 791 false, false) 792 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 793 false, false) 794 795 Pass *llvm::createLowerMatrixIntrinsicsPass() { 796 return new LowerMatrixIntrinsicsLegacyPass(); 797 } 798