1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Implement multiply & add fusion 13 // * Implement shape propagation 14 // * Implement optimizations to reduce or eliminateshufflevector uses by using 15 // shape information. 16 // * Add remark, summarizing the available matrix optimization opportunities. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/Analysis/VectorUtils.h" 26 #include "llvm/IR/CFG.h" 27 #include "llvm/IR/DataLayout.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicInst.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Transforms/Scalar.h" 36 37 using namespace llvm; 38 39 #define DEBUG_TYPE "lower-matrix-intrinsics" 40 41 namespace { 42 43 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 44 // the start address of column \p Col with type (\p EltType x \p NumRows) 45 // assuming \p Stride elements between start two consecutive columns. 46 // \p Stride must be >= \p NumRows. 47 // 48 // Consider a 4x4 matrix like below 49 // 50 // 0 1 2 3 51 // 0 v_0_0 v_0_1 v_0_2 v_0_3 52 // 1 v_1_0 v_1_1 v_1_2 v_1_3 53 // 2 v_2_0 v_2_1 v_2_2 v_2_3 54 // 3 v_3_0 v_3_1 v_3_2 v_3_3 55 56 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 57 // we need a pointer to the first element of the submatrix as base pointer. 58 // Then we can use computeColumnAddr to compute the addresses for the columns 59 // of the sub-matrix. 60 // 61 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 62 // -> just returns Base 63 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 64 // -> returns Base + (1 * 4) 65 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 66 // -> returns Base + (2 * 4) 67 // 68 // The graphic below illustrates the number of elements in a column (marked 69 // with |) and the number of skipped elements (marked with }). 70 // 71 // v_0_0 v_0_1 {v_0_2 {v_0_3 72 // Base Col 1 Col 2 73 // | | | 74 // v_1_0 |v_1_1 |v_1_2 |v_1_3 75 // v_2_0 |v_2_1 |v_2_2 |v_2_3 76 // v_3_0 {v_3_1 {v_3_2 v_3_3 77 // 78 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 79 unsigned NumRows, Type *EltType, 80 IRBuilder<> &Builder) { 81 82 assert((!isa<ConstantInt>(Stride) || 83 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 84 "Stride must be >= the number of rows."); 85 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 86 87 // Compute the start of the column with index Col as Col * Stride. 88 Value *ColumnStart = Builder.CreateMul(Col, Stride); 89 90 // Get pointer to the start of the selected column. Skip GEP creation, 91 // if we select column 0. 92 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 93 ColumnStart = BasePtr; 94 else 95 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart); 96 97 // Cast elementwise column start pointer to a pointer to a column 98 // (EltType x NumRows)*. 99 Type *ColumnType = VectorType::get(EltType, NumRows); 100 Type *ColumnPtrType = PointerType::get(ColumnType, AS); 101 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType); 102 } 103 104 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 105 /// 106 /// Currently, the lowering for each matrix intrinsic is done as follows: 107 /// 1. Split the operand vectors containing an embedded matrix into a set of 108 /// column vectors, based on the shape information from the intrinsic. 109 /// 2. Apply the transformation described by the intrinsic on the column 110 /// vectors, which yields a set of column vectors containing result matrix. 111 /// 3. Embed the columns of the result matrix in a flat vector and replace all 112 /// uses of the intrinsic result with it. 113 class LowerMatrixIntrinsics { 114 Function &Func; 115 const DataLayout &DL; 116 const TargetTransformInfo &TTI; 117 118 /// Wrapper class representing a matrix as a set of column vectors. 119 /// All column vectors must have the same vector type. 120 class ColumnMatrixTy { 121 SmallVector<Value *, 16> Columns; 122 123 public: 124 ColumnMatrixTy() : Columns() {} 125 ColumnMatrixTy(ArrayRef<Value *> Cols) 126 : Columns(Cols.begin(), Cols.end()) {} 127 128 Value *getColumn(unsigned i) const { return Columns[i]; } 129 130 void setColumn(unsigned i, Value *V) { Columns[i] = V; } 131 132 size_t getNumColumns() const { return Columns.size(); } 133 134 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } 135 136 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } 137 138 void addColumn(Value *V) { Columns.push_back(V); } 139 140 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 141 return make_range(Columns.begin(), Columns.end()); 142 } 143 144 /// Embed the columns of the matrix into a flat vector by concatenating 145 /// them. 146 Value *embedInVector(IRBuilder<> &Builder) const { 147 return Columns.size() == 1 ? Columns[0] 148 : concatenateVectors(Builder, Columns); 149 } 150 }; 151 152 struct ShapeInfo { 153 unsigned NumRows; 154 unsigned NumColumns; 155 156 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 157 : NumRows(NumRows), NumColumns(NumColumns) {} 158 159 ShapeInfo(ConstantInt *NumRows, ConstantInt *NumColumns) 160 : NumRows(NumRows->getZExtValue()), 161 NumColumns(NumColumns->getZExtValue()) {} 162 }; 163 164 public: 165 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) 166 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} 167 168 /// Return the set of column vectors that a matrix value is lowered to. 169 /// 170 /// We split the flat vector \p MatrixVal containing a matrix with shape \p SI 171 /// into column vectors. 172 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 173 IRBuilder<> Builder) { 174 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 175 assert(VType && "MatrixVal must be a vector type"); 176 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 177 "The vector size must match the number of matrix elements"); 178 SmallVector<Value *, 16> SplitVecs; 179 Value *Undef = UndefValue::get(VType); 180 181 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 182 MaskStart += SI.NumRows) { 183 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 184 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 185 SplitVecs.push_back(V); 186 } 187 188 return {SplitVecs}; 189 } 190 191 // Replace intrinsic calls 192 bool VisitCallInst(CallInst *Inst) { 193 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 194 return false; 195 196 switch (Inst->getCalledFunction()->getIntrinsicID()) { 197 case Intrinsic::matrix_multiply: 198 LowerMultiply(Inst); 199 break; 200 case Intrinsic::matrix_transpose: 201 LowerTranspose(Inst); 202 break; 203 case Intrinsic::matrix_columnwise_load: 204 LowerColumnwiseLoad(Inst); 205 break; 206 case Intrinsic::matrix_columnwise_store: 207 LowerColumnwiseStore(Inst); 208 break; 209 default: 210 return false; 211 } 212 Inst->eraseFromParent(); 213 return true; 214 } 215 216 bool Visit() { 217 ReversePostOrderTraversal<Function *> RPOT(&Func); 218 bool Changed = false; 219 for (auto *BB : RPOT) { 220 for (Instruction &Inst : make_early_inc_range(*BB)) { 221 if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 222 Changed |= VisitCallInst(CInst); 223 } 224 } 225 226 return Changed; 227 } 228 229 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 230 IRBuilder<> Builder) { 231 unsigned Align = DL.getABITypeAlignment(EltType); 232 return Builder.CreateAlignedLoad(ColumnPtr, Align); 233 } 234 235 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 236 Type *EltType, IRBuilder<> Builder) { 237 unsigned Align = DL.getABITypeAlignment(EltType); 238 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); 239 } 240 241 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 242 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 243 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 244 Type *EltPtrType = PointerType::get(EltType, AS); 245 return Builder.CreatePointerCast(BasePtr, EltPtrType); 246 } 247 248 /// Lowers llvm.matrix.columnwise.load. 249 /// 250 /// The intrinsic loads a matrix from memory using a stride between columns. 251 void LowerColumnwiseLoad(CallInst *Inst) { 252 IRBuilder<> Builder(Inst); 253 Value *Ptr = Inst->getArgOperand(0); 254 Value *Stride = Inst->getArgOperand(1); 255 auto VType = cast<VectorType>(Inst->getType()); 256 ShapeInfo Shape(cast<ConstantInt>(Inst->getArgOperand(2)), 257 cast<ConstantInt>(Inst->getArgOperand(3))); 258 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 259 260 ColumnMatrixTy Result; 261 // Distance between start of one column and the start of the next 262 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 263 Value *GEP = 264 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 265 VType->getElementType(), Builder); 266 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 267 Result.addColumn(Column); 268 } 269 270 Inst->replaceAllUsesWith(Result.embedInVector(Builder)); 271 } 272 273 /// Lowers llvm.matrix.columnwise.store. 274 /// 275 /// The intrinsic store a matrix back memory using a stride between columns. 276 void LowerColumnwiseStore(CallInst *Inst) { 277 IRBuilder<> Builder(Inst); 278 Value *Matrix = Inst->getArgOperand(0); 279 Value *Ptr = Inst->getArgOperand(1); 280 Value *Stride = Inst->getArgOperand(2); 281 ShapeInfo Shape(cast<ConstantInt>(Inst->getArgOperand(3)), 282 cast<ConstantInt>(Inst->getArgOperand(4))); 283 auto VType = cast<VectorType>(Matrix->getType()); 284 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 285 286 auto LM = getMatrix(Matrix, Shape, Builder); 287 for (auto C : enumerate(LM.columns())) { 288 Value *GEP = 289 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, 290 Shape.NumRows, VType->getElementType(), Builder); 291 createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 292 } 293 } 294 295 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 296 /// the matrix \p LM represented as a vector of column vectors. 297 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, 298 unsigned NumElts, IRBuilder<> Builder) { 299 Value *Col = LM.getColumn(J); 300 Value *Undef = UndefValue::get(Col->getType()); 301 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 302 return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 303 } 304 305 // Set elements I..I+NumElts-1 to Block 306 Value *insertVector(Value *Col, unsigned I, Value *Block, 307 IRBuilder<> Builder) { 308 309 // First, bring Block to the same size as Col 310 unsigned BlockNumElts = 311 cast<VectorType>(Block->getType())->getNumElements(); 312 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 313 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 314 315 Value *ExtendMask = 316 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 317 Value *Undef = UndefValue::get(Block->getType()); 318 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 319 320 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 321 // 8, 4, 5, 6 322 SmallVector<Constant *, 16> Mask; 323 unsigned i; 324 for (i = 0; i < I; i++) 325 Mask.push_back(Builder.getInt32(i)); 326 327 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 328 for (; i < I + BlockNumElts; i++) 329 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 330 331 for (; i < VecNumElts; i++) 332 Mask.push_back(Builder.getInt32(i)); 333 334 Value *MaskVal = ConstantVector::get(Mask); 335 336 return Builder.CreateShuffleVector(Col, Block, MaskVal); 337 } 338 339 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 340 IRBuilder<> &Builder) { 341 Value *Mul = UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 342 if (!Sum) 343 return Mul; 344 345 return UseFPOp ? Builder.CreateFAdd(Sum, Mul) : Builder.CreateAdd(Sum, Mul); 346 } 347 348 /// Lowers llvm.matrix.multiply. 349 void LowerMultiply(CallInst *MatMul) { 350 IRBuilder<> Builder(MatMul); 351 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 352 ShapeInfo LShape(cast<ConstantInt>(MatMul->getArgOperand(2)), 353 cast<ConstantInt>(MatMul->getArgOperand(3))); 354 ShapeInfo RShape(cast<ConstantInt>(MatMul->getArgOperand(3)), 355 cast<ConstantInt>(MatMul->getArgOperand(4))); 356 357 const ColumnMatrixTy &Lhs = 358 getMatrix(MatMul->getArgOperand(0), LShape, Builder); 359 const ColumnMatrixTy &Rhs = 360 getMatrix(MatMul->getArgOperand(1), RShape, Builder); 361 362 const unsigned R = LShape.NumRows; 363 const unsigned M = LShape.NumColumns; 364 const unsigned C = RShape.NumColumns; 365 assert(M == RShape.NumRows); 366 367 // Initialize the output 368 ColumnMatrixTy Result; 369 for (unsigned J = 0; J < C; ++J) 370 Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 371 372 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / 373 EltType->getPrimitiveSizeInBits(), 374 uint64_t(1)); 375 376 // Multiply columns from the first operand with scalars from the second 377 // operand. Then move along the K axes and accumulate the columns. With 378 // this the adds can be vectorized without reassociation. 379 for (unsigned J = 0; J < C; ++J) { 380 unsigned BlockSize = VF; 381 for (unsigned I = 0; I < R; I += BlockSize) { 382 // Gradually lower the vectorization factor to cover the remainder. 383 while (I + BlockSize > R) 384 BlockSize /= 2; 385 386 Value *Sum = nullptr; 387 for (unsigned K = 0; K < M; ++K) { 388 Value *L = extractVector(Lhs, I, K, BlockSize, Builder); 389 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); 390 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 391 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), 392 Builder); 393 } 394 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 395 } 396 } 397 398 MatMul->replaceAllUsesWith(Result.embedInVector(Builder)); 399 } 400 401 /// Lowers llvm.matrix.transpose. 402 void LowerTranspose(CallInst *Inst) { 403 ColumnMatrixTy Result; 404 IRBuilder<> Builder(Inst); 405 Value *InputVal = Inst->getArgOperand(0); 406 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 407 ShapeInfo ArgShape(cast<ConstantInt>(Inst->getArgOperand(1)), 408 cast<ConstantInt>(Inst->getArgOperand(2))); 409 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 410 411 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 412 // Build a single column vector for this row. First initialize it. 413 Value *ResultColumn = UndefValue::get( 414 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 415 416 // Go through the elements of this row and insert it into the resulting 417 // column vector. 418 for (auto C : enumerate(InputMatrix.columns())) { 419 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 420 // We insert at index Column since that is the row index after the 421 // transpose. 422 ResultColumn = 423 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 424 } 425 Result.addColumn(ResultColumn); 426 } 427 428 Inst->replaceAllUsesWith(Result.embedInVector(Builder)); 429 } 430 }; 431 } // namespace 432 433 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 434 FunctionAnalysisManager &AM) { 435 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 436 LowerMatrixIntrinsics LMT(F, TTI); 437 if (LMT.Visit()) { 438 PreservedAnalyses PA; 439 PA.preserveSet<CFGAnalyses>(); 440 return PA; 441 } 442 return PreservedAnalyses::all(); 443 } 444 445 namespace { 446 447 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 448 public: 449 static char ID; 450 451 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 452 initializeLowerMatrixIntrinsicsLegacyPassPass( 453 *PassRegistry::getPassRegistry()); 454 } 455 456 bool runOnFunction(Function &F) override { 457 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 458 LowerMatrixIntrinsics LMT(F, *TTI); 459 bool C = LMT.Visit(); 460 return C; 461 } 462 463 void getAnalysisUsage(AnalysisUsage &AU) const override { 464 AU.addRequired<TargetTransformInfoWrapperPass>(); 465 AU.setPreservesCFG(); 466 } 467 }; 468 } // namespace 469 470 static const char pass_name[] = "Lower the matrix intrinsics"; 471 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 472 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 473 false, false) 474 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 475 false, false) 476 477 Pass *llvm::createLowerMatrixIntrinsicsPass() { 478 return new LowerMatrixIntrinsicsLegacyPass(); 479 } 480